aboutsummaryrefslogtreecommitdiff
path: root/dataset.py
diff options
context:
space:
mode:
authorAnthony Wang2022-02-21 14:54:40 -0600
committerAnthony Wang2022-02-21 14:54:40 -0600
commit8f5bcc04d23da436dcde3eadf8a44198f1871cdb (patch)
tree4a6fab794a072de92ca26843020ce9b0e060d8f0 /dataset.py
parent3a92cf87c8ac8ed78d9e089b90d5e9bbbcc09a3c (diff)
Create a dataset class
Diffstat (limited to 'dataset.py')
-rw-r--r--dataset.py26
1 files changed, 26 insertions, 0 deletions
diff --git a/dataset.py b/dataset.py
new file mode 100644
index 0000000..1040b36
--- /dev/null
+++ b/dataset.py
@@ -0,0 +1,26 @@
+import torch
+
+
+class Dataset(torch.utils.data.Dataset):
+ def __init__(self, file, seq_size):
+ self.seq_size = seq_size
+
+ with open(file, 'r') as f:
+ self.words = f.read().split()
+
+ self.word_counts = Counter(self.words)
+ self.uniq_words = sorted(self.word_counts, key=self.word_counts.get)
+
+ self.index_to_word = {index: word for index,
+ word in enumerate(self.uniq_words)}
+ self.word_to_index = {word: index for index,
+ word in enumerate(self.uniq_words)}
+
+ self.words_indexes = [self.word_to_index[w] for w in self.words]
+
+ def __len__(self):
+ return len(self.words_indexes) - self.seq_size
+
+ def __getitem__(self, index):
+ return (torch.tensor(self.words_indexes[index:index+self.seq_size]),
+ torch.tensor(self.words_indexes[index+1:index+self.seq_size+1])) \ No newline at end of file