Skip to content

Commit

Permalink
Remove redundant code from T5 encoder mask creation (huggingface#27216)
Browse files Browse the repository at this point in the history
* remove redundant code

* update

* add typecasting

* make `attention_mask` float again
  • Loading branch information
pietrolesci authored Nov 2, 2023
1 parent a6c82d4 commit 147e8ce
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 18 deletions.
15 changes: 6 additions & 9 deletions src/transformers/models/mt5/modeling_mt5.py
Original file line number Diff line number Diff line change
Expand Up @@ -997,18 +997,13 @@ def forward(
if not self.is_decoder:
raise ValueError(f"`use_cache` can only be set to `True` if {self} is used as a decoder")

if attention_mask is None:
attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device)
if self.is_decoder and encoder_attention_mask is None and encoder_hidden_states is not None:
encoder_seq_length = encoder_hidden_states.shape[1]
encoder_attention_mask = torch.ones(
batch_size, encoder_seq_length, device=inputs_embeds.device, dtype=torch.long
)

# initialize past_key_values with `None` if past does not exist
if past_key_values is None:
past_key_values = [None] * len(self.block)

if attention_mask is None:
attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device)

# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape)
Expand All @@ -1019,7 +1014,9 @@ def forward(
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
if encoder_attention_mask is None:
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=inputs_embeds.device)
encoder_attention_mask = torch.ones(
encoder_hidden_shape, device=inputs_embeds.device, dtype=torch.long
)
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
else:
encoder_extended_attention_mask = None
Expand Down
15 changes: 6 additions & 9 deletions src/transformers/models/t5/modeling_t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -1024,18 +1024,13 @@ def forward(
if not self.is_decoder:
raise ValueError(f"`use_cache` can only be set to `True` if {self} is used as a decoder")

if attention_mask is None:
attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device)
if self.is_decoder and encoder_attention_mask is None and encoder_hidden_states is not None:
encoder_seq_length = encoder_hidden_states.shape[1]
encoder_attention_mask = torch.ones(
batch_size, encoder_seq_length, device=inputs_embeds.device, dtype=torch.long
)

# initialize past_key_values with `None` if past does not exist
if past_key_values is None:
past_key_values = [None] * len(self.block)

if attention_mask is None:
attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device)

# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape)
Expand All @@ -1046,7 +1041,9 @@ def forward(
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
if encoder_attention_mask is None:
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=inputs_embeds.device)
encoder_attention_mask = torch.ones(
encoder_hidden_shape, device=inputs_embeds.device, dtype=torch.long
)
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
else:
encoder_extended_attention_mask = None
Expand Down

0 comments on commit 147e8ce

Please sign in to comment.