diff options
Diffstat (limited to 'bot.py')
-rw-r--r-- | bot.py | 38 |
1 files changed, 16 insertions, 22 deletions
@@ -1,37 +1,31 @@ from argparse import ArgumentParser -import torch from mastodon import Mastodon - -from dataset import Dataset -from model import Model -from predict import predict +from transformers import AutoTokenizer, AutoModelForCausalLM parser = ArgumentParser() parser.add_argument('-t', '--token', help='Mastodon application access token') -parser.add_argument('-i', '--input', default='data', - help='training data input file') -parser.add_argument('-e', '--text', default='i am', - help='initial text for prediction') -parser.add_argument('-d', '--device', default='cpu', - help='device to run the model with') -parser.add_argument('-m', '--model', default='model.pt', +parser.add_argument('-i', '--input', default='i am', + help='initial input text for prediction') +parser.add_argument('-m', '--model', default='model', help='path to load saved model') args = parser.parse_args() -mastodon = Mastodon( - access_token=args.token, - api_base_url='https://social.exozy.me/' -) +tokenizer = AutoTokenizer.from_pretrained('distilgpt2') +model = AutoModelForCausalLM.from_pretrained(args.model) -dataset = Dataset(args.input, 32) -device = torch.device(args.device) -model = torch.load(args.model) +# Run the input through the model +inputs = tokenizer.encode(args.input, return_tensors="pt") +output = tokenizer.decode(model.generate(inputs, do_sample=True, max_length=25, top_p=0.9, temperature=0.8)[0]) +print(output) -text = predict(device, model, args.text) -print(text) -# mastodon.status_post(text) +# Post it to Mastodon +mastodon = Mastodon( + access_token=args.token, + api_base_url='https://social.exozy.me/' +) +mastodon.status_post(output) |