forked from dreamgonfly/transformer-pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpredictors.py
68 lines (50 loc) · 2.9 KB
/
predictors.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
from beam import Beam
from utils.pad import pad_masking
import torch
class Predictor:
def __init__(self, preprocess, postprocess, model, checkpoint_filepath, max_length=30, beam_size=8):
self.preprocess = preprocess
self.postprocess = postprocess
self.model = model
self.max_length = max_length
self.beam_size = beam_size
self.model.eval()
checkpoint = torch.load(checkpoint_filepath, map_location='cpu')
self.model.load_state_dict(checkpoint)
def predict_one(self, source, num_candidates=5):
source_preprocessed = self.preprocess(source)
source_tensor = torch.tensor(source_preprocessed).unsqueeze(0) # why unsqueeze?
length_tensor = torch.tensor(len(source_preprocessed)).unsqueeze(0)
sources_mask = pad_masking(source_tensor, source_tensor.size(1))
memory_mask = pad_masking(source_tensor, 1)
memory = self.model.encoder(source_tensor, sources_mask)
decoder_state = self.model.decoder.init_decoder_state()
# print('decoder_state src', decoder_state.src.shape)
# print('previous_input previous_input', decoder_state.previous_input)
# print('previous_input previous_layer_inputs ', decoder_state.previous_layer_inputs)
# Repeat beam_size times
memory_beam = memory.detach().repeat(self.beam_size, 1, 1) # (beam_size, seq_len, hidden_size)
beam = Beam(beam_size=self.beam_size, min_length=0, n_top=num_candidates, ranker=None)
for _ in range(self.max_length):
new_inputs = beam.get_current_state().unsqueeze(1) # (beam_size, seq_len=1)
decoder_outputs, decoder_state = self.model.decoder(new_inputs, memory_beam,
memory_mask,
state=decoder_state)
# decoder_outputs: (beam_size, target_seq_len=1, vocabulary_size)
# attentions['std']: (target_seq_len=1, beam_size, source_seq_len)
attention = self.model.decoder.decoder_layers[-1].memory_attention_layer.sublayer.attention
beam.advance(decoder_outputs.squeeze(1), attention)
beam_current_origin = beam.get_current_origin() # (beam_size, )
decoder_state.beam_update(beam_current_origin)
if beam.done():
break
scores, ks = beam.sort_finished(minimum=num_candidates)
hypothesises, attentions = [], []
for i, (times, k) in enumerate(ks[:num_candidates]):
hypothesis, attention = beam.get_hypothesis(times, k)
hypothesises.append(hypothesis)
attentions.append(attention)
self.attentions = attentions
self.hypothesises = [[token.item() for token in h] for h in hypothesises]
hs = [self.postprocess(h) for h in self.hypothesises]
return list(reversed(hs))