Skip to content

Commit

Permalink
Update DSN.py
Browse files Browse the repository at this point in the history
  • Loading branch information
Yang-Bob authored May 17, 2022
1 parent b8bdbcd commit 5bc87b9
Showing 1 changed file with 3 additions and 9 deletions.
12 changes: 3 additions & 9 deletions models/DSN.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,12 +52,7 @@ def forward(self, x, sess=0, epoch=0, Mode='train', IOF='image'):
x = self.get_feature(x)
else:
x = self.get_feature(x)

if sess > 0:
with torch.no_grad():
out1 = self.fc1(x)
else:
out1 = self.fc1(x)
out1 = self.fc1(x)
out = self._l2norm(out1, dim=1)
for i in range(sess + 1):
if i == 0:
Expand All @@ -79,8 +74,7 @@ def forward(self, x, sess=0, epoch=0, Mode='train', IOF='image'):
out_aux = out_aux * self.alpha
else:
out_aux = out_aux * self.Alpha[i]

new_node = out1 * self.gamma + out_aux
new_node = out * self.gamma + out_aux
new_node = self._l2norm(new_node, dim=1)
output = torch.cat([output, F.linear(F.normalize(new_node, p=2, dim=-1), F.normalize(fc.weight, p=2, dim=-1))], dim=1) # +out_aux

Expand All @@ -102,7 +96,7 @@ def get_loss(self, pred, label, output_old=None, logits=None, compression=True):
loss_dis = self.distillation_loss(pred, output_old)
if self.sess > 0 and compression:
R1 = torch.sum(nn.ReLU()(torch.norm(self.alpha, p=1, dim=0) / self.node - self.r))
return loss_bce_seg + loss_dis + 0.1 * R1
return loss_bce_seg + 1.0*loss_dis + 0.1 * R1

def distillation_loss(self, pred_N, pred_O, T=0.5):
if pred_N.shape[1] != pred_O.shape[1]:
Expand Down

0 comments on commit 5bc87b9

Please sign in to comment.