Skip to content

Commit

Permalink
fix a bug for binary classifier
Browse files Browse the repository at this point in the history
  • Loading branch information
binli123 committed Jul 8, 2021
1 parent e201e60 commit 1837b2b
Showing 1 changed file with 19 additions and 9 deletions.
28 changes: 19 additions & 9 deletions train_tcga.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,6 @@
from sklearn.datasets import load_svmlight_file
from collections import OrderedDict

import dsmil as mil

def get_bag_feats(csv_file_df, args):
if args.dataset == 'TCGA-lung-default':
feats_csv_path = 'datasets/tcga-dataset/tcga_lung_data_feats/' + csv_file_df.iloc[0].split('/')[1] + '.csv'
Expand Down Expand Up @@ -78,11 +76,18 @@ def test(test_df, milnet, criterion, optimizer, args):
test_labels = np.array(test_labels)
test_predictions = np.array(test_predictions)
auc_value, _, thresholds_optimal = multi_label_roc(test_labels, test_predictions, args.num_classes, pos_label=1)
for i in range(args.num_classes):
class_prediction_bag = test_predictions[:, i]
class_prediction_bag[class_prediction_bag>=thresholds_optimal[i]] = 1
class_prediction_bag[class_prediction_bag<thresholds_optimal[i]] = 0
test_predictions[:, i] = class_prediction_bag
if args.num_classes==1:
class_prediction_bag = test_predictions
class_prediction_bag[class_prediction_bag>=thresholds_optimal[0]] = 1
class_prediction_bag[class_prediction_bag<thresholds_optimal[0]] = 0
test_predictions = class_prediction_bag
test_labels = np.squeeze(test_labels)
else:
for i in range(args.num_classes):
class_prediction_bag = test_predictions[:, i]
class_prediction_bag[class_prediction_bag>=thresholds_optimal[i]] = 1
class_prediction_bag[class_prediction_bag<thresholds_optimal[i]] = 0
test_predictions[:, i] = class_prediction_bag
bag_score = 0
for i in range(0, len(test_df)):
bag_score = np.array_equal(test_labels[i], test_predictions[i]) + bag_score
Expand Down Expand Up @@ -122,9 +127,14 @@ def main():
parser.add_argument('--num_epochs', default=40, type=int, help='Number of total training epochs [40]')
parser.add_argument('--weight_decay', default=5e-3, type=float, help='Weight decay [5e-3]')
parser.add_argument('--dataset', default='TCGA-lung-default', type=str, help='Dataset folder name')
parser.add_argument('--split', default=0.2, type=float, help='training/validation split [0.2]')
parser.add_argument('--split', default=0.2, type=float, help='Training/Validation split [0.2]')
parser.add_argument('--model', default='dsmil', type=str, help='MIL model [dsmil]')
args = parser.parse_args()

if args.model == 'dsmil':
import dsmil as mil
elif args.model == 'abmil':
import abmil as mil

i_classifier = mil.FCLayer(in_size=args.feats_size, out_size=args.num_classes).cuda()
b_classifier = mil.BClassifier(input_size=args.feats_size, output_class=args.num_classes).cuda()
Expand Down Expand Up @@ -158,7 +168,7 @@ def main():
print('\r Epoch [%d/%d] train loss: %.4f test loss: %.4f, average score: %.4f, AUC: ' %
(epoch, args.num_epochs, train_loss_bag, test_loss_bag, avg_score) + '|'.join('class-{}>>{}'.format(*k) for k in enumerate(aucs)))
scheduler.step()
current_score = (aucs[0] + aucs[1] + avg_score + 1 - test_loss_bag)/4
current_score = (sum(aucs) + avg_score + 1 - test_loss_bag)/4
if current_score >= best_score:
best_score = current_score
save_name = os.path.join(save_path, str(run+1)+'.pth')
Expand Down

0 comments on commit 1837b2b

Please sign in to comment.