Skip to content

Commit

Permalink
feat: add 'ZEN' model.
Browse files Browse the repository at this point in the history
  • Loading branch information
Alexzhuan committed Jun 12, 2021
1 parent 8e1c9c4 commit fa28271
Show file tree
Hide file tree
Showing 8 changed files with 169 additions and 53 deletions.
43 changes: 29 additions & 14 deletions baselines/run_cdn.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,21 @@
from cblue.trainer import CDNForCLSTrainer, CDNForNUMTrainer
from cblue.utils import init_logger, seed_everything
from cblue.data import CDNDataset, CDNDataProcessor
from cblue.models import save_zen_model, ZenModel, ZenForSequenceClassification, ZenNgramDict


MODEL_CLASS = {
'bert': (BertTokenizer, BertModel),
'roberta': (BertTokenizer, BertModel),
'albert': (BertTokenizer, AlbertModel)
'albert': (BertTokenizer, AlbertModel),
'zen': (BertTokenizer, ZenModel)
}

CLS_MODEL_CLASS = {
'bert': BertForSequenceClassification,
'roberta': BertForSequenceClassification,
'albert': AlbertForSequenceClassification
'albert': AlbertForSequenceClassification,
'zen': ZenForSequenceClassification
}


Expand Down Expand Up @@ -107,8 +110,13 @@ def main():
tokenizer_class, model_class = MODEL_CLASS[args.model_type]

if args.do_train:
# logger.info('Training CLS model...')
logger.info('Training CLS model...')
tokenizer = tokenizer_class.from_pretrained(os.path.join(args.model_dir, args.model_name))

ngram_dict = None
if args.model_type == 'zen':
ngram_dict = ZenNgramDict(os.path.join(args.model_dir, args.model_name), tokenizer=tokenizer)

data_processor = CDNDataProcessor(root=args.data_dir, recall_k=args.recall_k,
negative_sample=args.num_neg)
train_samples, recall_orig_train_samples, recall_orig_train_samples_scores = data_processor.get_train_sample(dtype='cls', do_augment=args.do_aug)
Expand All @@ -122,12 +130,11 @@ def main():
model = CDNForCLSModel(model_class, encoder_path=os.path.join(args.model_dir, args.model_name),
num_labels=data_processor.num_labels_cls)
cls_model_class = CLS_MODEL_CLASS[args.model_type]
# model = cls_model_class.from_pretrained(os.path.join(args.model_dir, args.model_name),
# num_labels=data_processor.num_labels_cls)
trainer = CDNForCLSTrainer(args=args, model=model, data_processor=data_processor,
tokenizer=tokenizer, train_dataset=train_dataset, eval_dataset=eval_dataset,
logger=logger, recall_orig_eval_samples=recall_orig_eval_samples,
model_class=cls_model_class, recall_orig_eval_samples_scores=recall_orig_train_samples_scores)
model_class=cls_model_class, recall_orig_eval_samples_scores=recall_orig_train_samples_scores,
ngram_dict=ngram_dict)

global_step, best_step = trainer.train()

Expand All @@ -138,10 +145,12 @@ def main():
torch.save(model.state_dict(), os.path.join(args.output_dir, 'pytorch_model_cls.pt'))
if not os.path.exists(os.path.join(args.output_dir, 'cls')):
os.mkdir(os.path.join(args.output_dir, 'cls'))
# if args.model_type == 'zen':
# save_zen_model(os.path.join(args.output_dir, 'er'), model.encoder, tokenizer, ngram_dict, args)
# else:
model.encoder.save_pretrained(os.path.join(args.output_dir, 'cls'))

if args.model_type == 'zen':
save_zen_model(os.path.join(args.output_dir, 'cls'), model.encoder, tokenizer, ngram_dict, args)
else:
model.encoder.save_pretrained(os.path.join(args.output_dir, 'cls'))

tokenizer.save_vocabulary(save_directory=os.path.join(args.output_dir, 'cls'))
logger.info('Saving models checkpoint to %s', os.path.join(args.output_dir, 'cls'))

Expand All @@ -158,12 +167,17 @@ def main():
num_labels=data_processor.num_labels_num)
trainer = CDNForNUMTrainer(args=args, model=model, data_processor=data_processor,
tokenizer=tokenizer, train_dataset=train_dataset, eval_dataset=eval_dataset,
logger=logger, model_class=cls_model_class)
logger=logger, model_class=cls_model_class, ngram_dict=ngram_dict)

global_step, best_step = trainer.train()

if args.do_predict:
tokenizer = tokenizer_class.from_pretrained(os.path.join(args.output_dir, 'cls'))

ngram_dict = None
if args.model_type == 'zen':
ngram_dict = ZenNgramDict(os.path.join(args.model_dir, args.model_name), tokenizer=tokenizer)

data_processor = CDNDataProcessor(root=args.data_dir, recall_k=args.recall_k,
negative_sample=args.num_neg)
test_samples, recall_orig_test_samples, recall_orig_test_samples_scores = data_processor.get_test_sample(dtype='cls')
Expand All @@ -177,7 +191,7 @@ def main():
trainer = CDNForCLSTrainer(args=args, model=model, data_processor=data_processor,
tokenizer=tokenizer, logger=logger,
recall_orig_eval_samples=recall_orig_test_samples,
model_class=cls_model_class)
model_class=cls_model_class, ngram_dict=ngram_dict)
cls_preds = trainer.predict(test_dataset, model)

