Skip to content

Commit

Permalink
Update base_module.py
Browse files Browse the repository at this point in the history
  • Loading branch information
Frankluox authored Dec 27, 2022
1 parent 1503028 commit 76c8dc8
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion modules/base_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def shared_step(self, batch, mode):
way = getattr(self.hparams, f"{mode}_way")
logits = foward_function(batch, batch_size_per_gpu,way, shot)
label = getattr(self, f"{mode}_label")
label = torch.unsqueeze(self.label, 0).repeat(batch_size_per_gpu, 1).reshape(-1).to(logits.device)
label = torch.unsqueeze(label, 0).repeat(batch_size_per_gpu, 1).reshape(-1).to(logits.device)
logits = logits.reshape(label.size(0),-1)

loss = F.cross_entropy(logits, label)
Expand Down

0 comments on commit 76c8dc8

Please sign in to comment.