-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit 50f2213
Showing
137 changed files
with
16,769 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
build/ | ||
checkpoint/ | ||
*.egg-info/ | ||
output/ | ||
checkpoints/ | ||
__pycache__ | ||
*.so |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,114 @@ | ||
# 3D Cascade RCNN | ||
|
||
This is the implementation of **3D Cascade RCNN: High Quality Object Detection in Point Clouds**. | ||
<p align="center"> | ||
<img src="figures/framework.png" width="80%" height="80%"> | ||
</p> | ||
|
||
We designed a 3D object detection model on point clouds by: | ||
* Presenting a simple yet effective 3D cascade architecture | ||
* Analyzing the sparsity of the point clouds use point completeness score to re-weighting training samples. | ||
Following is detection results on Waymo Open Dataset. | ||
<p align="center"> | ||
<img src="figures/waymo_scene_1.gif" width="20%" height="20%"> | ||
<img src="figures/waymo_scene_2.gif" width="20%" height="20%"> | ||
<img src="figures/waymo_scene_3.gif" width="20%" height="20%"> | ||
</p> | ||
<p align="center"> | ||
<img src="figures/waymo_scene_4.gif" width="20%" height="20%"> | ||
<img src="figures/waymo_scene_5.gif" width="20%" height="20%"> | ||
<img src="figures/waymo_scene_6.gif" width="20%" height="20%"> | ||
</p> | ||
|
||
## Results on KITTI | ||
|
||
| | Easy Car | Moderate Car | Hard Car | | ||
| ----- | -------: | :----------: | :------: | | ||
| AP 11 | 90.05 | 86.02 | 79.27 | | ||
| AP 40 | 93.20 | 86.19 | 83.48 | | ||
|
||
|
||
## Results on Waymo | ||
|
||
| | Overall Vehicle | 0-30m Vehicle | 30-50m Vehicle | 50m-Inf Vehicle | | ||
| ------------- | --------------: | :-----------: | :------------: | :-------------: | | ||
| *LEVEL_1 mAP* | 76.27 | 92.66 | 74.99 | 54.49 | | ||
| *LEVEL_2 mAP* | 67.12 | 91.95 | 68.96 | 41.82 | | ||
|
||
## Installation | ||
1. Requirements. | ||
The code is tested on the following environment: | ||
* Ubuntu 16.04 with 4 V100 GPUs | ||
* Python 3.7 | ||
* Pytorch 1.7 | ||
* CUDA 10.1 | ||
* spconv 1.2.1 | ||
|
||
2. Build extensions | ||
``` | ||
python setup.py develop | ||
``` | ||
|
||
## Getting Started | ||
|
||
### Prepare for the data. | ||
|
||
Please download the official [KITTI](http://www.cvlibs.net/datasets/kitti/eval_object.php?obj_benchmark=3d) dataset and generate data infos by following command: | ||
``` | ||
python -m pcdet.datasets.kitti.kitti_dataset create_kitti_infos tools/cfgs/kitti_dataset.yaml | ||
``` | ||
The folder should be like: | ||
``` | ||
data | ||
├── kitti | ||
│ │── ImageSets | ||
│ │── training | ||
│ │ ├──calib & velodyne & label_2 & image_2 | ||
│ │── testing | ||
│ │ ├──calib & velodyne & image_2 | ||
| |── kitti_dbinfos_train.pkl | ||
| |── kitti_infos_train.pkl | ||
| |── kitti_infos_val.pkl | ||
``` | ||
|
||
### Training and evaluation. | ||
|
||
The configuration file is in tools/cfgs/3d_cascade_rcnn.yaml, and the training scripts is in tools/scripts. | ||
``` | ||
cd tools | ||
sh scripts/3d-cascade-rcnn.sh | ||
``` | ||
|
||
### Test a pre-trained model | ||
|
||
The pre-trained KITTI model is at: [model](https://drive.google.com/file/d/1IEDjt02hUSKJy49yCofqjyA2aDcPsX8v/view?usp=sharing). Run with: | ||
``` | ||
cd tools | ||
sh scripts/3d-cascade-rcnn_test.sh | ||
``` | ||
The evaluation results should be like: | ||
``` | ||
2021-08-10 14:06:14,608 INFO Car [email protected], 0.70, 0.70: | ||
bbox AP:97.9644, 90.1199, 89.7076 | ||
bev AP:90.6405, 89.0829, 88.4391 | ||
3d AP:90.0468, 86.0168, 79.2661 | ||
aos AP:97.91, 90.00, 89.48 | ||
Car [email protected], 0.70, 0.70: | ||
bbox AP:99.1663, 95.8055, 93.3149 | ||
bev AP:96.3107, 92.4128, 89.9473 | ||
3d AP:93.1961, 86.1857, 83.4783 | ||
aos AP:99.13, 95.65, 93.03 | ||
Car [email protected], 0.50, 0.50: | ||
bbox AP:97.9644, 90.1199, 89.7076 | ||
bev AP:98.0539, 97.1877, 89.7716 | ||
3d AP:97.9921, 90.1001, 89.7393 | ||
aos AP:97.91, 90.00, 89.48 | ||
Car [email protected], 0.50, 0.50: | ||
bbox AP:99.1663, 95.8055, 93.3149 | ||
bev AP:99.1943, 97.8180, 95.5420 | ||
3d AP:99.1717, 95.8046, 95.4500 | ||
aos AP:99.13, 95.65, 93.03 | ||
``` | ||
|
||
## Acknowledge | ||
The code is built on [`OpenPCDet`](https://github.com/open-mmlab/OpenPCDet) and [`Voxel R-CNN`](https://github.com/djiajunustc/Voxel-R-CNN). |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
import subprocess | ||
from pathlib import Path | ||
|
||
from .version import __version__ | ||
|
||
__all__ = [ | ||
'__version__' | ||
] | ||
|
||
|
||
def get_git_commit_number(): | ||
if not (Path(__file__).parent / '../.git').exists(): | ||
return '0000000' | ||
|
||
cmd_out = subprocess.run(['git', 'rev-parse', 'HEAD'], stdout=subprocess.PIPE) | ||
git_commit_number = cmd_out.stdout.decode('utf-8')[:7] | ||
return git_commit_number | ||
|
||
|
||
script_version = get_git_commit_number() | ||
|
||
|
||
if script_version not in __version__: | ||
__version__ = __version__ + '+py%s' % script_version |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
from pathlib import Path | ||
|
||
import yaml | ||
from easydict import EasyDict | ||
|
||
|
||
def log_config_to_file(cfg, pre='cfg', logger=None): | ||
for key, val in cfg.items(): | ||
if isinstance(cfg[key], EasyDict): | ||
logger.info('\n%s.%s = edict()' % (pre, key)) | ||
log_config_to_file(cfg[key], pre=pre + '.' + key, logger=logger) | ||
continue | ||
logger.info('%s.%s: %s' % (pre, key, val)) | ||
|
||
|
||
def cfg_from_list(cfg_list, config): | ||
"""Set config keys via list (e.g., from command line).""" | ||
from ast import literal_eval | ||
assert len(cfg_list) % 2 == 0 | ||
for k, v in zip(cfg_list[0::2], cfg_list[1::2]): | ||
key_list = k.split('.') | ||
d = config | ||
for subkey in key_list[:-1]: | ||
assert subkey in d, 'NotFoundKey: %s' % subkey | ||
d = d[subkey] | ||
subkey = key_list[-1] | ||
assert subkey in d, 'NotFoundKey: %s' % subkey | ||
try: | ||
value = literal_eval(v) | ||
except: | ||
value = v | ||
|
||
if type(value) != type(d[subkey]) and isinstance(d[subkey], EasyDict): | ||
key_val_list = value.split(',') | ||
for src in key_val_list: | ||
cur_key, cur_val = src.split(':') | ||
val_type = type(d[subkey][cur_key]) | ||
cur_val = val_type(cur_val) | ||
d[subkey][cur_key] = cur_val | ||
elif type(value) != type(d[subkey]) and isinstance(d[subkey], list): | ||
val_list = value.split(',') | ||
for k, x in enumerate(val_list): | ||
val_list[k] = type(d[subkey][0])(x) | ||
d[subkey] = val_list | ||
else: | ||
assert type(value) == type(d[subkey]), \ | ||
'type {} does not match original type {}'.format(type(value), type(d[subkey])) | ||
d[subkey] = value | ||
|
||
|
||
def merge_new_config(config, new_config): | ||
if '_BASE_CONFIG_' in new_config: | ||
with open(new_config['_BASE_CONFIG_'], 'r') as f: | ||
try: | ||
yaml_config = yaml.load(f, Loader=yaml.FullLoader) | ||
except: | ||
yaml_config = yaml.load(f) | ||
config.update(EasyDict(yaml_config)) | ||
|
||
for key, val in new_config.items(): | ||
if not isinstance(val, dict): | ||
config[key] = val | ||
continue | ||
if key not in config: | ||
config[key] = EasyDict() | ||
merge_new_config(config[key], val) | ||
|
||
return config | ||
|
||
|
||
def cfg_from_yaml_file(cfg_file, config): | ||
with open(cfg_file, 'r') as f: | ||
try: | ||
new_config = yaml.load(f, Loader=yaml.FullLoader) | ||
except: | ||
new_config = yaml.load(f) | ||
|
||
merge_new_config(config=config, new_config=new_config) | ||
|
||
return config | ||
|
||
|
||
cfg = EasyDict() | ||
cfg.ROOT_DIR = (Path(__file__).resolve().parent / '../').resolve() | ||
cfg.LOCAL_RANK = 0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
import torch | ||
from torch.utils.data import DataLoader | ||
from torch.utils.data import DistributedSampler as _DistributedSampler | ||
|
||
from pcdet.utils import common_utils | ||
|
||
from .dataset import DatasetTemplate | ||
from .kitti.kitti_dataset import KittiDataset | ||
|
||
__all__ = { | ||
'DatasetTemplate': DatasetTemplate, | ||
'KittiDataset': KittiDataset, | ||
} | ||
|
||
|
||
class DistributedSampler(_DistributedSampler): | ||
|
||
def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True): | ||
super().__init__(dataset, num_replicas=num_replicas, rank=rank) | ||
self.shuffle = shuffle | ||
|
||
def __iter__(self): | ||
if self.shuffle: | ||
g = torch.Generator() | ||
g.manual_seed(self.epoch) | ||
indices = torch.randperm(len(self.dataset), generator=g).tolist() | ||
else: | ||
indices = torch.arange(len(self.dataset)).tolist() | ||
|
||
indices += indices[:(self.total_size - len(indices))] | ||
assert len(indices) == self.total_size | ||
|
||
indices = indices[self.rank:self.total_size:self.num_replicas] | ||
assert len(indices) == self.num_samples | ||
|
||
return iter(indices) | ||
|
||
|
||
def build_dataloader(dataset_cfg, class_names, batch_size, dist, root_path=None, workers=4, | ||
logger=None, training=True, merge_all_iters_to_one_epoch=False, total_epochs=0, shuffle=None): | ||
dataset = __all__[dataset_cfg.DATASET]( | ||
dataset_cfg=dataset_cfg, | ||
class_names=class_names, | ||
root_path=root_path, | ||
training=training, | ||
logger=logger, | ||
) | ||
|
||
if merge_all_iters_to_one_epoch: | ||
assert hasattr(dataset, 'merge_all_iters_to_one_epoch') | ||
dataset.merge_all_iters_to_one_epoch(merge=True, epochs=total_epochs) | ||
|
||
if dist: | ||
if training: | ||
sampler = torch.utils.data.distributed.DistributedSampler(dataset) | ||
else: | ||
rank, world_size = common_utils.get_dist_info() | ||
sampler = DistributedSampler(dataset, world_size, rank, shuffle=False) | ||
else: | ||
sampler = None | ||
if shuffle is None: | ||
shuffle = (sampler is None) and training | ||
dataloader = DataLoader( | ||
dataset, batch_size=batch_size, pin_memory=True, num_workers=workers, | ||
shuffle=shuffle, collate_fn=dataset.collate_batch, | ||
drop_last=False, sampler=sampler, timeout=0 | ||
) | ||
|
||
return dataset, dataloader, sampler |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
import numpy as np | ||
|
||
from ...utils import common_utils | ||
|
||
|
||
def random_flip_along_x(gt_boxes, points): | ||
""" | ||
Args: | ||
gt_boxes: (N, 7 + C), [x, y, z, dx, dy, dz, heading, [vx], [vy]] | ||
points: (M, 3 + C) | ||
Returns: | ||
""" | ||
enable = np.random.choice([False, True], replace=False, p=[0.5, 0.5]) | ||
if enable: | ||
gt_boxes[:, 1] = -gt_boxes[:, 1] | ||
gt_boxes[:, 6] = -gt_boxes[:, 6] | ||
points[:, 1] = -points[:, 1] | ||
|
||
if gt_boxes.shape[1] > 7: | ||
gt_boxes[:, 8] = -gt_boxes[:, 8] | ||
|
||
return gt_boxes, points | ||
|
||
|
||
def random_flip_along_y(gt_boxes, points): | ||
""" | ||
Args: | ||
gt_boxes: (N, 7 + C), [x, y, z, dx, dy, dz, heading, [vx], [vy]] | ||
points: (M, 3 + C) | ||
Returns: | ||
""" | ||
enable = np.random.choice([False, True], replace=False, p=[0.5, 0.5]) | ||
if enable: | ||
gt_boxes[:, 0] = -gt_boxes[:, 0] | ||
gt_boxes[:, 6] = -(gt_boxes[:, 6] + np.pi) | ||
points[:, 0] = -points[:, 0] | ||
|
||
if gt_boxes.shape[1] > 7: | ||
gt_boxes[:, 7] = -gt_boxes[:, 7] | ||
|
||
return gt_boxes, points | ||
|
||
|
||
def global_rotation(gt_boxes, points, rot_range): | ||
""" | ||
Args: | ||
gt_boxes: (N, 7 + C), [x, y, z, dx, dy, dz, heading, [vx], [vy]] | ||
points: (M, 3 + C), | ||
rot_range: [min, max] | ||
Returns: | ||
""" | ||
noise_rotation = np.random.uniform(rot_range[0], rot_range[1]) | ||
points = common_utils.rotate_points_along_z(points[np.newaxis, :, :], np.array([noise_rotation]))[0] | ||
gt_boxes[:, 0:3] = common_utils.rotate_points_along_z(gt_boxes[np.newaxis, :, 0:3], np.array([noise_rotation]))[0] | ||
gt_boxes[:, 6] += noise_rotation | ||
if gt_boxes.shape[1] > 7: | ||
gt_boxes[:, 7:9] = common_utils.rotate_points_along_z( | ||
np.hstack((gt_boxes[:, 7:9], np.zeros((gt_boxes.shape[0], 1))))[np.newaxis, :, :], | ||
np.array([noise_rotation]) | ||
)[0][:, 0:2] | ||
|
||
return gt_boxes, points | ||
|
||
|
||
def global_scaling(gt_boxes, points, scale_range): | ||
""" | ||
Args: | ||
gt_boxes: (N, 7), [x, y, z, dx, dy, dz, heading] | ||
points: (M, 3 + C), | ||
scale_range: [min, max] | ||
Returns: | ||
""" | ||
if scale_range[1] - scale_range[0] < 1e-3: | ||
return gt_boxes, points | ||
noise_scale = np.random.uniform(scale_range[0], scale_range[1]) | ||
points[:, :3] *= noise_scale | ||
gt_boxes[:, :6] *= noise_scale | ||
return gt_boxes, points |
Oops, something went wrong.