diff options
Diffstat (limited to 'bot_lstm.py')
-rw-r--r-- | bot_lstm.py | 37 |
1 files changed, 37 insertions, 0 deletions
diff --git a/bot_lstm.py b/bot_lstm.py new file mode 100644 index 0000000..c2a2530 --- /dev/null +++ b/bot_lstm.py @@ -0,0 +1,37 @@ +from argparse import ArgumentParser + +import torch +from mastodon import Mastodon + +from dataset import Dataset +from model import Model +from predict import predict + + +parser = ArgumentParser() +parser.add_argument('-t', '--token', help='Mastodon application access token') +parser.add_argument('-i', '--input', default='data', + help='training data input file') +parser.add_argument('-e', '--text', default='i am', + help='initial text for prediction') +parser.add_argument('-d', '--device', default='cpu', + help='device to run the model with') +parser.add_argument('-m', '--model', default='model.pt', + help='path to load saved model') +args = parser.parse_args() + + +mastodon = Mastodon( + access_token=args.token, + api_base_url='https://social.exozy.me/' +) + + +dataset = Dataset(args.input, 32) +device = torch.device(args.device) +model = torch.load(args.model) + + +text = predict(device, model, args.text) +print(text) +# mastodon.status_post(text) |