Skip to content

Commit

Permalink
Commit with all codes
Browse files Browse the repository at this point in the history
  • Loading branch information
ohayonguy committed Oct 1, 2024
1 parent 718e0e5 commit dc2ec0d
Show file tree
Hide file tree
Showing 82 changed files with 13,203 additions and 0 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
*.ckpt
*.pth
203 changes: 203 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,203 @@
<div align="center">

# Posterior-Mean Rectified Flow:<br />Towards Minimum MSE Photo-Realistic Image Restoration

[[Paper]()]

[Guy Ohayon](https://ohayonguy.github.io/), [Tomer Michaeli](https://tomer.net.technion.ac.il/), [Michael Elad](https://elad.cs.technion.ac.il/)<br />
Technion—Israel Institute of Technology

</div>

> Posterior-Mean Rectified Flow (PMRF) is a novel photo-realistic image restoration algorithm: It approximates the *optimal* estimator that minimizes the MSE under a constraint of perfect perceptual index, namely where the distribution of the reconstructed images is equal to that of the ground-truth ones.
<div align="center">
<img src="assets/flow.png" width="2000">
</div>

---

<div align="center">

[![license](https://img.shields.io/badge/License-MIT-red.svg)](https://github.com/ohayonguy/PMRF/blob/main/LICENSE)
[![torch](https://img.shields.io/badge/PyTorch-2.3.1-DE3412)](https://github.com/pytorch/pytorch)
[![lightning](https://img.shields.io/badge/Lightning-2.3.3-8A2BE2)](https://github.com/Lightning-AI/pytorch-lightning)

</div>

# Some results from our paper
### CelebA-Test quantitative comparison

Red, blue and green indicate the best, the second best and the third best scores, respectively.
<img src="assets/celeba-test-table.png"/>


### WIDER-Test visual comparison
<img src="assets/wider.png"/>

### WebPhoto-Test visual comparison
<img src="assets/webphoto.png"/>

# ⚙️ Installation
We created a conda environment by running the following commands, exactly in the following order (these are also given in the `install.sh` file):

```
conda create -n pmrf python=3.10
conda activate pmrf
conda install pytorch==2.3.1 torchvision==0.18.1 torchaudio==2.3.1 pytorch-cuda=11.8 -c pytorch -c nvidia
conda install lightning==2.3.3 -c conda-forge
pip install opencv-python==4.10.0.84 timm==1.0.8 wandb==0.17.5 lovely-tensors==0.1.16 torch-fidelity==0.3.0 einops==0.8.0 dctorch==0.1.2 torch-ema==0.3
pip install natten==0.17.1+torch230cu118 -f https://shi-labs.com/natten/wheels
pip install nvidia-cuda-nvcc-cu11
pip install basicsr==1.4.2
pip install git+https://github.com/toshas/torch-fidelity.git
pip install lpips==0.1.4
pip install piq==0.8.0
```

1. Note that the package `natten` is required for the HDiT architecture used by PMRF.
Make sure to replace `natten==0.17.1+torch230cu118` with the correct CUDA version installed on your system.
Check out https://shi-labs.com/natten/ for the available versions.
2. We installed `nvidia-cuda-nvcc-cu11` because otherwise `torch.compile` got hanging for some reason.
`torch.compile` may work in your system without this package. In any case, if you wish to do so, you may simply skip
this package and/or remove all the `torch.compile` lines from our code.
3. Due to some issue in the `basicsr` package, you will need to modify one of the files in this package.
Open `/path/to/env/pmrf/lib/python3.10/site-packages/basicsr/data/degradations.py`, where `/path/to/env` is the path
where your conda installed the `pmrf` environment.
Then, change the line
```
from torchvision.transforms.functional_tensor import rgb_to_grayscale
```
to
```
from torchvision.transforms.functional import rgb_to_grayscale
```


# ⬇️ Download checkpoints


Our model checkpoints (from both sections 5.1 and 5.2 in the paper) can be downloaded from our [Google Drive](https://drive.google.com/drive/folders/1dfjZATcQ451uhvFH42tKnfMNHRkL6N_A?usp=sharing). Please keep the same folder structure as provided in Google Drive:

```
checkpoints/
├── blind_face_restoration_pmrf.ckpt # Checkpoint of our blind face image restoration model.
├── swinir_restoration512_L1.pth # Checkpoint of the SwinIR model trained by DifFace
├── controlled_experiments/ # Checkpoints for the controlled experiments
│ ├── colorization_gaussian_noise_025/
│ │ ├── pmrf/
│ │ │ └── epoch=999-step=273000.ckpt
│ │ ├── mmse/
│ │ │ └── epoch=999-step=273000.ckpt
. . .
. . .
. . .
```
To evaluate the landmark distance (LMD in the paper) and the identity metric (Deg in the paper), you will also need to download the `resnet18_110.pth` and `alignment_WFLW_4HG.pth` checkpoints from the [Google Drive](https://drive.google.com/drive/folders/1k3RCSliF6PsujCMIdCD1hNM63EozlDIZ) of [VQFR](https://github.com/TencentARC/VQFR). Place these checkpoints in the `evaluation/metrics_ckpt/` folder.

# 🌐 Download test data sets for blind face image restoration
1. Download WebPhoto-Test, LFW-Test, and CelebA-Test (HQ and LQ) from https://xinntao.github.io/projects/gfpgan.
2. Download WIDER-Test from https://shangchenzhou.com/projects/CodeFormer/.
3. Put these data sets wherever you want in your system.



# 🧑 Blind face image restoration (section 5.1 in the paper)
## ⚡ Quick inference ⚡
```
python inference.py \
--lq_data_path /path/to/lq/images \
--output_dir /path/to/results/dir \
--batch_size 64 \
--num_flow_steps 25
```
You can also alter the `inference.sh` file and run it.
You may alter the `--num_flow_steps` as you wish (this is the hyper-parameter `K` in our paper)

## 🔬 Evaluation

1. We downloaded the `resnet18_110.pth` and `alignment_WFLW_4HG.pth` checkpoints from the [Google Drive](https://drive.google.com/drive/folders/1k3RCSliF6PsujCMIdCD1hNM63EozlDIZ) of [VQFR](https://github.com/TencentARC/VQFR), and put these in the folder `evaluation/metrics_ckpt/`.
To evaluate the results on CelebA-Test, run:
```
cd evaluation
python compute_metrics_blind.py \
--parent_ffhq_512_path /path/to/parent/of/ffhq512 \
--rec_path /path/to/celeba-512-test/restored/images \
--gt_path /path/to/celeba-512-test/ground-truth/images
```
To evaluate the results on the real-world data sets, run:
```
cd evaluation
python compute_metrics_blind.py \
--parent_ffhq_512_path /path/to/parent/of/ffhq512 \
--rec_path /path/to/real-world/restored/images \
--mmse_rec_path /path/to/mmse/restored/images
```
The `--mmse_rec_path` argument is optional, and allows you to compute IndRMSE, as an indicator of the true RMSE for real-world degraded images.
Note that the MMSE reconstructions are saved automatically when you run `inference.py`, since the MMSE model
is also in the PMRF checkpoint.

## 💻 Training
In the folder `scripts/` we provide the training scripts we used for blind face image restoration and for training
the baseline models as well. If you want to run a script, you need to execute it in the root folder
(where `train.py` is located). To train the model, you will need the FFHQ data set.
We downloaded the original FFHQ 1024x1024 data set and down-sampled the images to size 512x512 using bi-cubic down-sampling.

1. Copy the `train_pmrf.sh` file (located in `scripts/train/blind_face_restoration`) to the root folder.
2. Adjust the arguments `--train_data_root` and `--val_data_root` according to the location of the training and validation data in your system.
3. The SwinIR model which was trained by [DifFace](https://github.com/zsyOAOA/DifFace) is provided in the `checkpoints/` folder. We downloaded it via
```
wget https://github.com/zsyOAOA/DifFace/releases/download/V1.0/swinir_restoration512_L1.pth
```
4. Adjust the argument `--mmse_model_ckpt_path` to the path of the SwinIR model.
5. Adjust the arguments `--num_gpus` and `--num_workers` according to your system.
6. Run the script `train_pmrf.sh` to train our model.


# 👩‍🔬 Controlled experiments (section 5.2 in the paper)
We provide training and evaluation codes for the controlled experiments in our paper, where we compare PMRF with the following baseline methods:
1. **Flow conditioned on Y**: A rectified flow model which is *conditioned* on the *input measurement*, and learns to flow from pure noise to the ground-truth data distribution.
2. **Flow conditioned on the posterior mean predictor**: A rectified flow model which is *conditioned* on the *posterior mean prediction*, and learns to flow from pure noise to the ground-truth data distribution.
3. **Flow from Y**: A rectified flow model which flows from the degraded measurement to the ground-truth data distribution.
4. **Posterior mean predictor**: A model which is trained to minimize the MSE loss.

## 🔬 Evaluation
We provide checkpoints for quick evaluation of PMRF and all the baseline methods.
1. The evaluation is conducted on CelebA-Test images of size 256x256. To acquire such images, we downloaded the CelebA-Test (HQ) images from [GFPGAN](https://xinntao.github.io/projects/gfpgan), and down-sampled them to 256x256 using bi-cubic down-sampling.
2. Adjust `--test_data_root` in `test.sh` to the path of the CelebA-Test 256x256 images, and adjust `--degradation` and `--ckpt_path` to the type of degradation you wish to assess and the corresponding model checkpoint.
3. Run `test.sh`.

We automatically save the reconstructed outputs, the degraded measurements, as well as the samples from the source distribution (the images from which the ODE solver begins).
After running `test.sh`, you may evaluate the results via :

```
cd evaluation
python compute_metrics_controlled_experiments.py \
--parent_ffhq_256_path /path/to/parent/of/ffhq256 \
--rec_path /path/to/restored/images \
--gt_path /path/to/celeba-256-test/ground-truth/images
```

## 💻 Training

* We trained our models on FFHQ 256x256. To acquire such images, with down-sampled the original FFHQ 1024x1024 images using bi-cubic down-sampling.
* The training scripts of PMRF and each of these baseline models are provided in the `scripts/train/controlled_experiments/` folder.
* To run each of these scripts, you need to copy it to the root folder where `train.py` is located. All you need to do is adjust the `--degradation`, `--source_noise_std`, `--train_data_root` and `--val_data_root` arguments in each script. For denoising, we used `--source_noise_std 0.025`, and for the rest of the tasks we used `--source_noise_std 0.1`.
* To run the `train_pmrf.sh` and `train_posterior_conditioned_on_mmse_model.sh` scripts, you first need to train the MMSE model via `train_mmse.sh`. Then, adjust the `--mmse_model_ckpt_path` argument according to the path of the MMSE model final checkpoint.


## Citation
@article{
ohayon2024pmrf,
title={Posterior-Mean Rectified Flow: Towards Minimum MSE Photo-Realistic Image Restoration},
author={Guy Ohayon and Tomer Michaeli and Michael Elad},
year={2024},
journal={arXiv preprint}
}

## License and acknowledgements
This project is released under the MIT license.
We borrow codes from [BasicSR](https://github.com/XPixelGroup/BasicSR), [VQFR](https://github.com/TencentARC/VQFR), [DifFace](https://github.com/zsyOAOA/DifFace), [k-diffusion](https://github.com/crowsonkb/k-diffusion), and [SwinIR](https://github.com/JingyunLiang/SwinIR). We thank the authors of such repositories for their useful implementations.

## Contact
If you have any questions or inquiries, please feel free to [contact me](mailto:[email protected]).
2 changes: 2 additions & 0 deletions arch/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from arch.hourglass.image_transformer_v2 import ImageTransformerDenoiserModelV2
from arch.swinir.swinir import SwinIR
Empty file added arch/hourglass/__init__.py
Empty file.
113 changes: 113 additions & 0 deletions arch/hourglass/axial_rope.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
"""k-diffusion transformer diffusion models, version 2.
Codes adopted from https://github.com/crowsonkb/k-diffusion
"""

import math

import torch
import torch._dynamo
from torch import nn

from . import flags

if flags.get_use_compile():
torch._dynamo.config.suppress_errors = True


def rotate_half(x):
x1, x2 = x[..., 0::2], x[..., 1::2]
x = torch.stack((-x2, x1), dim=-1)
*shape, d, r = x.shape
return x.view(*shape, d * r)


@flags.compile_wrap
def apply_rotary_emb(freqs, t, start_index=0, scale=1.0):
freqs = freqs.to(t)
rot_dim = freqs.shape[-1]
end_index = start_index + rot_dim
assert rot_dim <= t.shape[-1], f"feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}"
t_left, t, t_right = t[..., :start_index], t[..., start_index:end_index], t[..., end_index:]
t = (t * freqs.cos() * scale) + (rotate_half(t) * freqs.sin() * scale)
return torch.cat((t_left, t, t_right), dim=-1)


def centers(start, stop, num, dtype=None, device=None):
edges = torch.linspace(start, stop, num + 1, dtype=dtype, device=device)
return (edges[:-1] + edges[1:]) / 2


def make_grid(h_pos, w_pos):
grid = torch.stack(torch.meshgrid(h_pos, w_pos, indexing='ij'), dim=-1)
h, w, d = grid.shape
return grid.view(h * w, d)


def bounding_box(h, w, pixel_aspect_ratio=1.0):
# Adjusted dimensions
w_adj = w
h_adj = h * pixel_aspect_ratio

# Adjusted aspect ratio
ar_adj = w_adj / h_adj

# Determine bounding box based on the adjusted aspect ratio
y_min, y_max, x_min, x_max = -1.0, 1.0, -1.0, 1.0
if ar_adj > 1:
y_min, y_max = -1 / ar_adj, 1 / ar_adj
elif ar_adj < 1:
x_min, x_max = -ar_adj, ar_adj

return y_min, y_max, x_min, x_max


def make_axial_pos(h, w, pixel_aspect_ratio=1.0, align_corners=False, dtype=None, device=None):
y_min, y_max, x_min, x_max = bounding_box(h, w, pixel_aspect_ratio)
if align_corners:
h_pos = torch.linspace(y_min, y_max, h, dtype=dtype, device=device)
w_pos = torch.linspace(x_min, x_max, w, dtype=dtype, device=device)
else:
h_pos = centers(y_min, y_max, h, dtype=dtype, device=device)
w_pos = centers(x_min, x_max, w, dtype=dtype, device=device)
return make_grid(h_pos, w_pos)


def freqs_pixel(max_freq=10.0):
def init(shape):
freqs = torch.linspace(1.0, max_freq / 2, shape[-1]) * math.pi
return freqs.log().expand(shape)
return init


def freqs_pixel_log(max_freq=10.0):
def init(shape):
log_min = math.log(math.pi)
log_max = math.log(max_freq * math.pi / 2)
return torch.linspace(log_min, log_max, shape[-1]).expand(shape)
return init


class AxialRoPE(nn.Module):
def __init__(self, dim, n_heads, start_index=0, freqs_init=freqs_pixel_log(max_freq=10.0)):
super().__init__()
self.n_heads = n_heads
self.start_index = start_index
log_freqs = freqs_init((n_heads, dim // 4))
self.freqs_h = nn.Parameter(log_freqs.clone())
self.freqs_w = nn.Parameter(log_freqs.clone())

def extra_repr(self):
dim = (self.freqs_h.shape[-1] + self.freqs_w.shape[-1]) * 2
return f"dim={dim}, n_heads={self.n_heads}, start_index={self.start_index}"

def get_freqs(self, pos):
if pos.shape[-1] != 2:
raise ValueError("input shape must be (..., 2)")
freqs_h = pos[..., None, None, 0] * self.freqs_h.exp()
freqs_w = pos[..., None, None, 1] * self.freqs_w.exp()
freqs = torch.cat((freqs_h, freqs_w), dim=-1).repeat_interleave(2, dim=-1)
return freqs.transpose(-2, -3)

def forward(self, x, pos):
freqs = self.get_freqs(pos)
return apply_rotary_emb(freqs, x, self.start_index)
60 changes: 60 additions & 0 deletions arch/hourglass/flags.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
"""k-diffusion transformer diffusion models, version 2.
Codes adopted from https://github.com/crowsonkb/k-diffusion
"""

from contextlib import contextmanager
from functools import update_wrapper
import os
import threading

import torch


def get_use_compile():
return os.environ.get("K_DIFFUSION_USE_COMPILE", "1") == "1"


def get_use_flash_attention_2():
return os.environ.get("K_DIFFUSION_USE_FLASH_2", "1") == "1"


state = threading.local()
state.checkpointing = False


@contextmanager
def checkpointing(enable=True):
try:
old_checkpointing, state.checkpointing = state.checkpointing, enable
yield
finally:
state.checkpointing = old_checkpointing


def get_checkpointing():
return getattr(state, "checkpointing", False)


class compile_wrap:
def __init__(self, function, *args, **kwargs):
self.function = function
self.args = args
self.kwargs = kwargs
self._compiled_function = None
update_wrapper(self, function)

@property
def compiled_function(self):
if self._compiled_function is not None:
return self._compiled_function
if get_use_compile():
try:
self._compiled_function = torch.compile(self.function, *self.args, **self.kwargs)
except RuntimeError:
self._compiled_function = self.function
else:
self._compiled_function = self.function
return self._compiled_function

def __call__(self, *args, **kwargs):
return self.compiled_function(*args, **kwargs)
Loading

0 comments on commit dc2ec0d

Please sign in to comment.