Skip to content

Commit

Permalink
Torch-TensorRT attempt
Browse files Browse the repository at this point in the history
  • Loading branch information
antabangun committed Jul 27, 2022
1 parent 4f224a8 commit 08baf16
Show file tree
Hide file tree
Showing 35 changed files with 737 additions and 31 deletions.
Binary file modified dataloaders/__pycache__/__init__.cpython-37.pyc
Binary file not shown.
Binary file modified dataloaders/stereo/__pycache__/KITTILoader.cpython-37.pyc
Binary file not shown.
Binary file modified dataloaders/stereo/__pycache__/KITTIRawLoader.cpython-37.pyc
Binary file not shown.
Binary file not shown.
Binary file modified dataloaders/stereo/__pycache__/KITTIloader2012.cpython-37.pyc
Binary file not shown.
Binary file modified dataloaders/stereo/__pycache__/KITTIloader2015.cpython-37.pyc
Binary file not shown.
Binary file modified dataloaders/stereo/__pycache__/SceneFlowLoader.cpython-37.pyc
Binary file not shown.
Binary file modified dataloaders/stereo/__pycache__/__init__.cpython-37.pyc
Binary file not shown.
Binary file modified dataloaders/stereo/__pycache__/listflowfile.cpython-37.pyc
Binary file not shown.
Binary file modified dataloaders/stereo/__pycache__/preprocess.cpython-37.pyc
Binary file not shown.
Binary file modified dataloaders/stereo/__pycache__/readpfm.cpython-37.pyc
Binary file not shown.
Binary file not shown.
Binary file modified dataloaders/stereo/__pycache__/transforms.cpython-37.pyc
Binary file not shown.
5 changes: 3 additions & 2 deletions demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,8 @@ def load_configs(path):
start.record()
with torch.no_grad():
with torch.cuda.amp.autocast(enabled=half_precision):
disp = pose_ssstereo(imgL, imgR, False)
img = torch.cat([imgL, imgR], 0)
disp = pose_ssstereo(img, training=False)
end.record()
torch.cuda.synchronize()
runtime = start.elapsed_time(end)
Expand All @@ -91,7 +92,7 @@ def load_configs(path):
print('Stereo runtime: {:.3f}'.format(1000/avg_fps))

disp_np = (2*disp[0]).data.cpu().numpy().astype(np.uint8)
disp_np = cv2.applyColorMap(disp_np, cv2.COLORMAP_PLASMA)
disp_np = cv2.applyColorMap(disp_np, cv2.COLORMAP_MAGMA)

image_np = (imgLRaw[0].permute(1, 2, 0).data.cpu().numpy()).astype(np.uint8)

Expand Down
30 changes: 30 additions & 0 deletions demo_tensorrt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import cv2
import numpy as np

import torch
from torch.utils.data import DataLoader

from ruamel.yaml import YAML

from dataloaders import KITTIRawLoader as KRL

import torch_tensorrt


config = 'cfg_coex.yaml'
vid_date = "2011_09_26"
vid_num = '0093'


def load_configs(path):
cfg = YAML().load(open(path, 'r'))
backbone_cfg = YAML().load(
open(cfg['model']['stereo']['backbone']['cfg_path'], 'r'))
cfg['model']['stereo']['backbone'].update(backbone_cfg)
return cfg


if __name__ == '__main__':
cfg = load_configs(
'./configs/stereo/{}'.format(config))
stereo = torch.jit.load('zoo/tensorrt/trt_ts_module.ts')
14 changes: 8 additions & 6 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,17 @@ dependencies:
- python=3.7
- pip
- numpy
- pytorch=1.9.0
- torchvision=0.10.0
- cudatoolkit=11.1
- pytorch-lightning=1.4.0
- pytorch==1.11.0
- torchvision==0.12.0
- cudatoolkit=11.3
- ruamel.yaml
- pillow
- scikit-image
- pip:
- pytorch-lightning
- opencv-contrib-python
- albumentations
- timm
- test-tube
- timm==0.6.5
- test-tube
- --find-links https://github.com/pytorch/TensorRT/releases
- torch-tensorrt
Binary file modified models/__pycache__/__init__.cpython-37.pyc
Binary file not shown.
33 changes: 19 additions & 14 deletions models/stereo/CoEx.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,21 +82,26 @@ def __init__(self, cfg):
nn.BatchNorm2d(chans[1]), nn.ReLU()
)

def forward(self, imL, imR, u0=None, v0=None, training=False):
assert imL.shape == imR.shape

def forward(self, imL, imR=None, u0=None, v0=None, training=False):
if imR is not None:
assert imL.shape == imR.shape
imL = torch.cat([imL, imR], 0)

b, c, h, w = imL.shape

# # Matching comp
x2, x = self.feature(imL)
y2, y = self.feature(imR)

x, y = self.up(x, y)

stem_2x = self.stem_2(imL)
stem_4x = self.stem_4(stem_2x)
stem_2y = self.stem_2(imR)
stem_4y = self.stem_4(stem_2y)
v2, v = self.feature(imL)
x2, y2 = v2.split(dim=0, split_size=b//2)

v = self.up(v)
x, y = [], []
for v_ in v:
x_, y_ = v_.split(dim=0, split_size=b//2)
x.append(x_)
y.append(y_)

stem_2v = self.stem_2(imL)
stem_4v = self.stem_4(stem_2v)
stem_2x, stem_2y = stem_2v.split(dim=0, split_size=b//2)
stem_4x, stem_4y = stem_4v.split(dim=0, split_size=b//2)

x[0] = torch.cat((x[0], stem_4x), 1)
y[0] = torch.cat((y[0], stem_4y), 1)
Expand Down
Loading

0 comments on commit 08baf16

Please sign in to comment.