diff options
author | Anthony Wang | 2022-02-21 15:50:38 -0600 |
---|---|---|
committer | Anthony Wang | 2022-02-21 15:50:38 -0600 |
commit | f2d33b51b1a682be5a6e661541ddc69626bfaba1 (patch) | |
tree | 24572afe59f2f216c9ad76b4cd712dcbc94bdd22 | |
parent | 2f05004e4a058289446de7a451fe85164bac742a (diff) |
Delete old test.py code
-rw-r--r-- | test.py | 172 |
1 files changed, 0 insertions, 172 deletions
diff --git a/test.py b/test.py deleted file mode 100644 index ff88000..0000000 --- a/test.py +++ /dev/null @@ -1,172 +0,0 @@ -#!/usr/bin/python3 -# https://github.com/ChunML/NLP/blob/master/text_generation/train_pt.py - -import torch -import torch.nn as nn -import torch.nn.functional as F - -import numpy as np -from collections import Counter -import os -from argparse import Namespace - - -flags = Namespace( - train_file='data', - seq_size=32, - batch_size=256, - embedding_size=64, - lstm_size=64, - gradients_norm=5, - initial_words=['i', 'am'], - predict_top_k=3, - checkpoint_path='checkpoint', -) - - -def get_data_from_file(train_file, batch_size, seq_size): - with open(train_file, 'r', encoding='utf-8') as f: - text = f.read() - text = text.split() - - word_counts = Counter(text) - sorted_vocab = sorted(word_counts, key=word_counts.get, reverse=True) - int_to_vocab = {k: w for k, w in enumerate(sorted_vocab)} - vocab_to_int = {w: k for k, w in int_to_vocab.items()} - n_vocab = len(int_to_vocab) - - print('Vocabulary size', n_vocab) - - int_text = [vocab_to_int[w] for w in text] - num_batches = int(len(int_text) / (seq_size * batch_size)) - in_text = int_text[:num_batches * batch_size * seq_size] - out_text = np.zeros_like(in_text) - out_text[:-1] = in_text[1:] - out_text[-1] = in_text[0] - in_text = np.reshape(in_text, (batch_size, -1)) - out_text = np.reshape(out_text, (batch_size, -1)) - return int_to_vocab, vocab_to_int, n_vocab, in_text, out_text - - -def get_batches(in_text, out_text, batch_size, seq_size): - num_batches = np.prod(in_text.shape) // (seq_size * batch_size) - for i in range(0, num_batches * seq_size, seq_size): - yield in_text[:, i:i+seq_size], out_text[:, i:i+seq_size] - - -class RNNModule(nn.Module): - def __init__(self, n_vocab, seq_size, embedding_size, lstm_size): - super(RNNModule, self).__init__() - self.seq_size = seq_size - self.lstm_size = lstm_size - self.embedding = nn.Embedding(n_vocab, embedding_size) - self.lstm = nn.LSTM(embedding_size, - lstm_size, - batch_first=True) - self.dense = nn.Linear(lstm_size, n_vocab) - - def forward(self, x, prev_state): - embed = self.embedding(x) - output, state = self.lstm(embed, prev_state) - logits = self.dense(output) - - return logits, state - - def zero_state(self, batch_size): - return (torch.zeros(1, batch_size, self.lstm_size), - torch.zeros(1, batch_size, self.lstm_size)) - - -def get_loss_and_train_op(net, lr=0.001): - criterion = nn.CrossEntropyLoss() - optimizer = torch.optim.Adam(net.parameters(), lr=lr) - - return criterion, optimizer - - -def predict(device, net, words, n_vocab, vocab_to_int, int_to_vocab, top_k=5): - net.eval() - words = ['i', 'am'] - - state_h, state_c = net.zero_state(1) - state_h = state_h.to(device) - state_c = state_c.to(device) - for w in words: - ix = torch.tensor([[vocab_to_int[w]]]).to(device) - output, (state_h, state_c) = net(ix, (state_h, state_c)) - - _, top_ix = torch.topk(output[0], k=top_k) - choices = top_ix.tolist() - choice = np.random.choice(choices[0]) - - words.append(int_to_vocab[choice]) - - for _ in range(100): - ix = torch.tensor([[choice]]).to(device) - output, (state_h, state_c) = net(ix, (state_h, state_c)) - - _, top_ix = torch.topk(output[0], k=top_k) - choices = top_ix.tolist() - choice = np.random.choice(choices[0]) - words.append(int_to_vocab[choice]) - - print(' '.join(words).encode('utf-8')) - - -def main(): - device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') - int_to_vocab, vocab_to_int, n_vocab, in_text, out_text = get_data_from_file( - flags.train_file, flags.batch_size, flags.seq_size) - - net = RNNModule(n_vocab, flags.seq_size, - flags.embedding_size, flags.lstm_size) - net = net.to(device) - - criterion, optimizer = get_loss_and_train_op(net, 0.01) - - iteration = 0 - - for e in range(200): - batches = get_batches( - in_text, out_text, flags.batch_size, flags.seq_size) - state_h, state_c = net.zero_state(flags.batch_size) - state_h = state_h.to(device) - state_c = state_c.to(device) - for x, y in batches: - iteration += 1 - net.train() - - optimizer.zero_grad() - - x = torch.tensor(x).to(device) - y = torch.tensor(y).to(device) - - logits, (state_h, state_c) = net(x, (state_h, state_c)) - loss = criterion(logits.transpose(1, 2), y) - - loss_value = loss.item() - - loss.backward() - - state_h = state_h.detach() - state_c = state_c.detach() - - _ = torch.nn.utils.clip_grad_norm_( - net.parameters(), flags.gradients_norm) - - optimizer.step() - - if iteration % 1 == 0: - print('Epoch: {}/{}'.format(e, 20), - 'Iteration: {}'.format(iteration), - 'Loss: {}'.format(loss_value)) - - if iteration % 1000 == 0: - predict(device, net, flags.initial_words, n_vocab, - vocab_to_int, int_to_vocab, top_k=3) - torch.save(net.state_dict(), - 'checkpoint/model-{}.pth'.format(iteration)) - - -if __name__ == '__main__': - main() |