Skip to content

Commit

Permalink
Merge pull request keon#24 from keon/fix
Browse files Browse the repository at this point in the history
fix attention
  • Loading branch information
keon authored Dec 26, 2019
2 parents 65ebd8e + 9cb885b commit 1ad7c09
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 4 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ This implementation relies on [torchtext](https://github.com/pytorch/text) to mi

download tokenizers by doing so:
```
sudo python3 -m spacy download de
sudo python3 -m spacy download en
python -m spacy download de
python -m spacy download en
```


Expand Down
4 changes: 2 additions & 2 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,11 @@ def forward(self, hidden, encoder_outputs):
h = hidden.repeat(timestep, 1, 1).transpose(0, 1)
encoder_outputs = encoder_outputs.transpose(0, 1) # [B*T*H]
attn_energies = self.score(h, encoder_outputs)
return F.relu(attn_energies, dim=1).unsqueeze(1)
return F.softmax(attn_energies, dim=1).unsqueeze(1)

def score(self, hidden, encoder_outputs):
# [B*T*2H]->[B*T*H]
energy = F.softmax(self.attn(torch.cat([hidden, encoder_outputs], 2)))
energy = F.relu(self.attn(torch.cat([hidden, encoder_outputs], 2)))
energy = energy.transpose(1, 2) # [B*H*T]
v = self.v.repeat(encoder_outputs.size(0), 1).unsqueeze(1) # [B*1*H]
energy = torch.bmm(v, energy) # [B*1*T]
Expand Down

0 comments on commit 1ad7c09

Please sign in to comment.