aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAnthony Wang2022-02-20 22:12:17 -0600
committerAnthony Wang2022-02-20 22:12:17 -0600
commita49dc010ed918f508c728020abf08b9328aac021 (patch)
treec383287b1bc97ec83445f2cc9344b97dba13f8f6
parent2a0afe7c7a6d2a156ac9684fd7cc46776039ebfa (diff)
Clean up and organize code better
-rwxr-xr-xmain.py78
1 files changed, 49 insertions, 29 deletions
diff --git a/main.py b/main.py
index 4a6bc1d..edae126 100755
--- a/main.py
+++ b/main.py
@@ -4,8 +4,9 @@ import re
from collections import Counter
import psycopg2
+import numpy as np
import torch
-from torch import nn, optim
+from torch import nn
from torch.utils.data import DataLoader
@@ -22,6 +23,36 @@ text = [re.sub(r'<[^>]*>', '', status[2]) for status in statuses]
# print(text[0:100])
+class Dataset(torch.utils.data.Dataset):
+ def __init__(self):
+ self.words = [word for message in text for word in message.split()]
+ self.word_counts = Counter(self.words)
+ self.uniq_words = sorted(self.word_counts, key=self.word_counts.get)
+
+ self.index_to_word = {index: word for index,
+ word in enumerate(self.uniq_words)}
+ self.word_to_index = {word: index for index,
+ word in enumerate(self.uniq_words)}
+
+ self.words_indexes = [self.word_to_index[w] for w in self.words]
+
+ def __len__(self):
+ return len(self.words_indexes) - 4
+
+ def __getitem__(self, index):
+ return (torch.tensor(self.words_indexes[index:index+4]),
+ torch.tensor(self.words_indexes[index+1:index+4+1]))
+
+
+dataset = Dataset()
+dataloader = DataLoader(dataset, batch_size=256)
+print(len(dataloader))
+
+
+device = "cuda" if torch.cuda.is_available() else "cpu"
+print(f"Using {device} device")
+
+
class Model(nn.Module):
def __init__(self, dataset):
super(Model, self).__init__()
@@ -55,46 +86,35 @@ class Model(nn.Module):
torch.zeros(self.num_layers, sequence_length, self.lstm_size))
-class Dataset(torch.utils.data.Dataset):
- def __init__(self):
- self.words = [word for message in text for word in message.split()]
- self.word_counts = Counter(self.words)
- self.uniq_words = sorted(self.word_counts, key=self.word_counts.get)
-
- self.index_to_word = {index: word for index,
- word in enumerate(self.uniq_words)}
- self.word_to_index = {word: index for index,
- word in enumerate(self.uniq_words)}
-
- self.words_indexes = [self.word_to_index[w] for w in self.words]
-
- def __len__(self):
- return len(self.words_indexes) - 4
-
- def __getitem__(self, index):
- return (torch.tensor(self.words_indexes[index:index+4]),
- torch.tensor(self.words_indexes[index+1:index+4+1]))
+model = Model(dataset).to(device)
+print(model)
+loss_fn = nn.CrossEntropyLoss()
+optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
-dataset = Dataset()
-model = Model(dataset)
-
-
-dataloader = DataLoader(dataset, batch_size=256)
-criterion = nn.CrossEntropyLoss()
-optimizer = optim.Adam(model.parameters(), lr=0.001)
for epoch in range(1):
+ model.train()
+
state_h, state_c = model.init_state(4)
+
+ cnt = 0
for batch, (x, y) in enumerate(dataloader):
+ cnt += 1
+ if cnt > 50:
+ break
+
optimizer.zero_grad()
+ # Compute prediction error
y_pred, (state_h, state_c) = model(x, (state_h, state_c))
- loss = criterion(y_pred.transpose(1, 2), y)
+ loss = loss_fn(y_pred.transpose(1, 2), y)
state_h = state_h.detach()
state_c = state_c.detach()
+ # Backpropogation
+ optimizer.zero_grad()
loss.backward()
optimizer.step()
@@ -120,4 +140,4 @@ def predict(text, next_words=100):
return words
-predict('This is a test')
+print(predict('This is a test'))