Skip to content

Commit

Permalink
resnet50
Browse files Browse the repository at this point in the history
  • Loading branch information
sunbing7 committed Jun 14, 2023
1 parent 47cb34b commit b12370f
Showing 1 changed file with 5 additions and 4 deletions.
9 changes: 5 additions & 4 deletions src_torch/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,8 @@ def moth():
nesterov=True)

# load data
train_loader = get_dataloader(args.dataset, True, args.data_ratio)
test_loader = get_dataloader(args.dataset, False, 0.05)
train_loader = get_dataloader(args.dataset, True, args.data_ratio, batch_size=args.batch_size)
test_loader = get_dataloader(args.dataset, False, 0.05, batch_size=args.batch_size)

# a subset for loss calculation during warmup
for idx, (x_batch, y_batch) in enumerate(train_loader):
Expand Down Expand Up @@ -420,6 +420,7 @@ def moth():

# train model
optimizer.zero_grad()

output = model(x_batch)
loss = criterion(output, y_batch.to(args.device))
loss.backward()
Expand Down Expand Up @@ -479,7 +480,7 @@ def test():
model.to(args.device)
model.eval()

test_loader = get_dataloader(args.dataset, False)
test_loader = get_dataloader(args.dataset, False, batch_size=args.batch_size)

total = 0
correct = 0
Expand Down Expand Up @@ -583,7 +584,7 @@ def validate():
model.eval()

# load data
test_loader = get_dataloader(args.dataset, False)
test_loader = get_dataloader(args.dataset, False, batch_size=args.batch_size)
num_classes = get_num(args.dataset)

for batch_idx, (inputs, targets) in enumerate(test_loader):
Expand Down

0 comments on commit b12370f

Please sign in to comment.