Skip to content

Commit

Permalink
Update projection.py
Browse files Browse the repository at this point in the history
  • Loading branch information
mxuax authored May 4, 2023
1 parent 2c9b1b6 commit 5ee1f22
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions baseline/projection.py
Original file line number Diff line number Diff line change
Expand Up @@ -499,6 +499,7 @@ def eval_label(pred_labels,ground_truth,input, config,type = 'NN'):
help='Name of embedding model: mpnet/sent_roberta/simcse_bert/simcse_roberta/sent_t5')
# parser.add_argument('--embed_model', type=str, default='simcse_roberta', help='Name of embedding model: mpnet/sent_roberta/simcse_bert/simcse_roberta/sent_t5')
parser.add_argument('--model_type', type=str, default='NN', help='Type of baseline model: RNN or NN')
parser.add_argument('--eval', type=str, default=False, help='True or False')

args = parser.parse_args()
config = {}
Expand All @@ -509,7 +510,8 @@ def eval_label(pred_labels,ground_truth,input, config,type = 'NN'):
config['data_type'] = args.data_type
config['embed_model'] = args.embed_model
config['model_type'] = args.model_type

config['eval'] = args.eval

config['device'] = torch.device("cuda")
config['tokenizer'] = AutoTokenizer.from_pretrained('microsoft/DialoGPT-medium')
config['eos_token'] = config['tokenizer'].eos_token
Expand All @@ -523,7 +525,5 @@ def eval_label(pred_labels,ground_truth,input, config,type = 'NN'):
batch_size=config['batch_size'],
collate_fn=dataset.collate,
drop_last=True)
print("start training")
get_embedding(dataloader, config, eval=False)
print("start evaluation")
get_embedding(dataloader, config, eval=True)

get_embedding(dataloader, config, eval=config['eval'])

0 comments on commit 5ee1f22

Please sign in to comment.