aboutsummaryrefslogtreecommitdiff
path: root/bot.py
diff options
context:
space:
mode:
Diffstat (limited to 'bot.py')
-rw-r--r--bot.py7
1 files changed, 4 insertions, 3 deletions
diff --git a/bot.py b/bot.py
index f746e26..c3d5274 100644
--- a/bot.py
+++ b/bot.py
@@ -1,6 +1,7 @@
from argparse import ArgumentParser
from random import randint, choice
+from torch import float16
from transformers import AutoTokenizer, AutoModelForCausalLM
@@ -17,8 +18,8 @@ parser.add_argument('-m', '--model', default='model',
args = parser.parse_args()
-tokenizer = AutoTokenizer.from_pretrained('distilgpt2')
-model = AutoModelForCausalLM.from_pretrained(args.model)
+tokenizer = AutoTokenizer.from_pretrained('gpt2-large')
+model = AutoModelForCausalLM.from_pretrained(args.model, torch_dtype=float16).to('cuda')
if args.input is None:
@@ -71,7 +72,7 @@ if args.input is None:
# Run the input through the model
print(args.input)
-inputs = tokenizer.encode(args.input, return_tensors='pt')
+inputs = tokenizer.encode(args.input, return_tensors='pt').to('cuda')
output = tokenizer.decode(model.generate(
inputs, do_sample=True, max_length=150, top_p=0.9)[0])
print(output)