Skip to content

Commit 6e0db25

Browse files
authored
Update semi_protofeat.py
1 parent 238ffb3 commit 6e0db25

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

model/models/semi_protofeat.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ def get_proto(self, x_shot, x_pool):
9494
num_batch, num_shot, num_way, emb_dim = x_shot.shape
9595
num_pool_shot = x_pool.shape[1]
9696
num_pool = num_pool_shot * num_way
97-
label_support = torch.arange(self.args.way).repeat(self.args.shot).type(torch.LongTensor)
97+
label_support = torch.arange(num_way).repeat(num_shot).type(torch.LongTensor)
9898
label_support_onehot = one_hot(label_support, num_way)
9999
label_support_onehot = label_support_onehot.unsqueeze(0).repeat([num_batch, 1, 1])
100100
if torch.cuda.is_available():

0 commit comments

Comments
 (0)