Skip to content

Commit

Permalink
fixes here and there
Browse files Browse the repository at this point in the history
  • Loading branch information
arseny committed Jun 12, 2021
1 parent 6b231fc commit dda565b
Show file tree
Hide file tree
Showing 5 changed files with 27 additions and 15 deletions.
2 changes: 1 addition & 1 deletion aug.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def get_transforms(size: int, scope: str = 'geometric', crop='random'):
'center': albu.CenterCrop(size, size, always_apply=True)}[crop]
pad = albu.PadIfNeeded(size, size)

pipeline = albu.Compose([aug_fn, pad, crop_fn,], additional_targets={'target': 'image'})
pipeline = albu.Compose([aug_fn, pad, crop_fn], additional_targets={'target': 'image'})

def process(a, b):
r = pipeline(image=a, target=b)
Expand Down
19 changes: 12 additions & 7 deletions config/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@ project: deblur_gan
experiment_desc: fpn

train:
# files_a: &FILES_A /datasets/my_dataset/**/*.jpg
files_a: &FILES_A /home/arseny/kp_data/train2017/*.jpg
files_a: &FILES_A /datasets/my_dataset/**/*.jpg
files_b: *FILES_A
size: &SIZE 256
crop: random
Expand All @@ -13,13 +12,20 @@ train:
bounds: [0, .9]
scope: geometric
corrupt: &CORRUPT
- name: cutout
prob: 0.5
num_holes: 3
max_h_size: 25
max_w_size: 25
- name: jpeg
quality_lower: 50
quality_lower: 70
quality_upper: 90
- name: motion_blur
- name: median_blur
- name: gamma
- name: rgb_shift
- name: hsv_shift
- name: sharpen
- name: pixelize

val:
files_a: *FILES_A
Expand All @@ -35,7 +41,7 @@ val:
phase: train
warmup_num: 3
model:
g_name: fpn_dense
g_name: fpn_inception
blocks: 9
d_name: double_gan # may be no_gan, patch_gan, double_gan, multi_scale
d_layers: 3
Expand All @@ -58,5 +64,4 @@ optimizer:
scheduler:
name: linear
start_epoch: 50
min_lr: 0.0000001

min_lr: 0.0000001
2 changes: 1 addition & 1 deletion models/fpn_densenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def forward(self, x):
smoothed = nn.functional.upsample(smoothed, scale_factor=2, mode="nearest")

final = self.final(smoothed)
return nn.functional.tanh(final)
return nn.tanh(final)


class FPN(nn.Module):
Expand Down
6 changes: 3 additions & 3 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
torch==1.0.1
torch>=1.0.1
torchvision
torchsummary
pretrainedmodels
numpy
opencv-python-headless
joblib
albumentations
scikit-image
albumentations>=1.0.0
scikit-image==0.18.1
tqdm
glog
tensorboardx
Expand Down
13 changes: 10 additions & 3 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from models.models import get_model
from models.networks import get_nets
from schedulers import LinearDecay, WarmRestart
from fire import Fire

cv2.setNumThreads(0)

Expand Down Expand Up @@ -168,16 +169,22 @@ def _init_params(self):
self.scheduler_D = self._get_scheduler(self.optimizer_D)


if __name__ == '__main__':
with open('config/config.yaml', 'r') as f:
def main(config_path='config/config.yaml'):
with open(config_path, 'r') as f:
config = yaml.load(f)

batch_size = config.pop('batch_size')
get_dataloader = partial(DataLoader, batch_size=batch_size, num_workers=0 if os.environ['DEBUG'] else cpu_count(),
get_dataloader = partial(DataLoader,
batch_size=batch_size,
num_workers=0 if os.environ.get('DEBUG') else cpu_count(),
shuffle=True, drop_last=True)

datasets = map(config.pop, ('train', 'val'))
datasets = map(PairedDataset.from_config, datasets)
train, val = map(get_dataloader, datasets)
trainer = Trainer(config, train=train, val=val)
trainer.train()


if __name__ == '__main__':
Fire(main)

0 comments on commit dda565b

Please sign in to comment.