Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
misads committed Dec 23, 2020
1 parent a447578 commit e2a7e6a
Show file tree
Hide file tree
Showing 8 changed files with 30 additions and 1,191 deletions.
13 changes: 0 additions & 13 deletions .idea/AliProducts.iml

This file was deleted.

25 changes: 0 additions & 25 deletions .idea/misc.xml

This file was deleted.

8 changes: 0 additions & 8 deletions .idea/modules.xml

This file was deleted.

6 changes: 0 additions & 6 deletions .idea/vcs.xml

This file was deleted.

1,032 changes: 0 additions & 1,032 deletions .idea/workspace.xml

This file was deleted.

2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ AliProducts
② 训练模型

```bash
CUDA_VISIBLE_DEVICES=0 python train.py --tag resnest --model ResNeSt101 --epochs 20 -b 24 --lr 0.0001 # --tag用于区分每次实验,可以是任意字符串
CUDA_VISIBLE_DEVICES=0 python train.py --tag resnest --model ResNeSt101 --scheduler 2x -b 24 --lr 0.0001 # --tag用于区分每次实验,可以是任意字符串
```

  训练的中途可以在验证集上验证,添加`--val_freq 10`参数可以指定10个epoch验证一次,添加`--save_freq 10`参数可以指定10个epoch保存一次checkpoint。
Expand Down
63 changes: 7 additions & 56 deletions network/ResNeSt/Model.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,6 @@

# criterionCE = nn.CrossEntropyLoss()


def weights_init(m):
classname = m.__class__.__name__
if classname.find('Conv') != -1:
m.weight.data.normal_(0.0, 0.02)
elif classname.find('BatchNorm2d') != -1:
m.weight.data.normal_(1.0, 0.02)
m.bias.data.fill_(0)


class Model(BaseModel):
def __init__(self, opt):
super(Model, self).__init__()
Expand All @@ -48,35 +38,20 @@ def __init__(self, opt):
self.optimizer = get_optimizer(opt, self.classifier)
self.scheduler = get_scheduler(opt, self.optimizer)

# load networks
# if opt.load:
# pretrained_path = opt.load
# self.load_network(self.classifier, 'G', opt.which_epoch, pretrained_path)
# if self.training:
# self.load_network(self.discriminitor, 'D', opt.which_epoch, pretrained_path)

self.avg_meters = ExponentialMovingAverage(0.95)
self.save_dir = os.path.join(opt.checkpoint_dir, opt.tag)

# with open('datasets/class_weight.pkl', 'rb') as f:
# class_weight = pickle.load(f, encoding='bytes')
# class_weight = np.array(class_weight, dtype=np.float32)
# class_weight = torch.from_numpy(class_weight).to(opt.device)
# if opt.class_weight:
# self.criterionCE = nn.CrossEntropyLoss(weight=class_weight)
# else:
self.criterionCE = nn.CrossEntropyLoss()

def update(self, input, label):

# loss_ce = self.criterionCE(predicted, label)
# loss_ce = label_smooth_loss(predicted, label)
# loss = loss_ce
predicted = self.classifier(input)
loss_ce = label_smooth_loss(predicted, label)
loss = loss_ce
# smooth_loss = label_smooth_loss(predicted, label)
ce_loss = criterionCE(predicted, label)

loss = ce_loss

self.avg_meters.update({'CE loss(label smooth)': loss_ce.item()})
self.avg_meters.update({'CE loss': ce_loss.item()})

self.optimizer.zero_grad()
loss.backward()
Expand All @@ -88,31 +63,7 @@ def forward(self, x):
return self.classifier(x)

def load(self, ckpt_path):
load_dict = torch.load(ckpt_path, map_location=opt.device)
self.classifier.load_state_dict(load_dict['classifier'])
if opt.resume:
self.optimizer.load_state_dict(load_dict['optimizer'])
self.scheduler.load_state_dict(load_dict['scheduler'])
epoch = load_dict['epoch']
utils.color_print('Load checkpoint from %s, resume training.' % ckpt_path, 3)
else:
epoch = load_dict['epoch']
utils.color_print('Load checkpoint from %s.' % ckpt_path, 3)

