Skip to content

Commit

Permalink
infer AUC fixed
Browse files Browse the repository at this point in the history
  • Loading branch information
canallee committed Jan 29, 2023
1 parent ea7ef30 commit e5d3ceb
Show file tree
Hide file tree
Showing 15 changed files with 80 additions and 11 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,5 @@ dist/
esm/
build/
results/
model/
model/
gmm_test/
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ python
>>> infer_pvalue(train_data, test_data, p_value=1e-5, nk_random=20,
report_metrics=True, pretrained=True)
```
This should produce following results:
This should produce similar results (depending on the version of ESM-1b weights):
```
The embedding sizes for train and test: torch.Size([241025, 128]) torch.Size([392, 128])
Calculating eval distance map, between 392 test ids and 5242 train EC cluster centers
Expand All @@ -87,7 +87,7 @@ python
>>> test_data = "new"
>>> infer_maxsep(train_data, test_data, report_metrics=True, pretrained=True)
```
This should produce following results:
This should produce similar results (depending on the version of ESM-1b weights):
```
The embedding sizes for train and test: torch.Size([241025, 128]) torch.Size([392, 128])
Calculating eval distance map, between 392 test ids and 5242 train EC cluster centers
Expand Down
2 changes: 1 addition & 1 deletion gmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

counter = 0

dist_map = pickle.load(open('./data/distance_map/uniref100_full.pkl', 'rb'))
dist_map = pickle.load(open('./data/distance_map/split100.pkl', 'rb'))
negative = mine_hard_negative(dist_map, 5)

for i in range(40):
Expand Down
4 changes: 4 additions & 0 deletions inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from src.CLEAN.infer import infer_maxsep
train_data = "split100"
test_data = "new"
infer_maxsep(train_data, test_data, report_metrics=True, pretrained=True)
Binary file added src/CLEAN/__pycache__/__init__.cpython-38.pyc
Binary file not shown.
Binary file added src/CLEAN/__pycache__/dataloader.cpython-38.pyc
Binary file not shown.
Binary file added src/CLEAN/__pycache__/distance_map.cpython-38.pyc
Binary file not shown.
Binary file added src/CLEAN/__pycache__/evaluate.cpython-38.pyc
Binary file not shown.
Binary file added src/CLEAN/__pycache__/infer.cpython-38.pyc
Binary file not shown.
Binary file added src/CLEAN/__pycache__/losses.cpython-38.pyc
Binary file not shown.
Binary file added src/CLEAN/__pycache__/model.cpython-38.pyc
Binary file not shown.
Binary file added src/CLEAN/__pycache__/uncertainty.cpython-38.pyc
Binary file not shown.
Binary file added src/CLEAN/__pycache__/utils.cpython-38.pyc
Binary file not shown.
70 changes: 65 additions & 5 deletions src/CLEAN/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,27 @@ def get_pred_labels(out_filename, pred_type="_maxsep"):
pred_label.append(preds_ec_lst)
return pred_label

def get_pred_probs(out_filename, pred_type="_maxsep"):
file_name = out_filename+pred_type
result = open(file_name+'.csv', 'r')
csvreader = csv.reader(result, delimiter=',')
pred_probs = []
for row in csvreader:
preds_ec_lst = []
preds_with_dist = row[1:]
probs = torch.zeros(len(preds_with_dist))
count = 0
for pred_ec_dist in preds_with_dist:
# get EC number 3.5.2.6 from EC:3.5.2.6/10.8359
ec_i = float(pred_ec_dist.split(":")[1].split("/")[1])
probs[count] = ec_i
#preds_ec_lst.append(probs)
count += 1
# sigmoid of the negative distances
probs = (1 - torch.exp(-1/probs)) / (1 + torch.exp(-1/probs))
probs = probs/torch.sum(probs)
pred_probs.append(probs)
return pred_probs

def get_pred_labels_prc(out_filename, cutoff, pred_type="_maxsep"):
file_name = out_filename+pred_type
Expand All @@ -232,20 +253,59 @@ def get_pred_labels_prc(out_filename, cutoff, pred_type="_maxsep"):
return pred_label


