forked from haku-huang/ReCoNet
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit 301cddf
Showing
36 changed files
with
1,517 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
name: Mirror to DUT DIMT | ||
|
||
on : [ push, delete, create ] | ||
|
||
jobs: | ||
git-mirror: | ||
runs-on: ubuntu-latest | ||
steps : | ||
- | ||
name: Configure Private Key | ||
env : | ||
SSH_PRIVATE_KEY: ${{ secrets.PRIVATE_KEY }} | ||
run : | | ||
mkdir -p ~/.ssh | ||
echo "$SSH_PRIVATE_KEY" > ~/.ssh/id_rsa | ||
chmod 600 ~/.ssh/id_rsa | ||
echo "StrictHostKeyChecking no" >> ~/.ssh/config | ||
- | ||
name: Push Mirror | ||
env : | ||
SOURCE_REPO : 'https://github.com/MisakiCoca/ReCoNet.git' | ||
DESTINATION_REPO: '[email protected]:dlut-dimt/ReCoNet.git' | ||
run : | | ||
git clone --mirror "$SOURCE_REPO" && cd `basename "$SOURCE_REPO"` | ||
git remote set-url --push origin "$DESTINATION_REPO" | ||
git fetch -p origin | ||
git for-each-ref --format 'delete %(refname)' refs/pull | git update-ref --stdin | ||
git push --mirror |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
# JetBarins | ||
.idea/* | ||
|
||
# macOS | ||
.DS_*/* | ||
|
||
# data | ||
data/* |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
MIT License | ||
|
||
Copyright (c) 2022 Zhanbo Huang | ||
|
||
Permission is hereby granted, free of charge, to any person obtaining a copy | ||
of this software and associated documentation files (the "Software"), to deal | ||
in the Software without restriction, including without limitation the rights | ||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | ||
copies of the Software, and to permit persons to whom the Software is | ||
furnished to do so, subject to the following conditions: | ||
|
||
The above copyright notice and this permission notice shall be included in all | ||
copies or substantial portions of the Software. | ||
|
||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | ||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | ||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | ||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | ||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | ||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | ||
SOFTWARE. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,121 @@ | ||
# ReCoNet | ||
|
||
 | ||
|
||
Zhanbo Huang, Jinyuan Liu, Xin Fan*, Risheng Liu, Wei Zhong, Zhongxuan Luo. | ||
**"Recurrent Correction Network for Fast and Efficient Multi-modality Image Fusion"**, European Conference on Computer | ||
Vision **(ECCV)**, 2022. | ||
|
||
## Milestone | ||
|
||
In the near future, we will publish the following materials. | ||
|
||
* v0 [ECCV]: Fuse network (ReCo) with pre-trained parameters for generating results in paper. **Finished** | ||
* v1: A new script & architecture of ReCo+ for fast training & prediction. **Building** | ||
* v1: A highly robust pre-trained parameters for ReCo+ based on realistic scene training. (We are collecting data with | ||
realistic implications.) | ||
|
||
## Update | ||
|
||
[2022-07-13] Preview of micro-register is available! | ||
|
||
[2022-07-12] The ReCo(v0) is available! | ||
|
||
## Requirements | ||
|
||
* Python 3.10 | ||
* PyTorch 1.12 | ||
* TorchVision 0.13.0 | ||
* PyTorch lightning 0.8.5 | ||
* Kornia 0.6.5 | ||
|
||
## Extended Experiments | ||
|
||
### Generate fake visible images | ||
|
||
To generating fake visible images as described in our paper, you can refer to my | ||
another repository [complex-deformation](https://github.com/MisakiCoca/complex-deformation), which is a component of | ||
this work. | ||
|
||
It shows how we can deform the image and generate a restored field that **approximates** the ground truth. | ||
|
||
### Have a quick preview of our micro-register | ||
|
||
To give a quick preview of our micro-register module, you can try the training & prediction based on | ||
the [MNIST](http://yann.lecun.com/exdb/mnist/) dataset. | ||
|
||
Activate your conda environment and enter folder `exp/test_register`. | ||
|
||
1. To train the register yourself, you just need to run this code. | ||
|
||
```shell | ||
export PYTHONPATH="${PYTHONPATH}:$RECO_ROOT" | ||
python test_register.py --backbone $BACKBONE --dst $DST | ||
``` | ||
|
||
The `$RECO_ROOT` is the root path of ReCo repository, like `~/lab/reco`, the `$BACKBONE` denotes which architecture to | ||
use `m`-`micro` or `u`-`unet`. | ||
|
||
We will do following things automatically: download MNIST dataset, train the register, and save predictions in `$DST`. | ||
|
||
2. If you just want to test the performance, we offer pre-trained parameters for both `micro` and `unet` based register. | ||
|
||
```shell | ||
export PYTHONPATH="${PYTHONPATH}:$RECO_ROOT" | ||
python test_register.py --backbone $BACKBONE --dst $DST --only_pred | ||
``` | ||
|
||
The prediction results will be save in `$DST` and the patches from left to right are `moving`, `fixed` and `moved`, | ||
respectively. | ||
|
||
## Get start (v0) (**current recommended**) | ||
|
||
1. To use our pre-trained parameters of ECCV-22 for fusion, you need to prepare your dataset in `$ROOT/data/$NAME`. | ||
|
||
``` | ||
$DATA (dataset name, like: tno) | ||
├── ir | ||
├── vi | ||
``` | ||
|
||
2. Enter the archive folder `cd archive`, and activate your conda environment `conda activate $CONDA_ENV`. | ||
|
||
```shell | ||
export PYTHONPATH="${PYTHONPATH}:$RECO_ROOT" | ||
python fuse.py --ir ../data/$DATA/ir --vi ../data/$DATA/vi --dst $SAVE_TO_WHERE | ||
``` | ||
|
||
3. Now, you will find the fusion results in `$SAVE_TO_WHERE`, this operation will create output folder automatically. | ||
|
||
## **Building:** ~~Get start (v1)~~ | ||
|
||
**Only recommended if you are intending in training ReCo+ yourself.** | ||
|
||
**Note that: Due to the instability of the micro-register module in the future, we recommend training only the fusion | ||
part.** | ||
|
||
1. To use the script to train ReCo+ yourself, you need to prepare your dataset in `$ROOT/data/$NAME`. | ||
|
||
``` | ||
$DATA (dataset name, like: tno) | ||
├── ir | ||
├── vi | ||
├── iqa (new for v1, optional) | ||
| | ├── ir (information measurement for infrared images) | ||
| | ├── vi (information measurement for visible images) | ||
├── meta (new for v1) | ||
| | ├── train.txt (which images are used for training) | ||
| | ├── val.txt (which images are used for validation) | ||
| | ├── pred.txt (which images are used for prediction) | ||
``` | ||
|
||
2. Activate your conda environment `conda activate $CONDA_ENV`. | ||
|
||
```shell | ||
# only train fuse part (ReCo) **current recommended** | ||
python train.py --register x --data data/$DATA --ckpt $CHECKPOINT_PATH --lr 1e-3 | ||
# train registration and fuse (ReCo+) | ||
python train.py --register m --data data/$DATA --ckpt $CHECKPOINT_PATH --lr 1e-3 --deform $DEFORM_LEVEL | ||
``` | ||
|
||
The `$DEFORM_LEVEL` should be `easy`, `normal` or `hard`. |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,128 @@ | ||
import argparse | ||
import pathlib | ||
import statistics | ||
import time | ||
|
||
import cv2 | ||
import kornia | ||
import torch | ||
import torch.backends.cudnn | ||
from tqdm import tqdm | ||
|
||
from archive.model import Fuser | ||
|
||
|
||
class Fuse: | ||
""" | ||
Fuse images with given args. | ||
""" | ||
|
||
def __init__(self, checkpoint: pathlib.Path, loop_num: int = 3, dim: int = 64): | ||
""" | ||
Init model and load pre-trained parameters. | ||
:param checkpoint: pre-trained model checkpoint | ||
:param loop_num: AFuse recurrent loop number, default: 3 | ||
:param dim: AFuse feather number, default: 64 | ||
""" | ||
|
||
# device [cuda or cpu] | ||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | ||
self.device = device | ||
|
||
# load pre-trained network | ||
net = Fuser(loop_num=loop_num, feather_num=dim) | ||
net.load_state_dict(torch.load(str(checkpoint), map_location='cpu')) | ||
net.to(device) | ||
net.eval() | ||
self.net = net | ||
|
||
@torch.no_grad() | ||
def __call__(self, ir_path: pathlib.Path, vi_path: pathlib.Path, dst: pathlib.Path): | ||
""" | ||
Fuse image with infrared folder, visible folder and destination path. | ||
:param ir_path: infrared folder path | ||
:param vi_path: visible folder path | ||
:param dst: fused images destination path | ||
""" | ||
|
||
# src list | ||
ir_list = [x for x in ir_path.glob('*') if x.suffix in ['.bmp', '.jpg', '.png']] | ||
vi_list = [x for x in vi_path.glob('*') if x.suffix in ['.bmp', '.jpg', '.png']] | ||
|
||
# time record | ||
fuse_time = [] | ||
|
||
# fuse images | ||
src = tqdm(zip(ir_list, vi_list)) | ||
for ir_path, vi_path in src: | ||
"fuse one pair with src image path" | ||
|
||
# judge image pair | ||
assert ir_path.name == vi_path.name | ||
src.set_description(f'fuse {ir_path.name}') | ||
|
||
# read image with Tensor | ||
ir = self._imread(ir_path).unsqueeze(0) | ||
vi = self._imread(vi_path).unsqueeze(0) | ||
ir = ir.to(self.device) | ||
vi = vi.to(self.device) | ||
|
||
# network flow | ||
torch.cuda.synchronize() if str(self.device) == 'cuda' else None | ||
start = time.time() | ||
im_f, _, _ = self.net([ir, vi]) | ||
torch.cuda.synchronize() if str(self.device) == 'cuda' else None | ||
end = time.time() | ||
fuse_time.append(end - start) | ||
|
||
# save fusion image | ||
self._imsave(dst / ir_path.name, im_f[-1]) | ||
|
||
# analyze fuse time | ||
std = statistics.stdev(fuse_time[1:]) | ||
avg = statistics.mean(fuse_time[1:]) | ||
print(f'fuse std time: {std:.4f}(s)') | ||
print(f'fuse avg time: {avg:.4f}(s)') | ||
print('fps (equivalence): {:.4f}'.format(1. / avg)) | ||
|
||
@staticmethod | ||
def _imread(path: pathlib.Path, flags=cv2.IMREAD_GRAYSCALE) -> torch.Tensor: | ||
im_cv = cv2.imread(str(path), flags) | ||
im_ts = kornia.utils.image_to_tensor(im_cv / 255.0).type(torch.FloatTensor) | ||
return im_ts | ||
|
||
@staticmethod | ||
def _imsave(path: pathlib.Path, image: torch.Tensor): | ||
im_ts = image.squeeze().cpu() | ||
path.parent.mkdir(parents=True, exist_ok=True) | ||
im_cv = kornia.utils.tensor_to_image(im_ts) * 255. | ||
cv2.imwrite(str(path), im_cv) | ||
|
||
|
||
def hyper_args(): | ||
""" | ||
get hyper parameters from args | ||
""" | ||
|
||
parser = argparse.ArgumentParser(description='ReCo(v0) fuse process') | ||
|
||
# dataset | ||
parser.add_argument('--ir', default='../data/tno/ir', help='infrared image folder') | ||
parser.add_argument('--vi', default='../data/tno/vi', help='visible image folder') | ||
parser.add_argument('--dst', default='../runs/archive', help='fuse image save folder') | ||
# checkpoint | ||
parser.add_argument('--cp', default='params.pth', help='weight checkpoint') | ||
# fuse network | ||
parser.add_argument('--loop', default=3, type=int, help='fuse loop time') | ||
parser.add_argument('--dim', default=64, type=int, help='fuse feather dim') | ||
|
||
args = parser.parse_args() | ||
return args | ||
|
||
|
||
if __name__ == '__main__': | ||
# hyper parameters | ||
args = hyper_args() | ||
|
||
f = Fuse(checkpoint=pathlib.Path(args.cp), loop_num=args.loop, dim=args.dim) | ||
f(pathlib.Path(args.ir), pathlib.Path(args.vi), pathlib.Path(args.dst)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
import torch | ||
import torch.nn as nn | ||
|
||
|
||
class Fuser(nn.Module): | ||
""" | ||
Fuse the two input images. | ||
""" | ||
|
||
def __init__(self, loop_num=3, feather_num=64, fine_tune=False): | ||
super().__init__() | ||
self.loop_num = loop_num | ||
self.fine_tune = fine_tune | ||
|
||
# attention layer | ||
self.att_a_conv = nn.Conv2d(2, 1, 3, padding=1, bias=False) | ||
self.att_b_conv = nn.Conv2d(2, 1, 3, padding=1, bias=False) | ||
|
||
# dilation conv layer | ||
self.dil_conv_1 = nn.Sequential(nn.Conv2d(3, feather_num, 3, 1, 1, 1), nn.BatchNorm2d(feather_num), nn.ReLU()) | ||
self.dil_conv_2 = nn.Sequential(nn.Conv2d(3, feather_num, 3, 1, 2, 2), nn.BatchNorm2d(feather_num), nn.ReLU()) | ||
self.dil_conv_3 = nn.Sequential(nn.Conv2d(3, feather_num, 3, 1, 3, 3), nn.BatchNorm2d(feather_num), nn.ReLU()) | ||
|
||
# fuse conv layer | ||
self.fus_conv = nn.Sequential(nn.Conv2d(3 * feather_num, 1, 3, padding=1), nn.BatchNorm2d(1), nn.Tanh()) | ||
|
||
def forward(self, im_p): | ||
""" | ||
:param im_p: image pair | ||
""" | ||
|
||
# unpack im_p | ||
im_a, im_b = im_p | ||
|
||
# recurrent sub network | ||
# generate f_0 with manual function | ||
im_f = [torch.max(im_a, im_b)] # init im_f_0 | ||
att_a = [] | ||
att_b = [] | ||
|
||
# loop in sub network | ||
for e in range(self.loop_num): | ||
im_f_x, att_a_x, att_b_x = self._sub_forward(im_a, im_b, im_f[-1]) | ||
im_f.append(im_f_x) | ||
att_a.append(att_a_x) | ||
att_b.append(att_b_x) | ||
|
||
# return im_f, att list | ||
return im_f, att_a, att_b | ||
|
||
def _sub_forward(self, im_a, im_b, im_f): | ||
# attention | ||
att_a = self._attention(self.att_a_conv, im_a, im_f) | ||
att_b = self._attention(self.att_b_conv, im_b, im_f) | ||
att_a = att_a.detach() if self.fine_tune else att_a | ||
att_b = att_b.detach() if self.fine_tune else att_b | ||
|
||
# focus on attention | ||
im_a_att = im_a * att_a | ||
im_b_att = im_b * att_b | ||
|
||
# image concat | ||
im_cat = torch.cat([im_a_att, im_f, im_b_att], dim=1) | ||
im_cat = im_cat.detach() if self.fine_tune else im_cat | ||
|
||
# dilation | ||
dil_1 = self.dil_conv_1(im_cat) | ||
dil_2 = self.dil_conv_2(im_cat) | ||
dil_3 = self.dil_conv_3(im_cat) | ||
|
||
# feather concat | ||
f_cat = torch.cat([dil_1, dil_2, dil_3], dim=1) | ||
|
||
# fuse | ||
im_f_n = self.fus_conv(f_cat) | ||
|
||
return im_f_n, att_a, att_b | ||
|
||
@staticmethod | ||
def _attention(att_conv, im_x, im_f): | ||
x = torch.cat([im_x, im_f], dim=1) | ||
x_max, _ = torch.max(x, dim=1, keepdim=True) | ||
x_avg = torch.mean(x, dim=1, keepdim=True) | ||
x = torch.cat([x_max, x_avg], dim=1) | ||
x = att_conv(x) | ||
return torch.sigmoid(x) |
Binary file not shown.
Empty file.
Empty file.
Oops, something went wrong.