forked from antabangun/coex
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
0c29bdf
commit edac882
Showing
10 changed files
with
1,165 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,121 @@ | ||
# CoEx | ||
|
||
PyTorch implementation of our paper: | ||
|
||
|
||
**Correlate-and-Excite: Real-Time Stereo Matching via Guided Cost Volume Excitation** | ||
*Authors: Antyanta Bangunharcana, Jae Won Cho, Seokju Lee, In So Kweon, Kyung-Soo Kim, Soohyun Kim* | ||
IEEE/RSJ International Conference on Intelligent Robots and Systems (IROS), 2021 | ||
|
||
\[[Project page](https://antabangun.github.io/projects/CoEx/)\] | ||
|
||
We propose a Guided Cost volume Excitation (GCE) and top-k soft-argmax disparity regression for real-time and accurate stereo matching. | ||
|
||
## Contents | ||
- [Installation](#installation) | ||
- [Datasets](#datasets) | ||
- [Data for demo](#data-for-demo) | ||
- [If you want to re-train the models](#if-you-want-to-re-train-the-models) | ||
- [Data directories](#data-directories) | ||
- [Demo on KITTI raw data](#demo-on-kitti-raw-data) | ||
- [Model zoo](#model-zoo) | ||
- [Re-training the model](#re-training-the-model) | ||
|
||
## Installation | ||
|
||
We recommend using [conda](https://www.anaconda.com/distribution/) for installation: | ||
```bash | ||
conda env create -f environment.yml | ||
conda activate coex | ||
``` | ||
|
||
## Datasets | ||
|
||
### Data for demo | ||
|
||
For a demo of our code on the KITTI dataset, download the "\[synced+rectified data\]" from [raw KITTI data](http://www.cvlibs.net/datasets/kitti/raw_data.php). Unzip and place the extracted folders following the directory tree below. | ||
|
||
### If you want to re-train the models | ||
**Sceneflow dataset** | ||
Download the *finalpass* data of the [Sceneflow dataset](https://lmb.informatik.uni-freiburg.de/resources/datasets/SceneFlowDatasets.en.html) as well as the *Disparity* data. | ||
|
||
**KITTI 2015** | ||
Download [kitti15](http://www.cvlibs.net/datasets/kitti/eval_scene_flow.php?benchmark=stereo) dataset, and unzip data_scene_flow.zip, rename it as kitti15, and move it into SceneFlow directory as shown in the tree below. | ||
|
||
**KITTI 2012** | ||
Download [kitti12](http://www.cvlibs.net/datasets/kitti/eval_stereo_flow.php?benchmark=stereo) dataset. Unzip data_stereo_flow.zip, rename it as kitti12, and move it into SceneFlow directory as shown in the tree below. | ||
|
||
Make sure the directory names matches the tree below so that the dataloaders can locate the files. | ||
|
||
### Data directories | ||
|
||
In our setup, the dataset is organized as follows | ||
``` | ||
../../data | ||
└── datasets | ||
├── KITTI_raw | ||
| ├── 2011_09_26 | ||
| │ ├── 2011_09_26_drive_0001_sync | ||
| │ ├── 2011_09_26_drive_0002_sync | ||
| | : | ||
| | | ||
| ├── 2011_09_28 | ||
| │ ├── 2011_09_28_drive_0001_sync | ||
| │ └── 2011_09_28_drive_0002_sync | ||
| | : | ||
| | : | ||
| | ||
└── SceneFlow | ||
├── driving | ||
│ ├── disparity | ||
│ └── frames_finalpass | ||
├── flyingthings3d_final | ||
│ ├── disparity | ||
│ └── frames_finalpass | ||
├── monkaa | ||
│ ├── disparity | ||
│ └── frames_finalpass | ||
├── kitti12 | ||
│ ├── testing | ||
│ └── training | ||
└── kitti15 | ||
├── testing | ||
└── training | ||
``` | ||
|
||
## Demo on KITTI raw data | ||
The pretrained KITTI model is already included in './logs'. | ||
Run | ||
```bash | ||
python demo.py | ||
``` | ||
to perform stereo matching on raw kitti sequence. Here is an example result on our system with RTX 2080Ti on Ubuntu 18.04. | ||
|
||
<p align="center"> | ||
<img width="422" height="223" src="./imgs/coex_compress.gif" data-zoomable> | ||
</p> | ||
|
||
For more demo results, checkout our [Project](https://antabangun.github.io/projects/CoEx/#demo) page | ||
|
||
## Re-training the model | ||
To re-train the model, configure './configs/stereo/cfg_yaml', e.g., batch_size, paths, device num, precision, etc. Then run | ||
```bash | ||
python stereo.py | ||
``` | ||
|
||
## Citation | ||
|
||
If you find our work useful in your research, please consider citing our paper | ||
|
||
@inproceedings{bangunharcana2021correlate, | ||
title={Correlate-and-Excite: Real-Time Stereo Matching via Guided Cost Volume Excitation}, | ||
author={Bangunharcana, Antyanta and Cho, Jae Won and Lee, Seokju and Kweon, In So and Kim, Kyung-Soo and Kim, Soohyun}, | ||
booktitle={2021 IEEE/RSJ International Conference on Intelligent Robots and Systems (IROS)}, | ||
pages={3542--3548}, | ||
year={2021}, | ||
organization={IEEE} | ||
} | ||
|
||
## Acknowledgements | ||
|
||
Part of the code is adopted from previous works: [PSMNet](https://github.com/JiaRenChang/PSMNet), [AANet](https://github.com/haofeixu/aanet), [GANet](https://github.com/feihuzhang/GANet), [SpixelFCN](https://github.com/fuy34/superpixel_fcn) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,106 @@ | ||
import cv2 | ||
import numpy as np | ||
|
||
import torch | ||
from torch.utils.data import DataLoader | ||
|
||
from ruamel.yaml import YAML | ||
|
||
from dataloaders import KITTIRawLoader as KRL | ||
|
||
from stereo import Stereo | ||
|
||
torch.backends.cudnn.benchmark = True | ||
|
||
|
||
torch.set_grad_enabled(False) | ||
|
||
config = 'cfg_coex.yaml' | ||
version = 0 # CoEx | ||
|
||
vid_date = "2011_09_26" | ||
vid_num = '0093' | ||
half_precision = True | ||
|
||
|
||
def load_configs(path): | ||
cfg = YAML().load(open(path, 'r')) | ||
backbone_cfg = YAML().load( | ||
open(cfg['model']['stereo']['backbone']['cfg_path'], 'r')) | ||
cfg['model']['stereo']['backbone'].update(backbone_cfg) | ||
return cfg | ||
|
||
|
||
if __name__ == '__main__': | ||
cfg = load_configs( | ||
'./configs/stereo/{}'.format(config)) | ||
|
||
ckpt = '{}/{}/version_{}/checkpoints/last.ckpt'.format( | ||
'logs/stereo', cfg['model']['name'], version) | ||
cfg['stereo_ckpt'] = ckpt | ||
pose_ssstereo = Stereo.load_from_checkpoint(cfg['stereo_ckpt'], | ||
strict=False, | ||
cfg=cfg).cuda() | ||
|
||
left_cam, right_cam = KRL.listfiles( | ||
cfg, | ||
vid_date, | ||
vid_num, | ||
True) | ||
cfg['training']['th'] = 0 | ||
cfg['training']['tw'] = 0 | ||
kitti_train = KRL.ImageLoader( | ||
left_cam, right_cam, cfg, training=True, demo=True) | ||
kitti_train = DataLoader( | ||
kitti_train, batch_size=1, | ||
num_workers=4, shuffle=False, drop_last=False) | ||
|
||
fps_list = np.array([]) | ||
|
||
pose_ssstereo.eval() | ||
for i, batch in enumerate(kitti_train): | ||
start = torch.cuda.Event(enable_timing=True) | ||
end = torch.cuda.Event(enable_timing=True) | ||
start.record() | ||
|
||
imgL, imgR = batch['imgL'].cuda(), batch['imgR'].cuda() | ||
imgLRaw = batch['imgLRaw'] | ||
imgLRaw = imgLRaw.cuda() | ||
|
||
end.record() | ||
torch.cuda.synchronize() | ||
runtime = start.elapsed_time(end) | ||
print('Data Preparation: {:.3f}'.format(runtime)) | ||
|
||
start = torch.cuda.Event(enable_timing=True) | ||
end = torch.cuda.Event(enable_timing=True) | ||
start.record() | ||
with torch.no_grad(): | ||
with torch.cuda.amp.autocast(enabled=half_precision): | ||
img = torch.cat([imgL, imgR], 0) | ||
disp = pose_ssstereo(img, training=False) | ||
end.record() | ||
torch.cuda.synchronize() | ||
runtime = start.elapsed_time(end) | ||
# print('Stereo runtime: {:.3f}'.format(runtime)) | ||
|
||
fps = 1000/runtime | ||
fps_list = np.append(fps_list, fps) | ||
if len(fps_list) > 5: | ||
fps_list = fps_list[-5:] | ||
avg_fps = np.mean(fps_list) | ||
print('Stereo runtime: {:.3f}'.format(1000/avg_fps)) | ||
|
||
disp_np = (2*disp[0]).data.cpu().numpy().astype(np.uint8) | ||
disp_np = cv2.applyColorMap(disp_np, cv2.COLORMAP_MAGMA) | ||
|
||
image_np = (imgLRaw[0].permute(1, 2, 0).data.cpu().numpy()).astype(np.uint8) | ||
|
||
out_img = np.concatenate((image_np, disp_np), 0) | ||
cv2.putText( | ||
out_img, | ||
"%.1f fps" % (avg_fps), | ||
(10, image_np.shape[0]+30), | ||
cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2, cv2.LINE_AA) | ||
cv2.imshow('img', out_img) | ||
cv2.waitKey(1) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
import cv2 | ||
import numpy as np | ||
|
||
import torch | ||
from torch.utils.data import DataLoader | ||
|
||
from ruamel.yaml import YAML | ||
|
||
from dataloaders import KITTIRawLoader as KRL | ||
|
||
torch.backends.cudnn.benchmark = True | ||
|
||
|
||
torch.set_grad_enabled(False) | ||
|
||
config = 'cfg_coex.yaml' | ||
|
||
vid_date = "2011_09_26" | ||
vid_num = '0093' | ||
half_precision = True | ||
|
||
|
||
def load_configs(path): | ||
cfg = YAML().load(open(path, 'r')) | ||
backbone_cfg = YAML().load( | ||
open(cfg['model']['stereo']['backbone']['cfg_path'], 'r')) | ||
cfg['model']['stereo']['backbone'].update(backbone_cfg) | ||
return cfg | ||
|
||
|
||
if __name__ == '__main__': | ||
cfg = load_configs( | ||
'./configs/stereo/{}'.format(config)) | ||
stereo = torch.jit.load('zoo/torchscript/CoEx.pt') | ||
|
||
left_cam, right_cam = KRL.listfiles( | ||
cfg, | ||
vid_date, | ||
vid_num, | ||
True) | ||
cfg['training']['th'] = 0 | ||
cfg['training']['tw'] = 0 | ||
kitti_train = KRL.ImageLoader( | ||
left_cam, right_cam, cfg, training=True, demo=True) | ||
kitti_train = DataLoader( | ||
kitti_train, batch_size=1, | ||
num_workers=4, shuffle=False, drop_last=False) | ||
|
||
fps_list = np.array([]) | ||
|
||
stereo.eval() | ||
for i, batch in enumerate(kitti_train): | ||
start = torch.cuda.Event(enable_timing=True) | ||
end = torch.cuda.Event(enable_timing=True) | ||
start.record() | ||
|
||
imgL, imgR = batch['imgL'].cuda(), batch['imgR'].cuda() | ||
imgLRaw = batch['imgLRaw'] | ||
imgLRaw = imgLRaw.cuda() | ||
|
||
end.record() | ||
torch.cuda.synchronize() | ||
runtime = start.elapsed_time(end) | ||
print('Data Preparation: {:.3f}'.format(runtime)) | ||
|
||
start = torch.cuda.Event(enable_timing=True) | ||
end = torch.cuda.Event(enable_timing=True) | ||
start.record() | ||
|
||
img = torch.cat((imgL, imgR), 1) | ||
disp = stereo(img) | ||
|
||
end.record() | ||
torch.cuda.synchronize() | ||
runtime = start.elapsed_time(end) | ||
# print('Stereo runtime: {:.3f}'.format(runtime)) | ||
|
||
fps = 1000/runtime | ||
fps_list = np.append(fps_list, fps) | ||
if len(fps_list) > 5: | ||
fps_list = fps_list[-5:] | ||
avg_fps = np.mean(fps_list) | ||
print('Stereo runtime: {:.3f}'.format(1000/avg_fps)) | ||
|
||
disp_np = (2*disp[0]).data.cpu().numpy().astype(np.uint8) | ||
disp_np = cv2.applyColorMap(disp_np, cv2.COLORMAP_PLASMA) | ||
|
||
image_np = (imgLRaw[0].permute(1, 2, 0).data.cpu().numpy()).astype(np.uint8) | ||
|
||
out_img = np.concatenate((image_np, disp_np), 0) | ||
cv2.putText( | ||
out_img, | ||
"%.1f fps" % (avg_fps), | ||
(10, image_np.shape[0]+30), | ||
cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2, cv2.LINE_AA) | ||
cv2.imshow('img', out_img) | ||
cv2.waitKey(1) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
name: coex | ||
channels: | ||
- pytorch | ||
- nvidia | ||
- conda-forge | ||
- anaconda | ||
- defaults | ||
dependencies: | ||
- python=3.7 | ||
- pip | ||
- numpy | ||
- pytorch==1.11.0 | ||
- torchvision==0.12.0 | ||
- cudatoolkit=11.3 | ||
- ruamel.yaml | ||
- pillow | ||
- scikit-image | ||
- pip: | ||
- pytorch-lightning==1.6.5 | ||
- opencv-contrib-python | ||
- albumentations | ||
- timm==0.6.5 | ||
- test-tube | ||
# - --find-links https://github.com/pytorch/TensorRT/releases | ||
# - torch-tensorrt |
Oops, something went wrong.