Skip to content

Commit

Permalink
Merge branch 'master' into patch-3
Browse files Browse the repository at this point in the history
  • Loading branch information
eriklindernoren authored Apr 26, 2019
2 parents e4701b5 + d01fa6b commit 22d8b45
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 3 deletions.
7 changes: 7 additions & 0 deletions implementations/esrgan/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,13 @@
std = np.array([0.229, 0.224, 0.225])


def denormalize(tensors):
""" Denormalizes image tensors using mean and std """
for c in range(3):
tensors[:, c].mul_(std[c]).add_(mean[c])
return torch.clamp(tensors, 0, 255)


class ImageDataset(Dataset):
def __init__(self, root, hr_shape):
hr_height, hr_width = hr_shape
Expand Down
5 changes: 2 additions & 3 deletions implementations/esrgan/esrgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,8 @@

if opt.epoch != 0:
# Load pretrained models
print("loading pretrain model")
generator.load_state_dict(torch.load("saved_models/generator_%d.pth"%opt.epoch))
discriminator.load_state_dict(torch.load("saved_models/discriminator_%d.pth"%opt.epoch))
generator.load_state_dict(torch.load("saved_models/generator_%d.pth" % opt.epoch))
discriminator.load_state_dict(torch.load("saved_models/discriminator_%d.pth" % opt.epoch))

# Optimizers
optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
Expand Down
1 change: 1 addition & 0 deletions implementations/esrgan/test_on_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from torch.autograd import Variable
import argparse
import os
from torchvision import transforms
from torchvision.utils import save_image
from PIL import Image

Expand Down

0 comments on commit 22d8b45

Please sign in to comment.