Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
HobbitLong committed May 31, 2020
1 parent c7b2310 commit eeb1791
Showing 1 changed file with 3 additions and 4 deletions.
7 changes: 3 additions & 4 deletions main_supcon.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,9 +189,7 @@ def train(train_loader, model, criterion, optimizer, epoch, opt):
for idx, (images, labels) in enumerate(train_loader):
data_time.update(time.time() - end)

images = torch.cat([images[0].unsqueeze(1), images[1].unsqueeze(1)],
dim=1)
images = images.view(-1, 3, 32, 32).cuda(non_blocking=True)
images = torch.cat([images[0], images[1]], dim=0)
labels = labels.cuda(non_blocking=True)
bsz = labels.shape[0]

Expand All @@ -200,7 +198,8 @@ def train(train_loader, model, criterion, optimizer, epoch, opt):

# compute loss
features = model(images)
features = features.view(bsz, 2, -1)
f1, f2 = torch.split(features, [bsz, bsz], dim=0)
features = torch.cat([f1.unsqueeze(1), f2.unsqueeze(1)], dim=1)
if opt.method == 'SupCon':
loss = criterion(features, labels)
elif opt.method == 'SimCLR':
Expand Down

0 comments on commit eeb1791

Please sign in to comment.