Skip to content

Commit

Permalink
成功的训练版本,跑的epoch数太少,学习率过大,效果不好。
Browse files Browse the repository at this point in the history
  • Loading branch information
Polytechwangchao committed Jun 21, 2022
1 parent e399901 commit 7fdac82
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 7 deletions.
12 changes: 9 additions & 3 deletions config/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,12 @@ project: deblur_gan
experiment_desc: fpn

train:
files_a: &FILES_A /datasets/my_dataset/**/*.jpg
# files_a: &FILES_A /datasets/my_dataset/**/*.jpg
#可以使用绝对路径
files_a: &FILES_A D:/deblur/goprol_large/**/*.png
files_b: *FILES_A
# files_a: &FILES_A ./dataset1/blur/*.png
# files_b: &FILES_B ./dataset1/sharp/*.png
size: &SIZE 256
crop: random
preload: &PRELOAD false
Expand All @@ -30,6 +34,8 @@ train:
val:
files_a: *FILES_A
files_b: *FILES_A
# files_a: &FILES_A
# files_b: &FILES_B
size: *SIZE
scope: geometric
crop: center
Expand All @@ -52,15 +58,15 @@ model:
norm_layer: instance
dropout: True

num_epochs: 200
num_epochs: 10
train_batches_per_epoch: 1000
val_batches_per_epoch: 100
batch_size: 1
image_size: [256, 256]

optimizer:
name: adam
lr: 0.0001
lr: 0.01
scheduler:
name: linear
start_epoch: 50
Expand Down
4 changes: 2 additions & 2 deletions predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

class Predictor:
def __init__(self, weights_path: str, model_name: str = ''):
with open('config/config.yaml') as cfg:
with open('config/config.yaml',encoding='utf-8') as cfg:
config = yaml.load(cfg, Loader=yaml.FullLoader)
model = get_generator(model_name or config['model'])
model.load_state_dict(torch.load(weights_path)['model'])
Expand Down Expand Up @@ -122,4 +122,4 @@ def sorted_glob(pattern):

if __name__ == '__main__':
# Fire(main)
main('submit/t6.jpg')
main('test_img/000027.png')
7 changes: 5 additions & 2 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,13 +171,16 @@ def _init_params(self):


def main(config_path='config/config.yaml'):
with open(config_path, 'r') as f:
with open(config_path, 'r',encoding='utf-8') as f:
config = yaml.load(f, Loader=yaml.SafeLoader)

batch_size = config.pop('batch_size')
# get_dataloader = partial(DataLoader,
# batch_size=batch_size,
# num_workers=0 if os.environ.get('DEBUG') else cpu_count(),
# shuffle=True, drop_last=True)
get_dataloader = partial(DataLoader,
batch_size=batch_size,
num_workers=0 if os.environ.get('DEBUG') else cpu_count(),
shuffle=True, drop_last=True)

datasets = map(config.pop, ('train', 'val'))
Expand Down

0 comments on commit 7fdac82

Please sign in to comment.