diff options
author | Anthony Wang | 2022-02-20 22:12:17 -0600 |
---|---|---|
committer | Anthony Wang | 2022-02-20 22:12:17 -0600 |
commit | a49dc010ed918f508c728020abf08b9328aac021 (patch) | |
tree | c383287b1bc97ec83445f2cc9344b97dba13f8f6 | |
parent | 2a0afe7c7a6d2a156ac9684fd7cc46776039ebfa (diff) |
Clean up and organize code better
-rwxr-xr-x | main.py | 78 |
1 files changed, 49 insertions, 29 deletions
@@ -4,8 +4,9 @@ import re from collections import Counter import psycopg2 +import numpy as np import torch -from torch import nn, optim +from torch import nn from torch.utils.data import DataLoader @@ -22,6 +23,36 @@ text = [re.sub(r'<[^>]*>', '', status[2]) for status in statuses] # print(text[0:100]) +class Dataset(torch.utils.data.Dataset): + def __init__(self): + self.words = [word for message in text for word in message.split()] + self.word_counts = Counter(self.words) + self.uniq_words = sorted(self.word_counts, key=self.word_counts.get) + + self.index_to_word = {index: word for index, + word in enumerate(self.uniq_words)} + self.word_to_index = {word: index for index, + word in enumerate(self.uniq_words)} + + self.words_indexes = [self.word_to_index[w] for w in self.words] + + def __len__(self): + return len(self.words_indexes) - 4 + + def __getitem__(self, index): + return (torch.tensor(self.words_indexes[index:index+4]), + torch.tensor(self.words_indexes[index+1:index+4+1])) + + +dataset = Dataset() +dataloader = DataLoader(dataset, batch_size=256) +print(len(dataloader)) + + +device = "cuda" if torch.cuda.is_available() else "cpu" +print(f"Using {device} device") + + class Model(nn.Module): def __init__(self, dataset): super(Model, self).__init__() @@ -55,46 +86,35 @@ class Model(nn.Module): torch.zeros(self.num_layers, sequence_length, self.lstm_size)) -class Dataset(torch.utils.data.Dataset): - def __init__(self): - self.words = [word for message in text for word in message.split()] - self.word_counts = Counter(self.words) - self.uniq_words = sorted(self.word_counts, key=self.word_counts.get) - - self.index_to_word = {index: word for index, - word in enumerate(self.uniq_words)} - self.word_to_index = {word: index for index, - word in enumerate(self.uniq_words)} - - self.words_indexes = [self.word_to_index[w] for w in self.words] - - def __len__(self): - return len(self.words_indexes) - 4 - - def __getitem__(self, index): - return (torch.tensor(self.words_indexes[index:index+4]), - torch.tensor(self.words_indexes[index+1:index+4+1])) +model = Model(dataset).to(device) +print(model) +loss_fn = nn.CrossEntropyLoss() +optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) -dataset = Dataset() -model = Model(dataset) - - -dataloader = DataLoader(dataset, batch_size=256) -criterion = nn.CrossEntropyLoss() -optimizer = optim.Adam(model.parameters(), lr=0.001) for epoch in range(1): + model.train() + state_h, state_c = model.init_state(4) + + cnt = 0 for batch, (x, y) in enumerate(dataloader): + cnt += 1 + if cnt > 50: + break + optimizer.zero_grad() + # Compute prediction error y_pred, (state_h, state_c) = model(x, (state_h, state_c)) - loss = criterion(y_pred.transpose(1, 2), y) + loss = loss_fn(y_pred.transpose(1, 2), y) state_h = state_h.detach() state_c = state_c.detach() + # Backpropogation + optimizer.zero_grad() loss.backward() optimizer.step() @@ -120,4 +140,4 @@ def predict(text, next_words=100): return words -predict('This is a test') +print(predict('This is a test')) |