diff options
-rw-r--r-- | bot.py | 103 |
1 files changed, 60 insertions, 43 deletions
@@ -6,15 +6,19 @@ from transformers import AutoTokenizer, AutoModelForCausalLM parser = ArgumentParser() -parser.add_argument('-b', '--backend', choices=['mastodon', 'misskey', 'matrix', 'none'], default='mastodon', - help='fediverse server type') -parser.add_argument('-i', '--instance', help='Mastodon instance hosting the bot') -parser.add_argument('-t', '--token', help='Mastodon application access token') parser.add_argument('-n', '--input', help='initial input text') +parser.add_argument('-b', '--backend', choices=['mastodon', 'misskey', 'matrix'], + action='append', help='fediverse server type') +parser.add_argument('-i', '--instance', action='append', + help='Mastodon instance hosting the bot') +parser.add_argument('-t', '--token', action='append', + help='Mastodon application access token') parser.add_argument('-d', '--data', default='data', help='data for automatic input generation') parser.add_argument('-m', '--model', default='model', help='path to load saved model') +parser.add_argument('-y', '--yes', action='store_true', + help='answer yes to all prompts') args = parser.parse_args() @@ -22,10 +26,10 @@ tokenizer = AutoTokenizer.from_pretrained('gpt2-large') model = AutoModelForCausalLM.from_pretrained(args.model).to('cuda') -if args.input is None: +def generate_input(): # Create random input if randint(0, 1) == 0: - args.input = choice([ + return choice([ 'I am', 'My life is', 'Computers are', @@ -65,58 +69,71 @@ if args.input is None: while len(line) < 2: line = choice(lines).split() - args.input = line[0] + ' ' + line[1] + return line[0] + ' ' + line[1] -# Run the input through the model -print(args.input) -inputs = tokenizer.encode(args.input, return_tensors='pt').to('cuda') -output = tokenizer.decode(model.generate( - inputs, max_length=150, do_sample=True, top_p=0.9)[0]) -print(output) +if args.input is None: + args.input = generate_input() -# Prepare the post -output = output.split('\n') -post = output[0] -if len(post) < 200 and len(output) > 1: - post = output[0] + '\n' + output[1] -post = post[:500] +# Loop until we're satisfied +while True: + # Run the input through the model + print(args.input) + inputs = tokenizer.encode(args.input, return_tensors='pt').to('cuda') + output = tokenizer.decode(model.generate( + inputs, max_length=150, do_sample=True, top_p=0.9)[0]) + print(output) -# Remove mentions -post = sub('(@[^ ]*)@[^ ]*', '\\1', post) + # Prepare the post + output = output.split('\n') + post = output[0] + if len(post) < 200 and len(output) > 1: + post = output[0] + '\n' + output[1] + post = post[:500] + # Remove mentions + post = sub('(@[^ ]*)@[^ ]*', '\\1', post) -# Quit if no instance specified -if args.instance is None: - quit() + # Prompt the user + res = input('Post/Retry/New input/Custom input/Quit: ') + if res not in 'prnPRNcC': + quit() + if res in 'pP': + break + if res in 'nN': + args.input = generate_input() + if res in 'cC': + args.input = input('Enter custom input: ') # Post it! -if args.backend == 'mastodon': - from mastodon import Mastodon +for backend, instance, token in zip(args.backend, args.instance, args.token): + if backend == 'mastodon': + from mastodon import Mastodon - mastodon = Mastodon( - access_token=args.token, - api_base_url=args.instance - ) - mastodon.status_post(post) + mastodon = Mastodon( + access_token=token, + api_base_url=instance + ) + mastodon.status_post(post) -elif args.backend == 'misskey': - from Misskey import Misskey + elif backend == 'misskey': + from Misskey import Misskey - misskey = Misskey(args.instance, i=args.token) - misskey.notes_create(post) + misskey = Misskey(instance, i=token) + misskey.notes_create(post) -elif args.backend == 'matrix': - import simplematrixbotlib as botlib + elif backend == 'matrix': + import simplematrixbotlib as botlib - creds = botlib.Creds(args.instance, 'ebooks', args.token) - bot = botlib.Bot(creds) + creds = botlib.Creds(instance, 'ebooks', token) + bot = botlib.Bot(creds) - @bot.listener.on_startup - async def room_joined(room_id): - await bot.api.send_text_message(room_id=room_id, message=post) + @bot.listener.on_startup + async def room_joined(room_id): + await bot.api.send_text_message(room_id=room_id, message=post) + quit() - bot.run() + bot.run() |