Beam search

1.简单实现

import torch

# Beam  search
samples = []
topk = 10
log_prob, v_idx = decoder_outputs.detach().topk(topk)
for k in range(topk):
    samples.append([[v_idx[0][k].item()], log_prob[0][k], decoder_state])
for i in range(max_len):
    new_samples = []
    for sample in samples:
        v_list, score, decoder_state = sample
        if v_list[-1] == de_vocab.item2index['_EOS_']:
            new_samples.append([v_list, score, decoder_state])
            continue

        decoder_inputs = torch.LongTensor([v_list[-1]])
        decoder_outputs, new_states = decoder(decoder_inputs, encoder_output, decoder_state)
        log_prob, v_idx = decoder_outputs.data.topk(topk)

        for k in range(topk):
            new_v_list = []
            new_v_list += v_list + [v_idx[0][k].item()]
            new_samples.append([new_v_list, score + log_prob[0][k], new_states])

    new_samples = sorted(new_samples, key=lambda sample: sample[1], reverse=True)
    samples = new_samples[:topk]
    
v_list, score, states = samples[0]
for v_idx in v_list:
    pred_sent.append(de_vocab.index2item[v_idx])

2.transformers实现

0