Skip to content

Commit

Permalink
change causal mask to -INF
Browse files Browse the repository at this point in the history
  • Loading branch information
smallv0221 committed Jun 9, 2021
1 parent d10e479 commit 1eacd0a
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 67 deletions.
3 changes: 1 addition & 2 deletions examples/language_model/gpt/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,8 +369,7 @@ def _construct_sample(self, tokens):
loss_mask[np.where(np.array(tokens) == self.eos_id)] = 0.0
position_ids = np.arange(0, seq_length, dtype="int64")

# Optional mask method: -INF mask value attention_mask = (attention_mask - 1.0) * 1e9
# Bool mask of attention
attention_mask = (attention_mask - 1.0) * 1e9
attention_mask = attention_mask.astype("float32")
return [tokens, loss_mask, attention_mask, position_ids, labels]

Expand Down
8 changes: 5 additions & 3 deletions examples/language_model/gpt/run_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ def parse_args():
'--top_k',
type=int,
default=5,
help='The number of highest probability vocabulary tokens to keep for top-k sampling.'
help=
'The number of highest probability vocabulary tokens to keep for top-k sampling.'
)
parser.add_argument(
'--temperature',
Expand All @@ -62,7 +63,8 @@ def parse_args():
'--early_stopping',
type=eval,
default=False,
help='Whether to stop the beam search when at least `num_beams` sentences are finished per batch or not.'
help=
'Whether to stop the beam search when at least `num_beams` sentences are finished per batch or not.'
)
parser.add_argument(
'--min_dec_len',
Expand Down Expand Up @@ -166,5 +168,5 @@ def main(args, input_text):

if __name__ == "__main__":
args = parse_args()
input_text = '默写古诗: 大漠孤烟直,长河落日圆。\n举杯邀明月,'
input_text = '花间一壶酒,独酌无相亲。举杯邀明月,'
main(args, input_text)
119 changes: 57 additions & 62 deletions paddlenlp/transformers/gpt/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,6 @@ def forward(self,
query,
key,
value,
causal_mask=None,
attn_mask=None,
use_cache=False,
cache=None):
Expand All @@ -197,10 +196,7 @@ def forward(self,
# scale dot product attention
product = layers.matmul(
x=q, y=k, transpose_y=True, alpha=self.head_dim**-0.5)
if causal_mask is not None:
product = product * causal_mask
mask_score = (causal_mask - 1.0) * 10000.0
product = product + mask_score

if attn_mask is not None:
product = product + attn_mask

Expand Down Expand Up @@ -254,7 +250,6 @@ def __init__(self,

def forward(self,
tgt,
causal_mask,
memory,
tgt_mask=None,
memory_mask=None,
Expand All @@ -272,28 +267,29 @@ def forward(self,
for i, mod in enumerate(self.layers):
if cache is None:
if use_cache:
output, new_cache = mod(output,
causal_mask,
memory,
tgt_mask=tgt_mask,
use_cache=use_cache,
cache=cache)
output, new_cache = mod(
output,
memory,
tgt_mask=tgt_mask,
use_cache=use_cache,
cache=cache)
new_caches.append(new_cache)
else:
output = mod(output,
causal_mask,
memory,
tgt_mask=tgt_mask,
use_cache=use_cache,
cache=cache)
output = mod(
output,
causal_mask,
memory,
tgt_mask=tgt_mask,
use_cache=use_cache,
cache=cache)

else:
output, new_cache = mod(output,
causal_mask,
memory,
tgt_mask=tgt_mask,
use_cache=use_cache,
cache=cache[i])
output, new_cache = mod(
output,
memory,
tgt_mask=tgt_mask,
use_cache=use_cache,
cache=cache[i])
new_caches.append(new_cache)
self.checkpoints.append(output.name)

Expand Down Expand Up @@ -382,24 +378,17 @@ def __init__(self,
self.dropout2 = nn.Dropout(act_dropout, mode="upscale_in_train")
self.activation = getattr(F, activation)

def forward(self,
tgt,
causal_mask,
memory,
tgt_mask=None,
use_cache=False,
cache=None):
def forward(self, tgt, memory, tgt_mask=None, use_cache=False, cache=None):
residual = tgt

if self.normalize_before:
tgt = self.norm1(tgt)

if use_cache is False:
tgt = self.self_attn(tgt, tgt, tgt, causal_mask, tgt_mask,
use_cache, cache)
tgt = self.self_attn(tgt, tgt, tgt, tgt_mask, use_cache, cache)
else:
tgt, incremental_cache = self.self_attn(tgt, tgt, tgt, causal_mask,
tgt_mask, use_cache, cache)
tgt, incremental_cache = self.self_attn(tgt, tgt, tgt, tgt_mask,
use_cache, cache)
tgt = residual + self.dropout1(tgt)
if not self.normalize_before:
tgt = self.norm1(tgt)
Expand All @@ -408,8 +397,7 @@ def forward(self,
if self.normalize_before:
tgt = self.norm2(tgt)
tgt = self.dropout2(
self.linear2(F.gelu(
self.linear1(tgt), approximate=True)))
self.linear2(F.gelu(self.linear1(tgt), approximate=True)))
tgt = residual + tgt

if not self.normalize_before:
Expand Down Expand Up @@ -606,9 +594,9 @@ def init_weights(self, layer):
layer.weight.set_value(
paddle.tensor.normal(
mean=0.0,
std=self.initializer_range
if hasattr(self, "initializer_range") else
self.gpt.config["initializer_range"],
std=self.initializer_range if hasattr(
self, "initializer_range") else self.gpt.config[
"initializer_range"],
shape=layer.weight.shape))


Expand Down Expand Up @@ -639,9 +627,10 @@ def __init__(self,
self.pad_token_id = pad_token_id
self.initializer_range = initializer_range
self.register_buffer("bias",
paddle.tensor.tril(
paddle.tensor.triu(
paddle.ones((max_position_embeddings,
max_position_embeddings))))
max_position_embeddings)) * -1e9,
diagonal=1))
self.topo = topo
self.hidden_size = hidden_size
self.vocab_size = vocab_size
Expand Down Expand Up @@ -677,8 +666,8 @@ def __init__(self,
topo=topo))

if self.pipline_mode:
Decoder = paddlenlp.ops.guard('gpu:{}'.format(
self.topo.pp_info.size - 1))(TransformerDecoder)
Decoder = paddlenlp.ops.guard(
'gpu:{}'.format(self.topo.pp_info.size - 1))(TransformerDecoder)
else:
Decoder = TransformerDecoder

Expand Down Expand Up @@ -709,15 +698,18 @@ def forward(self,
dtype='int64')
position_ids = position_ids.unsqueeze(0)
# .expand_as(input_ids)
position_ids = paddle.fluid.layers.expand_as(position_ids,
input_ids)
position_ids = paddle.fluid.layers.expand_as(
position_ids, input_ids)
embedding_output = self.embeddings(
input_ids=input_ids, position_ids=position_ids)
causal_mask = self.bias[:paddle.shape(input_ids)[-1], :paddle.shape(
input_ids)[-1]]
if attention_mask is not None:
attention_mask = attention_mask + causal_mask
else:
attention_mask = causal_mask
encoder_outputs = self.decoder(
embedding_output,
causal_mask,
memory=None,
tgt_mask=attention_mask,
use_cache=use_cache,
Expand Down Expand Up @@ -745,11 +737,12 @@ def forward(self,
masked_positions=None,
use_cache=False,
cache=None):
outputs = self.gpt(input_ids,
position_ids=position_ids,
attention_mask=attention_mask,
use_cache=use_cache,
cache=cache)
outputs = self.gpt(
input_ids,
position_ids=position_ids,
attention_mask=attention_mask,
use_cache=use_cache,
cache=cache)
if use_cache:
encoder_outputs, cached_kvs = outputs[:2]
else:
Expand Down Expand Up @@ -805,11 +798,12 @@ def model(self,
masked_positions=None,
use_cache=False,
cache=None):
outputs = self.gpt(input_ids,
position_ids=position_ids,
attention_mask=attention_mask,
use_cache=use_cache,
cache=cache)
outputs = self.gpt(
input_ids,
position_ids=position_ids,
attention_mask=attention_mask,
use_cache=use_cache,
cache=cache)
if use_cache:
encoder_outputs, cached_kvs = outputs[:2]
else:
Expand Down Expand Up @@ -871,11 +865,12 @@ def forward(self,
attention_mask=None,
use_cache=False,
cache=None):
outputs = self.gpt(input_ids,
position_ids=position_ids,
attention_mask=attention_mask,
use_cache=use_cache,
cache=cache)
outputs = self.gpt(
input_ids,
position_ids=position_ids,
attention_mask=attention_mask,
use_cache=use_cache,
cache=cache)

if use_cache:
encoder_outputs, cached_kvs = outputs[:2]
Expand Down

0 comments on commit 1eacd0a

Please sign in to comment.