Skip to content

Commit

Permalink
print
Browse files Browse the repository at this point in the history
  • Loading branch information
alicebizeul committed Apr 13, 2023
1 parent 95ed917 commit b88d7aa
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions simclr/modules/custom_infonce.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,15 @@ class Custom_InfoNCE(nn.Module):
def __init__(self, batch_size, bound, simclr_compatibility,subsample):
super(Custom_InfoNCE, self).__init__()
self.criterion = nn.CrossEntropyLoss(reduction="sum")
self.similarity_f = lambda x1, x2: custom_similarity(x1.unsqueeze(-1),torch.transpose(x2,0,1).unsqueeze(1),bound,subsample)
self.similarity_f = lambda x1, x2: custom_similarity(x1.unsqueeze(-1),torch.transpose(x2.unsqueeze(-1),0,-1),bound,subsample)
self.simclr_compatibility=simclr_compatibility
self.symetric=True
self.subsample=subsample
self.bound = bound

def forward(self, anchor_rec, positive_rec):

print("Original inputs",anchor_rec.shape,positive_rec.shape,torch.matmul(anchor_rec.unsqueeze(-1),positive_rec.unsqueeze(-2)).shape)
print("Original inputs",anchor_rec.shape,positive_rec.shape,torch.matmul(anchor_rec.unsqueeze(-1),torch.transpose(positive_rec.unsqueeze(-1),0,-1)).shape)

sim11 = self.similarity_f(anchor_rec,anchor_rec)
sim22 = self.similarity_f(positive_rec,positive_rec)
Expand Down Expand Up @@ -85,7 +85,7 @@ def custom_similarity(p_z_zrec,p_zpos_zrecpos,bound,subsample):
keep=random.shuffle(list(np.range(p_z_zrec.shape[0])))
p_z_zrec = p_z_zrec[keep,:,:]
p_zpos_zrecpos = p_zpos_zrecpos[keep,:,:]
else: return torch.log(torch.sum(torch.matmul(p_z_zrec,p_zpos_zrecpos),dim=-1)) # log because cross entropy adds an exp
else: return torch.matmul(p_z_zrec,p_zpos_zrecpos)
else: return torch.log(torch.sum(p_z_zrec*p_zpos_zrecpos,dim=-1)) # log because cross entropy adds an exp
else: return p_z_zrec*p_zpos_zrecpos


0 comments on commit b88d7aa

Please sign in to comment.