From 6bab795fe83d683cc0bfbbc875ccdef70fb4112c Mon Sep 17 00:00:00 2001 From: Anthony Wang Date: Tue, 22 Feb 2022 16:58:19 -0600 Subject: Add new transformers bot script and move old one to bot_lstm.py --- bot.py | 38 ++++++++++++++++---------------------- bot_lstm.py | 37 +++++++++++++++++++++++++++++++++++++ 2 files changed, 53 insertions(+), 22 deletions(-) create mode 100644 bot_lstm.py diff --git a/bot.py b/bot.py index c2a2530..8840a28 100644 --- a/bot.py +++ b/bot.py @@ -1,37 +1,31 @@ from argparse import ArgumentParser -import torch from mastodon import Mastodon - -from dataset import Dataset -from model import Model -from predict import predict +from transformers import AutoTokenizer, AutoModelForCausalLM parser = ArgumentParser() parser.add_argument('-t', '--token', help='Mastodon application access token') -parser.add_argument('-i', '--input', default='data', - help='training data input file') -parser.add_argument('-e', '--text', default='i am', - help='initial text for prediction') -parser.add_argument('-d', '--device', default='cpu', - help='device to run the model with') -parser.add_argument('-m', '--model', default='model.pt', +parser.add_argument('-i', '--input', default='i am', + help='initial input text for prediction') +parser.add_argument('-m', '--model', default='model', help='path to load saved model') args = parser.parse_args() -mastodon = Mastodon( - access_token=args.token, - api_base_url='https://social.exozy.me/' -) +tokenizer = AutoTokenizer.from_pretrained('distilgpt2') +model = AutoModelForCausalLM.from_pretrained(args.model) -dataset = Dataset(args.input, 32) -device = torch.device(args.device) -model = torch.load(args.model) +# Run the input through the model +inputs = tokenizer.encode(args.input, return_tensors="pt") +output = tokenizer.decode(model.generate(inputs, do_sample=True, max_length=25, top_p=0.9, temperature=0.8)[0]) +print(output) -text = predict(device, model, args.text) -print(text) -# mastodon.status_post(text) +# Post it to Mastodon +mastodon = Mastodon( + access_token=args.token, + api_base_url='https://social.exozy.me/' +) +mastodon.status_post(output) diff --git a/bot_lstm.py b/bot_lstm.py new file mode 100644 index 0000000..c2a2530 --- /dev/null +++ b/bot_lstm.py @@ -0,0 +1,37 @@ +from argparse import ArgumentParser + +import torch +from mastodon import Mastodon + +from dataset import Dataset +from model import Model +from predict import predict + + +parser = ArgumentParser() +parser.add_argument('-t', '--token', help='Mastodon application access token') +parser.add_argument('-i', '--input', default='data', + help='training data input file') +parser.add_argument('-e', '--text', default='i am', + help='initial text for prediction') +parser.add_argument('-d', '--device', default='cpu', + help='device to run the model with') +parser.add_argument('-m', '--model', default='model.pt', + help='path to load saved model') +args = parser.parse_args() + + +mastodon = Mastodon( + access_token=args.token, + api_base_url='https://social.exozy.me/' +) + + +dataset = Dataset(args.input, 32) +device = torch.device(args.device) +model = torch.load(args.model) + + +text = predict(device, model, args.text) +print(text) +# mastodon.status_post(text) -- cgit v1.2.3-70-g09d2