Skip to content

Commit

Permalink
argmin to argmax
Browse files Browse the repository at this point in the history
  • Loading branch information
Kthyeon committed May 11, 2021
1 parent f819848 commit e86ebe3
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 9 deletions.
16 changes: 8 additions & 8 deletions dividemix/Train_cifar.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,7 @@ def fit_mixture(scores, labels, p_threshold=0.5):

gmm.fit(feats_)
prob = gmm.predict_proba(feats_)
prob = prob[:,gmm.means_.argmin()]
prob = prob[:,gmm.means_.argmax()]
# weights, means, covars = g.weights_, g.means_, g.covariances_

# # boundary? QDA!
Expand All @@ -336,7 +336,7 @@ def fit_mixture(scores, labels, p_threshold=0.5):
return np.array(clean_labels, dtype=np.int64)


def fine(current_features, current_labels, fit = 'kmeans', prev_features=None, prev_labels=None, p_threshold=0.5):
def fine(current_features, current_labels, fit = 'kmeans', prev_features=None, prev_labels=None, p_threshold=0.7):
'''
prev_features, prev_labels: data from the previous round
current_features, current_labels: current round's data
Expand All @@ -358,7 +358,7 @@ def fine(current_features, current_labels, fit = 'kmeans', prev_features=None, p
clean_labels = cleansing(scores, current_labels)
elif 'gmm' in fit:
# fit a two-component GMM to the loss
clean_labels = fit_mixture(scores, current_labels)
clean_labels = fit_mixture(scores, current_labels, p_threshold)
else:
raise NotImplemented

Expand All @@ -383,14 +383,14 @@ def cleansing(scores, labels):

return np.array(clean_labels, dtype=np.int64)

def extract_cleanidx(model, loader, mode='fine-kmeans'):
def extract_cleanidx(model, loader, mode='fine-kmeans', p_threshold=0.6):
model.eval()
for params in model.parameters(): params.requires_grad = False

# get teacher_idx
if 'fine' in mode:
features, labels = get_features(model, loader)
teacher_idx = fine(current_features=features, current_labels=labels, fit = 'fine-gmm')
teacher_idx = fine(current_features=features, current_labels=labels, fit = 'fine-gmm', p_threshold=p_threshold)
else: # get teacher _idx via kmeans
teacher_idx = get_loss_list(model, loader)

Expand All @@ -415,7 +415,7 @@ def extract_cleanidx(model, loader, mode='fine-kmeans'):


if args.dataset=='cifar10':
warm_up = 10
warm_up = 0
elif args.dataset=='cifar100':
warm_up = 30

Expand Down Expand Up @@ -463,8 +463,8 @@ def extract_cleanidx(model, loader, mode='fine-kmeans'):
root_dir=args.data_path,log=stats_log,noise_file='%s/%.1f_%s.json'%(args.data_path,args.r,args.noise_mode))
all_loader = loader.run('warmup')

teacher_idx_1 = extract_cleanidx(net1, all_loader, mode=args.distill_mode)
teacher_idx_2 = extract_cleanidx(net2, all_loader, mode=args.distill_mode)
teacher_idx_1 = extract_cleanidx(net1, all_loader, mode=args.distill_mode, p_threshold=args.p_threshold)
teacher_idx_2 = extract_cleanidx(net2, all_loader, mode=args.distill_mode, p_threshold=args.p_threshold)

pred1, prob1 = None, None
pred2, prob2 = None, None
Expand Down
2 changes: 1 addition & 1 deletion dynamic_selection/selection/gmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def fit_mixture(scores, labels, p_threshold=0.5):

gmm.fit(feats_)
prob = gmm.predict_proba(feats_)
prob = prob[:,gmm.means_.argmin()]
prob = prob[:,gmm.means_.argmax()]
# weights, means, covars = g.weights_, g.means_, g.covariances_

# # boundary? QDA!
Expand Down

0 comments on commit e86ebe3

Please sign in to comment.