return epoch
return super(Model, self).load(ckpt_path)

def save(self, which_epoch):
# self.save_network(self.classifier, 'G', which_epoch)
save_filename = f'{which_epoch}_{opt.model}.pt'
save_path = os.path.join(self.save_dir, save_filename)
save_dict = OrderedDict()
save_dict['classifier'] = self.classifier.state_dict()
# save_dict['discriminitor'] = self.discriminitor.state_dict()
save_dict['optimizer'] = self.optimizer.state_dict()
save_dict['scheduler'] = self.scheduler.state_dict()
save_dict['epoch'] = which_epoch
torch.save(save_dict, save_path)
utils.color_print(f'Save checkpoint "{save_path}".', 3)

# self.save_network(self.discriminitor, 'D', which_epoch)

super(Model, self).save(which_epoch)
72 changes: 22 additions & 50 deletions network/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,61 +15,33 @@ def __init__(self):
def forward(self, x):
pass

@abstractmethod
def load(self, ckpt_path):
pass

@abstractmethod
def save(self, which_epoch):
pass

@abstractmethod
def update(self, *args, **kwargs):
pass

# helper saving function that can be used by subclasses
def save_network(self, network, network_label, epoch_label):
save_filename = '%s_net_%s.pt' % (epoch_label, network_label)
save_path = os.path.join(self.save_dir, save_filename)
torch.save(network.state_dict(), save_path)

# helper loading function that can be used by subclasses
def load_network(self, network, network_label, epoch_label, save_dir=''):
save_filename = '%s_net_%s.pt' % (epoch_label, network_label)
if not save_dir:
save_dir = self.save_dir
save_path = os.path.join(save_dir, save_filename)
if not os.path.isfile(save_path):
color_print("Exception: Checkpoint '%s' not found" % save_path, 1)
if network_label == 'G':
raise Exception("Generator must exist!,file '%s' not found" % save_path)
def load(self, ckpt_path):
load_dict = torch.load(ckpt_path, map_location=opt.device)
self.classifier.load_state_dict(load_dict['classifier'])
if opt.resume:
self.optimizer.load_state_dict(load_dict['optimizer'])
self.scheduler.load_state_dict(load_dict['scheduler'])
epoch = load_dict['epoch']
utils.color_print('Load checkpoint from %s, resume training.' % ckpt_path, 3)
else:
# network.load_state_dict(torch.load(save_path))
try:
network.load_state_dict(torch.load(save_path, map_location=opt.device))
color_print('Load checkpoint from %s.' % save_path, 3)

except:
pretrained_dict = torch.load(save_path, map_location=opt.device)
model_dict = network.state_dict()
try:
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
network.load_state_dict(pretrained_dict)
if self.opt.verbose:
print(
'Pretrained network %s has excessive layers; Only loading layers that are used' % network_label)
except:
print('Pretrained network %s has fewer layers; The following are not initialized:' % network_label)
for k, v in pretrained_dict.items():
if v.size() == model_dict[k].size():
model_dict[k] = v
epoch = load_dict['epoch']
utils.color_print('Load checkpoint from %s.' % ckpt_path, 3)

not_initialized = set()
return epoch

for k, v in model_dict.items():
if k not in pretrained_dict or v.size() != pretrained_dict[k].size():
not_initialized.add(k.split('.')[0])

print(sorted(not_initialized))
network.load_state_dict(model_dict)
def save(self, which_epoch):
save_filename = f'{which_epoch}_{opt.model}.pt'
save_path = os.path.join(self.save_dir, save_filename)
save_dict = OrderedDict()
save_dict['classifier'] = self.classifier.state_dict()

save_dict['optimizer'] = self.optimizer.state_dict()
save_dict['scheduler'] = self.scheduler.state_dict()
save_dict['epoch'] = which_epoch
torch.save(save_dict, save_path)
utils.color_print(f'Save checkpoint "{save_path}".', 3)

0 comments on commit e2a7e6a

Please sign in to comment.