aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAnthony Wang2022-02-21 16:39:58 -0600
committerAnthony Wang2022-02-21 16:39:58 -0600
commitb240317d761f1a85c4acbaec38003846576d95d9 (patch)
tree481465ab3f919093087c03239c3b513ff057c5f9
parentf2d33b51b1a682be5a6e661541ddc69626bfaba1 (diff)
Take top 5 instead of top 3
-rw-r--r--predict.py2
1 files changed, 1 insertions, 1 deletions
diff --git a/predict.py b/predict.py
index 53447b9..9ef453d 100644
--- a/predict.py
+++ b/predict.py
@@ -2,7 +2,7 @@ import numpy as np
import torch
-def predict(device, dataset, model, text, next_words=100, top_k=3):
+def predict(device, dataset, model, text, next_words=100, top_k=5):
model.eval()
words = text.split()