Skip to content

Commit

Permalink
Update preview of script (v1)
Browse files Browse the repository at this point in the history
  • Loading branch information
haku-huang committed Jul 15, 2022
0 parents commit 301cddf
Show file tree
Hide file tree
Showing 36 changed files with 1,517 additions and 0 deletions.
28 changes: 28 additions & 0 deletions .github/workflows/sync.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
name: Mirror to DUT DIMT

on : [ push, delete, create ]

jobs:
git-mirror:
runs-on: ubuntu-latest
steps :
-
name: Configure Private Key
env :
SSH_PRIVATE_KEY: ${{ secrets.PRIVATE_KEY }}
run : |
mkdir -p ~/.ssh
echo "$SSH_PRIVATE_KEY" > ~/.ssh/id_rsa
chmod 600 ~/.ssh/id_rsa
echo "StrictHostKeyChecking no" >> ~/.ssh/config
-
name: Push Mirror
env :
SOURCE_REPO : 'https://github.com/MisakiCoca/ReCoNet.git'
DESTINATION_REPO: '[email protected]:dlut-dimt/ReCoNet.git'
run : |
git clone --mirror "$SOURCE_REPO" && cd `basename "$SOURCE_REPO"`
git remote set-url --push origin "$DESTINATION_REPO"
git fetch -p origin
git for-each-ref --format 'delete %(refname)' refs/pull | git update-ref --stdin
git push --mirror
8 changes: 8 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# JetBarins
.idea/*

# macOS
.DS_*/*

