Skip to content

Commit

Permalink
[Feature] Support inference of DSVT in projects (open-mmlab#2606)
Browse files Browse the repository at this point in the history
* support inference

* align inference precision

* add readme

* polish docs

* polish docs
  • Loading branch information
JingweiZhang12 authored Jun 16, 2023
1 parent 456b740 commit 8fb2cf6
Show file tree
Hide file tree
Showing 18 changed files with 2,549 additions and 1 deletion.
10 changes: 10 additions & 0 deletions mmdet3d/datasets/transforms/loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -579,6 +579,8 @@ class LoadPointsFromFile(BaseTransform):
use_color (bool): Whether to use color features. Defaults to False.
norm_intensity (bool): Whether to normlize the intensity. Defaults to
False.
norm_elongation (bool): Whether to normlize the elongation. This is
usually used in Waymo dataset.Defaults to False.
backend_args (dict, optional): Arguments to instantiate the
corresponding backend. Defaults to None.
"""
Expand All @@ -590,6 +592,7 @@ def __init__(self,
shift_height: bool = False,
use_color: bool = False,
norm_intensity: bool = False,
norm_elongation: bool = False,
backend_args: Optional[dict] = None) -> None:
self.shift_height = shift_height
self.use_color = use_color
Expand All @@ -603,6 +606,7 @@ def __init__(self,
self.load_dim = load_dim
self.use_dim = use_dim
self.norm_intensity = norm_intensity
self.norm_elongation = norm_elongation
self.backend_args = backend_args

def _load_points(self, pts_filename: str) -> np.ndarray:
Expand Down Expand Up @@ -646,6 +650,10 @@ def transform(self, results: dict) -> dict:
assert len(self.use_dim) >= 4, \
f'When using intensity norm, expect used dimensions >= 4, got {len(self.use_dim)}' # noqa: E501
points[:, 3] = np.tanh(points[:, 3])
if self.norm_elongation:
assert len(self.use_dim) >= 5, \
f'When using elongation norm, expect used dimensions >= 5, got {len(self.use_dim)}' # noqa: E501
points[:, 4] = np.tanh(points[:, 4])
attribute_dims = None

if self.shift_height:
Expand Down Expand Up @@ -682,6 +690,8 @@ def __repr__(self) -> str:
repr_str += f'backend_args={self.backend_args}, '
repr_str += f'load_dim={self.load_dim}, '
repr_str += f'use_dim={self.use_dim})'
repr_str += f'norm_intensity={self.norm_intensity})'
repr_str += f'norm_elongation={self.norm_elongation})'
return repr_str


Expand Down
3 changes: 2 additions & 1 deletion mmdet3d/models/necks/second_fpn.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,8 @@ def forward(self, x):
"""Forward function.
Args:
x (torch.Tensor): 4D Tensor in (N, C, H, W) shape.
x (List[torch.Tensor]): Multi-level features with 4D Tensor in
(N, C, H, W) shape.
Returns:
list[torch.Tensor]: Multi-level feature maps.
Expand Down
81 changes: 81 additions & 0 deletions projects/DSVT/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
# DSVT: Dynamic Sparse Voxel Transformer with Rotated Sets

> [DSVT: Dynamic Sparse Voxel Transformer with Rotated Sets](https://arxiv.org/abs/2301.06051)
<!-- [ALGORITHM] -->

## Abstract

Designing an efficient yet deployment-friendly 3D backbone to handle sparse point clouds is a fundamental problem
in 3D perception. Compared with the customized sparse
convolution, the attention mechanism in Transformers is
more appropriate for flexibly modeling long-range relationships and is easier to be deployed in real-world applications.
However, due to the sparse characteristics of point clouds,
it is non-trivial to apply a standard transformer on sparse
points. In this paper, we present Dynamic Sparse Voxel
Transformer (DSVT), a single-stride window-based voxel
Transformer backbone for outdoor 3D perception. In order
to efficiently process sparse points in parallel, we propose
Dynamic Sparse Window Attention, which partitions a series
of local regions in each window according to its sparsity
and then computes the features of all regions in a fully parallel manner. To allow the cross-set connection, we design
a rotated set partitioning strategy that alternates between
two partitioning configurations in consecutive self-attention
layers. To support effective downsampling and better encode geometric information, we also propose an attentionstyle 3D pooling module on sparse points, which is powerful
and deployment-friendly without utilizing any customized
CUDA operations. Our model achieves state-of-the-art performance with a broad range of 3D perception tasks. More
importantly, DSVT can be easily deployed by TensorRT with
real-time inference speed (27Hz). Code will be available at
https://github.com/Haiyang-W/DSVT.

<div align=center>
<img src="https://github-production-user-asset-6210df.s3.amazonaws.com/34888372/245692705-e61be20c-2a7d-4ab9-85e3-b36f662c1bdf.png" width="800"/>
</div>

## Introduction

We implement DSVT and provide the results on Waymo dataset.

## Usage

<!-- For a typical model, this section should contain the commands for training and testing. You are also suggested to dump your environment specification to env.yml by `conda env export > env.yml`. -->

### Installation

```shell
pip install torch_scatter==2.0.9
python projects/DSVT/setup.py develop # compile `ingroup_inds_op` cuda operation
```

### Testing commands

In MMDetection3D's root directory, run the following command to test the model:

```bash
python tools/test.py projects/DSVT/configs/dsvt_voxel032_res-second_secfpn_8xb1-cyclic-12e_waymoD5-3d-3class.py ${CHECKPOINT_PATH}
```

### Training commands

The support of training DSVT is on the way.

## Results and models

### Waymo

| Middle Encoder | Backbone | Load Interval | Voxel type (voxel size) | Multi-Class NMS | Multi-frames | Mem (GB) | Inf time (fps) | mAP@L1 | mAPH@L1 | mAP@L2 | **mAPH@L2** | Download |
| :------------------------------------------------------------------------------------: | :-----------------------------------------------------------------------------------------: | :-----------: | :---------------------: | :-------------: | :----------: | :------: | :------------: | :----: | :-----: | :----: | :---------: | :------: |
| [DSVT](./configs/dsvt_voxel032_res-second_secfpn_8xb1-cyclic-12e_waymoD5-3d-3class.py) | [ResSECOND](./configs/dsvt_voxel032_res-second_secfpn_8xb1-cyclic-12e_waymoD5-3d-3class.py) | 5 | voxel (0.32) || × | | | 75.2 | 72.2 | 68.9 | 66.1 | |

**Note** that `ResSECOND` denotes the base block in SECOND has residual layers.

## Citation

```latex
@inproceedings{wang2023dsvt,
title={DSVT: Dynamic Sparse Voxel Transformer with Rotated Sets},
author={Haiyang Wang, Chen Shi, Shaoshuai Shi, Meng Lei, Sen Wang, Di He, Bernt Schiele and Liwei Wang},
booktitle={CVPR},
year={2023}
}
```
Original file line number Diff line number Diff line change
@@ -0,0 +1,239 @@
_base_ = ['../../../configs/_base_/default_runtime.py']
custom_imports = dict(
imports=['projects.DSVT.dsvt'], allow_failed_imports=False)

voxel_size = [0.32, 0.32, 6]
grid_size = [468, 468, 1]
point_cloud_range = [-74.88, -74.88, -2, 74.88, 74.88, 4.0]
data_root = 'data/waymo/kitti_format/'
class_names = ['Car', 'Pedestrian', 'Cyclist']
metainfo = dict(classes=class_names)
input_modality = dict(use_lidar=True, use_camera=False)
backend_args = None

model = dict(
type='DSVT',
data_preprocessor=dict(type='Det3DDataPreprocessor', voxel=False),
voxel_encoder=dict(
type='DynamicPillarVFE3D',
with_distance=False,
use_absolute_xyz=True,
use_norm=True,
num_filters=[192, 192],
num_point_features=5,
voxel_size=voxel_size,
grid_size=grid_size,
point_cloud_range=point_cloud_range),
middle_encoder=dict(
type='DSVTMiddleEncoder',
input_layer=dict(
sparse_shape=grid_size,
downsample_stride=[],
dim_model=[192],
set_info=[[36, 4]],
window_shape=[[12, 12, 1]],
hybrid_factor=[2, 2, 1], # x, y, z
shift_list=[[[0, 0, 0], [6, 6, 0]]],
normalize_pos=False),
set_info=[[36, 4]],
dim_model=[192],
dim_feedforward=[384],
stage_num=1,
nhead=[8],
conv_out_channel=192,
output_shape=[468, 468],
dropout=0.,
activation='gelu'),
map2bev=dict(
type='PointPillarsScatter3D',
output_shape=grid_size,
num_bev_feats=192),
backbone=dict(
type='ResSECOND',
in_channels=192,
out_channels=[128, 128, 256],
blocks_nums=[1, 2, 2],
layer_strides=[1, 2, 2]),
neck=dict(
type='SECONDFPN',
in_channels=[128, 128, 256],
out_channels=[128, 128, 128],
upsample_strides=[1, 2, 4],
norm_cfg=dict(type='BN', eps=1e-3, momentum=0.01),
upsample_cfg=dict(type='deconv', bias=False),
use_conv_for_no_stride=False),
bbox_head=dict(
type='DSVTCenterHead',
in_channels=sum([128, 128, 128]),
tasks=[dict(num_class=3, class_names=class_names)],
common_heads=dict(
reg=(2, 2), height=(1, 2), dim=(3, 2), rot=(2, 2), iou=(1, 2)),
share_conv_channel=64,
conv_cfg=dict(type='Conv2d'),
norm_cfg=dict(type='BN2d', eps=1e-3, momentum=0.01),
bbox_coder=dict(
type='DSVTBBoxCoder',
pc_range=point_cloud_range,
max_num=500,
post_center_range=[-80, -80, -10.0, 80, 80, 10.0],
score_threshold=0.1,
out_size_factor=1,
voxel_size=voxel_size[:2],
code_size=7),
separate_head=dict(
type='SeparateHead',
init_bias=-2.19,
final_kernel=3,
norm_cfg=dict(type='BN2d', eps=1e-3, momentum=0.01)),
loss_cls=dict(
type='mmdet.GaussianFocalLoss', reduction='mean', loss_weight=1.0),
loss_bbox=dict(type='mmdet.L1Loss', reduction='mean', loss_weight=2.0),
norm_bbox=True),
# model training and testing settings
train_cfg=dict(
pts=dict(
grid_size=grid_size,
voxel_size=voxel_size,
out_size_factor=4,
dense_reg=1,
gaussian_overlap=0.1,
max_objs=500,
min_radius=2,
code_weights=[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0])),
test_cfg=dict(
max_per_img=500,
max_pool_nms=False,
min_radius=[4, 12, 10, 1, 0.85, 0.175],
iou_rectifier=[[0.68, 0.71, 0.65]],
pc_range=[-80, -80],
out_size_factor=4,
voxel_size=voxel_size[:2],
nms_type='rotate',
multi_class_nms=True,
pre_max_size=[[4096, 4096, 4096]],
post_max_size=[[500, 500, 500]],
nms_thr=[[0.7, 0.6, 0.55]]))

db_sampler = dict(
data_root=data_root,
info_path=data_root + 'waymo_dbinfos_train.pkl',
rate=1.0,
prepare=dict(
filter_by_difficulty=[-1],
filter_by_min_points=dict(Car=5, Pedestrian=5, Cyclist=5)),
classes=class_names,
sample_groups=dict(Car=15, Pedestrian=10, Cyclist=10),
points_loader=dict(
type='LoadPointsFromFile',
coord_type='LIDAR',
load_dim=6,
use_dim=[0, 1, 2, 3, 4],
backend_args=backend_args),
backend_args=backend_args)

train_pipeline = [
dict(
type='LoadPointsFromFile',
coord_type='LIDAR',
load_dim=6,
use_dim=5,
norm_intensity=True,
backend_args=backend_args),
# Add this if using `MultiFrameDeformableDecoderRPN`
# dict(
# type='LoadPointsFromMultiSweeps',
# sweeps_num=9,
# load_dim=6,
# use_dim=[0, 1, 2, 3, 4],
# pad_empty_sweeps=True,
# remove_close=True),
dict(type='LoadAnnotations3D', with_bbox_3d=True, with_label_3d=True),
dict(type='ObjectSample', db_sampler=db_sampler),
dict(
type='GlobalRotScaleTrans',
rot_range=[-0.78539816, 0.78539816],
scale_ratio_range=[0.95, 1.05],
translation_std=[0.5, 0.5, 0]),
dict(type='PointsRangeFilter', point_cloud_range=point_cloud_range),
dict(type='ObjectRangeFilter', point_cloud_range=point_cloud_range),
dict(type='ObjectNameFilter', classes=class_names),
dict(type='PointShuffle'),
dict(
type='Pack3DDetInputs',
keys=['points', 'gt_bboxes_3d', 'gt_labels_3d'])
]

test_pipeline = [
dict(
type='LoadPointsFromFile',
coord_type='LIDAR',
load_dim=6,
use_dim=5,
norm_intensity=True,
norm_elongation=True,
backend_args=backend_args),
dict(
type='MultiScaleFlipAug3D',
img_scale=(1333, 800),
pts_scale_ratio=1,
flip=False,
transforms=[
dict(
type='GlobalRotScaleTrans',
rot_range=[0, 0],
scale_ratio_range=[1., 1.],
translation_std=[0, 0, 0]),
dict(type='RandomFlip3D'),
dict(
type='PointsRangeFilter', point_cloud_range=point_cloud_range)
]),
dict(type='Pack3DDetInputs', keys=['points'])
]

dataset_type = 'WaymoDataset'
val_dataloader = dict(
batch_size=4,
num_workers=4,
persistent_workers=True,
drop_last=False,
sampler=dict(type='DefaultSampler', shuffle=False),
dataset=dict(
type=dataset_type,
data_root=data_root,
data_prefix=dict(pts='training/velodyne', sweeps='training/velodyne'),
ann_file='waymo_infos_val.pkl',
pipeline=test_pipeline,
modality=input_modality,
test_mode=True,
metainfo=metainfo,
box_type_3d='LiDAR',
backend_args=backend_args))
test_dataloader = val_dataloader

val_evaluator = dict(
type='WaymoMetric',
ann_file='./data/waymo/kitti_format/waymo_infos_val.pkl',
waymo_bin_file='./data/waymo/waymo_format/gt.bin',
data_root='./data/waymo/waymo_format',
backend_args=backend_args,
convert_kitti_format=False,
idx2metainfo='./data/waymo/waymo_format/idx2metainfo.pkl')
test_evaluator = val_evaluator

vis_backends = [dict(type='LocalVisBackend')]
visualizer = dict(
type='Det3DLocalVisualizer', vis_backends=vis_backends, name='visualizer')

# runtime settings
val_cfg = dict()
test_cfg = dict()

# Default setting for scaling LR automatically
# - `enable` means enable scaling LR automatically
# or not by default.
# - `base_batch_size` = (8 GPUs) x (1 samples per GPU).
# auto_scale_lr = dict(enable=False, base_batch_size=8)

default_hooks = dict(
logger=dict(type='LoggerHook', interval=50),
checkpoint=dict(type='CheckpointHook', interval=5))
12 changes: 12 additions & 0 deletions projects/DSVT/dsvt/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from .dsvt import DSVT
from .dsvt_head import DSVTCenterHead
from .dsvt_transformer import DSVTMiddleEncoder
from .dynamic_pillar_vfe import DynamicPillarVFE3D
from .map2bev import PointPillarsScatter3D
from .res_second import ResSECOND
from .utils import DSVTBBoxCoder

__all__ = [
'DSVTCenterHead', 'DSVT', 'DSVTMiddleEncoder', 'DynamicPillarVFE3D',
'PointPillarsScatter3D', 'ResSECOND', 'DSVTBBoxCoder'
]
Loading

0 comments on commit 8fb2cf6

Please sign in to comment.