Skip to content

Commit

Permalink
support 768 model
Browse files Browse the repository at this point in the history
  • Loading branch information
IceClear committed Jun 28, 2023
1 parent b124eb8 commit 4a0bc32
Show file tree
Hide file tree
Showing 7 changed files with 295 additions and 8 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
logs/*
models/*
src/
results/
wandb/

*.DS_Store
Expand Down
5 changes: 4 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,9 @@ S-Lab, Nanyang Technological University
:star: If StableSR is helpful to your images or projects, please help star this repo. Thanks! :hugs:

### Update
- **2023.06.28**: Support training on SD-2.1-768v.
- **2023.05.22**: :whale: Improve the code to save more GPU memory, now 128 --> 512 needs 8.9G. Enable start from intermediate steps.
- **2023.05.20**: The [WebUI Demo](https://github.com/pkuliyi2015/sd-webui-stablesr) of StableSR is avaliable. Thank [Li Yi](https://github.com/pkuliyi2015) for the implementation!
- **2023.05.20**: :whale: The [**WebUI Demo**](https://github.com/pkuliyi2015/sd-webui-stablesr) [![GitHub Stars](https://img.shields.io/github/stars/pkuliyi2015/sd-webui-stablesr?style=social)](https://github.com/pkuliyi2015/sd-webui-stablesr) of StableSR is avaliable. Thank [Li Yi](https://github.com/pkuliyi2015) for the implementation!
- **2023.05.13**: Add Colab demo of StableSR. <a href="https://colab.research.google.com/drive/11SE2_oDvbYtcuHDbaLAxsKk_o3flsO1T?usp=sharing"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="google colab logo"></a>
- **2023.05.11**: Repo is released.

Expand Down Expand Up @@ -137,6 +138,8 @@ Note the min tile size is 512 and stride should be smaller than tile size. Small
python scripts/sr_val_ddpm_text_T_vqganfin_oldcanvas_tile.py --config configs/stableSRNew/v2-finetune_text_T_512.yaml --ckpt CKPT_PATH --vqgan_ckpt VQGANCKPT_PATH --init-img INPUT_PATH --outdir OUT_DIR --ddpm_steps 200 --dec_w 0.5 --colorfix_type adain
```

For test on 768 model, you need to set ```--input_size 768```. You can also adjust ```--tile_overlap```, ```--vqgantile_size``` and ```--vqgantile_stride``` accordingly.

### Citation
If our work is useful for your research, please consider citing:

Expand Down
12 changes: 10 additions & 2 deletions basicsr/data/realesrgan_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,16 @@ def __init__(self, opt):
for class_file in class_list:
self.paths.extend(sorted([str(x) for x in Path(os.path.join(opt['imagenet_path'], class_file)).glob('*.'+'JPEG')]))
if 'face_gt_path' in opt:
face_list = sorted([str(x) for x in Path(opt['face_gt_path']).glob('*.'+opt['image_type'])])
self.paths.extend(face_list[:opt['num_face']])
if isinstance(opt['face_gt_path'], str):
face_list = sorted([str(x) for x in Path(opt['face_gt_path']).glob('*.'+opt['image_type'])])
self.paths.extend(face_list[:opt['num_face']])
else:
face_list = sorted([str(x) for x in Path(opt['face_gt_path'][0]).glob('*.'+opt['image_type'])])
self.paths.extend(face_list[:opt['num_face']])
if len(opt['face_gt_path']) > 1:
for i in range(len(opt['face_gt_path'])-1):
self.paths.extend(sorted([str(x) for x in Path(opt['face_gt_path'][0]).glob('*.'+opt['image_type'])])[:opt['num_face']])

# limit number of pictures for test
if 'num_pic' in opt:
if 'val' or 'test' in opt:
Expand Down
247 changes: 247 additions & 0 deletions configs/stableSRNew/v2-finetune_text_T_768v.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,247 @@
sf: 4
model:
base_learning_rate: 5.0e-05
target: ldm.models.diffusion.ddpm.LatentDiffusionSRTextWT
params:
parameterization: "v"
linear_start: 0.00085
linear_end: 0.0120
num_timesteps_cond: 1
log_every_t: 200
timesteps: 1000
first_stage_key: image
cond_stage_key: caption
image_size: 768
channels: 4
cond_stage_trainable: False # Note: different from the one we trained before
conditioning_key: crossattn
monitor: val/loss_simple_ema
scale_factor: 0.18215
use_ema: False
# for training only
# ckpt_path: /mnt/lustre/jywang/code/stable_diffmodels/v2-1_768-ema-pruned.ckpt
unfrozen_diff: False
random_size: False
time_replace: 1000
use_usm: False
#P2 weighting, we do not use in final version
p2_gamma: ~
p2_k: ~
# ignore_keys: []

unet_config:
target: ldm.modules.diffusionmodules.openaimodel.UNetModelDualcondV2
params:
image_size: 32 # unused
in_channels: 4
out_channels: 4
model_channels: 320
attention_resolutions: [ 4, 2, 1 ]
num_res_blocks: 2
channel_mult: [ 1, 2, 4, 4 ]
num_head_channels: 64
use_spatial_transformer: True
use_linear_in_transformer: True
transformer_depth: 1
context_dim: 1024
use_checkpoint: False
legacy: False
semb_channels: 256

first_stage_config:
target: ldm.models.autoencoder.AutoencoderKL
params:
# for training only
# ckpt_path: /mnt/lustre/jywang/code/stable_diffmodels/v2-1_768-ema-pruned.ckpt
embed_dim: 4
monitor: val/rec_loss
ddconfig:
double_z: true
z_channels: 4
resolution: 768
in_channels: 3
out_ch: 3
ch: 128
ch_mult:
- 1
- 2
- 4
- 4
num_res_blocks: 2
attn_resolutions: []
dropout: 0.0
lossconfig:
target: torch.nn.Identity

cond_stage_config:
target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder
params:
freeze: True
layer: "penultimate"

structcond_stage_config:
target: ldm.modules.diffusionmodules.openaimodel.EncoderUNetModelWT
params:
image_size: 96
in_channels: 4
model_channels: 256
out_channels: 256
num_res_blocks: 2
attention_resolutions: [ 4, 2, 1 ]
dropout: 0
channel_mult: [ 1, 1, 2, 2 ]
conv_resample: True
dims: 2
use_checkpoint: False
use_fp16: False
num_heads: 4
num_head_channels: -1
num_heads_upsample: -1
use_scale_shift_norm: False
resblock_updown: False
use_new_attention_order: False


degradation:
# the first degradation process
resize_prob: [0.2, 0.7, 0.1] # up, down, keep
resize_range: [0.3, 1.5]
gaussian_noise_prob: 0.5
noise_range: [1, 15]
poisson_scale_range: [0.05, 2.0]
gray_noise_prob: 0.4
jpeg_range: [60, 95]

# the second degradation process
second_blur_prob: 0.5
resize_prob2: [0.3, 0.4, 0.3] # up, down, keep
resize_range2: [0.6, 1.2]
gaussian_noise_prob2: 0.5
noise_range2: [1, 12]
poisson_scale_range2: [0.05, 1.0]
gray_noise_prob2: 0.4
jpeg_range2: [60, 95]

gt_size: 768
no_degradation_prob: 0

data:
target: main.DataModuleFromConfig
params:
batch_size: 3
num_workers: 6
wrap: false
train:
target: basicsr.data.realesrgan_dataset.RealESRGANDataset
params:
queue_size: 180
gt_path: ['/mnt/lustre/share/jywang/dataset/DIV8K/train_HR/', '/mnt/lustre/share/jywang/dataset/df2k_ost/GT/']
face_gt_path: ['/mnt/lustre/share/jywang/dataset/FFHQ/1024/', '/mnt/lustre/share/jywang/dataset/FFHQ/ffhq_wild/']
num_face: 5000
crop_size: 768
io_backend:
type: disk

blur_kernel_size: 21
kernel_list: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso']
kernel_prob: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03]
sinc_prob: 0.1
blur_sigma: [0.2, 1.5]
betag_range: [0.5, 2.0]
betap_range: [1, 1.5]

blur_kernel_size2: 11
kernel_list2: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso']
kernel_prob2: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03]
sinc_prob2: 0.1
blur_sigma2: [0.2, 1.0]
betag_range2: [0.5, 2.0]
betap_range2: [1, 1.5]

final_sinc_prob: 0.8

gt_size: 768
use_hflip: True
use_rot: False
validation:
target: basicsr.data.realesrgan_dataset.RealESRGANDataset
params:
gt_path: /mnt/lustre/share/jywang/dataset/ImageSR/DIV2K/DIV2K_train_HR/
crop_size: 768
io_backend:
type: disk

blur_kernel_size: 21
kernel_list: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso']
kernel_prob: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03]
sinc_prob: 0.1
blur_sigma: [0.2, 1.5]
betag_range: [0.5, 2.0]
betap_range: [1, 1.5]

blur_kernel_size2: 11
kernel_list2: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso']
kernel_prob2: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03]
sinc_prob2: 0.1
blur_sigma2: [0.2, 1.0]
betag_range2: [0.5, 2.0]
betap_range2: [1, 1.5]

final_sinc_prob: 0.8

gt_size: 768
use_hflip: True
use_rot: False

test_data:
target: main.DataModuleFromConfig
params:
batch_size: 1
num_workers: 6
wrap: false
test:
target: basicsr.data.realesrgan_dataset.RealESRGANDataset
params:
gt_path: ['/mnt/lustre/jywang/dataset/ImageSR/Set5/HR/', '/mnt/lustre/jywang/dataset/ImageSR/Set14/HR/']
crop_size: 768
io_backend:
type: disk

blur_kernel_size: 21
kernel_list: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso']
kernel_prob: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03]
sinc_prob: 0.1
blur_sigma: [0.2, 1.5]
betag_range: [0.5, 2.0]
betap_range: [1, 1.5]

blur_kernel_size2: 11
kernel_list2: ['iso', 'aniso', 'generalized_iso', 'generalized_aniso', 'plateau_iso', 'plateau_aniso']
kernel_prob2: [0.45, 0.25, 0.12, 0.03, 0.12, 0.03]
sinc_prob2: 0.1
blur_sigma2: [0.2, 1.0]
betag_range2: [0.5, 2.0]
betap_range2: [1, 1.5]

final_sinc_prob: 0.8

gt_size: 768
use_hflip: True
use_rot: False

lightning:
modelcheckpoint:
params:
every_n_train_steps: 1000
callbacks:
image_logger:
target: main.ImageLogger
params:
batch_frequency: 1000
max_images: 2
increase_log_steps: False

trainer:
benchmark: True
max_steps: 800000
accumulate_grad_batches: 4
18 changes: 18 additions & 0 deletions ldm/models/diffusion/ddpm.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,8 @@ def p_mean_variance(self, x, t, clip_denoised: bool):
x_recon = self.predict_start_from_noise(x, t=t, noise=model_out)
elif self.parameterization == "x0":
x_recon = model_out
elif self.parameterization == "v":
x_recon = self.predict_start_from_z_and_v(x, model_out, t)
if clip_denoised:
x_recon.clamp_(-1., 1.)

Expand Down Expand Up @@ -408,6 +410,12 @@ def get_v(self, x, noise, t):
extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x.shape) * x
)

def predict_start_from_z_and_v(self, x, v, t):
return (
extract_into_tensor(self.sqrt_alphas_cumprod, t, x.shape) * x -
extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x.shape) * v
)

def get_loss(self, pred, target, mean=True):
if self.loss_type == 'l1':
loss = (target - pred).abs()
Expand All @@ -433,6 +441,8 @@ def p_losses(self, x_start, t, noise=None):
target = noise
elif self.parameterization == "x0":
target = x_start
elif self.parameterization == "v":
target = self.get_v(x_start, noise, t)
else:
raise NotImplementedError(f"Paramterization {self.parameterization} not yet supported")

Expand Down Expand Up @@ -1172,6 +1182,8 @@ def p_losses(self, x_start, cond, t, noise=None):
target = x_start
elif self.parameterization == "eps":
target = noise
elif self.parameterization == "v":
target = self.get_v(x_start, noise, t)
else:
raise NotImplementedError()

Expand Down Expand Up @@ -1211,6 +1223,8 @@ def p_mean_variance(self, x, c, t, clip_denoised: bool, return_codebook_ids=Fals
x_recon = self.predict_start_from_noise(x, t=t, noise=model_out)
elif self.parameterization == "x0":
x_recon = model_out
elif self.parameterization == "v":
x_recon = self.predict_start_from_z_and_v(x, model_out, t)
else:
raise NotImplementedError()

Expand Down Expand Up @@ -2504,6 +2518,8 @@ def p_mean_variance(self, x, c, struct_cond, t, clip_denoised: bool, return_code
x_recon = self.predict_start_from_noise(x, t=t, noise=model_out)
elif self.parameterization == "x0":
x_recon = model_out
elif self.parameterization == "v":
x_recon = self.predict_start_from_z_and_v(x, model_out, t)
else:
raise NotImplementedError()

Expand Down Expand Up @@ -2634,6 +2650,8 @@ def p_mean_variance_canvas(self, x, c, struct_cond, t, clip_denoised: bool, retu
x_recon = self.predict_start_from_noise(x, t=t[:model_out.size(0)], noise=model_out)
elif self.parameterization == "x0":
x_recon = model_out
elif self.parameterization == "v":
x_recon = self.predict_start_from_z_and_v(x, model_out, t[:model_out.size(0)])
else:
raise NotImplementedError()

Expand Down
6 changes: 5 additions & 1 deletion scripts/sr_val_ddpm_text_T_vqganfin_oldcanvas.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,7 @@ def main():
x_T = model.q_sample_respace(x_start=init_latent, t=t, sqrt_alphas_cumprod=sqrt_alphas_cumprod, sqrt_one_minus_alphas_cumprod=sqrt_one_minus_alphas_cumprod, noise=noise)
# x_T = noise

samples, _ = model.sample_canvas(cond=semantic_c, struct_cond=init_latent, batch_size=init_image.size(0), timesteps=opt.ddpm_steps, time_replace=opt.ddpm_steps, x_T=x_T, return_intermediates=True, tile_size=64, tile_overlap=opt.tile_overlap, batch_size_sample=opt.n_samples)
samples, _ = model.sample_canvas(cond=semantic_c, struct_cond=init_latent, batch_size=init_image.size(0), timesteps=opt.ddpm_steps, time_replace=opt.ddpm_steps, x_T=x_T, return_intermediates=True, tile_size=int(opt.input_size/8), tile_overlap=opt.tile_overlap, batch_size_sample=opt.n_samples)
_, enc_fea_lq = vq_model.encode(init_template)
x_samples = vq_model.decode(samples * 1. / model.scale_factor, enc_fea_lq)
if ori_size is not None:
Expand All @@ -336,6 +336,10 @@ def main():
x_sample = 255. * rearrange(x_samples[i].cpu().numpy(), 'c h w -> h w c')
Image.fromarray(x_sample.astype(np.uint8)).save(
os.path.join(outpath, basename+'.png'))
init_image = torch.clamp((init_image + 1.0) / 2.0, min=0.0, max=1.0)
init_image = 255. * rearrange(init_image[i].cpu().numpy(), 'c h w -> h w c')
Image.fromarray(init_image.astype(np.uint8)).save(
os.path.join(outpath, basename+'_lq.png'))

toc = time.time()

Expand Down
Loading

0 comments on commit 4a0bc32

Please sign in to comment.