Skip to content

Commit

Permalink
feat: update CDN task.
Browse files Browse the repository at this point in the history
  • Loading branch information
Alexzhuan committed Jun 12, 2021
1 parent fb46ac4 commit 8e1c9c4
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 4 deletions.
5 changes: 3 additions & 2 deletions baselines/run_cdn.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ def main():
global_step, best_step = trainer.train()

if args.do_predict:
tokenizer = tokenizer_class.from_pretrained(args.output_dir)
tokenizer = tokenizer_class.from_pretrained(os.path.join(args.output_dir, 'cls'))
data_processor = CDNDataProcessor(root=args.data_dir, recall_k=args.recall_k,
negative_sample=args.num_neg)
test_samples, recall_orig_test_samples, recall_orig_test_samples_scores = data_processor.get_test_sample(dtype='cls')
Expand All @@ -183,13 +183,14 @@ def main():
# cls_preds = np.load(os.path.join(args.result_output_dir, 'cdn_test_preds.npy'))

test_samples = data_processor.get_test_sample(dtype='num')
orig_texts = data_processor.get_test_orig_text()
test_dataset = CDNDataset(test_samples, data_processor, dtype='num', mode='test')
model = cls_model_class.from_pretrained(os.path.join(args.output_dir, 'num'),
num_labels=data_processor.num_labels_num)
trainer = CDNForNUMTrainer(args=args, model=model, data_processor=data_processor,
tokenizer=tokenizer, logger=logger,
model_class=cls_model_class)
trainer.predict(model, test_dataset, cls_preds, recall_orig_test_samples, recall_orig_test_samples_scores)
trainer.predict(model, test_dataset, orig_texts, cls_preds, recall_orig_test_samples, recall_orig_test_samples_scores)


if __name__ == '__main__':
Expand Down
Binary file modified cblue/data/__pycache__/data_process.cpython-37.pyc
Binary file not shown.
5 changes: 5 additions & 0 deletions cblue/data/data_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,6 +411,11 @@ def get_test_sample(self, dtype='cls'):
outputs = self._get_num_samples(orig_sample=samples, is_predict=True)
return outputs

def get_test_orig_text(self):
samples = load_json(self.test_path)
texts = [sample['text'] for sample in samples]
return texts

def _pre_process(self, path, is_predict=False):
samples = load_json(path)
outputs = {'text': [], 'label': []}
Expand Down
4 changes: 2 additions & 2 deletions cblue/trainer/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -1941,7 +1941,7 @@ def evaluate(self, model):
logger.info("%s-%s f1: %s", args.task_name, args.model_name, f1)
return f1

def predict(self, model, test_dataset, cls_preds, recall_labels, recall_scores):
def predict(self, model, test_dataset, orig_texts, cls_preds, recall_labels, recall_scores):
args = self.args
logger = self.logger
test_dataloader = self.get_test_dataloader(test_dataset)
Expand Down Expand Up @@ -1978,7 +1978,7 @@ def predict(self, model, test_dataset, cls_preds, recall_labels, recall_scores):

recall_labels = np.array(recall_labels['recall_label'])
recall_scores = recall_scores
cdn_commit_prediction(test_dataset.text1, cls_preds, preds, recall_labels, recall_scores,
cdn_commit_prediction(orig_texts, cls_preds, preds, recall_labels, recall_scores,
args.result_output_dir, self.data_processor.id2label)

def _save_checkpoint(self, model, step):
Expand Down

0 comments on commit 8e1c9c4

Please sign in to comment.