aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAnthony Wang2022-07-17 23:10:13 -0500
committerAnthony Wang2022-07-17 23:10:13 -0500
commit36baee49255e60610babcd4bd473fb204797eab9 (patch)
tree76daf7f753746d3922ab9cc74db61e8cb91d9d08
parent37093e5282b401b582ec1fcc9af9dc56e0e4ceef (diff)
Rewrite bot.py to rerun generations and allow editing prompt
-rw-r--r--bot.py103
1 files changed, 60 insertions, 43 deletions
diff --git a/bot.py b/bot.py
index 2ee8e60..86e14f1 100644
--- a/bot.py
+++ b/bot.py
@@ -6,15 +6,19 @@ from transformers import AutoTokenizer, AutoModelForCausalLM
parser = ArgumentParser()
-parser.add_argument('-b', '--backend', choices=['mastodon', 'misskey', 'matrix', 'none'], default='mastodon',
- help='fediverse server type')
-parser.add_argument('-i', '--instance', help='Mastodon instance hosting the bot')
-parser.add_argument('-t', '--token', help='Mastodon application access token')
parser.add_argument('-n', '--input', help='initial input text')
+parser.add_argument('-b', '--backend', choices=['mastodon', 'misskey', 'matrix'],
+ action='append', help='fediverse server type')
+parser.add_argument('-i', '--instance', action='append',
+ help='Mastodon instance hosting the bot')
+parser.add_argument('-t', '--token', action='append',
+ help='Mastodon application access token')
parser.add_argument('-d', '--data', default='data',
help='data for automatic input generation')
parser.add_argument('-m', '--model', default='model',
help='path to load saved model')
+parser.add_argument('-y', '--yes', action='store_true',
+ help='answer yes to all prompts')
args = parser.parse_args()
@@ -22,10 +26,10 @@ tokenizer = AutoTokenizer.from_pretrained('gpt2-large')
model = AutoModelForCausalLM.from_pretrained(args.model).to('cuda')
-if args.input is None:
+def generate_input():
# Create random input
if randint(0, 1) == 0:
- args.input = choice([
+ return choice([
'I am',
'My life is',
'Computers are',
@@ -65,58 +69,71 @@ if args.input is None:
while len(line) < 2:
line = choice(lines).split()
- args.input = line[0] + ' ' + line[1]
+ return line[0] + ' ' + line[1]
-# Run the input through the model
-print(args.input)
-inputs = tokenizer.encode(args.input, return_tensors='pt').to('cuda')
-output = tokenizer.decode(model.generate(
- inputs, max_length=150, do_sample=True, top_p=0.9)[0])
-print(output)
+if args.input is None:
+ args.input = generate_input()
-# Prepare the post
-output = output.split('\n')
-post = output[0]
-if len(post) < 200 and len(output) > 1:
- post = output[0] + '\n' + output[1]
-post = post[:500]
+# Loop until we're satisfied
+while True:
+ # Run the input through the model
+ print(args.input)
+ inputs = tokenizer.encode(args.input, return_tensors='pt').to('cuda')
+ output = tokenizer.decode(model.generate(
+ inputs, max_length=150, do_sample=True, top_p=0.9)[0])
+ print(output)
-# Remove mentions
-post = sub('(@[^ ]*)@[^ ]*', '\\1', post)
+ # Prepare the post
+ output = output.split('\n')
+ post = output[0]
+ if len(post) < 200 and len(output) > 1:
+ post = output[0] + '\n' + output[1]
+ post = post[:500]
+ # Remove mentions
+ post = sub('(@[^ ]*)@[^ ]*', '\\1', post)
-# Quit if no instance specified
-if args.instance is None:
- quit()
+ # Prompt the user
+ res = input('Post/Retry/New input/Custom input/Quit: ')
+ if res not in 'prnPRNcC':
+ quit()
+ if res in 'pP':
+ break
+ if res in 'nN':
+ args.input = generate_input()
+ if res in 'cC':
+ args.input = input('Enter custom input: ')
# Post it!
-if args.backend == 'mastodon':
- from mastodon import Mastodon
+for backend, instance, token in zip(args.backend, args.instance, args.token):
+ if backend == 'mastodon':
+ from mastodon import Mastodon
- mastodon = Mastodon(
- access_token=args.token,
- api_base_url=args.instance
- )
- mastodon.status_post(post)
+ mastodon = Mastodon(
+ access_token=token,
+ api_base_url=instance
+ )
+ mastodon.status_post(post)
-elif args.backend == 'misskey':
- from Misskey import Misskey
+ elif backend == 'misskey':
+ from Misskey import Misskey
- misskey = Misskey(args.instance, i=args.token)
- misskey.notes_create(post)
+ misskey = Misskey(instance, i=token)
+ misskey.notes_create(post)
-elif args.backend == 'matrix':
- import simplematrixbotlib as botlib
+ elif backend == 'matrix':
+ import simplematrixbotlib as botlib
- creds = botlib.Creds(args.instance, 'ebooks', args.token)
- bot = botlib.Bot(creds)
+ creds = botlib.Creds(instance, 'ebooks', token)
+ bot = botlib.Bot(creds)
- @bot.listener.on_startup
- async def room_joined(room_id):
- await bot.api.send_text_message(room_id=room_id, message=post)
+ @bot.listener.on_startup
+ async def room_joined(room_id):
+ await bot.api.send_text_message(room_id=room_id, message=post)
+ quit()
- bot.run()
+ bot.run()