Skip to content

Commit

Permalink
add strict_load option
Browse files Browse the repository at this point in the history
  • Loading branch information
xinntao committed Jun 9, 2019
1 parent 1c2f6bc commit 876f282
Show file tree
Hide file tree
Showing 4 changed files with 6 additions and 5 deletions.
4 changes: 2 additions & 2 deletions codes/models/SRGAN_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,11 +232,11 @@ def load(self):
load_path_G = self.opt['path']['pretrain_model_G']
if load_path_G is not None:
logger.info('Loading pretrained model for G [{:s}] ...'.format(load_path_G))
self.load_network(load_path_G, self.netG)
self.load_network(load_path_G, self.netG, self.opt['path']['strict_load'])
load_path_D = self.opt['path']['pretrain_model_D']
if self.opt['is_train'] and load_path_D is not None:
logger.info('Loading pretrained model for D [{:s}] ...'.format(load_path_D))
self.load_network(load_path_D, self.netD)
self.load_network(load_path_D, self.netD, self.opt['path']['strict_load'])

def save(self, iter_step):
self.save_network(self.netG, 'G', iter_step)
Expand Down
4 changes: 2 additions & 2 deletions codes/models/SRRaGAN_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,11 +243,11 @@ def load(self):
load_path_G = self.opt['path']['pretrain_model_G']
if load_path_G is not None:
logger.info('Loading pretrained model for G [{:s}] ...'.format(load_path_G))
self.load_network(load_path_G, self.netG)
self.load_network(load_path_G, self.netG, self.opt['path']['strict_load'])
load_path_D = self.opt['path']['pretrain_model_D']
if self.opt['is_train'] and load_path_D is not None:
logger.info('Loading pretrained model for D [{:s}] ...'.format(load_path_D))
self.load_network(load_path_D, self.netD)
self.load_network(load_path_D, self.netD, self.opt['path']['strict_load'])

def save(self, iter_step):
self.save_network(self.netG, 'G', iter_step)
Expand Down
2 changes: 1 addition & 1 deletion codes/models/SR_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ def load(self):
load_path_G = self.opt['path']['pretrain_model_G']
if load_path_G is not None:
logger.info('Loading model for G [{:s}] ...'.format(load_path_G))
self.load_network(load_path_G, self.netG)
self.load_network(load_path_G, self.netG, self.opt['path']['strict_load'])

def save(self, iter_label):
self.save_network(self.netG, 'G', iter_label)
1 change: 1 addition & 0 deletions codes/options/train/train_SRResNet.yml
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ network_G:
#### path
path:
pretrain_model_G: ~
strict_load: true
resume_state: ~

#### training settings: learning rate scheme, loss
Expand Down

0 comments on commit 876f282

Please sign in to comment.