Skip to content

Commit

Permalink
fix transae bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
flow3rdown authored Jun 26, 2024
1 parent 085d522 commit ab28122
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions M-KGE/IKRL_TransAE/TransAE.py
Original file line number Diff line number Diff line change
Expand Up @@ -986,14 +986,14 @@ def create_mappings(dataset_path: str):
transe = TransE(
ent_tot = len(entity2id),
rel_tot = len(relation2id),
dim = 400,
dim = 200,
p_norm = 1,
norm_flag = True)
else:
transe = Analogy(
ent_tot = len(entity2id),
rel_tot = len(relation2id),
dim = 400,
dim = 200,
p_norm = 1,
norm_flag = True)

Expand All @@ -1005,10 +1005,10 @@ def create_mappings(dataset_path: str):

trainer = Trainer(model=model, data_loader=train_dataloader, train_times=2000, alpha=alpha, use_gpu=True)
trainer.run()
transe.save_checkpoint('ckpt/analogy/pt_trainse.ckpt')
transe.save_checkpoint('ckpt/analogy/pt_transe.ckpt')

# test the model
transe.load_checkpoint('ckpt/analogy/pt_trainse.ckpt')
transe.load_checkpoint('ckpt/analogy/pt_transe.ckpt')
tester = Tester(model=transe, data_loader= test_dataloader, use_gpu = True)
tester.run_link_prediction(type_constrain = False)

Expand Down Expand Up @@ -1066,4 +1066,4 @@ def create_mappings(dataset_path: str):
print("hit10: ", hit10)
print("hti5: ", hti5)
print("hti3: ", hti3)
print("hit1: ", hit1)
print("hit1: ", hit1)

0 comments on commit ab28122

Please sign in to comment.