# cls_preds = np.load(os.path.join(args.result_output_dir, 'cdn_test_preds.npy'))
Expand All @@ -189,8 +203,9 @@ def main():
num_labels=data_processor.num_labels_num)
trainer = CDNForNUMTrainer(args=args, model=model, data_processor=data_processor,
tokenizer=tokenizer, logger=logger,
model_class=cls_model_class)
trainer.predict(model, test_dataset, orig_texts, cls_preds, recall_orig_test_samples, recall_orig_test_samples_scores)
model_class=cls_model_class, ngram_dict=ngram_dict)
trainer.predict(model, test_dataset, orig_texts, cls_preds, recall_orig_test_samples,
recall_orig_test_samples_scores)


if __name__ == '__main__':
Expand Down
Binary file modified cblue/data/__pycache__/data_process.cpython-37.pyc
Binary file not shown.
31 changes: 22 additions & 9 deletions cblue/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,19 +77,32 @@ class CDNForCLSModel(nn.Module):
def __init__(self, encoder_class, encoder_path, num_labels):
super(CDNForCLSModel, self).__init__()

self.encoder = encoder_class.from_pretrained(encoder_path, output_hidden_states=True)
self.encoder = encoder_class.from_pretrained(encoder_path)
self.num_labels = num_labels
self.dropout = nn.Dropout(self.encoder.config.hidden_dropout_prob)
self.classifier = nn.Linear(self.encoder.config.hidden_size, num_labels)

def forward(self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
labels=None,
output_hidden_states=None):
outputs = self.encoder(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask,
output_hidden_states=output_hidden_states)
def forward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
labels=None,
output_hidden_states=True,
input_ngram_ids=None,
ngram_position_matrix=None,
ngram_token_type_ids=None,
ngram_attention_mask=None
):
if isinstance(self.encoder, ZenModel):
outputs = self.encoder(input_ids=input_ids, input_ngram_ids=input_ngram_ids, ngram_position_matrix=ngram_position_matrix,
token_type_ids=token_type_ids, attention_mask=attention_mask,
ngram_attention_mask=ngram_attention_mask, ngram_token_type_ids=ngram_token_type_ids,
output_all_encoded_layers=output_hidden_states)
else:
outputs = self.encoder(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask,
output_hidden_states=output_hidden_states)

# batch, seq, hidden
last_hidden_states, first_hidden_states = outputs[0], outputs[2][0]
# batch, hidden
Expand Down
8 changes: 7 additions & 1 deletion cblue/models/zen/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,13 @@ def convert_examples_to_features_for_tokens(text, max_seq_length=128, tokenizer=

text_list = text

tokens = tokenizer.tokenize(text_list)
tokens = []
if isinstance(text_list, list):
for i, word in enumerate(text_list):
token = tokenizer.tokenize(word)
tokens.extend(token)
else:
tokens = tokenizer.tokenize(text_list)

if len(tokens) >= max_seq_length - 1:
tokens = tokens[0:(max_seq_length - 2)]
Expand Down
14 changes: 7 additions & 7 deletions cblue/models/zen/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -921,7 +921,7 @@ def __init__(self, config, output_attentions=False, keep_multihead_output=False)
self.embeddings = BertEmbeddings(config)
self.word_embeddings = BertWordEmbeddings(config)
self.encoder = ZenEncoder(config, output_attentions=output_attentions,
keep_multihead_output=keep_multihead_output)
keep_multihead_output=keep_multihead_output)
self.pooler = BertPooler(config)
self.apply(self.init_bert_weights)

Expand Down Expand Up @@ -1010,8 +1010,8 @@ def forward(self, input_ids,
if not output_all_encoded_layers:
encoded_layers = encoded_layers[-1]
if self.output_attentions:
return all_attentions, encoded_layers, pooled_output
return encoded_layers, pooled_output
return all_attentions, sequence_output, pooled_output, encoded_layers
return sequence_output, pooled_output, encoded_layers

def get_input_embeddings(self):
return self.embeddings.word_embeddings
Expand Down Expand Up @@ -1296,9 +1296,9 @@ def forward(self, input_ids, input_ngram_ids, ngram_position_matrix, ngram_atten
ngram_attention_mask=ngram_attention_mask, ngram_token_type_ids=ngram_token_type_ids,
head_mask=head_mask)
if self.output_attentions:
all_attentions, _, pooled_output = outputs
all_attentions, _, pooled_output, _ = outputs
else:
_, pooled_output = outputs
_, pooled_output, _ = outputs
pooled_output = self.dropout(pooled_output)
logits = self.classifier(pooled_output)

Expand Down Expand Up @@ -1368,9 +1368,9 @@ def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=No
ngram_attention_mask=ngram_attention_mask, ngram_token_type_ids=ngram_token_type_ids,
output_all_encoded_layers=False, head_mask=head_mask)
if self.output_attentions:
all_attentions, sequence_output, _ = outputs
all_attentions, sequence_output, _, _ = outputs
else:
sequence_output, _ = outputs
sequence_output, _, _ = outputs

sequence_output = self.dropout(sequence_output)
logits = self.classifier(sequence_output)
Expand Down
Loading

0 comments on commit fa28271

Please sign in to comment.