Skip to content

Commit

Permalink
Add SelfReg
Browse files Browse the repository at this point in the history
  • Loading branch information
dnap512 authored Apr 25, 2021
1 parent 264c7d2 commit f852bef
Showing 1 changed file with 92 additions and 1 deletion.
93 changes: 92 additions & 1 deletion domainbed/algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@
'RSC',
'SD',
'ANDMask',
'IGA'
'IGA',
'SelfReg'
]

def get_algorithm_class(algorithm_name):
Expand Down Expand Up @@ -1023,3 +1024,93 @@ def update(self, minibatches, unlabeled=False):


return {'loss': mean_loss.item(), 'penalty': penalty_value.item()}



class SelfReg(ERM):
def __init__(self, input_shape, num_classes, num_domains, hparams):
super(SelfReg, self).__init__(input_shape, num_classes, num_domains,
hparams)
self.num_classes = num_classes
self.MSEloss = nn.MSELoss()
input_feat_size = self.featurizer.n_outputs
hidden_size = input_feat_size if input_feat_size==2048 else input_feat_size*2

self.cdpl = nn.Sequential(
nn.Linear(input_feat_size, hidden_size),
nn.BatchNorm1d(hidden_size),
nn.ReLU(inplace=True),
nn.Linear(hidden_size, hidden_size),
nn.BatchNorm1d(hidden_size),
nn.ReLU(inplace=True),
nn.Linear(hidden_size, input_feat_size),
nn.BatchNorm1d(input_feat_size)
)

def update(self, minibatches, unlabeled=None):

all_x = torch.cat([x for x, y in minibatches])
all_y = torch.cat([y for _, y in minibatches])

lam = np.random.beta(0.5, 0.5)

batch_size = all_y.size()[0]

# cluster and order features into same-class group
with torch.no_grad():
sorted_y, indices = torch.sort(all_y)
sorted_x = torch.zeros_like(all_x)
for idx, order in enumerate(indices):
sorted_x[idx] = all_x[order]
intervals = []
ex = 0
for idx, val in enumerate(sorted_y):
if ex==val:
continue
intervals.append(idx)
ex = val
intervals.append(batch_size)

all_x = sorted_x
all_y = sorted_y

feat = self.featurizer(all_x)
proj = self.cdpl(feat)

output = self.classifier(feat)

# shuffle
output_2 = torch.zeros_like(output)
feat_2 = torch.zeros_like(proj)
output_3 = torch.zeros_like(output)
feat_3 = torch.zeros_like(proj)
ex = 0
for end in intervals:
shuffle_indices = torch.randperm(end-ex)+ex
shuffle_indices2 = torch.randperm(end-ex)+ex
for idx in range(end-ex):
output_2[idx+ex] = output[shuffle_indices[idx]]
feat_2[idx+ex] = proj[shuffle_indices[idx]]
output_3[idx+ex] = output[shuffle_indices2[idx]]
feat_3[idx+ex] = proj[shuffle_indices2[idx]]
ex = end

# mixup
output_3 = lam*output_2 + (1-lam)*output_3
feat_3 = lam*feat_2 + (1-lam)*feat_3

# regularization
L_ind_logit = self.MSEloss(output, output_2)
L_hdl_logit = self.MSEloss(output, output_3)
L_ind_feat = 0.3 * self.MSEloss(feat, feat_2)
L_hdl_feat = 0.3 * self.MSEloss(feat, feat_3)

cl_loss = F.cross_entropy(output, all_y)
C_scale = min(cl_loss.item(), 1.)
loss = cl_loss + C_scale*(lam*(L_ind_logit + L_ind_feat)+(1-lam)*(L_hdl_logit + L_hdl_feat))

self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()

return {'loss': loss.item()}

0 comments on commit f852bef

Please sign in to comment.