Skip to content

Commit

Permalink
Added huggingface checkpoint availability for blind face image restor…
Browse files Browse the repository at this point in the history
…ation
  • Loading branch information
ohayonguy committed Oct 3, 2024
1 parent 9b8b1cb commit 8e1b206
Show file tree
Hide file tree
Showing 5 changed files with 50 additions and 23 deletions.
25 changes: 19 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ pip install basicsr==1.4.2
pip install git+https://github.com/toshas/torch-fidelity.git
pip install lpips==0.1.4
pip install piq==0.8.0
pip install huggingface_hub==0.24.5
```

1. Note that the package `natten` is required for the HDiT architecture used by PMRF.
Expand All @@ -73,11 +74,11 @@ to
from torchvision.transforms.functional import rgb_to_grayscale
```


# ⬇️ Download checkpoints


Our model checkpoints (from both sections 5.1 and 5.2 in the paper) can be downloaded from our [Google Drive](https://drive.google.com/drive/folders/1dfjZATcQ451uhvFH42tKnfMNHRkL6N_A?usp=sharing). Please keep the same folder structure as provided in Google Drive:
We provide our blind face image restoration model checkpoint in [Hugging Face](https://huggingface.co/ohayonguy/PMRF_blind_face_image_restoration) and in [Google Drive](https://drive.google.com/drive/folders/1dfjZATcQ451uhvFH42tKnfMNHRkL6N_A?usp=sharing).
The checkpoints for section 5.2 in the paper (the controlled experiments) can be downloaded from [Google Drive](https://drive.google.com/drive/folders/1dfjZATcQ451uhvFH42tKnfMNHRkL6N_A?usp=sharing). Please keep the same folder structure as provided in Google Drive:

```
checkpoints/
Expand All @@ -101,18 +102,30 @@ To evaluate the landmark distance (LMD in the paper) and the identity metric (De
3. Put these data sets wherever you want in your system.



# 🧑 Blind face image restoration (section 5.1 in the paper)
## ⚡ Quick inference ⚡

To quickly use our model, we provide a [Hugging Face checkpoint](https://huggingface.co/ohayonguy/PMRF_blind_face_image_restoration) which is automatically downloaded. Simply run
```
python inference.py \
--ckpt_path ohayonguy/PMRF_blind_face_image_restoration \
--ckpt_path_is_huggingface \
--lq_data_path /path/to/lq/images \
--output_dir /path/to/results/dir \
--batch_size 64 \
--num_flow_steps 25
```
Please alter `--num_flow_steps` as you wish (this is the hyper-parameter `K` in our paper)

You may also provide a local model checkpoint (e.g., if you train your own PMRF model, or if you wish to use our [Google Drive](https://drive.google.com/drive/folders/1dfjZATcQ451uhvFH42tKnfMNHRkL6N_A?usp=sharing) checkpoint instead of the Hugging Face one). Simply run
```
python inference.py \
--ckpt_path ./checkpoints/blind_face_restoration_pmrf.ckpt \
--lq_data_path /path/to/lq/images \
--output_dir /path/to/results/dir \
--batch_size 64 \
--num_flow_steps 25
```
You can also alter the `inference.sh` file and run it.
You may alter the `--num_flow_steps` as you wish (this is the hyper-parameter `K` in our paper)

## 🔬 Evaluation

Expand Down
31 changes: 18 additions & 13 deletions inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,23 +18,26 @@ def main(args):
output_path = os.path.join(args.output_dir, 'restored_images')
os.makedirs(output_path, exist_ok=True)

ckpt = torch.load(args.ckpt_path, map_location="cpu")
mmse_model_arch = ckpt['hyper_parameters']['mmse_model_arch']
model = MMSERectifiedFlow.load_from_checkpoint(args.ckpt_path,
# Need to provide mmse_model_arch to
# make sure the model initializes it.
mmse_model_arch=mmse_model_arch,
mmse_model_ckpt_path=None, # Will ignore the original path of the
# MMSE model used for training,
# and instead load it from the model checkpoint.
map_location='cpu').cuda()
if args.ckpt_path_is_huggingface:
model = MMSERectifiedFlow.from_pretrained(args.ckpt_path).cuda()
else:
ckpt = torch.load(args.ckpt_path, map_location="cpu")
mmse_model_arch = ckpt['hyper_parameters']['mmse_model_arch']
model = MMSERectifiedFlow.load_from_checkpoint(args.ckpt_path,
# Need to provide mmse_model_arch to
# make sure the model initializes it.
mmse_model_arch=mmse_model_arch,
mmse_model_ckpt_path=None, # Will ignore the original path of the
# MMSE model used for training,
# and instead load it from the model checkpoint.
map_location='cpu').cuda()
if model.ema_wanted:
model.ema.load_state_dict(ckpt['ema'])
model.ema.copy_to()
if model.mmse_model is not None:
output_path_mmse = os.path.join(args.output_dir, 'restored_images_posterior_mean')
os.makedirs(output_path_mmse, exist_ok=True)

if model.ema_wanted:
model.ema.load_state_dict(ckpt['ema'])
model.ema.copy_to()

torch.compile(model, mode='max-autotune')
print("Compiled model")
Expand All @@ -60,6 +63,8 @@ def main(args):
parser.add_argument('--ckpt_path', type=str, required=False,
default='./checkpoints/blind_face_restoration_pmrf.ckpt',
help='Path to the model checkpoint.')
parser.add_argument('--ckpt_path_is_huggingface', action='store_true', required=False, default=False,
help='Whether the ckpt path is a huggingface model or a path to a local file.')
parser.add_argument('--lq_data_path', type=str, required=True,
help='Path to a folder that contains low quality images.')
parser.add_argument('--output_dir', type=str, required=True,
Expand Down
6 changes: 4 additions & 2 deletions inference.sh
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
#!/bin/bash

python inference.py \
--lq_data_path /path/to/lq/images \
--output_dir /path/to/results/dir \
--ckpt_path ohayonguy/PMRF_blind_face_image_restoration \
--ckpt_path_is_huggingface \
--lq_data_path /home/ohayonguy/projects/mmse_rectified_flow/data/celeba_512_validation_lq \
--output_dir ./results_huggingface \
--batch_size 64 \
--num_flow_steps 25
3 changes: 2 additions & 1 deletion install.sh
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,5 @@ pip install nvidia-cuda-nvcc-cu11 --no-input
pip install basicsr==1.4.2 --no-input
pip install git+https://github.com/toshas/torch-fidelity.git --no-input
pip install lpips==0.1.4 --no-input
pip install piq==0.8.0 --no-input
pip install piq==0.8.0 --no-input
pip install huggingface_hub==0.24.5 --no-input
8 changes: 7 additions & 1 deletion lightning_models/mmse_rectified_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,15 @@

from utils.create_arch import create_arch
from utils.img_utils import create_grid
from huggingface_hub import PyTorchModelHubMixin


class MMSERectifiedFlow(LightningModule):

class MMSERectifiedFlow(LightningModule,
PyTorchModelHubMixin,
pipeline_tag="image-to-image",
license="mit",
):
def __init__(self,
stage,
arch,
Expand Down

0 comments on commit 8e1b206

Please sign in to comment.