aboutsummaryrefslogtreecommitdiff
path: root/bot.py
blob: d285471ac726f5886a4426934fffb6bd7d52eba3 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
from argparse import ArgumentParser
from random import randint, choice

from torch import float16
from transformers import AutoTokenizer, AutoModelForCausalLM


parser = ArgumentParser()
parser.add_argument('-b', '--backend', choices=['mastodon', 'misskey', 'matrix', 'none'], default='mastodon',
                    help='fediverse server type')
parser.add_argument('-i', '--instance', help='Mastodon instance hosting the bot')
parser.add_argument('-t', '--token', help='Mastodon application access token')
parser.add_argument('-n', '--input', help='initial input text')
parser.add_argument('-d', '--data', default='data',
                    help='data for automatic input generation')
parser.add_argument('-m', '--model', default='model',
                    help='path to load saved model')
args = parser.parse_args()


tokenizer = AutoTokenizer.from_pretrained('gpt2-medium')
model = AutoModelForCausalLM.from_pretrained(args.model, low_cpu_mem_usage=True).to('cuda')


if args.input is None:
    # Create random input
    if randint(0, 1) == 0:
        args.input = choice([
            'I am',
            'My life is',
            'Computers are',
            'This is',
            'My',
            'I\'ve',
            'No one',
            'I love',
            'I will die of',
            'I',
            'The',
            'Anime',
            'I\'m going to die',
            'Hello',
            '@ta180m@exozy.me',
            'Life',
            'My favorite',
            'I\'m not',
            'I hate',
            'I think',
            'In my opinion',
            'Breaking news:',
            'Have I ever told you that',
            'I read on the news that',
            'I never knew that',
            'My dream is',
            'It\'s terrible that'
        ])
    else:
        with open(args.data, 'r') as f:
            # Get a line with at least two words
            lines = f.readlines()
            line = choice(lines).split()
            while len(line) < 2:
                line = choice(lines).split()
            
            # Remove mentions
            if line[0].count('@') > 1:
                line[0] = '@'.join(line[0].split('@')[0:2])
            if line[1].count('@') > 1:
                line[1] = '@'.join(line[1].split('@')[0:2])
            args.input = line[0] + ' ' + line[1]


# Run the input through the model
print(args.input)
inputs = tokenizer.encode(args.input, return_tensors='pt').to('cuda')
output = tokenizer.decode(model.generate(
    inputs, max_length=150, do_sample=True, top_p=0.9)[0])
print(output)


# Prepare the post
output = output.split('\n')
post = output[0]
if len(post) < 200 and len(output) > 1:
    post = output[0] + '\n' + output[1]
post = post[:500]


# Post it!
if args.backend == 'mastodon':
    from mastodon import Mastodon

    mastodon = Mastodon(
        access_token=args.token,
        api_base_url=args.instance
    )
    mastodon.status_post(post)

elif args.backend == 'misskey':
    from Misskey import Misskey

    misskey = Misskey(args.instance, i=args.token)
    misskey.notes_create(post)

elif args.backend == 'matrix':
    import simplematrixbotlib as botlib

    creds = botlib.Creds(args.instance, 'ebooks', args.token)
    bot = botlib.Bot(creds)

    @bot.listener.on_startup
    async def room_joined(room_id):
        await bot.api.send_text_message(room_id=room_id, message=post)

    bot.run()