Skip to content

Commit

Permalink
Merge pull request facebookresearch#78 from Newbeeer/new
Browse files Browse the repository at this point in the history
Adding a new algorithm (TRM)
  • Loading branch information
lopezpaz authored Oct 20, 2021
2 parents d89994f + a8ab340 commit a09974e
Show file tree
Hide file tree
Showing 3 changed files with 176 additions and 1 deletion.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ The [currently available algorithms](domainbed/algorithms.py) are:
* Self-supervised Contrastive Regularization (SelfReg, [Kim et al., 2021](https://arxiv.org/abs/2104.09841))
* Smoothed-AND mask (SAND-mask, [Shahtalebi et al., 2021](https://arxiv.org/abs/2106.02266))
* Invariant Gradient Variances for Out-of-distribution Generalization (Fishr, [Rame et al., 2021](https://arxiv.org/abs/2109.02934))
* Learning Representations that Support Robust Transfer of Predictors (TRM, [Xu et al., 2021](https://arxiv.org/abs/2110.09940))

Send us a PR to add your algorithm! Our implementations use ResNet50 / ResNet18 networks ([He et al., 2015](https://arxiv.org/abs/1512.03385)) and the hyper-parameter grids [described here](domainbed/hparams_registry.py).

Expand Down
171 changes: 170 additions & 1 deletion domainbed/algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@
'SANDMask', # SAND-mask
'IGA',
'SelfReg',
"Fishr"
"Fishr",
'TRM'
]

def get_algorithm_class(algorithm_name):
Expand Down Expand Up @@ -1312,3 +1313,171 @@ def _compute_distance_grads_var(self, grads_var_per_domain):
def predict(self, x):
return self.network(x)

class TRM(Algorithm):
"""
Learning Representations that Support Robust Transfer of Predictors
<https://arxiv.org/abs/2110.09940>
"""

def __init__(self, input_shape, num_classes, num_domains, hparams):
super(TRM, self).__init__(input_shape, num_classes, num_domains,hparams)
self.register_buffer('update_count', torch.tensor([0]))
self.num_domains = num_domains
self.featurizer = networks.Featurizer(input_shape, self.hparams)
self.classifier = nn.Linear(self.featurizer.n_outputs, num_classes).cuda()
self.clist = [nn.Linear(self.featurizer.n_outputs, num_classes).cuda() for i in range(num_domains+1)]
self.olist = [torch.optim.SGD(
self.clist[i].parameters(),
lr=1e-1,
) for i in range(num_domains+1)]

self.optimizer_f = torch.optim.Adam(
self.featurizer.parameters(),
lr=self.hparams["lr"],
weight_decay=self.hparams['weight_decay']
)
self.optimizer_c = torch.optim.Adam(
self.classifier.parameters(),
lr=self.hparams["lr"],
weight_decay=self.hparams['weight_decay']
)
# initial weights
self.alpha = torch.ones((num_domains, num_domains)).cuda() - torch.eye(num_domains).cuda()

@staticmethod
def neum(v, model, batch):
def hvp(y, w, v):

# First backprop
first_grads = autograd.grad(y, w, retain_graph=True, create_graph=True, allow_unused=True)
first_grads = torch.nn.utils.parameters_to_vector(first_grads)
# Elementwise products
elemwise_products = first_grads @ v
# Second backprop
return_grads = autograd.grad(elemwise_products, w, create_graph=True)
return_grads = torch.nn.utils.parameters_to_vector(return_grads)
return return_grads

v = v.detach()
h_estimate = v
cnt = 0.
model.eval()
iter = 10
for i in range(iter):
model.weight.grad *= 0
y = model(batch[0].detach())
loss = F.cross_entropy(y, batch[1].detach())
hv = hvp(loss, model.weight, v)
v -= hv
v = v.detach()
h_estimate = v + h_estimate
h_estimate = h_estimate.detach()
# not converge
if torch.max(abs(h_estimate)) > 10:
break
cnt += 1

model.train()
return h_estimate.detach()

def update(self, minibatches):

loss_swap = 0.0
trm = 0.0

if self.update_count >= self.hparams['iters']:
# TRM
if self.hparams['class_balanced']:
# for stability when facing unbalanced labels across environments
for classifier in self.clist:
classifier.weight.data = copy.deepcopy(self.classifier.weight.data)
self.alpha /= self.alpha.sum(1, keepdim=True)

self.featurizer.train()
all_x = torch.cat([x for x, y in minibatches])
all_y = torch.cat([y for x, y in minibatches])
all_feature = self.featurizer(all_x)
# updating original network
loss = F.cross_entropy(self.classifier(all_feature), all_y)

for i in range(30):
all_logits_idx = 0
loss_erm = 0.
for j, (x, y) in enumerate(minibatches):
# j-th domain
feature = all_feature[all_logits_idx:all_logits_idx + x.shape[0]]
all_logits_idx += x.shape[0]
loss_erm += F.cross_entropy(self.clist[j](feature.detach()), y)
for opt in self.olist:
opt.zero_grad()
loss_erm.backward()
for opt in self.olist:
opt.step()

# collect (feature, y)
feature_split = list()
y_split = list()
all_logits_idx = 0
for i, (x, y) in enumerate(minibatches):
feature = all_feature[all_logits_idx:all_logits_idx + x.shape[0]]
all_logits_idx += x.shape[0]
feature_split.append(feature)
y_split.append(y)

# estimate transfer risk
for Q, (x, y) in enumerate(minibatches):
sample_list = list(range(len(minibatches)))
sample_list.remove(Q)

loss_Q = F.cross_entropy(self.clist[Q](feature_split[Q]), y_split[Q])
grad_Q = autograd.grad(loss_Q, self.clist[Q].weight, create_graph=True)
vec_grad_Q = nn.utils.parameters_to_vector(grad_Q)

loss_P = [F.cross_entropy(self.clist[Q](feature_split[i]), y_split[i])*(self.alpha[Q, i].data.detach())
if i in sample_list else 0. for i in range(len(minibatches))]
loss_P_sum = sum(loss_P)
grad_P = autograd.grad(loss_P_sum, self.clist[Q].weight, create_graph=True)
vec_grad_P = nn.utils.parameters_to_vector(grad_P).detach()
vec_grad_P = self.neum(vec_grad_P, self.clist[Q], (feature_split[Q], y_split[Q]))

loss_swap += loss_P_sum - self.hparams['cos_lambda'] * (vec_grad_P.detach() @ vec_grad_Q)

for i in sample_list:
self.alpha[Q, i] *= (self.hparams["groupdro_eta"] * loss_P[i].data).exp()

loss_swap /= len(minibatches)
trm /= len(minibatches)
else:
# ERM
self.featurizer.train()
all_x = torch.cat([x for x, y in minibatches])
all_y = torch.cat([y for x, y in minibatches])
all_feature = self.featurizer(all_x)
loss = F.cross_entropy(self.classifier(all_feature), all_y)

nll = loss.item()
self.optimizer_c.zero_grad()
self.optimizer_f.zero_grad()
if self.update_count >= self.hparams['iters']:
loss_swap = (loss + loss_swap)
else:
loss_swap = loss

loss_swap.backward()
self.optimizer_f.step()
self.optimizer_c.step()

loss_swap = loss_swap.item() - nll
self.update_count += 1

return {'nll': nll, 'trm_loss': loss_swap}

def predict(self, x):
return self.classifier(self.featurizer(x))

def train(self):
self.featurizer.train()

def eval(self):
self.featurizer.eval()

5 changes: 5 additions & 0 deletions domainbed/hparams_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,11 @@ def _hparam(name, default_val, random_val_fn):
_hparam('penalty_anneal_iters', 1500, lambda r: int(r.uniform(0., 5000.)))
_hparam('ema', 0.95, lambda r: r.uniform(0.90, 0.99))

elif algorithm == "TRM":
hparams['cos_lambda'] = (1e-4, lambda r: 10 ** r.uniform(-5, 0))
hparams['iters'] = (200, lambda r: int(10 ** r.uniform(0, 4)))
hparams['groupdro_eta'] = (1e-2, lambda r: 10 ** r.uniform(-3, -1))

# Dataset-and-algorithm-specific hparam definitions. Each block of code
# below corresponds to exactly one hparam. Avoid nested conditionals.

Expand Down

0 comments on commit a09974e

Please sign in to comment.