Skip to content

Commit

Permalink
Bug fixes.
Browse files Browse the repository at this point in the history
  • Loading branch information
itewqq committed May 19, 2020
1 parent ab5fca2 commit d66d33b
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 6 deletions.
4 changes: 2 additions & 2 deletions CUB_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ def __init__(self, root, split='train', transform=None):
# std = 255.
# means = [109.97 , 127.34 , 123.88 ]

# split_list = open(os.path.join(root, 'train_test_split.txt')).readlines()
split_list = open(os.path.join(root, 'tts2.txt')).readlines()
split_list = open(os.path.join(root, 'train_test_split.txt')).readlines()
# split_list = open(os.path.join(root, 'tts2.txt')).readlines()
self.idx2name = []
classes = open(os.path.join(root, 'classes.txt')).readlines()
self._imgpath = []
Expand Down
2 changes: 1 addition & 1 deletion GCN_CBLN.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,7 +455,7 @@ def _accuracy(net, data_loader):
(t + 1, solver.param_groups[0]['lr'],sum(epoch_loss) / len(epoch_loss), train_acc, test_acc))
logger.handlers[1].flush()

torch.save(net.state_dict(),'/data/GCN_CBLN_DiffLr.pth')
torch.save(net.state_dict(),'/data/GCN_CBLN_DiffLr.pth')



Expand Down
6 changes: 3 additions & 3 deletions GCN_STN_BLN.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,7 +372,7 @@ def main():
solver = torch.optim.SGD([
{'params': gcn_params, 'lr': 0.02},
{'params': bcn_params, 'lr': 0.02}
], lr=0.001, momentum=0.9, weight_decay=1e-8)
], lr=0.001, momentum=0.9, weight_decay=5e-4)
lrscheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
solver, mode='max', factor=0.2, patience=3, verbose=True,
threshold=1e-4)
Expand Down Expand Up @@ -430,7 +430,7 @@ def _accuracy(net, data_loader):
solver.step()


if (num_total >= cnt * 500):
if (num_total >= cnt * 2000):
cnt += 1
logger.info("Train Acc: " + str((100 * num_correct / num_total).item()) + "%" + "\n" + str(
num_correct) + " " + str(num_total) + "\n" + str(prediction) + " " + str(y.data) + "\n" + str(
Expand All @@ -449,7 +449,7 @@ def _accuracy(net, data_loader):
(t + 1, solver.param_groups[0]['lr'],sum(epoch_loss) / len(epoch_loss), train_acc, test_acc))
logger.handlers[1].flush()

torch.save(net.state_dict(),'/data/GCN_CBLN_DiffLr.pth')
torch.save(net.state_dict(),'/data/GCN_CBLN_DiffLr.pth')



Expand Down

0 comments on commit d66d33b

Please sign in to comment.