Detect Malware Using LSTM (Syscall only)
import torch import torch.nn as nn from torch.utils.data import DataLoader, Dataset from collections import Counter import ast # Dataset preparation with frequencies class SyscallDataset ( Dataset ): def __init__ ( self , sequences , labels , tokenizer , max_len ): self .sequences = sequences self .labels = labels self .tokenizer = tokenizer # A dictionary mapping syscalls to token IDs self .max_len = max_len def __len__ ( self ): return len ( self .sequences) def __getitem__ ( self , index ): sequence = self .sequences[index] label = self .labels[index] # Count frequency of syscalls in the sequence freq = Counter(sequence) # Tokenize the sequence manually input_ids = torch.zeros( self .max_len, dtype =torch.long) for i, syscall in enumerate (sequence[: self .max_len]): input_ids[i] = self .tokenizer.get(syscall, 0 ) # Default to...