def get_eval_metrics(pred_label, true_label, all_label):
# def get_eval_metrics(pred_label, true_label, all_label):
# mlb = MultiLabelBinarizer()
# mlb.fit([list(all_label)])
# n_test = len(pred_label)
# pred_m = np.zeros((n_test, len(mlb.classes_)))
# true_m = np.zeros((n_test, len(mlb.classes_)))
# for i in range(n_test):
# pred_m[i] = mlb.transform([pred_label[i]])
# true_m[i] = mlb.transform([true_label[i]])
# pre = precision_score(true_m, pred_m, average='weighted', zero_division=0)
# rec = recall_score(true_m, pred_m, average='weighted')
# f1 = f1_score(true_m, pred_m, average='weighted')
# roc = roc_auc_score(true_m, pred_m, average='weighted')
# acc = accuracy_score(true_m, pred_m)
# return pre, rec, f1, roc, acc

def get_ec_pos_dict(mlb, true_label, pred_label):
ec_list = []
pos_list = []
for i in range(len(true_label)):
ec_list += list(mlb.inverse_transform(mlb.transform([true_label[i]]))[0])
pos_list += list(np.nonzero(mlb.transform([true_label[i]]))[1])
for i in range(len(pred_label)):
ec_list += list(mlb.inverse_transform(mlb.transform([pred_label[i]]))[0])
pos_list += list(np.nonzero(mlb.transform([pred_label[i]]))[1])
label_pos_dict = {}
for i in range(len(ec_list)):
ec, pos = ec_list[i], pos_list[i]
label_pos_dict[ec] = pos

return label_pos_dict

def get_eval_metrics(pred_label, pred_probs, true_label, all_label):
mlb = MultiLabelBinarizer()
mlb.fit([list(all_label)])
n_test = len(pred_label)
pred_m = np.zeros((n_test, len(mlb.classes_)))
true_m = np.zeros((n_test, len(mlb.classes_)))
# for including probability
pred_m_auc = np.zeros((n_test, len(mlb.classes_)))
label_pos_dict = get_ec_pos_dict(mlb, true_label, pred_label)
for i in range(n_test):
pred_m[i] = mlb.transform([pred_label[i]])
true_m[i] = mlb.transform([true_label[i]])
# fill in probabilities for prediction
labels, probs = pred_label[i], pred_probs[i]
for label, prob in zip(labels, probs):
if label in all_label:
pos = label_pos_dict[label]
pred_m_auc[i, pos] = prob
pre = precision_score(true_m, pred_m, average='weighted', zero_division=0)
rec = recall_score(true_m, pred_m, average='weighted')
f1 = f1_score(true_m, pred_m, average='weighted')
roc = roc_auc_score(true_m, pred_m, average='weighted')
roc = roc_auc_score(true_m, pred_m_auc, average='weighted')
acc = accuracy_score(true_m, pred_m)
return pre, rec, f1, roc, acc


return pre, rec, f1, roc, acc
8 changes: 6 additions & 2 deletions src/CLEAN/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,10 @@ def infer_pvalue(train_data, test_data, p_value = 1e-5, nk_random = 20,
# optionally report prediction precision/recall/...
if report_metrics:
pred_label = get_pred_labels(out_filename, pred_type='_pvalue')
pred_probs = get_pred_probs(out_filename, pred_type='_pvalue')
true_label, all_label = get_true_labels('./data/' + test_data)
pre, rec, f1, roc, __annotations__ = get_eval_metrics(pred_label, true_label, all_label)
pre, rec, f1, roc, acc = get_eval_metrics(
pred_label, pred_probs, true_label, all_label)
print(f'############ EC calling results using random '
f'chosen {nk_random}k samples ############')
print('-' * 75)
Expand Down Expand Up @@ -113,8 +115,10 @@ def infer_maxsep(train_data, test_data, report_metrics = False,
write_max_sep_choices(eval_df, out_filename)
if report_metrics:
pred_label = get_pred_labels(out_filename, pred_type='_maxsep')
pred_probs = get_pred_probs(out_filename, pred_type='_maxsep')
true_label, all_label = get_true_labels('./data/' + test_data)
pre, rec, f1, roc, _ = get_eval_metrics(pred_label, true_label, all_label)
pre, rec, f1, roc, acc = get_eval_metrics(
pred_label, pred_probs, true_label, all_label)
print("############ EC calling results using maximum separation ############")
print('-' * 75)
print(f'>>> total samples: {len(true_label)} | total ec: {len(all_label)} \n'
Expand Down

0 comments on commit e5d3ceb

Please sign in to comment.