aboutsummaryrefslogtreecommitdiff
path: root/bot.py
diff options
context:
space:
mode:
Diffstat (limited to 'bot.py')
-rw-r--r--bot.py6
1 files changed, 3 insertions, 3 deletions
diff --git a/bot.py b/bot.py
index 5bcac18..d285471 100644
--- a/bot.py
+++ b/bot.py
@@ -18,8 +18,8 @@ parser.add_argument('-m', '--model', default='model',
args = parser.parse_args()
-tokenizer = AutoTokenizer.from_pretrained('gpt2-large')
-model = AutoModelForCausalLM.from_pretrained(args.model).to('cuda')
+tokenizer = AutoTokenizer.from_pretrained('gpt2-medium')
+model = AutoModelForCausalLM.from_pretrained(args.model, low_cpu_mem_usage=True).to('cuda')
if args.input is None:
@@ -74,7 +74,7 @@ if args.input is None:
print(args.input)
inputs = tokenizer.encode(args.input, return_tensors='pt').to('cuda')
output = tokenizer.decode(model.generate(
- inputs, do_sample=True, max_length=150, top_p=0.9)[0])
+ inputs, max_length=150, do_sample=True, top_p=0.9)[0])
print(output)