diff --git a/domainbed/algorithms.py b/domainbed/algorithms.py index 17a42d3b..9a8410e6 100644 --- a/domainbed/algorithms.py +++ b/domainbed/algorithms.py @@ -31,7 +31,8 @@ 'RSC', 'SD', 'ANDMask', - 'IGA' + 'IGA', + 'SelfReg' ] def get_algorithm_class(algorithm_name): @@ -1023,3 +1024,93 @@ def update(self, minibatches, unlabeled=False): return {'loss': mean_loss.item(), 'penalty': penalty_value.item()} + + + +class SelfReg(ERM): + def __init__(self, input_shape, num_classes, num_domains, hparams): + super(SelfReg, self).__init__(input_shape, num_classes, num_domains, + hparams) + self.num_classes = num_classes + self.MSEloss = nn.MSELoss() + input_feat_size = self.featurizer.n_outputs + hidden_size = input_feat_size if input_feat_size==2048 else input_feat_size*2 + + self.cdpl = nn.Sequential( + nn.Linear(input_feat_size, hidden_size), + nn.BatchNorm1d(hidden_size), + nn.ReLU(inplace=True), + nn.Linear(hidden_size, hidden_size), + nn.BatchNorm1d(hidden_size), + nn.ReLU(inplace=True), + nn.Linear(hidden_size, input_feat_size), + nn.BatchNorm1d(input_feat_size) + ) + + def update(self, minibatches, unlabeled=None): + + all_x = torch.cat([x for x, y in minibatches]) + all_y = torch.cat([y for _, y in minibatches]) + + lam = np.random.beta(0.5, 0.5) + + batch_size = all_y.size()[0] + + # cluster and order features into same-class group + with torch.no_grad(): + sorted_y, indices = torch.sort(all_y) + sorted_x = torch.zeros_like(all_x) + for idx, order in enumerate(indices): + sorted_x[idx] = all_x[order] + intervals = [] + ex = 0 + for idx, val in enumerate(sorted_y): + if ex==val: + continue + intervals.append(idx) + ex = val + intervals.append(batch_size) + + all_x = sorted_x + all_y = sorted_y + + feat = self.featurizer(all_x) + proj = self.cdpl(feat) + + output = self.classifier(feat) + + # shuffle + output_2 = torch.zeros_like(output) + feat_2 = torch.zeros_like(proj) + output_3 = torch.zeros_like(output) + feat_3 = torch.zeros_like(proj) + ex = 0 + for end in intervals: + shuffle_indices = torch.randperm(end-ex)+ex + shuffle_indices2 = torch.randperm(end-ex)+ex + for idx in range(end-ex): + output_2[idx+ex] = output[shuffle_indices[idx]] + feat_2[idx+ex] = proj[shuffle_indices[idx]] + output_3[idx+ex] = output[shuffle_indices2[idx]] + feat_3[idx+ex] = proj[shuffle_indices2[idx]] + ex = end + + # mixup + output_3 = lam*output_2 + (1-lam)*output_3 + feat_3 = lam*feat_2 + (1-lam)*feat_3 + + # regularization + L_ind_logit = self.MSEloss(output, output_2) + L_hdl_logit = self.MSEloss(output, output_3) + L_ind_feat = 0.3 * self.MSEloss(feat, feat_2) + L_hdl_feat = 0.3 * self.MSEloss(feat, feat_3) + + cl_loss = F.cross_entropy(output, all_y) + C_scale = min(cl_loss.item(), 1.) + loss = cl_loss + C_scale*(lam*(L_ind_logit + L_ind_feat)+(1-lam)*(L_hdl_logit + L_hdl_feat)) + + self.optimizer.zero_grad() + loss.backward() + self.optimizer.step() + + return {'loss': loss.item()}