diff options
author | Anthony Wang | 2022-02-21 14:54:40 -0600 |
---|---|---|
committer | Anthony Wang | 2022-02-21 14:54:40 -0600 |
commit | 8f5bcc04d23da436dcde3eadf8a44198f1871cdb (patch) | |
tree | 4a6fab794a072de92ca26843020ce9b0e060d8f0 /dataset.py | |
parent | 3a92cf87c8ac8ed78d9e089b90d5e9bbbcc09a3c (diff) |
Create a dataset class
Diffstat (limited to 'dataset.py')
-rw-r--r-- | dataset.py | 26 |
1 files changed, 26 insertions, 0 deletions
diff --git a/dataset.py b/dataset.py new file mode 100644 index 0000000..1040b36 --- /dev/null +++ b/dataset.py @@ -0,0 +1,26 @@ +import torch + + +class Dataset(torch.utils.data.Dataset): + def __init__(self, file, seq_size): + self.seq_size = seq_size + + with open(file, 'r') as f: + self.words = f.read().split() + + self.word_counts = Counter(self.words) + self.uniq_words = sorted(self.word_counts, key=self.word_counts.get) + + self.index_to_word = {index: word for index, + word in enumerate(self.uniq_words)} + self.word_to_index = {word: index for index, + word in enumerate(self.uniq_words)} + + self.words_indexes = [self.word_to_index[w] for w in self.words] + + def __len__(self): + return len(self.words_indexes) - self.seq_size + + def __getitem__(self, index): + return (torch.tensor(self.words_indexes[index:index+self.seq_size]), + torch.tensor(self.words_indexes[index+1:index+self.seq_size+1]))
\ No newline at end of file |