Skip to content

Commit

Permalink
[Feature] Fully support DiffusionDet in projects (open-mmlab#9768)
Browse files Browse the repository at this point in the history
  • Loading branch information
BIGWangYuDong authored Feb 21, 2023
1 parent c3513a1 commit 659f7a6
Show file tree
Hide file tree
Showing 4 changed files with 69 additions and 31 deletions.
37 changes: 28 additions & 9 deletions projects/DiffusionDet/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@ We give a table to compare the inference results on `ResNet50-500-proposals` bet
| [MMDetection](configs/diffusiondet_r50_fpn_500-proposals_1-step_crop-ms-480-800-450k_coco.py) (random seed) | 1 | 45.6~45.8 |
| [DiffusionDet](https://github.com/ShoufaChen/DiffusionDet/blob/main/configs/diffdet.coco.res50.yaml) (released results) | 4 | 46.1 |
| [DiffusionDet](https://github.com/ShoufaChen/DiffusionDet/blob/main/configs/diffdet.coco.res50.yaml) (seed=0) | 4 | 46.38 |
| [MMDetection](configs/diffusiondet_r50_fpn_500-proposals_4-steps_crop-ms-480-800-450k_coco.py) (seed=0) | 4 | 46.4 |
| [MMDetection](configs/diffusiondet_r50_fpn_500-proposals_4-steps_crop-ms-480-800-450k_coco.py) (random seed) | 4 | 46.2~46.4 |
| [MMDetection](configs/diffusiondet_r50_fpn_500-proposals_1-step_crop-ms-480-800-450k_coco.py) (seed=0) | 4 | 46.4 |
| [MMDetection](configs/diffusiondet_r50_fpn_500-proposals_1-step_crop-ms-480-800-450k_coco.py) (random seed) | 4 | 46.2~46.4 |

- `seed=0` means hard set seed before generating random boxes.
```python
Expand All @@ -60,25 +60,44 @@ We give a table to compare the inference results on `ResNet50-500-proposals` bet

### Training commands

MMDetection currently does not fully support training DiffusionDet.
In MMDetection's root directory, run the following command to train the model:

```bash
python tools/train.py projects/DiffusionDet/configs/diffusiondet_r50_fpn_500-proposals_1-step_crop-ms-480-800-450k_coco.py
```

For multi-gpu training, run:

```bash
python -m torch.distributed.launch --nnodes=1 --node_rank=0 --nproc_per_node=${NUM_GPUS} --master_port=29506 --master_addr="127.0.0.1" tools/train.py projects/DiffusionDet/configs/diffusiondet_r50_fpn_500-proposals_1-step_crop-ms-480-800-450k_coco.py
```

### Testing commands

In MMDetection's root directory, run the following command to test the model:

```bash
python tools/test.py projects/DiffusionDet/configs/${CONFIG_PATH} ${CHECKPOINT_PATH}
# for 1 step inference
# test command
python tools/test.py projects/DiffusionDet/configs/diffusiondet_r50_fpn_500-proposals_1-step_crop-ms-480-800-450k_coco.py ${CHECKPOINT_PATH}

# for 4 steps inference

# test command
python tools/test.py projects/DiffusionDet/configs/diffusiondet_r50_fpn_500-proposals_1-step_crop-ms-480-800-450k_coco.py ${CHECKPOINT_PATH} --cfg-options model.bbox_head.sampling_timesteps=4
```

**Note:** There is no difference between 1 step or 4 steps (or other multi-step) during training. Users can set different steps during inference through `--cfg-options model.bbox_head.sampling_timesteps=${STEPS}`, but larger `sampling_timesteps` will affect the inference time.

## Results

Here we provide the baseline version of DiffusionDet with ResNet50 backbone.

To find more variants, please visit the [official model zoo](https://github.com/ShoufaChen/DiffusionDet#models).

| Backbone | Style | Lr schd | Mem (GB) | FPS | AP | Config | Download |
| :------: | :-----: | :-----: | :------: | :-: | :-: | :----------: | :----------------------: |
| R-50 | PyTorch | | | | | [config](<>) | [model](<>) \| [log](<>) |
| Backbone | Style | Lr schd | AP (Step=1) | AP (Step=4) | Config | Download |
| :------: | :-----: | :-----: | :---------: | :---------: | :----------------------------------------------------------------------------------------: | :--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: |
| R-50 | PyTorch | 450k | 44.5 | 46.2 | [config](./configs/diffusiondet_r50_fpn_500-proposals_1-step_crop-ms-480-800-450k_coco.py) | [model](https://download.openmmlab.com/mmdetection/v3.0/diffusiondet/diffusiondet_r50_fpn_500-proposals_1-step_crop-ms-480-800-450k_coco/diffusiondet_r50_fpn_500-proposals_1-step_crop-ms-480-800-450k_coco_20230215_090925-7d6ed504.pth) \| [log](https://download.openmmlab.com/mmdetection/v3.0/diffusiondet/diffusiondet_r50_fpn_500-proposals_1-step_crop-ms-480-800-450k_coco/diffusiondet_r50_fpn_500-proposals_1-step_crop-ms-480-800-450k_coco_20230215_090925.log.json) |

## License

Expand Down Expand Up @@ -122,9 +141,9 @@ A project does not necessarily have to be finished in a single PR, but it's esse

<!-- As this template does. -->

- [ ] Milestone 2: Indicates a successful model implementation.
- [x] Milestone 2: Indicates a successful model implementation.

- [ ] Training-time correctness
- [x] Training-time correctness

<!-- If you are reproducing the result from a paper, checking this item means that you should have trained your model from scratch based on the original paper's specification and verified that the final result matches the report within a minor error range. -->

Expand Down

This file was deleted.

14 changes: 11 additions & 3 deletions projects/DiffusionDet/diffusiondet/head.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,7 @@ def loss(self, x: Tuple[Tensor], batch_data_samples: SampleList) -> dict:
batch_img_metas) = prepare_outputs

batch_diff_bboxes = torch.stack([
pred_instances.diff_bboxes
pred_instances.diff_bboxes_abs
for pred_instances in batch_pred_instances
])
batch_time = torch.stack(
Expand All @@ -339,6 +339,13 @@ def loss(self, x: Tuple[Tensor], batch_data_samples: SampleList) -> dict:
return losses

def prepare_training_targets(self, batch_data_samples):
# hard-setting seed to keep results same (if necessary)
# random.seed(0)
# torch.manual_seed(0)
# torch.cuda.manual_seed_all(0)
# torch.backends.cudnn.deterministic = True
# torch.backends.cudnn.benchmark = False

batch_gt_instances = []
batch_pred_instances = []
batch_gt_instances_ignore = []
Expand All @@ -357,7 +364,7 @@ def prepare_training_targets(self, batch_data_samples):
image_size)

gt_instances.set_metainfo(dict(image_size=image_size))
gt_instances.norm_bboxes = norm_gt_bboxes
gt_instances.norm_bboxes_cxcywh = norm_gt_bboxes_cxcywh

batch_gt_instances.append(gt_instances)
batch_pred_instances.append(pred_instances)
Expand Down Expand Up @@ -399,11 +406,12 @@ def prepare_diffusion(self, gt_boxes, image_size):

diff_bboxes = bbox_cxcywh_to_xyxy(x)
# convert to abs bboxes
diff_bboxes = diff_bboxes * image_size
diff_bboxes_abs = diff_bboxes * image_size

metainfo = dict(time=time.squeeze(-1))
pred_instances = InstanceData(metainfo=metainfo)
pred_instances.diff_bboxes = diff_bboxes
pred_instances.diff_bboxes_abs = diff_bboxes_abs
pred_instances.noise = noise
return pred_instances

Expand Down
45 changes: 30 additions & 15 deletions projects/DiffusionDet/diffusiondet/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,12 +68,12 @@ def forward(self, outputs, batch_gt_instances, batch_img_metas):
if self.deep_supervision:
assert 'aux_outputs' in outputs
for i, aux_outputs in enumerate(outputs['aux_outputs']):
batch_indices = self.assigner(outputs, batch_gt_instances,
batch_indices = self.assigner(aux_outputs, batch_gt_instances,
batch_img_metas)
loss_cls = self.loss_classification(outputs,
loss_cls = self.loss_classification(aux_outputs,
batch_gt_instances,
batch_indices)
loss_bbox, loss_giou = self.loss_boxes(outputs,
loss_bbox, loss_giou = self.loss_boxes(aux_outputs,
batch_gt_instances,
batch_indices)
tmp_losses = dict(
Expand Down Expand Up @@ -101,8 +101,7 @@ def loss_classification(self, outputs, batch_gt_instances, indices):
src_logits = src_logits.flatten(0, 1)
target_classes = target_classes.flatten(0, 1)
# comp focal loss.
num_instances = torch.cat(target_classes_list).shape[0] if \
len(target_classes_list) != 0 else 1
num_instances = max(torch.cat(target_classes_list).shape[0], 1)
loss_cls = self.loss_cls(
src_logits,
target_classes,
Expand All @@ -114,7 +113,7 @@ def loss_boxes(self, outputs, batch_gt_instances, indices):
pred_boxes = outputs['pred_boxes']

target_bboxes_norm_list = [
gt.norm_bboxes[J]
gt.norm_bboxes_cxcywh[J]
for gt, (_, J) in zip(batch_gt_instances, indices)
]
target_bboxes_list = [
Expand All @@ -137,8 +136,9 @@ def loss_boxes(self, outputs, batch_gt_instances, indices):
if len(pred_boxes_cat) > 0:
num_instances = pred_boxes_cat.shape[0]

loss_bbox = self.loss_bbox(pred_boxes_norm_cat,
target_bboxes_norm_cat) / num_instances
loss_bbox = self.loss_bbox(
pred_boxes_norm_cat,
bbox_cxcywh_to_xyxy(target_bboxes_norm_cat)) / num_instances
loss_giou = self.loss_giou(pred_boxes_cat,
target_bboxes_cat) / num_instances
else:
Expand Down Expand Up @@ -222,6 +222,12 @@ def single_assigner(self, pred_instances, gt_instances, img_meta):
dtype=torch.long)
return valid_mask, matched_gt_inds

valid_mask, is_in_boxes_and_center = \
self.get_in_gt_and_in_center_info(
bbox_xyxy_to_cxcywh(pred_bboxes),
bbox_xyxy_to_cxcywh(gt_bboxes)
)

cost_list = []
for match_cost in self.match_costs:
cost = match_cost(
Expand All @@ -230,12 +236,6 @@ def single_assigner(self, pred_instances, gt_instances, img_meta):
img_meta=img_meta)
cost_list.append(cost)

valid_mask, is_in_boxes_and_center = \
self.get_in_gt_and_in_center_info(
bbox_xyxy_to_cxcywh(pred_bboxes),
bbox_xyxy_to_cxcywh(gt_bboxes)
)

pairwise_ious = self.iou_calculator(pred_bboxes, gt_bboxes)

cost_list.append((~is_in_boxes_and_center) * 100.0)
Expand Down Expand Up @@ -301,7 +301,7 @@ def dynamic_k_matching(self, cost: Tensor, pairwise_ious: Tensor,
num_gt: int) -> Tuple[Tensor, Tensor]:
"""Use IoU and matching cost to calculate the dynamic top-k positive
targets."""
matching_matrix = torch.zeros_like(cost, dtype=torch.uint8)
matching_matrix = torch.zeros_like(cost)
# select candidate topk ious for dynamic-k calculation
candidate_topk = min(self.candidate_topk, pairwise_ious.size(0))
topk_ious, _ = torch.topk(pairwise_ious, candidate_topk, dim=0)
Expand All @@ -319,6 +319,21 @@ def dynamic_k_matching(self, cost: Tensor, pairwise_ious: Tensor,
_, cost_argmin = torch.min(cost[prior_match_gt_mask, :], dim=1)
matching_matrix[prior_match_gt_mask, :] *= 0
matching_matrix[prior_match_gt_mask, cost_argmin] = 1

while (matching_matrix.sum(0) == 0).any():
matched_query_id = matching_matrix.sum(1) > 0
cost[matched_query_id] += 100000.0
unmatch_id = torch.nonzero(
matching_matrix.sum(0) == 0, as_tuple=False).squeeze(1)
for gt_idx in unmatch_id:
pos_idx = torch.argmin(cost[:, gt_idx])
matching_matrix[:, gt_idx][pos_idx] = 1.0
if (matching_matrix.sum(1) > 1).sum() > 0:
_, cost_argmin = torch.min(cost[prior_match_gt_mask], dim=1)
matching_matrix[prior_match_gt_mask] *= 0
matching_matrix[prior_match_gt_mask, cost_argmin, ] = 1

assert not (matching_matrix.sum(0) == 0).any()
# get foreground mask inside box and center prior
fg_mask_inboxes = matching_matrix.sum(1) > 0
matched_gt_inds = matching_matrix[fg_mask_inboxes, :].argmax(1)
Expand Down

0 comments on commit 659f7a6

Please sign in to comment.