Skip to content

Commit

Permalink
Add GPNRoFormer
Browse files Browse the repository at this point in the history
  • Loading branch information
gonzalobenegas committed Mar 7, 2023
1 parent 4361a6c commit f4c86bd
Show file tree
Hide file tree
Showing 2 changed files with 126 additions and 67 deletions.
173 changes: 106 additions & 67 deletions gpn/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@
from transformers.modeling_outputs import MaskedLMOutput, BaseModelOutput
from typing import Optional, Tuple, Union

from .modules import TransposeLayer, ConvLayer, OneHotEmbedding, get_dilation_schedule
from .modules import (
TransposeLayer, ConvLayer, OneHotEmbedding, get_dilation_schedule,
GPNEmbedding,
)


class ConvNetConfig(PretrainedConfig):
Expand Down Expand Up @@ -136,69 +139,105 @@ def forward(self, input_ids=None, labels=None, loss_weight=None, **kwargs):
AutoModel.register(ConvNetConfig, ConvNetModel)
AutoModelForMaskedLM.register(ConvNetConfig, ConvNetForMaskedLM)

from transformers import BertForMaskedLM, RoFormerForMaskedLM


# modifying to have weighted loss
def RoFormerForMaskedLM_forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
loss_weight=None,
) -> Union[MaskedLMOutput, Tuple[torch.Tensor]]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

outputs = self.roformer(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)

sequence_output = outputs[0]
prediction_scores = self.cls(sequence_output)

masked_lm_loss = None
if labels is not None:
logits = prediction_scores
loss_fct = CrossEntropyLoss(reduction="none")
labels = labels.view(-1)
loss = loss_fct(logits.view(-1, self.config.vocab_size), labels)
loss_weight = loss_weight.view(-1)
loss_weight[labels==-100] = 0.0
loss = (loss * loss_weight / loss_weight.sum()).sum()
masked_lm_loss = loss

if not return_dict:
output = (prediction_scores,) + outputs[1:]
return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output

return MaskedLMOutput(
loss=masked_lm_loss,
logits=prediction_scores,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)

RoFormerForMaskedLM.forward = RoFormerForMaskedLM_forward

from transformers import RoFormerConfig, RoFormerModel, RoFormerForMaskedLM
from transformers.models.roformer.modeling_roformer import RoFormerEncoder, RoFormerOnlyMLMHead, RoFormerSinusoidalPositionalEmbedding


class GPNRoFormerConfig(RoFormerConfig):
model_type = "GPNRoFormer"

def __init__(
self,
n_aux_features=0,
**kwargs
):
super().__init__(**kwargs)
self.n_aux_features = n_aux_features


class GPNRoFormerPreTrainedModel(PreTrainedModel):
config_class = GPNRoFormerConfig
base_model_prefix = "model"

def _init_weights(self, module):
"""Initialize the weights"""
if isinstance(module, nn.Linear):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, RoFormerSinusoidalPositionalEmbedding):
pass
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)

def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, RoFormerEncoder):
module.gradient_checkpointing = value


class GPNRoFormerModel(GPNRoFormerPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.config = config
self.embedding = GPNEmbedding(
config.vocab_size, config.n_aux_features, config.hidden_size,
)
self.encoder = RoFormerEncoder(config)

# Initialize weights and apply final processing
self.post_init()

def forward(self, input_ids=None, aux_features=None):
x = self.embedding(input_ids, aux_features=aux_features)
x = self.encoder(x)
return x


class GPNRoFormerForMaskedLM(GPNRoFormerPreTrainedModel):
def __init__(self, config):
super().__init__(config)

self.model = GPNRoFormerModel(config)
self.cls = RoFormerOnlyMLMHead(config)

# Initialize weights and apply final processing
self.post_init()

def forward(
self,
labels=None,
loss_weight=None,
**kwargs
):
hidden_state = self.model(**kwargs).last_hidden_state
logits = self.cls(hidden_state)
loss = None
if labels is not None and loss_weight is None:
loss_fct = CrossEntropyLoss()
loss = loss_fct(
logits.view(-1, self.config.vocab_size), labels.view(-1)
)
if labels is not None and loss_weight is not None:
loss_fct = CrossEntropyLoss(reduction="none")
labels = labels.view(-1)
loss = loss_fct(logits.view(-1, self.config.vocab_size), labels)
loss_weight = loss_weight.view(-1)
loss_weight[labels==-100] = 0.0
loss = (loss * loss_weight / loss_weight.sum()).sum()
return MaskedLMOutput(
loss=loss,
logits=logits,
)


AutoConfig.register("GPNRoFormer", GPNRoFormerConfig)
AutoModel.register(GPNRoFormerConfig, GPNRoFormerModel)
AutoModelForMaskedLM.register(GPNRoFormerConfig, GPNRoFormerForMaskedLM)
20 changes: 20 additions & 0 deletions gpn/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,26 @@ def forward(self, x):
return F.one_hot(x, num_classes=self.hidden_size).float()


class GPNEmbedding(nn.Module):
def __init__(
self,
vocab_size=None,
n_aux_features=None,
hidden_size=None,
):
super().__init__()
assert vocab_size + n_aux_features <= hidden_size
self.vocab_size = vocab_size
self.n_aux_features = n_aux_features
self.hidden_size = hidden_size

def forward(self, input_ids, aux_features=None):
res = F.one_hot(input_ids, num_classes=self.hidden_size).float()
if aux_features is not None:
res[:, :, self.vocab_size:self.vocab_size+self.n_aux_features] = aux_features
return res


def get_dilation_schedule(config):
return [
min(config.dilation_max, 2**((i%config.dilation_cycle)//config.dilation_double_every))
Expand Down

0 comments on commit f4c86bd

Please sign in to comment.