forked from Ree1s/IDM
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit f0946e9
Showing
127 changed files
with
126,733 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
<TOC> | ||
|
||
# Implicit Diffusion Models for Continuous Super-Resolution | ||
|
||
This repository is an offical implementation of the paper "Implicit Diffusion Models for Continuous Super-Resolution" from CVPR 2023. | ||
|
||
This repository is still under development. | ||
|
||
|
||
## Environment configuration | ||
|
||
###### **The codes are based on python3.7+, CUDA version 11.0+. The specific configuration steps are as follows:** | ||
|
||
1. Create conda environment | ||
|
||
```shell | ||
conda create -n idm python=3.7.10 | ||
conda activate idm | ||
``` | ||
|
||
2. Install pytorch | ||
|
||
```shell | ||
conda install pytorch==1.11.0 torchvision==0.12.0 torchaudio==0.11.0 cudatoolkit=11.3 | ||
``` | ||
|
||
3. Installation profile | ||
|
||
```shell | ||
pip install -r requirements.txt | ||
python setup.py develop | ||
``` | ||
## Pre-trained checkpoint | ||
|
||
The pre-trained checkpoint and the val dataset of face 8X SR can be found at the following anonymous link: [link](https://1drv.ms/u/s!AraiW_uJqO8vhnlIa-8nd0PEH4Ur?e=qDfSep). Download and unzip `checkpoint_dataset.zip`. Move `checkpoint_dataset/best_psnr_gen.pth` and `checkpoint_dataset/dataset` to `./`. | ||
|
||
## Validation | ||
Run the following command for the validation: | ||
|
||
```shell | ||
sh run.sh | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
1.4.0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
# https://github.com/xinntao/BasicSR | ||
# flake8: noqa | ||
from .archs import * | ||
from .data import * | ||
from .losses import * | ||
from .metrics import * | ||
from .models import * | ||
from .ops import * | ||
from .test import * | ||
from .train import * | ||
from .utils import * | ||
from .version import __gitsha__, __version__ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
import importlib | ||
from copy import deepcopy | ||
from os import path as osp | ||
|
||
from basicsr.utils import get_root_logger, scandir | ||
from basicsr.utils.registry import ARCH_REGISTRY | ||
|
||
__all__ = ['build_network'] | ||
|
||
# automatically scan and import arch modules for registry | ||
# scan all the files under the 'archs' folder and collect files ending with | ||
# '_arch.py' | ||
arch_folder = osp.dirname(osp.abspath(__file__)) | ||
arch_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(arch_folder) if v.endswith('_arch.py')] | ||
# import all the arch modules | ||
_arch_modules = [importlib.import_module(f'basicsr.archs.{file_name}') for file_name in arch_filenames] | ||
|
||
|
||
def build_network(opt): | ||
opt = deepcopy(opt) | ||
network_type = opt.pop('type') | ||
net = ARCH_REGISTRY.get(network_type)(**opt) | ||
logger = get_root_logger() | ||
logger.info(f'Network [{net.__class__.__name__}] is created.') | ||
return net |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,227 @@ | ||
import math | ||
import torch | ||
from torch import nn as nn | ||
from torch.nn import functional as F | ||
from torch.nn import init as init | ||
from torch.nn.modules.batchnorm import _BatchNorm | ||
|
||
from basicsr.ops.dcn import ModulatedDeformConvPack, modulated_deform_conv | ||
from basicsr.utils import get_root_logger | ||
|
||
|
||
@torch.no_grad() | ||
def default_init_weights(module_list, scale=1, bias_fill=0, **kwargs): | ||
"""Initialize network weights. | ||
Args: | ||
module_list (list[nn.Module] | nn.Module): Modules to be initialized. | ||
scale (float): Scale initialized weights, especially for residual | ||
blocks. Default: 1. | ||
bias_fill (float): The value to fill bias. Default: 0 | ||
kwargs (dict): Other arguments for initialization function. | ||
""" | ||
if not isinstance(module_list, list): | ||
module_list = [module_list] | ||
for module in module_list: | ||
for m in module.modules(): | ||
if isinstance(m, nn.Conv2d): | ||
init.kaiming_normal_(m.weight, **kwargs) | ||
m.weight.data *= scale | ||
if m.bias is not None: | ||
m.bias.data.fill_(bias_fill) | ||
elif isinstance(m, nn.Linear): | ||
init.kaiming_normal_(m.weight, **kwargs) | ||
m.weight.data *= scale | ||
if m.bias is not None: | ||
m.bias.data.fill_(bias_fill) | ||
elif isinstance(m, _BatchNorm): | ||
init.constant_(m.weight, 1) | ||
if m.bias is not None: | ||
m.bias.data.fill_(bias_fill) | ||
|
||
|
||
def make_layer(basic_block, num_basic_block, **kwarg): | ||
"""Make layers by stacking the same blocks. | ||
Args: | ||
basic_block (nn.module): nn.module class for basic block. | ||
num_basic_block (int): number of blocks. | ||
Returns: | ||
nn.Sequential: Stacked blocks in nn.Sequential. | ||
""" | ||
layers = [] | ||
for _ in range(num_basic_block): | ||
layers.append(basic_block(**kwarg)) | ||
return nn.Sequential(*layers) | ||
|
||
|
||
class ResidualBlockNoBN(nn.Module): | ||
"""Residual block without BN. | ||
It has a style of: | ||
---Conv-ReLU-Conv-+- | ||
|________________| | ||
Args: | ||
num_feat (int): Channel number of intermediate features. | ||
Default: 64. | ||
res_scale (float): Residual scale. Default: 1. | ||
pytorch_init (bool): If set to True, use pytorch default init, | ||
otherwise, use default_init_weights. Default: False. | ||
""" | ||
|
||
def __init__(self, num_feat=64, res_scale=1, pytorch_init=False): | ||
super(ResidualBlockNoBN, self).__init__() | ||
self.res_scale = res_scale | ||
self.conv1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True) | ||
self.conv2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=True) | ||
self.relu = nn.ReLU(inplace=True) | ||
|
||
if not pytorch_init: | ||
default_init_weights([self.conv1, self.conv2], 0.1) | ||
|
||
def forward(self, x): | ||
identity = x | ||
out = self.conv2(self.relu(self.conv1(x))) | ||
return identity + out * self.res_scale | ||
|
||
|
||
class Upsample(nn.Sequential): | ||
"""Upsample module. | ||
Args: | ||
scale (int): Scale factor. Supported scales: 2^n and 3. | ||
num_feat (int): Channel number of intermediate features. | ||
""" | ||
|
||
def __init__(self, scale, num_feat): | ||
m = [] | ||
if (scale & (scale - 1)) == 0: # scale = 2^n | ||
for _ in range(int(math.log(scale, 2))): | ||
m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1)) | ||
m.append(nn.PixelShuffle(2)) | ||
elif scale == 3: | ||
m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1)) | ||
m.append(nn.PixelShuffle(3)) | ||
else: | ||
raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.') | ||
super(Upsample, self).__init__(*m) | ||
|
||
|
||
def flow_warp(x, flow, interp_mode='bilinear', padding_mode='zeros', align_corners=True): | ||
"""Warp an image or feature map with optical flow. | ||
Args: | ||
x (Tensor): Tensor with size (n, c, h, w). | ||
flow (Tensor): Tensor with size (n, h, w, 2), normal value. | ||
interp_mode (str): 'nearest' or 'bilinear'. Default: 'bilinear'. | ||
padding_mode (str): 'zeros' or 'border' or 'reflection'. | ||
Default: 'zeros'. | ||
align_corners (bool): Before pytorch 1.3, the default value is | ||
align_corners=True. After pytorch 1.3, the default value is | ||
align_corners=False. Here, we use the True as default. | ||
Returns: | ||
Tensor: Warped image or feature map. | ||
""" | ||
assert x.size()[-2:] == flow.size()[1:3] | ||
_, _, h, w = x.size() | ||
# create mesh grid | ||
grid_y, grid_x = torch.meshgrid(torch.arange(0, h).type_as(x), torch.arange(0, w).type_as(x)) | ||
grid = torch.stack((grid_x, grid_y), 2).float() # W(x), H(y), 2 | ||
grid.requires_grad = False | ||
|
||
vgrid = grid + flow | ||
# scale grid to [-1,1] | ||
vgrid_x = 2.0 * vgrid[:, :, :, 0] / max(w - 1, 1) - 1.0 | ||
vgrid_y = 2.0 * vgrid[:, :, :, 1] / max(h - 1, 1) - 1.0 | ||
vgrid_scaled = torch.stack((vgrid_x, vgrid_y), dim=3) | ||
output = F.grid_sample(x, vgrid_scaled, mode=interp_mode, padding_mode=padding_mode, align_corners=align_corners) | ||
|
||
# TODO, what if align_corners=False | ||
return output | ||
|
||
|
||
def resize_flow(flow, size_type, sizes, interp_mode='bilinear', align_corners=False): | ||
"""Resize a flow according to ratio or shape. | ||
Args: | ||
flow (Tensor): Precomputed flow. shape [N, 2, H, W]. | ||
size_type (str): 'ratio' or 'shape'. | ||
sizes (list[int | float]): the ratio for resizing or the final output | ||
shape. | ||
1) The order of ratio should be [ratio_h, ratio_w]. For | ||
downsampling, the ratio should be smaller than 1.0 (i.e., ratio | ||
< 1.0). For upsampling, the ratio should be larger than 1.0 (i.e., | ||
ratio > 1.0). | ||
2) The order of output_size should be [out_h, out_w]. | ||
interp_mode (str): The mode of interpolation for resizing. | ||
Default: 'bilinear'. | ||
align_corners (bool): Whether align corners. Default: False. | ||
Returns: | ||
Tensor: Resized flow. | ||
""" | ||
_, _, flow_h, flow_w = flow.size() | ||
if size_type == 'ratio': | ||
output_h, output_w = int(flow_h * sizes[0]), int(flow_w * sizes[1]) | ||
elif size_type == 'shape': | ||
output_h, output_w = sizes[0], sizes[1] | ||
else: | ||
raise ValueError(f'Size type should be ratio or shape, but got type {size_type}.') | ||
|
||
input_flow = flow.clone() | ||
ratio_h = output_h / flow_h | ||
ratio_w = output_w / flow_w | ||
input_flow[:, 0, :, :] *= ratio_w | ||
input_flow[:, 1, :, :] *= ratio_h | ||
resized_flow = F.interpolate( | ||
input=input_flow, size=(output_h, output_w), mode=interp_mode, align_corners=align_corners) | ||
return resized_flow | ||
|
||
|
||
# TODO: may write a cpp file | ||
def pixel_unshuffle(x, scale): | ||
""" Pixel unshuffle. | ||
Args: | ||
x (Tensor): Input feature with shape (b, c, hh, hw). | ||
scale (int): Downsample ratio. | ||
Returns: | ||
Tensor: the pixel unshuffled feature. | ||
""" | ||
b, c, hh, hw = x.size() | ||
out_channel = c * (scale**2) | ||
assert hh % scale == 0 and hw % scale == 0 | ||
h = hh // scale | ||
w = hw // scale | ||
x_view = x.view(b, c, h, scale, w, scale) | ||
return x_view.permute(0, 1, 3, 5, 2, 4).reshape(b, out_channel, h, w) | ||
|
||
|
||
class DCNv2Pack(ModulatedDeformConvPack): | ||
"""Modulated deformable conv for deformable alignment. | ||
Different from the official DCNv2Pack, which generates offsets and masks | ||
from the preceding features, this DCNv2Pack takes another different | ||
features to generate offsets and masks. | ||
Ref: | ||
Delving Deep into Deformable Alignment in Video Super-Resolution. | ||
""" | ||
|
||
def forward(self, x, feat): | ||
out = self.conv_offset(feat) | ||
o1, o2, mask = torch.chunk(out, 3, dim=1) | ||
offset = torch.cat((o1, o2), dim=1) | ||
mask = torch.sigmoid(mask) | ||
|
||
offset_absmean = torch.mean(torch.abs(offset)) | ||
if offset_absmean > 50: | ||
logger = get_root_logger() | ||
logger.warning(f'Offset abs mean is {offset_absmean}, larger than 50.') | ||
|
||
return modulated_deform_conv(x, offset, mask, self.weight, self.bias, self.stride, self.padding, self.dilation, | ||
self.groups, self.deformable_groups) |
Oops, something went wrong.