# data
data/*
21 changes: 21 additions & 0 deletions LICENSE
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
MIT License

Copyright (c) 2022 Zhanbo Huang

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
121 changes: 121 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
# ReCoNet

![visitors](https://visitor-badge.glitch.me/badge?page_id=MisakiCoca.ReCoNet)

Zhanbo Huang, Jinyuan Liu, Xin Fan*, Risheng Liu, Wei Zhong, Zhongxuan Luo.
**"Recurrent Correction Network for Fast and Efficient Multi-modality Image Fusion"**, European Conference on Computer
Vision **(ECCV)**, 2022.

## Milestone

In the near future, we will publish the following materials.

* v0 [ECCV]: Fuse network (ReCo) with pre-trained parameters for generating results in paper. **Finished**
* v1: A new script & architecture of ReCo+ for fast training & prediction. **Building**
* v1: A highly robust pre-trained parameters for ReCo+ based on realistic scene training. (We are collecting data with
realistic implications.)

## Update

[2022-07-13] Preview of micro-register is available!

[2022-07-12] The ReCo(v0) is available!

## Requirements

* Python 3.10
* PyTorch 1.12
* TorchVision 0.13.0
* PyTorch lightning 0.8.5
* Kornia 0.6.5

## Extended Experiments

### Generate fake visible images

To generating fake visible images as described in our paper, you can refer to my
another repository [complex-deformation](https://github.com/MisakiCoca/complex-deformation), which is a component of
this work.

It shows how we can deform the image and generate a restored field that **approximates** the ground truth.

### Have a quick preview of our micro-register

To give a quick preview of our micro-register module, you can try the training & prediction based on
the [MNIST](http://yann.lecun.com/exdb/mnist/) dataset.

Activate your conda environment and enter folder `exp/test_register`.

1. To train the register yourself, you just need to run this code.

```shell
export PYTHONPATH="${PYTHONPATH}:$RECO_ROOT"
python test_register.py --backbone $BACKBONE --dst $DST
```

The `$RECO_ROOT` is the root path of ReCo repository, like `~/lab/reco`, the `$BACKBONE` denotes which architecture to
use `m`-`micro` or `u`-`unet`.

We will do following things automatically: download MNIST dataset, train the register, and save predictions in `$DST`.

2. If you just want to test the performance, we offer pre-trained parameters for both `micro` and `unet` based register.

```shell
export PYTHONPATH="${PYTHONPATH}:$RECO_ROOT"
python test_register.py --backbone $BACKBONE --dst $DST --only_pred
```

The prediction results will be save in `$DST` and the patches from left to right are `moving`, `fixed` and `moved`,
respectively.

## Get start (v0) (**current recommended**)

1. To use our pre-trained parameters of ECCV-22 for fusion, you need to prepare your dataset in `$ROOT/data/$NAME`.

```
$DATA (dataset name, like: tno)
├── ir
├── vi
```

2. Enter the archive folder `cd archive`, and activate your conda environment `conda activate $CONDA_ENV`.

```shell
export PYTHONPATH="${PYTHONPATH}:$RECO_ROOT"
python fuse.py --ir ../data/$DATA/ir --vi ../data/$DATA/vi --dst $SAVE_TO_WHERE
```

3. Now, you will find the fusion results in `$SAVE_TO_WHERE`, this operation will create output folder automatically.

## **Building:** ~~Get start (v1)~~

**Only recommended if you are intending in training ReCo+ yourself.**

**Note that: Due to the instability of the micro-register module in the future, we recommend training only the fusion
part.**

1. To use the script to train ReCo+ yourself, you need to prepare your dataset in `$ROOT/data/$NAME`.

```
$DATA (dataset name, like: tno)
├── ir
├── vi
├── iqa (new for v1, optional)
| | ├── ir (information measurement for infrared images)
| | ├── vi (information measurement for visible images)
├── meta (new for v1)
| | ├── train.txt (which images are used for training)
| | ├── val.txt (which images are used for validation)
| | ├── pred.txt (which images are used for prediction)
```

2. Activate your conda environment `conda activate $CONDA_ENV`.

```shell
# only train fuse part (ReCo) **current recommended**
python train.py --register x --data data/$DATA --ckpt $CHECKPOINT_PATH --lr 1e-3
# train registration and fuse (ReCo+)
python train.py --register m --data data/$DATA --ckpt $CHECKPOINT_PATH --lr 1e-3 --deform $DEFORM_LEVEL
```

The `$DEFORM_LEVEL` should be `easy`, `normal` or `hard`.
Empty file added archive/__init__.py
Empty file.
128 changes: 128 additions & 0 deletions archive/fuse.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
import argparse
import pathlib
import statistics
import time

import cv2
import kornia
import torch
import torch.backends.cudnn
from tqdm import tqdm

from archive.model import Fuser


class Fuse:
"""
Fuse images with given args.
"""

def __init__(self, checkpoint: pathlib.Path, loop_num: int = 3, dim: int = 64):
"""
Init model and load pre-trained parameters.
:param checkpoint: pre-trained model checkpoint
:param loop_num: AFuse recurrent loop number, default: 3
:param dim: AFuse feather number, default: 64
"""

# device [cuda or cpu]
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.device = device

# load pre-trained network
net = Fuser(loop_num=loop_num, feather_num=dim)
net.load_state_dict(torch.load(str(checkpoint), map_location='cpu'))
net.to(device)
net.eval()
self.net = net

@torch.no_grad()
def __call__(self, ir_path: pathlib.Path, vi_path: pathlib.Path, dst: pathlib.Path):
"""
Fuse image with infrared folder, visible folder and destination path.
:param ir_path: infrared folder path
:param vi_path: visible folder path
:param dst: fused images destination path
"""

# src list
ir_list = [x for x in ir_path.glob('*') if x.suffix in ['.bmp', '.jpg', '.png']]
vi_list = [x for x in vi_path.glob('*') if x.suffix in ['.bmp', '.jpg', '.png']]

# time record
fuse_time = []

# fuse images
src = tqdm(zip(ir_list, vi_list))
for ir_path, vi_path in src:
"fuse one pair with src image path"

# judge image pair
assert ir_path.name == vi_path.name
src.set_description(f'fuse {ir_path.name}')

# read image with Tensor
ir = self._imread(ir_path).unsqueeze(0)
vi = self._imread(vi_path).unsqueeze(0)
ir = ir.to(self.device)
vi = vi.to(self.device)

# network flow
torch.cuda.synchronize() if str(self.device) == 'cuda' else None
start = time.time()
im_f, _, _ = self.net([ir, vi])
torch.cuda.synchronize() if str(self.device) == 'cuda' else None
end = time.time()
fuse_time.append(end - start)

# save fusion image
self._imsave(dst / ir_path.name, im_f[-1])

# analyze fuse time
std = statistics.stdev(fuse_time[1:])
avg = statistics.mean(fuse_time[1:])
print(f'fuse std time: {std:.4f}(s)')
print(f'fuse avg time: {avg:.4f}(s)')
print('fps (equivalence): {:.4f}'.format(1. / avg))

@staticmethod
def _imread(path: pathlib.Path, flags=cv2.IMREAD_GRAYSCALE) -> torch.Tensor:
im_cv = cv2.imread(str(path), flags)
im_ts = kornia.utils.image_to_tensor(im_cv / 255.0).type(torch.FloatTensor)
return im_ts

@staticmethod
def _imsave(path: pathlib.Path, image: torch.Tensor):
im_ts = image.squeeze().cpu()
path.parent.mkdir(parents=True, exist_ok=True)
im_cv = kornia.utils.tensor_to_image(im_ts) * 255.
cv2.imwrite(str(path), im_cv)


def hyper_args():
"""
get hyper parameters from args
"""

parser = argparse.ArgumentParser(description='ReCo(v0) fuse process')

# dataset
parser.add_argument('--ir', default='../data/tno/ir', help='infrared image folder')
parser.add_argument('--vi', default='../data/tno/vi', help='visible image folder')
parser.add_argument('--dst', default='../runs/archive', help='fuse image save folder')
# checkpoint
parser.add_argument('--cp', default='params.pth', help='weight checkpoint')
# fuse network
parser.add_argument('--loop', default=3, type=int, help='fuse loop time')
parser.add_argument('--dim', default=64, type=int, help='fuse feather dim')

args = parser.parse_args()
return args


if __name__ == '__main__':
# hyper parameters
args = hyper_args()

f = Fuse(checkpoint=pathlib.Path(args.cp), loop_num=args.loop, dim=args.dim)
f(pathlib.Path(args.ir), pathlib.Path(args.vi), pathlib.Path(args.dst))
86 changes: 86 additions & 0 deletions archive/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
import torch
import torch.nn as nn


class Fuser(nn.Module):
"""
Fuse the two input images.
"""

def __init__(self, loop_num=3, feather_num=64, fine_tune=False):
super().__init__()
self.loop_num = loop_num
self.fine_tune = fine_tune

# attention layer
self.att_a_conv = nn.Conv2d(2, 1, 3, padding=1, bias=False)
self.att_b_conv = nn.Conv2d(2, 1, 3, padding=1, bias=False)

# dilation conv layer
self.dil_conv_1 = nn.Sequential(nn.Conv2d(3, feather_num, 3, 1, 1, 1), nn.BatchNorm2d(feather_num), nn.ReLU())
self.dil_conv_2 = nn.Sequential(nn.Conv2d(3, feather_num, 3, 1, 2, 2), nn.BatchNorm2d(feather_num), nn.ReLU())
self.dil_conv_3 = nn.Sequential(nn.Conv2d(3, feather_num, 3, 1, 3, 3), nn.BatchNorm2d(feather_num), nn.ReLU())

# fuse conv layer
self.fus_conv = nn.Sequential(nn.Conv2d(3 * feather_num, 1, 3, padding=1), nn.BatchNorm2d(1), nn.Tanh())

def forward(self, im_p):
"""
:param im_p: image pair
"""

# unpack im_p
im_a, im_b = im_p

# recurrent sub network
# generate f_0 with manual function
im_f = [torch.max(im_a, im_b)] # init im_f_0
att_a = []
att_b = []

# loop in sub network
for e in range(self.loop_num):
im_f_x, att_a_x, att_b_x = self._sub_forward(im_a, im_b, im_f[-1])
im_f.append(im_f_x)
att_a.append(att_a_x)
att_b.append(att_b_x)

# return im_f, att list
return im_f, att_a, att_b

def _sub_forward(self, im_a, im_b, im_f):
# attention
att_a = self._attention(self.att_a_conv, im_a, im_f)
att_b = self._attention(self.att_b_conv, im_b, im_f)
att_a = att_a.detach() if self.fine_tune else att_a
att_b = att_b.detach() if self.fine_tune else att_b

# focus on attention
im_a_att = im_a * att_a
im_b_att = im_b * att_b

# image concat
im_cat = torch.cat([im_a_att, im_f, im_b_att], dim=1)
im_cat = im_cat.detach() if self.fine_tune else im_cat

# dilation
dil_1 = self.dil_conv_1(im_cat)
dil_2 = self.dil_conv_2(im_cat)
dil_3 = self.dil_conv_3(im_cat)

# feather concat
f_cat = torch.cat([dil_1, dil_2, dil_3], dim=1)

# fuse
im_f_n = self.fus_conv(f_cat)

return im_f_n, att_a, att_b

@staticmethod
def _attention(att_conv, im_x, im_f):
x = torch.cat([im_x, im_f], dim=1)
x_max, _ = torch.max(x, dim=1, keepdim=True)
x_avg = torch.mean(x, dim=1, keepdim=True)
x = torch.cat([x_max, x_avg], dim=1)
x = att_conv(x)
return torch.sigmoid(x)
Binary file added archive/params.pth
Binary file not shown.
Empty file added exp/__init__.py
Empty file.
Empty file added exp/find_adjust/__init__.py
Empty file.
Loading

0 comments on commit 301cddf

Please sign in to comment.