Skip to content

Commit

Permalink
[W2V2 with LM] Fix decoder test with params (huggingface#21277)
Browse files Browse the repository at this point in the history
  • Loading branch information
sanchit-gandhi authored Jan 24, 2023
1 parent 94a7edd commit 14d058b
Showing 1 changed file with 10 additions and 3 deletions.
13 changes: 10 additions & 3 deletions tests/models/wav2vec2_with_lm/test_processor_wav2vec2_with_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,6 @@ def test_decoder_batch(self, pool_context):
self.assertListEqual(logit_scores_decoder, decoded_processor.logit_score)
self.assertListEqual(lm_scores_decoder, decoded_processor.lm_score)

@unittest.skip("Fix me Sanchit")
def test_decoder_with_params(self):
feature_extractor = self.get_feature_extractor()
tokenizer = self.get_tokenizer()
Expand All @@ -240,7 +239,7 @@ def test_decoder_with_params(self):

logits = self._get_dummy_logits()

beam_width = 20
beam_width = 15
beam_prune_logp = -20.0
token_min_logp = -4.0

Expand All @@ -264,9 +263,17 @@ def test_decoder_with_params(self):
)

decoded_decoder = [d[0][0] for d in decoded_decoder_out]
logit_scores = [d[0][2] for d in decoded_decoder_out]
lm_scores = [d[0][3] for d in decoded_decoder_out]

self.assertListEqual(decoded_decoder, decoded_processor)
self.assertListEqual(["<s> </s> </s>", "<s> <s> </s>"], decoded_processor)
self.assertListEqual(["</s> <s> <s>", "<s> <s> <s>"], decoded_processor)

self.assertTrue(np.array_equal(logit_scores, decoded_processor_out.logit_score))
self.assertTrue(np.allclose([-20.054, -18.447], logit_scores, atol=1e-3))

self.assertTrue(np.array_equal(lm_scores, decoded_processor_out.lm_score))
self.assertTrue(np.allclose([-15.554, -13.9474], lm_scores, atol=1e-3))

def test_decoder_with_params_of_lm(self):
feature_extractor = self.get_feature_extractor()
Expand Down

0 comments on commit 14d058b

Please sign in to comment.