1.简单实现
import torch # Beam search samples = [] topk = 10 log_prob, v_idx = decoder_outputs.detach().topk(topk) for k in range(topk): samples.append([[v_idx[0][k].item()], log_prob[0][k], decoder_state]) for i in range(max_len): new_samples = [] for sample in samples: v_list, score, decoder_state = sample if v_list[-1] == de_vocab.item2index['_EOS_']: new_samples.append([v_list, score, decoder_state]) continue decoder_inputs = torch.LongTensor([v_list[-1]]) decoder_outputs, new_states = decoder(decoder_inputs, encoder_output, decoder_state) log_prob, v_idx = decoder_outputs.data.topk(topk) for k in range(topk): new_v_list = [] new_v_list += v_list + [v_idx[0][k].item()] new_samples.append([new_v_list, score + log_prob[0][k], new_states]) new_samples = sorted(new_samples, key=lambda sample: sample[1], reverse=True) samples = new_samples[:topk] v_list, score, states = samples[0] for v_idx in v_list: pred_sent.append(de_vocab.index2item[v_idx])
2.transformers实现
0