aboutsummaryrefslogtreecommitdiff
path: root/bot.py
blob: 86e14f122baaecdc0b31eee422dd62c560fee699 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
from argparse import ArgumentParser
from random import randint, choice
from re import sub

from transformers import AutoTokenizer, AutoModelForCausalLM


parser = ArgumentParser()
parser.add_argument('-n', '--input', help='initial input text')
parser.add_argument('-b', '--backend', choices=['mastodon', 'misskey', 'matrix'],
                    action='append', help='fediverse server type')
parser.add_argument('-i', '--instance', action='append',
                    help='Mastodon instance hosting the bot')
parser.add_argument('-t', '--token', action='append',
                    help='Mastodon application access token')
parser.add_argument('-d', '--data', default='data',
                    help='data for automatic input generation')
parser.add_argument('-m', '--model', default='model',
                    help='path to load saved model')
parser.add_argument('-y', '--yes', action='store_true',
                    help='answer yes to all prompts')
args = parser.parse_args()


tokenizer = AutoTokenizer.from_pretrained('gpt2-large')
model = AutoModelForCausalLM.from_pretrained(args.model).to('cuda')


def generate_input():
    # Create random input
    if randint(0, 1) == 0:
        return choice([
            'I am',
            'My life is',
            'Computers are',
            'This is',
            'My',
            'I\'ve',
            'No one',
            'I love',
            'I will die of',
            'I',
            'The',
            'Anime',
            'I\'m going to die',
            'Hello',
            '@ta180m@exozy.me',
            'Life',
            'My favorite',
            'I\'m not',
            'I hate',
            'I think',
            'In my opinion',
            'Breaking news:',
            'Have I ever told you that',
            'I read on the news that',
            'I never knew that',
            'My dream is',
            'It\'s terrible that',
            'My new theory:',
            'My conspiracy theory',
            'The worst thing'
        ])
    else:
        with open(args.data, 'r') as f:
            # Get a line with at least two words
            lines = f.readlines()
            line = choice(lines).split()
            while len(line) < 2:
                line = choice(lines).split()
            
            return line[0] + ' ' + line[1]


if args.input is None:
    args.input = generate_input()


# Loop until we're satisfied
while True:
    # Run the input through the model
    print(args.input)
    inputs = tokenizer.encode(args.input, return_tensors='pt').to('cuda')
    output = tokenizer.decode(model.generate(
        inputs, max_length=150, do_sample=True, top_p=0.9)[0])
    print(output)


    # Prepare the post
    output = output.split('\n')
    post = output[0]
    if len(post) < 200 and len(output) > 1:
        post = output[0] + '\n' + output[1]
    post = post[:500]
    # Remove mentions
    post = sub('(@[^ ]*)@[^ ]*', '\\1', post)


    # Prompt the user
    res = input('Post/Retry/New input/Custom input/Quit: ')
    if res not in 'prnPRNcC':
        quit()
    if res in 'pP':
        break
    if res in 'nN':
        args.input = generate_input()
    if res in 'cC':
        args.input = input('Enter custom input: ')


# Post it!
for backend, instance, token in zip(args.backend, args.instance, args.token):
    if backend == 'mastodon':
        from mastodon import Mastodon

        mastodon = Mastodon(
            access_token=token,
            api_base_url=instance
        )
        mastodon.status_post(post)

    elif backend == 'misskey':
        from Misskey import Misskey

        misskey = Misskey(instance, i=token)
        misskey.notes_create(post)

    elif backend == 'matrix':
        import simplematrixbotlib as botlib

        creds = botlib.Creds(instance, 'ebooks', token)
        bot = botlib.Bot(creds)

        @bot.listener.on_startup
        async def room_joined(room_id):
            await bot.api.send_text_message(room_id=room_id, message=post)
            quit()

        bot.run()