From 278e702d12a8b69e2b06867aeb45af88b31d22c5 Mon Sep 17 00:00:00 2001 From: Yash Savani Date: Wed, 1 Jul 2020 18:20:33 +0000 Subject: [PATCH] Celeba bug fixes. --- post_hoc_celeba.py | 90 ++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 84 insertions(+), 6 deletions(-) diff --git a/post_hoc_celeba.py b/post_hoc_celeba.py index ef0522c..7eae771 100644 --- a/post_hoc_celeba.py +++ b/post_hoc_celeba.py @@ -193,6 +193,24 @@ def _get_results(y_true, y_pred, y_prot): return _get_results +class Critic(nn.Module): + + def __init__(self, sizein, num_deep=3, hid=32): + super().__init__() + self.fc0 = nn.Linear(sizein, hid) + self.fcs = nn.ModuleList([nn.Linear(hid, hid) for _ in range(num_deep)]) + self.dropout = nn.Dropout(0.2) + self.out = nn.Linear(hid, 1) + + def forward(self, t): + t = t.reshape(1, -1) + t = self.fc0(t) + for fc in self.fcs: + t = F.relu(fc(t)) + t = self.dropout(t) + return self.out(t) + + def main(args): trainsize = args.trainsize @@ -206,9 +224,9 @@ def main(args): protected_index = descriptions.index(protected_attr) prediction_index = descriptions.index(prediction_attr) - trainset, valset, testset, trainloader, valloader, testloader = load_celeba(trainsize=trainsize, - testsize=testsize, - num_workers=num_workers) + _, _, _, trainloader, valloader, testloader = load_celeba(trainsize=trainsize, + testsize=testsize, + num_workers=num_workers) if print_priors: compute_priors(testloader, protected_index, prediction_index) @@ -218,7 +236,8 @@ def main(args): optimizer = optim.Adam(net.parameters()) train_model(net, trainloader, valloader, criterion, optimizer, protected_index, prediction_index, epochs=epochs) - rocauc_score, best_acc, bias, obj = val_model(net, testloader, get_objective_with_best_accuracy, protected_index, prediction_index) + _, best_thresh = val_model(rand_model, valloader, get_best_accuracy, protected_index, prediction_index) + rocauc_score, best_acc, bias, obj = val_model(net, testloader, get_best_objective_results(best_thresh), protected_index, prediction_index) print('roc auc', rocauc_score) print('accuracy with best thresh', best_acc) @@ -254,12 +273,71 @@ def main(args): print('aod', bias.item()) print('objective', obj.item()) + # base_model = copy.deepcopy(net) + # base_model.fc = nn.Linear(base_model.fc.in_features, base_model.fc.in_features) + + # actor = nn.Sequential(base_model, nn.Linear(base_model.fc.in_features, 2)) + # actor.to(device) + # actor_optimizer = optim.Adam(actor.parameters()) + # actor_loss_fn = nn.BCEWithLogitsLoss() + + # critic = Critic(net.fc.in_features) + # critic.to(device) + # critic_optimizer = optim.Adam(critic.parameters()) + # critic_loss_fn = nn.MSELoss() + + # for epoch in range(100): + # for param in critic.parameters(): + # param.requires_grad = True + # for param in actor.parameters(): + # param.requires_grad = False + # actor.eval() + # critic.train() + # for index, (inputs, labels) in enumerate(valloader): + # if index >= 300: + # break + # inputs, labels = inputs.to(device), labels.to(device) + # critic_optimizer.zero_grad() + + # with torch.no_grad(): + # scores = actor(inputs) + + # bias = compute_bias(scores, cy_valid.numpy(), cp_valid, config['metric']) + # res = critic(actor.trunc_forward(cX_valid)) + # loss = critic_loss_fn(torch.tensor([bias]), res[0]) + # loss.backward() + # train_loss = loss.item() + # critic_optimizer.step() + # if step % 100 == 0: + # print(f'=======> Epoch: {(epoch, step)} Critic loss: {train_loss}') + + # for param in critic.parameters(): + # param.requires_grad = False + # for param in actor.parameters(): + # param.requires_grad = True + # actor.train() + # critic.eval() + # for step in range(100): + # actor_optimizer.zero_grad() + + # lam = config['adversarial']['lambda'] + + # bias = critic(actor.trunc_forward(cX_valid)) + # loss = actor_loss_fn(actor(cX_valid)[:, 0], cy_valid) + # loss = lam*abs(bias) + (1-lam)*loss + + # loss.backward() + # train_loss = loss.item() + # actor_optimizer.step() + # if step % 100 == 0: + # print(f'=======> Epoch: {(epoch, step)} Actor loss: {train_loss}') + if __name__ == "__main__": parser = argparse.ArgumentParser(description='Args for CelebA experiments') parser.add_argument('--epochs', type=int, default=2, help='Number of epochs') - parser.add_argument('--trainsize', type=int, default=100, help='Size of training set') - parser.add_argument('--testsize', type=int, default=100, help='Size of test set') + parser.add_argument('--trainsize', type=int, default=5000, help='Size of training set') + parser.add_argument('--testsize', type=int, default=1000, help='Size of test set') parser.add_argument('--num_workers', type=int, default=2, help='Number of worker threads') parser.add_argument('--print_priors', type=bool, default=True, help='Compute the prior percents') parser.add_argument('--protected_attr', type=str, default='Black', help='Protected class')