diff options
author | Anthony Wang | 2022-02-22 16:58:19 -0600 |
---|---|---|
committer | Anthony Wang | 2022-02-22 16:58:19 -0600 |
commit | 6bab795fe83d683cc0bfbbc875ccdef70fb4112c (patch) | |
tree | 2cd587e7f3c463adde4322d9f6051acf2cd67674 | |
parent | edd4708123a9a87669893797a61dd42401de21dd (diff) |
Add new transformers bot script and move old one to bot_lstm.py
-rw-r--r-- | bot.py | 38 | ||||
-rw-r--r-- | bot_lstm.py | 37 |
2 files changed, 53 insertions, 22 deletions
@@ -1,37 +1,31 @@ from argparse import ArgumentParser -import torch from mastodon import Mastodon - -from dataset import Dataset -from model import Model -from predict import predict +from transformers import AutoTokenizer, AutoModelForCausalLM parser = ArgumentParser() parser.add_argument('-t', '--token', help='Mastodon application access token') -parser.add_argument('-i', '--input', default='data', - help='training data input file') -parser.add_argument('-e', '--text', default='i am', - help='initial text for prediction') -parser.add_argument('-d', '--device', default='cpu', - help='device to run the model with') -parser.add_argument('-m', '--model', default='model.pt', +parser.add_argument('-i', '--input', default='i am', + help='initial input text for prediction') +parser.add_argument('-m', '--model', default='model', help='path to load saved model') args = parser.parse_args() -mastodon = Mastodon( - access_token=args.token, - api_base_url='https://social.exozy.me/' -) +tokenizer = AutoTokenizer.from_pretrained('distilgpt2') +model = AutoModelForCausalLM.from_pretrained(args.model) -dataset = Dataset(args.input, 32) -device = torch.device(args.device) -model = torch.load(args.model) +# Run the input through the model +inputs = tokenizer.encode(args.input, return_tensors="pt") +output = tokenizer.decode(model.generate(inputs, do_sample=True, max_length=25, top_p=0.9, temperature=0.8)[0]) +print(output) -text = predict(device, model, args.text) -print(text) -# mastodon.status_post(text) +# Post it to Mastodon +mastodon = Mastodon( + access_token=args.token, + api_base_url='https://social.exozy.me/' +) +mastodon.status_post(output) diff --git a/bot_lstm.py b/bot_lstm.py new file mode 100644 index 0000000..c2a2530 --- /dev/null +++ b/bot_lstm.py @@ -0,0 +1,37 @@ +from argparse import ArgumentParser + +import torch +from mastodon import Mastodon + +from dataset import Dataset +from model import Model +from predict import predict + + +parser = ArgumentParser() +parser.add_argument('-t', '--token', help='Mastodon application access token') +parser.add_argument('-i', '--input', default='data', + help='training data input file') +parser.add_argument('-e', '--text', default='i am', + help='initial text for prediction') +parser.add_argument('-d', '--device', default='cpu', + help='device to run the model with') +parser.add_argument('-m', '--model', default='model.pt', + help='path to load saved model') +args = parser.parse_args() + + +mastodon = Mastodon( + access_token=args.token, + api_base_url='https://social.exozy.me/' +) + + +dataset = Dataset(args.input, 32) +device = torch.device(args.device) +model = torch.load(args.model) + + +text = predict(device, model, args.text) +print(text) +# mastodon.status_post(text) |