Skip to content

Commit

Permalink
compute avg train accuracy across batches
Browse files Browse the repository at this point in the history
  • Loading branch information
kohpangwei committed Nov 20, 2019
1 parent e80fe20 commit 7d035c9
Showing 1 changed file with 12 additions and 7 deletions.
19 changes: 12 additions & 7 deletions analysis_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def load_log(run_dir):
return tuple(dfs)


def get_robust_acc_for_epoch_across_batches(df, epoch):
def get_accs_for_epoch_across_batches(df, epoch):
n_groups = 1 + np.max([int(col.split(':')[1]) for col in df.columns if col.startswith('avg_acc_group')])

indices = df['epoch'] == epoch
Expand All @@ -96,11 +96,12 @@ def get_robust_acc_for_epoch_across_batches(df, epoch):
correct_counts[group] += np.round(
df.loc[i, f'avg_acc_group:{group}'] * df.loc[i, f'processed_data_count_group:{group}'])

for group in range(n_groups):
accs[group] = correct_counts[group] / total_counts[group]
accs = correct_counts / total_counts

robust_acc = np.min(accs)
return robust_acc
avg_acc = accs @ total_counts / np.sum(total_counts)

return avg_acc, robust_acc


def print_accs(dfs, params=None,
Expand Down Expand Up @@ -140,8 +141,12 @@ def print_accs(dfs, params=None,
if output:
print(f"{metric_str} {split:<5} acc ({epoch_print_str} {epoch_to_eval}): Not yet run")
else:
if (metric == 'robust_acc') and (split == 'train'):
acc = get_robust_acc_for_epoch_across_batches(dfs[split], epoch)
if split == 'train':
avg_acc, robust_acc = get_accs_for_epoch_across_batches(dfs[split], epoch)
if metric == 'avg_acc':
acc = avg_acc
elif metric == 'robust_acc':
acc = robust_acc
else:
idx = np.where(dfs[split]['epoch'] == epoch)[0][-1] # Take the last batch in this epoch
acc = dfs[split].loc[idx, metric]
Expand Down Expand Up @@ -202,7 +207,7 @@ def print_best_adj_accs(dfs, params, epoch_to_eval=None, print_avg=False,
if epoch_to_eval is None:
epoch = np.argmax(adj_dfs['val']['robust_acc'].values)
else:
epoch = epoch_to_eval
epoch = epoch_to_eval
robust_accs.append(adj_dfs['val'].loc[epoch,'robust_acc'])
best_adj = params['adj_list'][np.argmax(robust_accs)]
print(f'================== DRO, adj={best_adj} ================== ')
Expand Down

0 comments on commit 7d035c9

Please sign in to comment.