Skip to content

Commit

Permalink
Celeba bug fixes.
Browse files Browse the repository at this point in the history
  • Loading branch information
yashsavani committed Jul 1, 2020
1 parent e1c6938 commit 278e702
Showing 1 changed file with 84 additions and 6 deletions.
90 changes: 84 additions & 6 deletions post_hoc_celeba.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,24 @@ def _get_results(y_true, y_pred, y_prot):
return _get_results


class Critic(nn.Module):

def __init__(self, sizein, num_deep=3, hid=32):
super().__init__()
self.fc0 = nn.Linear(sizein, hid)
self.fcs = nn.ModuleList([nn.Linear(hid, hid) for _ in range(num_deep)])
self.dropout = nn.Dropout(0.2)
self.out = nn.Linear(hid, 1)

def forward(self, t):
t = t.reshape(1, -1)
t = self.fc0(t)
for fc in self.fcs:
t = F.relu(fc(t))
t = self.dropout(t)
return self.out(t)


def main(args):

trainsize = args.trainsize
Expand All @@ -206,9 +224,9 @@ def main(args):
protected_index = descriptions.index(protected_attr)
prediction_index = descriptions.index(prediction_attr)

trainset, valset, testset, trainloader, valloader, testloader = load_celeba(trainsize=trainsize,
testsize=testsize,
num_workers=num_workers)
_, _, _, trainloader, valloader, testloader = load_celeba(trainsize=trainsize,
testsize=testsize,
num_workers=num_workers)

if print_priors:
compute_priors(testloader, protected_index, prediction_index)
Expand All @@ -218,7 +236,8 @@ def main(args):
optimizer = optim.Adam(net.parameters())
train_model(net, trainloader, valloader, criterion, optimizer, protected_index, prediction_index, epochs=epochs)

rocauc_score, best_acc, bias, obj = val_model(net, testloader, get_objective_with_best_accuracy, protected_index, prediction_index)
_, best_thresh = val_model(rand_model, valloader, get_best_accuracy, protected_index, prediction_index)
rocauc_score, best_acc, bias, obj = val_model(net, testloader, get_best_objective_results(best_thresh), protected_index, prediction_index)

print('roc auc', rocauc_score)
print('accuracy with best thresh', best_acc)
Expand Down Expand Up @@ -254,12 +273,71 @@ def main(args):
print('aod', bias.item())
print('objective', obj.item())

# base_model = copy.deepcopy(net)
# base_model.fc = nn.Linear(base_model.fc.in_features, base_model.fc.in_features)

# actor = nn.Sequential(base_model, nn.Linear(base_model.fc.in_features, 2))
# actor.to(device)
# actor_optimizer = optim.Adam(actor.parameters())
# actor_loss_fn = nn.BCEWithLogitsLoss()

# critic = Critic(net.fc.in_features)
# critic.to(device)
# critic_optimizer = optim.Adam(critic.parameters())
# critic_loss_fn = nn.MSELoss()

# for epoch in range(100):
# for param in critic.parameters():
# param.requires_grad = True
# for param in actor.parameters():
# param.requires_grad = False
# actor.eval()
# critic.train()
# for index, (inputs, labels) in enumerate(valloader):
# if index >= 300:
# break
# inputs, labels = inputs.to(device), labels.to(device)
# critic_optimizer.zero_grad()

# with torch.no_grad():
# scores = actor(inputs)

# bias = compute_bias(scores, cy_valid.numpy(), cp_valid, config['metric'])
# res = critic(actor.trunc_forward(cX_valid))
# loss = critic_loss_fn(torch.tensor([bias]), res[0])
# loss.backward()
# train_loss = loss.item()
# critic_optimizer.step()
# if step % 100 == 0:
# print(f'=======> Epoch: {(epoch, step)} Critic loss: {train_loss}')

# for param in critic.parameters():
# param.requires_grad = False
# for param in actor.parameters():
# param.requires_grad = True
# actor.train()
# critic.eval()
# for step in range(100):
# actor_optimizer.zero_grad()

# lam = config['adversarial']['lambda']

# bias = critic(actor.trunc_forward(cX_valid))
# loss = actor_loss_fn(actor(cX_valid)[:, 0], cy_valid)
# loss = lam*abs(bias) + (1-lam)*loss

# loss.backward()
# train_loss = loss.item()
# actor_optimizer.step()
# if step % 100 == 0:
# print(f'=======> Epoch: {(epoch, step)} Actor loss: {train_loss}')


if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Args for CelebA experiments')
parser.add_argument('--epochs', type=int, default=2, help='Number of epochs')
parser.add_argument('--trainsize', type=int, default=100, help='Size of training set')
parser.add_argument('--testsize', type=int, default=100, help='Size of test set')
parser.add_argument('--trainsize', type=int, default=5000, help='Size of training set')
parser.add_argument('--testsize', type=int, default=1000, help='Size of test set')
parser.add_argument('--num_workers', type=int, default=2, help='Number of worker threads')
parser.add_argument('--print_priors', type=bool, default=True, help='Compute the prior percents')
parser.add_argument('--protected_attr', type=str, default='Black', help='Protected class')
Expand Down

0 comments on commit 278e702

Please sign in to comment.