Skip to content

Commit

Permalink
[Fix]: Remove sampling hardcode. (open-mmlab#6317)
Browse files Browse the repository at this point in the history
* [Fix]: Remove sampling hardcode.

* add doc
  • Loading branch information
RangiLyu authored Oct 25, 2021
1 parent 71a1cf2 commit e43df7c
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 23 deletions.
37 changes: 29 additions & 8 deletions docs/tutorials/customize_losses.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,38 @@ This tutorial first elaborate the computation pipeline of losses, then give some

Given the input prediction and target, as well as the weights, a loss function maps the input tensor to the final loss scalar. The mapping can be divided into four steps:

1. Get **element-wise** or **sample-wise** loss by the loss kernel function.
1. Set the sampling method to sample positive and negative samples.

2. Weighting the loss with a weight tensor **element-wisely**.
2. Get **element-wise** or **sample-wise** loss by the loss kernel function.

3. Reduce the loss tensor to a **scalar**.
3. Weighting the loss with a weight tensor **element-wisely**.

4. Weighting the loss with a **scalar**.
4. Reduce the loss tensor to a **scalar**.

5. Weighting the loss with a **scalar**.

## Set sampling method (step 1)

For some loss functions, sampling strategies are needed to avoid imbalance between positive and negative samples.

For example, when using `CrossEntropyLoss` in RPN head, we need to set `RandomSampler` in `train_cfg`

```python
train_cfg=dict(
rpn=dict(
sampler=dict(
type='RandomSampler',
num=256,
pos_fraction=0.5,
neg_pos_ub=-1,
add_gt_as_proposals=False))
```

For some other losses which have positive and negative sample balance mechanism such as Focal Loss, GHMC, and QualityFocalLoss, the sampler is no more necessary.

## Tweaking loss

Tweaking a loss is more related with step 1, 3, 4, and most modifications can be specified in the config.
Tweaking a loss is more related with step 2, 4, 5, and most modifications can be specified in the config.
Here we take [Focal Loss (FL)](https://github.com/open-mmlab/mmdetection/blob/master/mmdet/models/losses/focal_loss.py) as an example.
The following code sniper are the construction method and config of FL respectively, they are actually one to one correspondence.

Expand All @@ -43,7 +64,7 @@ loss_cls=dict(
loss_weight=1.0)
```

### Tweaking hyper-parameters (step 1)
### Tweaking hyper-parameters (step 2)

`gamma` and `beta` are two hyper-parameters in the Focal Loss. Say if we want to change the value of `gamma` to be 1.5 and `alpha` to be 0.5, then we can specify them in the config as follows:

Expand All @@ -70,7 +91,7 @@ loss_cls=dict(
reduction='sum')
```

### Tweaking loss weight (step 4)
### Tweaking loss weight (step 5)

The loss weight here is a scalar which controls the weight of different losses in multi-task learning, e.g. classification loss and regression loss. Say if we want to change to loss weight of classification loss to be 0.5, we can specify it in the config as follows:

Expand All @@ -83,7 +104,7 @@ loss_cls=dict(
loss_weight=0.5)
```

## Weighting loss (step 2)
## Weighting loss (step 3)

Weighting loss means we re-weight the loss element-wisely. To be more specific, we multiply the loss tensor with a weight tensor which has the same shape. As a result, different entries of the loss can be scaled differently, and so called element-wisely.
The loss weight varies across different models and highly context related, but overall there are two kinds of loss weights, `label_weights` for classification loss and `bbox_weights` for bbox regression loss. You can find them in the `get_target` method of the corresponding head. Here we take [ATSSHead](https://github.com/open-mmlab/mmdetection/blob/master/mmdet/models/dense_heads/atss_head.py#L530) as an example, which inherit [AnchorHead](https://github.com/open-mmlab/mmdetection/blob/master/mmdet/models/dense_heads/anchor_head.py) but overwrite its `get_targets` method which yields different `label_weights` and `bbox_weights`.
Expand Down
37 changes: 28 additions & 9 deletions docs_zh-CN/tutorials/customize_losses.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,38 @@ MMDetection 为用户提供了不同的损失函数。但是默认的配置可
## 一个损失的计算过程

给定输入(包括预测和目标,以及权重),损失函数会把输入的张量映射到最后的损失标量。映射过程可以分为下面四个步骤:
1. 设置采样方法为对正负样本进行采样。

1. 通过损失核函数获取**元素**或者**样本**损失。
2. 通过损失核函数获取**元素**或者**样本**损失。

2. 通过权重张量来给损失**逐元素**权重。
3. 通过权重张量来给损失**逐元素**权重。

3. 把损失张量归纳为一个**标量**
4. 把损失张量归纳为一个**标量**

4. 用一个**张量**给当前损失一个权重。
5. 用一个**张量**给当前损失一个权重。

## 设置采样方法(步骤 1)

对于一些损失函数,需要采样策略来避免正负样本之间的不平衡。

例如,在RPN head中使用`CrossEntropyLoss`时,我们需要在`train_cfg`中设置`RandomSampler`

```python
train_cfg=dict(
rpn=dict(
sampler=dict(
type='RandomSampler',
num=256,
pos_fraction=0.5,
neg_pos_ub=-1,
add_gt_as_proposals=False))
```

对于其他一些具有正负样本平衡机制的损失,例如 Focal Loss、GHMC 和 QualityFocalLoss,不再需要进行采样。

## 微调损失

微调一个损失主要与步骤 1,3,4 有关,大部分的修改可以在配置文件中指定。这里我们用 [Focal Loss (FL)](https://github.com/open-mmlab/mmdetection/blob/master/mmdet/models/losses/focal_loss.py) 作为例子。
微调一个损失主要与步骤 245 有关,大部分的修改可以在配置文件中指定。这里我们用 [Focal Loss (FL)](https://github.com/open-mmlab/mmdetection/blob/master/mmdet/models/losses/focal_loss.py) 作为例子。
下面的代码分别是构建 FL 的方法和它的配置文件,他们是一一对应的。

```python
Expand All @@ -44,7 +63,7 @@ loss_cls=dict(
loss_weight=1.0)
```

### 微调超参数(步骤1
### 微调超参数(步骤2

`gamma``beta` 是 Focal Loss 中的两个超参数。如果我们想把 `gamma` 的值设为 1.5,把 `alpha` 的值设为 0.5,我们可以在配置文件中按照如下指定:

Expand All @@ -57,7 +76,7 @@ loss_cls=dict(
loss_weight=1.0)
```

### 微调归纳方式(步骤3
### 微调归纳方式(步骤4

Focal Loss 默认的归纳方式是 `mean`。如果我们想把归纳方式从 `mean` 改成 `sum`,我们可以在配置文件中按照如下指定:

Expand All @@ -71,7 +90,7 @@ loss_cls=dict(
reduction='sum')
```

### 微调损失权重(步骤4
### 微调损失权重(步骤5

这里的损失权重是一个标量,他用来控制多任务学习中不同损失的重要程度,例如,分类损失和回归损失。如果我们想把分类损失的权重设为 0.5,我们可以在配置文件中如下指定:

Expand All @@ -84,7 +103,7 @@ loss_cls=dict(
loss_weight=0.5)
```

## 加权损失(步骤2
## 加权损失(步骤3

加权损失就是我们逐元素修改损失权重。更具体来说,我们给损失张量乘以一个与他有相同形状的权重张量。所以,损失中不同的元素可以被赋予不同的比例,所以这里叫做逐元素。损失的权重在不同模型中变化很大,而且与上下文相关,但是总的来说主要有两种损失权重:分类损失的 `label_weights` 和边界框的 `bbox_weights`。你可以在相应的头中的 `get_target` 方法中找到他们。这里我们使用 [ATSSHead](https://github.com/open-mmlab/mmdetection/blob/master/mmdet/models/dense_heads/atss_head.py#L530) 作为一个例子。它继承了 [AnchorHead](https://github.com/open-mmlab/mmdetection/blob/master/mmdet/models/dense_heads/anchor_head.py),但是我们重写它的
`get_targets` 方法来产生不同的 `label_weights``bbox_weights`
Expand Down
24 changes: 18 additions & 6 deletions mmdet/models/dense_heads/anchor_head.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
import warnings

import torch
import torch.nn as nn
from mmcv.runner import force_fp32
Expand Down Expand Up @@ -63,10 +65,6 @@ def __init__(self,
self.num_classes = num_classes
self.feat_channels = feat_channels
self.use_sigmoid_cls = loss_cls.get('use_sigmoid', False)
# TODO better way to determine whether sample or not
self.sampling = loss_cls['type'] not in [
'FocalLoss', 'GHMC', 'QualityFocalLoss'
]
if self.use_sigmoid_cls:
self.cls_out_channels = num_classes
else:
Expand All @@ -83,10 +81,24 @@ def __init__(self,
self.test_cfg = test_cfg
if self.train_cfg:
self.assigner = build_assigner(self.train_cfg.assigner)
# use PseudoSampler when sampling is False
if self.sampling and hasattr(self.train_cfg, 'sampler'):
if hasattr(self.train_cfg,
'sampler') and self.train_cfg.sampler.type.split(
'.')[-1] != 'PseudoSampler':
self.sampling = True
sampler_cfg = self.train_cfg.sampler
# avoid BC-breaking
if loss_cls['type'] in [
'FocalLoss', 'GHMC', 'QualityFocalLoss'
]:
warnings.warn(
'DeprecationWarning: Determining whether to sampling'
'by loss type is deprecated, please delete sampler in'
'your config when using `FocalLoss`, `GHMC`, '
'`QualityFocalLoss` or other FocalLoss variant.')
self.sampling = False
sampler_cfg = dict(type='PseudoSampler')
else:
self.sampling = False
sampler_cfg = dict(type='PseudoSampler')
self.sampler = build_sampler(sampler_cfg, context=self)
self.fp16_enabled = False
Expand Down

0 comments on commit e43df7c

Please sign in to comment.