Skip to content

Commit

Permalink
Fix diversity rate bug (PaddlePaddle#1477)
Browse files Browse the repository at this point in the history
* update perf

* fix doc and constrains for FasterGeneration

* update readme

* fix diversity rate bug
  • Loading branch information
smallv0221 authored Dec 16, 2021
1 parent 4e59ce0 commit cc58a23
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 4 deletions.
2 changes: 1 addition & 1 deletion examples/faster/faster_generation/samples/unimo_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def postprocess_response(token_ids, tokenizer):
add_start_token_for_decoding=True,
return_tensors=True,
is_split_into_words=False)
model.eval()

outputs, _ = model.generate(
input_ids=inputs_ids['input_ids'],
token_type_ids=inputs_ids['token_type_ids'],
Expand Down
6 changes: 3 additions & 3 deletions paddlenlp/ops/faster_transformer/transformer/decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -1381,7 +1381,7 @@ def forward(self,
_bos_id=bos_token_id,
_eos_id=eos_token_id,
_max_out_len=max_out_len,
_diversity_rate=diversity_rate,
_diversity_rate=-diversity_rate,
_unk_id=self._unk_id,
_mask_id=self._mask_id,
_temperature=temperature,
Expand Down Expand Up @@ -1625,7 +1625,7 @@ def forward(self,
self.linear_weight, self.linear_bias, self.pos_emb,
decoding_strategy, beam_size, top_k, top_p, self._n_head,
int(self._d_model / self._n_head), self._num_decoder_layers,
bos_token_id, eos_token_id, max_out_len, diversity_rate, rel_len,
bos_token_id, eos_token_id, max_out_len, -diversity_rate, rel_len,
alpha, early_stopping)

ids = finalize(
Expand Down Expand Up @@ -1877,7 +1877,7 @@ def forward(self,
self.linear_bias, self.pos_emb, trg_word, decoding_strategy,
beam_size, top_k, top_p, self._n_head,
int(self._d_model / self._n_head), self._num_decoder_layers,
bos_token_id, eos_token_id, max_out_len, diversity_rate, rel_len,
bos_token_id, eos_token_id, max_out_len, -diversity_rate, rel_len,
alpha, temperature, early_stopping, self._hidden_act)

ids = finalize(
Expand Down

0 comments on commit cc58a23

Please sign in to comment.