Skip to content

Commit

Permalink
Fix hotflip bug where vocab items were not re-encoded correctly (alle…
Browse files Browse the repository at this point in the history
…nai#4759)

* Fix hotflip bug where vocab items were not re-encoded correctly

* changelog

Co-authored-by: Evan Pete Walsh <[email protected]>
  • Loading branch information
matt-gardner and epwalsh authored Oct 29, 2020
1 parent aeb6d36 commit 812ac57
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 1 deletion.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

- Made it possible to instantiate `TrainerCallback` from config files.
- Fixed the remaining broken internal links in the API docs.
- Fixed a bug where Hotflip would crash with a model that had multiple TokenIndexers and the input
used rare vocabulary items.
- Fixed a bug where `BeamSearch` would fail if `max_steps` was equal to 1.

## [v1.2.0rc1](https://github.com/allenai/allennlp/releases/tag/v1.2.0rc1) - 2020-10-22
Expand Down
2 changes: 1 addition & 1 deletion allennlp/interpret/attackers/hotflip.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,7 +360,7 @@ def _first_order_taylor(self, grad: numpy.ndarray, token_idx: torch.Tensor, sign
# This happens when we've truncated our fake embedding matrix. We need to do a dot
# product with the word vector of the current token; if that token is out of
# vocabulary for our truncated matrix, we need to run it through the embedding layer.
inputs = self._make_embedder_input([self.vocab.get_token_from_index(token_idx)])
inputs = self._make_embedder_input([self.vocab.get_token_from_index(token_idx.item())])
word_embedding = self.embedding_layer(inputs)[0]
else:
word_embedding = torch.nn.functional.embedding(
Expand Down
11 changes: 11 additions & 0 deletions tests/interpret/hotflip_test.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import torch

from allennlp.common.testing import AllenNlpTestCase
from allennlp.data.token_indexers import TokenCharactersIndexer
from allennlp.interpret.attackers import Hotflip
Expand Down Expand Up @@ -49,3 +51,12 @@ def test_with_token_characters_indexer(self):
assert len(attack["final"][0]) == len(
attack["original"]
) # hotflip replaces words without removing

# This checks for a bug that arose with a change in the pytorch API. We want to be sure we
# can handle the case where we have to re-encode a vocab item because we didn't save it in
# our fake embedding matrix (see Hotflip docstring for more info).
hotflipper = Hotflip(predictor, max_tokens=50)
hotflipper.initialize()
hotflipper._first_order_taylor(
grad=torch.rand((10,)).numpy(), token_idx=torch.tensor(60), sign=1
)

0 comments on commit 812ac57

Please sign in to comment.