Skip to content

Commit

Permalink
fix bias
Browse files Browse the repository at this point in the history
  • Loading branch information
crwhite14 committed Jul 10, 2020
1 parent 72a938b commit df2a9d4
Showing 1 changed file with 17 additions and 10 deletions.
27 changes: 17 additions & 10 deletions post_hoc_celeba.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,26 +175,26 @@ def compute_priors(data, protected_index, prediction_index):
print('Prob. positive outcome given unprotected class', unprot_rate)


def compute_bias(y_pred, y_true, priv, metric):
def compute_bias(y_pred, y_true, prot, metric):
"""Compute bias on the dataset"""
def zero_if_nan(data):
"""Zero if there is a nan"""
return 0. if torch.isnan(data) else data

gtpr_priv = zero_if_nan(y_pred[priv * y_true == 1].mean())
gfpr_priv = zero_if_nan(y_pred[priv * (1-y_true) == 1].mean())
mean_priv = zero_if_nan(y_pred[priv == 1].mean())
gtpr_prot = zero_if_nan(y_pred[prot * y_true == 1].mean())
gfpr_prot = zero_if_nan(y_pred[prot * (1-y_true) == 1].mean())
mean_prot = zero_if_nan(y_pred[prot == 1].mean())

gtpr_unpriv = zero_if_nan(y_pred[(1-priv) * y_true == 1].mean())
gfpr_unpriv = zero_if_nan(y_pred[(1-priv) * (1-y_true) == 1].mean())
mean_unpriv = zero_if_nan(y_pred[(1-priv) == 1].mean())
gtpr_unprot = zero_if_nan(y_pred[(1-prot) * y_true == 1].mean())
gfpr_unprot = zero_if_nan(y_pred[(1-prot) * (1-y_true) == 1].mean())
mean_unprot = zero_if_nan(y_pred[(1-prot) == 1].mean())

if metric == "spd":
return mean_unpriv - mean_priv
return mean_prot - mean_unprot
elif metric == "aod":
return 0.5 * ((gfpr_unpriv - gfpr_priv) + (gtpr_unpriv - gtpr_priv))
return 0.5 * ((gfpr_prot - gfpr_unprot) + (gtpr_prot - gtpr_unprot))
elif metric == "eod":
return gtpr_unpriv - gtpr_priv
return gtpr_prot - gtpr_unprot


def get_objective_with_best_accuracy(y_true, y_pred, y_prot):
Expand Down Expand Up @@ -263,6 +263,13 @@ def main(config):
batch_size=config['batch_size']
)
if config['print_priors']:
print('train priors')
compute_priors(trainloader, protected_index, prediction_index)
print()
print('val priors')
compute_priors(valloader, protected_index, prediction_index)
print()
print('test priors')
compute_priors(testloader, protected_index, prediction_index)

net = get_resnet_model()
Expand Down

0 comments on commit df2a9d4

Please sign in to comment.