aboutsummaryrefslogtreecommitdiff
path: root/predict.py
blob: 53447b940d9bf6f449a03ab1b1694050ef9119aa (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
import numpy as np
import torch


def predict(device, dataset, model, text, next_words=100, top_k=3):
    model.eval()

    words = text.split()

    state_h, state_c = model.zero_state(1)
    state_h = state_h.to(device)
    state_c = state_c.to(device)

    for word in words:
        ix = torch.tensor([[dataset.word_to_index[word]]]).to(device)
        output, (state_h, state_c) = model(ix, (state_h, state_c))

    for i in range(next_words):
        _, top_ix = torch.topk(output[0], k=top_k)
        choices = top_ix.tolist()
        choice = np.random.choice(choices[0])

        words.append(dataset.index_to_word[choice])

        ix = torch.tensor([[choice]]).to(device)
        output, (state_h, state_c) = model(ix, (state_h, state_c))

    return words