-
Notifications
You must be signed in to change notification settings - Fork 36
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
82 changed files
with
13,203 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
*.ckpt | ||
*.pth |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,203 @@ | ||
<div align="center"> | ||
|
||
# Posterior-Mean Rectified Flow:<br />Towards Minimum MSE Photo-Realistic Image Restoration | ||
|
||
[[Paper]()] | ||
|
||
[Guy Ohayon](https://ohayonguy.github.io/), [Tomer Michaeli](https://tomer.net.technion.ac.il/), [Michael Elad](https://elad.cs.technion.ac.il/)<br /> | ||
Technion—Israel Institute of Technology | ||
|
||
</div> | ||
|
||
> Posterior-Mean Rectified Flow (PMRF) is a novel photo-realistic image restoration algorithm: It approximates the *optimal* estimator that minimizes the MSE under a constraint of perfect perceptual index, namely where the distribution of the reconstructed images is equal to that of the ground-truth ones. | ||
<div align="center"> | ||
<img src="assets/flow.png" width="2000"> | ||
</div> | ||
|
||
--- | ||
|
||
<div align="center"> | ||
|
||
[](https://github.com/ohayonguy/PMRF/blob/main/LICENSE) | ||
[](https://github.com/pytorch/pytorch) | ||
[](https://github.com/Lightning-AI/pytorch-lightning) | ||
|
||
</div> | ||
|
||
# Some results from our paper | ||
### CelebA-Test quantitative comparison | ||
|
||
Red, blue and green indicate the best, the second best and the third best scores, respectively. | ||
<img src="assets/celeba-test-table.png"/> | ||
|
||
|
||
### WIDER-Test visual comparison | ||
<img src="assets/wider.png"/> | ||
|
||
### WebPhoto-Test visual comparison | ||
<img src="assets/webphoto.png"/> | ||
|
||
# ⚙️ Installation | ||
We created a conda environment by running the following commands, exactly in the following order (these are also given in the `install.sh` file): | ||
|
||
``` | ||
conda create -n pmrf python=3.10 | ||
conda activate pmrf | ||
conda install pytorch==2.3.1 torchvision==0.18.1 torchaudio==2.3.1 pytorch-cuda=11.8 -c pytorch -c nvidia | ||
conda install lightning==2.3.3 -c conda-forge | ||
pip install opencv-python==4.10.0.84 timm==1.0.8 wandb==0.17.5 lovely-tensors==0.1.16 torch-fidelity==0.3.0 einops==0.8.0 dctorch==0.1.2 torch-ema==0.3 | ||
pip install natten==0.17.1+torch230cu118 -f https://shi-labs.com/natten/wheels | ||
pip install nvidia-cuda-nvcc-cu11 | ||
pip install basicsr==1.4.2 | ||
pip install git+https://github.com/toshas/torch-fidelity.git | ||
pip install lpips==0.1.4 | ||
pip install piq==0.8.0 | ||
``` | ||
|
||
1. Note that the package `natten` is required for the HDiT architecture used by PMRF. | ||
Make sure to replace `natten==0.17.1+torch230cu118` with the correct CUDA version installed on your system. | ||
Check out https://shi-labs.com/natten/ for the available versions. | ||
2. We installed `nvidia-cuda-nvcc-cu11` because otherwise `torch.compile` got hanging for some reason. | ||
`torch.compile` may work in your system without this package. In any case, if you wish to do so, you may simply skip | ||
this package and/or remove all the `torch.compile` lines from our code. | ||
3. Due to some issue in the `basicsr` package, you will need to modify one of the files in this package. | ||
Open `/path/to/env/pmrf/lib/python3.10/site-packages/basicsr/data/degradations.py`, where `/path/to/env` is the path | ||
where your conda installed the `pmrf` environment. | ||
Then, change the line | ||
``` | ||
from torchvision.transforms.functional_tensor import rgb_to_grayscale | ||
``` | ||
to | ||
``` | ||
from torchvision.transforms.functional import rgb_to_grayscale | ||
``` | ||
|
||
|
||
# ⬇️ Download checkpoints | ||
|
||
|
||
Our model checkpoints (from both sections 5.1 and 5.2 in the paper) can be downloaded from our [Google Drive](https://drive.google.com/drive/folders/1dfjZATcQ451uhvFH42tKnfMNHRkL6N_A?usp=sharing). Please keep the same folder structure as provided in Google Drive: | ||
|
||
``` | ||
checkpoints/ | ||
├── blind_face_restoration_pmrf.ckpt # Checkpoint of our blind face image restoration model. | ||
├── swinir_restoration512_L1.pth # Checkpoint of the SwinIR model trained by DifFace | ||
├── controlled_experiments/ # Checkpoints for the controlled experiments | ||
│ ├── colorization_gaussian_noise_025/ | ||
│ │ ├── pmrf/ | ||
│ │ │ └── epoch=999-step=273000.ckpt | ||
│ │ ├── mmse/ | ||
│ │ │ └── epoch=999-step=273000.ckpt | ||
. . . | ||
. . . | ||
. . . | ||
``` | ||
To evaluate the landmark distance (LMD in the paper) and the identity metric (Deg in the paper), you will also need to download the `resnet18_110.pth` and `alignment_WFLW_4HG.pth` checkpoints from the [Google Drive](https://drive.google.com/drive/folders/1k3RCSliF6PsujCMIdCD1hNM63EozlDIZ) of [VQFR](https://github.com/TencentARC/VQFR). Place these checkpoints in the `evaluation/metrics_ckpt/` folder. | ||
|
||
# 🌐 Download test data sets for blind face image restoration | ||
1. Download WebPhoto-Test, LFW-Test, and CelebA-Test (HQ and LQ) from https://xinntao.github.io/projects/gfpgan. | ||
2. Download WIDER-Test from https://shangchenzhou.com/projects/CodeFormer/. | ||
3. Put these data sets wherever you want in your system. | ||
|
||
|
||
|
||
# 🧑 Blind face image restoration (section 5.1 in the paper) | ||
## ⚡ Quick inference ⚡ | ||
``` | ||
python inference.py \ | ||
--lq_data_path /path/to/lq/images \ | ||
--output_dir /path/to/results/dir \ | ||
--batch_size 64 \ | ||
--num_flow_steps 25 | ||
``` | ||
You can also alter the `inference.sh` file and run it. | ||
You may alter the `--num_flow_steps` as you wish (this is the hyper-parameter `K` in our paper) | ||
|
||
## 🔬 Evaluation | ||
|
||
1. We downloaded the `resnet18_110.pth` and `alignment_WFLW_4HG.pth` checkpoints from the [Google Drive](https://drive.google.com/drive/folders/1k3RCSliF6PsujCMIdCD1hNM63EozlDIZ) of [VQFR](https://github.com/TencentARC/VQFR), and put these in the folder `evaluation/metrics_ckpt/`. | ||
To evaluate the results on CelebA-Test, run: | ||
``` | ||
cd evaluation | ||
python compute_metrics_blind.py \ | ||
--parent_ffhq_512_path /path/to/parent/of/ffhq512 \ | ||
--rec_path /path/to/celeba-512-test/restored/images \ | ||
--gt_path /path/to/celeba-512-test/ground-truth/images | ||
``` | ||
To evaluate the results on the real-world data sets, run: | ||
``` | ||
cd evaluation | ||
python compute_metrics_blind.py \ | ||
--parent_ffhq_512_path /path/to/parent/of/ffhq512 \ | ||
--rec_path /path/to/real-world/restored/images \ | ||
--mmse_rec_path /path/to/mmse/restored/images | ||
``` | ||
The `--mmse_rec_path` argument is optional, and allows you to compute IndRMSE, as an indicator of the true RMSE for real-world degraded images. | ||
Note that the MMSE reconstructions are saved automatically when you run `inference.py`, since the MMSE model | ||
is also in the PMRF checkpoint. | ||
|
||
## 💻 Training | ||
In the folder `scripts/` we provide the training scripts we used for blind face image restoration and for training | ||
the baseline models as well. If you want to run a script, you need to execute it in the root folder | ||
(where `train.py` is located). To train the model, you will need the FFHQ data set. | ||
We downloaded the original FFHQ 1024x1024 data set and down-sampled the images to size 512x512 using bi-cubic down-sampling. | ||
|
||
1. Copy the `train_pmrf.sh` file (located in `scripts/train/blind_face_restoration`) to the root folder. | ||
2. Adjust the arguments `--train_data_root` and `--val_data_root` according to the location of the training and validation data in your system. | ||
3. The SwinIR model which was trained by [DifFace](https://github.com/zsyOAOA/DifFace) is provided in the `checkpoints/` folder. We downloaded it via | ||
``` | ||
wget https://github.com/zsyOAOA/DifFace/releases/download/V1.0/swinir_restoration512_L1.pth | ||
``` | ||
4. Adjust the argument `--mmse_model_ckpt_path` to the path of the SwinIR model. | ||
5. Adjust the arguments `--num_gpus` and `--num_workers` according to your system. | ||
6. Run the script `train_pmrf.sh` to train our model. | ||
|
||
|
||
# 👩🔬 Controlled experiments (section 5.2 in the paper) | ||
We provide training and evaluation codes for the controlled experiments in our paper, where we compare PMRF with the following baseline methods: | ||
1. **Flow conditioned on Y**: A rectified flow model which is *conditioned* on the *input measurement*, and learns to flow from pure noise to the ground-truth data distribution. | ||
2. **Flow conditioned on the posterior mean predictor**: A rectified flow model which is *conditioned* on the *posterior mean prediction*, and learns to flow from pure noise to the ground-truth data distribution. | ||
3. **Flow from Y**: A rectified flow model which flows from the degraded measurement to the ground-truth data distribution. | ||
4. **Posterior mean predictor**: A model which is trained to minimize the MSE loss. | ||
|
||
## 🔬 Evaluation | ||
We provide checkpoints for quick evaluation of PMRF and all the baseline methods. | ||
1. The evaluation is conducted on CelebA-Test images of size 256x256. To acquire such images, we downloaded the CelebA-Test (HQ) images from [GFPGAN](https://xinntao.github.io/projects/gfpgan), and down-sampled them to 256x256 using bi-cubic down-sampling. | ||
2. Adjust `--test_data_root` in `test.sh` to the path of the CelebA-Test 256x256 images, and adjust `--degradation` and `--ckpt_path` to the type of degradation you wish to assess and the corresponding model checkpoint. | ||
3. Run `test.sh`. | ||
|
||
We automatically save the reconstructed outputs, the degraded measurements, as well as the samples from the source distribution (the images from which the ODE solver begins). | ||
After running `test.sh`, you may evaluate the results via : | ||
|
||
``` | ||
cd evaluation | ||
python compute_metrics_controlled_experiments.py \ | ||
--parent_ffhq_256_path /path/to/parent/of/ffhq256 \ | ||
--rec_path /path/to/restored/images \ | ||
--gt_path /path/to/celeba-256-test/ground-truth/images | ||
``` | ||
|
||
## 💻 Training | ||
|
||
* We trained our models on FFHQ 256x256. To acquire such images, with down-sampled the original FFHQ 1024x1024 images using bi-cubic down-sampling. | ||
* The training scripts of PMRF and each of these baseline models are provided in the `scripts/train/controlled_experiments/` folder. | ||
* To run each of these scripts, you need to copy it to the root folder where `train.py` is located. All you need to do is adjust the `--degradation`, `--source_noise_std`, `--train_data_root` and `--val_data_root` arguments in each script. For denoising, we used `--source_noise_std 0.025`, and for the rest of the tasks we used `--source_noise_std 0.1`. | ||
* To run the `train_pmrf.sh` and `train_posterior_conditioned_on_mmse_model.sh` scripts, you first need to train the MMSE model via `train_mmse.sh`. Then, adjust the `--mmse_model_ckpt_path` argument according to the path of the MMSE model final checkpoint. | ||
|
||
|
||
## Citation | ||
@article{ | ||
ohayon2024pmrf, | ||
title={Posterior-Mean Rectified Flow: Towards Minimum MSE Photo-Realistic Image Restoration}, | ||
author={Guy Ohayon and Tomer Michaeli and Michael Elad}, | ||
year={2024}, | ||
journal={arXiv preprint} | ||
} | ||
|
||
## License and acknowledgements | ||
This project is released under the MIT license. | ||
We borrow codes from [BasicSR](https://github.com/XPixelGroup/BasicSR), [VQFR](https://github.com/TencentARC/VQFR), [DifFace](https://github.com/zsyOAOA/DifFace), [k-diffusion](https://github.com/crowsonkb/k-diffusion), and [SwinIR](https://github.com/JingyunLiang/SwinIR). We thank the authors of such repositories for their useful implementations. | ||
|
||
## Contact | ||
If you have any questions or inquiries, please feel free to [contact me](mailto:[email protected]). |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
from arch.hourglass.image_transformer_v2 import ImageTransformerDenoiserModelV2 | ||
from arch.swinir.swinir import SwinIR |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,113 @@ | ||
"""k-diffusion transformer diffusion models, version 2. | ||
Codes adopted from https://github.com/crowsonkb/k-diffusion | ||
""" | ||
|
||
import math | ||
|
||
import torch | ||
import torch._dynamo | ||
from torch import nn | ||
|
||
from . import flags | ||
|
||
if flags.get_use_compile(): | ||
torch._dynamo.config.suppress_errors = True | ||
|
||
|
||
def rotate_half(x): | ||
x1, x2 = x[..., 0::2], x[..., 1::2] | ||
x = torch.stack((-x2, x1), dim=-1) | ||
*shape, d, r = x.shape | ||
return x.view(*shape, d * r) | ||
|
||
|
||
@flags.compile_wrap | ||
def apply_rotary_emb(freqs, t, start_index=0, scale=1.0): | ||
freqs = freqs.to(t) | ||
rot_dim = freqs.shape[-1] | ||
end_index = start_index + rot_dim | ||
assert rot_dim <= t.shape[-1], f"feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}" | ||
t_left, t, t_right = t[..., :start_index], t[..., start_index:end_index], t[..., end_index:] | ||
t = (t * freqs.cos() * scale) + (rotate_half(t) * freqs.sin() * scale) | ||
return torch.cat((t_left, t, t_right), dim=-1) | ||
|
||
|
||
def centers(start, stop, num, dtype=None, device=None): | ||
edges = torch.linspace(start, stop, num + 1, dtype=dtype, device=device) | ||
return (edges[:-1] + edges[1:]) / 2 | ||
|
||
|
||
def make_grid(h_pos, w_pos): | ||
grid = torch.stack(torch.meshgrid(h_pos, w_pos, indexing='ij'), dim=-1) | ||
h, w, d = grid.shape | ||
return grid.view(h * w, d) | ||
|
||
|
||
def bounding_box(h, w, pixel_aspect_ratio=1.0): | ||
# Adjusted dimensions | ||
w_adj = w | ||
h_adj = h * pixel_aspect_ratio | ||
|
||
# Adjusted aspect ratio | ||
ar_adj = w_adj / h_adj | ||
|
||
# Determine bounding box based on the adjusted aspect ratio | ||
y_min, y_max, x_min, x_max = -1.0, 1.0, -1.0, 1.0 | ||
if ar_adj > 1: | ||
y_min, y_max = -1 / ar_adj, 1 / ar_adj | ||
elif ar_adj < 1: | ||
x_min, x_max = -ar_adj, ar_adj | ||
|
||
return y_min, y_max, x_min, x_max | ||
|
||
|
||
def make_axial_pos(h, w, pixel_aspect_ratio=1.0, align_corners=False, dtype=None, device=None): | ||
y_min, y_max, x_min, x_max = bounding_box(h, w, pixel_aspect_ratio) | ||
if align_corners: | ||
h_pos = torch.linspace(y_min, y_max, h, dtype=dtype, device=device) | ||
w_pos = torch.linspace(x_min, x_max, w, dtype=dtype, device=device) | ||
else: | ||
h_pos = centers(y_min, y_max, h, dtype=dtype, device=device) | ||
w_pos = centers(x_min, x_max, w, dtype=dtype, device=device) | ||
return make_grid(h_pos, w_pos) | ||
|
||
|
||
def freqs_pixel(max_freq=10.0): | ||
def init(shape): | ||
freqs = torch.linspace(1.0, max_freq / 2, shape[-1]) * math.pi | ||
return freqs.log().expand(shape) | ||
return init | ||
|
||
|
||
def freqs_pixel_log(max_freq=10.0): | ||
def init(shape): | ||
log_min = math.log(math.pi) | ||
log_max = math.log(max_freq * math.pi / 2) | ||
return torch.linspace(log_min, log_max, shape[-1]).expand(shape) | ||
return init | ||
|
||
|
||
class AxialRoPE(nn.Module): | ||
def __init__(self, dim, n_heads, start_index=0, freqs_init=freqs_pixel_log(max_freq=10.0)): | ||
super().__init__() | ||
self.n_heads = n_heads | ||
self.start_index = start_index | ||
log_freqs = freqs_init((n_heads, dim // 4)) | ||
self.freqs_h = nn.Parameter(log_freqs.clone()) | ||
self.freqs_w = nn.Parameter(log_freqs.clone()) | ||
|
||
def extra_repr(self): | ||
dim = (self.freqs_h.shape[-1] + self.freqs_w.shape[-1]) * 2 | ||
return f"dim={dim}, n_heads={self.n_heads}, start_index={self.start_index}" | ||
|
||
def get_freqs(self, pos): | ||
if pos.shape[-1] != 2: | ||
raise ValueError("input shape must be (..., 2)") | ||
freqs_h = pos[..., None, None, 0] * self.freqs_h.exp() | ||
freqs_w = pos[..., None, None, 1] * self.freqs_w.exp() | ||
freqs = torch.cat((freqs_h, freqs_w), dim=-1).repeat_interleave(2, dim=-1) | ||
return freqs.transpose(-2, -3) | ||
|
||
def forward(self, x, pos): | ||
freqs = self.get_freqs(pos) | ||
return apply_rotary_emb(freqs, x, self.start_index) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
"""k-diffusion transformer diffusion models, version 2. | ||
Codes adopted from https://github.com/crowsonkb/k-diffusion | ||
""" | ||
|
||
from contextlib import contextmanager | ||
from functools import update_wrapper | ||
import os | ||
import threading | ||
|
||
import torch | ||
|
||
|
||
def get_use_compile(): | ||
return os.environ.get("K_DIFFUSION_USE_COMPILE", "1") == "1" | ||
|
||
|
||
def get_use_flash_attention_2(): | ||
return os.environ.get("K_DIFFUSION_USE_FLASH_2", "1") == "1" | ||
|
||
|
||
state = threading.local() | ||
state.checkpointing = False | ||
|
||
|
||
@contextmanager | ||
def checkpointing(enable=True): | ||
try: | ||
old_checkpointing, state.checkpointing = state.checkpointing, enable | ||
yield | ||
finally: | ||
state.checkpointing = old_checkpointing | ||
|
||
|
||
def get_checkpointing(): | ||
return getattr(state, "checkpointing", False) | ||
|
||
|
||
class compile_wrap: | ||
def __init__(self, function, *args, **kwargs): | ||
self.function = function | ||
self.args = args | ||
self.kwargs = kwargs | ||
self._compiled_function = None | ||
update_wrapper(self, function) | ||
|
||
@property | ||
def compiled_function(self): | ||
if self._compiled_function is not None: | ||
return self._compiled_function | ||
if get_use_compile(): | ||
try: | ||
self._compiled_function = torch.compile(self.function, *self.args, **self.kwargs) | ||
except RuntimeError: | ||
self._compiled_function = self.function | ||
else: | ||
self._compiled_function = self.function | ||
return self._compiled_function | ||
|
||
def __call__(self, *args, **kwargs): | ||
return self.compiled_function(*args, **kwargs) |
Oops, something went wrong.