aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAnthony Wang2022-02-21 15:50:38 -0600
committerAnthony Wang2022-02-21 15:50:38 -0600
commitf2d33b51b1a682be5a6e661541ddc69626bfaba1 (patch)
tree24572afe59f2f216c9ad76b4cd712dcbc94bdd22
parent2f05004e4a058289446de7a451fe85164bac742a (diff)
Delete old test.py code
-rw-r--r--test.py172
1 files changed, 0 insertions, 172 deletions
diff --git a/test.py b/test.py
deleted file mode 100644
index ff88000..0000000
--- a/test.py
+++ /dev/null
@@ -1,172 +0,0 @@
-#!/usr/bin/python3
-# https://github.com/ChunML/NLP/blob/master/text_generation/train_pt.py
-
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-
-import numpy as np
-from collections import Counter
-import os
-from argparse import Namespace
-
-
-flags = Namespace(
- train_file='data',
- seq_size=32,
- batch_size=256,
- embedding_size=64,
- lstm_size=64,
- gradients_norm=5,
- initial_words=['i', 'am'],
- predict_top_k=3,
- checkpoint_path='checkpoint',
-)
-
-
-def get_data_from_file(train_file, batch_size, seq_size):
- with open(train_file, 'r', encoding='utf-8') as f:
- text = f.read()
- text = text.split()
-
- word_counts = Counter(text)
- sorted_vocab = sorted(word_counts, key=word_counts.get, reverse=True)
- int_to_vocab = {k: w for k, w in enumerate(sorted_vocab)}
- vocab_to_int = {w: k for k, w in int_to_vocab.items()}
- n_vocab = len(int_to_vocab)
-
- print('Vocabulary size', n_vocab)
-
- int_text = [vocab_to_int[w] for w in text]
- num_batches = int(len(int_text) / (seq_size * batch_size))
- in_text = int_text[:num_batches * batch_size * seq_size]
- out_text = np.zeros_like(in_text)
- out_text[:-1] = in_text[1:]
- out_text[-1] = in_text[0]
- in_text = np.reshape(in_text, (batch_size, -1))
- out_text = np.reshape(out_text, (batch_size, -1))
- return int_to_vocab, vocab_to_int, n_vocab, in_text, out_text
-
-
-def get_batches(in_text, out_text, batch_size, seq_size):
- num_batches = np.prod(in_text.shape) // (seq_size * batch_size)
- for i in range(0, num_batches * seq_size, seq_size):
- yield in_text[:, i:i+seq_size], out_text[:, i:i+seq_size]
-
-
-class RNNModule(nn.Module):
- def __init__(self, n_vocab, seq_size, embedding_size, lstm_size):
- super(RNNModule, self).__init__()
- self.seq_size = seq_size
- self.lstm_size = lstm_size
- self.embedding = nn.Embedding(n_vocab, embedding_size)
- self.lstm = nn.LSTM(embedding_size,
- lstm_size,
- batch_first=True)
- self.dense = nn.Linear(lstm_size, n_vocab)
-
- def forward(self, x, prev_state):
- embed = self.embedding(x)
- output, state = self.lstm(embed, prev_state)
- logits = self.dense(output)
-
- return logits, state
-
- def zero_state(self, batch_size):
- return (torch.zeros(1, batch_size, self.lstm_size),
- torch.zeros(1, batch_size, self.lstm_size))
-
-
-def get_loss_and_train_op(net, lr=0.001):
- criterion = nn.CrossEntropyLoss()
- optimizer = torch.optim.Adam(net.parameters(), lr=lr)
-
- return criterion, optimizer
-
-
-def predict(device, net, words, n_vocab, vocab_to_int, int_to_vocab, top_k=5):
- net.eval()
- words = ['i', 'am']
-
- state_h, state_c = net.zero_state(1)
- state_h = state_h.to(device)
- state_c = state_c.to(device)
- for w in words:
- ix = torch.tensor([[vocab_to_int[w]]]).to(device)
- output, (state_h, state_c) = net(ix, (state_h, state_c))
-
- _, top_ix = torch.topk(output[0], k=top_k)
- choices = top_ix.tolist()
- choice = np.random.choice(choices[0])
-
- words.append(int_to_vocab[choice])
-
- for _ in range(100):
- ix = torch.tensor([[choice]]).to(device)
- output, (state_h, state_c) = net(ix, (state_h, state_c))
-
- _, top_ix = torch.topk(output[0], k=top_k)
- choices = top_ix.tolist()
- choice = np.random.choice(choices[0])
- words.append(int_to_vocab[choice])
-
- print(' '.join(words).encode('utf-8'))
-
-
-def main():
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
- int_to_vocab, vocab_to_int, n_vocab, in_text, out_text = get_data_from_file(
- flags.train_file, flags.batch_size, flags.seq_size)
-
- net = RNNModule(n_vocab, flags.seq_size,
- flags.embedding_size, flags.lstm_size)
- net = net.to(device)
-
- criterion, optimizer = get_loss_and_train_op(net, 0.01)
-
- iteration = 0
-
- for e in range(200):
- batches = get_batches(
- in_text, out_text, flags.batch_size, flags.seq_size)
- state_h, state_c = net.zero_state(flags.batch_size)
- state_h = state_h.to(device)
- state_c = state_c.to(device)
- for x, y in batches:
- iteration += 1
- net.train()
-
- optimizer.zero_grad()
-
- x = torch.tensor(x).to(device)
- y = torch.tensor(y).to(device)
-
- logits, (state_h, state_c) = net(x, (state_h, state_c))
- loss = criterion(logits.transpose(1, 2), y)
-
- loss_value = loss.item()
-
- loss.backward()
-
- state_h = state_h.detach()
- state_c = state_c.detach()
-
- _ = torch.nn.utils.clip_grad_norm_(
- net.parameters(), flags.gradients_norm)
-
- optimizer.step()
-
- if iteration % 1 == 0:
- print('Epoch: {}/{}'.format(e, 20),
- 'Iteration: {}'.format(iteration),
- 'Loss: {}'.format(loss_value))
-
- if iteration % 1000 == 0:
- predict(device, net, flags.initial_words, n_vocab,
- vocab_to_int, int_to_vocab, top_k=3)
- torch.save(net.state_dict(),
- 'checkpoint/model-{}.pth'.format(iteration))
-
-
-if __name__ == '__main__':
- main()