Skip to content

Commit

Permalink
focal_loss
Browse files Browse the repository at this point in the history
  • Loading branch information
zingp committed Jan 27, 2021
1 parent f1f7bce commit 730ffee
Showing 1 changed file with 91 additions and 90 deletions.
181 changes: 91 additions & 90 deletions PyTorchCS/focal_loss.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,13 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"## FocalLoss"
"## FocalLoss\n",
"- $\\mathrm{FL}\\left(p_{\\mathrm{t}}\\right)=-\\left(1-p_{\\mathrm{t}}\\right)^{\\gamma} \\log \\left(p_{\\mathrm{t}}\\right)$"
]
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -21,13 +22,13 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 30,
"metadata": {},
"outputs": [],
"source": [
"class FocalLoss(nn.Module):\n",
"\n",
" def __init__(self, weight=None, reduction='mean', gamma=0, eps=1e-7):\n",
" def __init__(self, weight=None, reduction='mean', gamma=2, eps=1e-7):\n",
" super(FocalLoss, self).__init__()\n",
" self.gamma = gamma\n",
" self.eps = eps\n",
Expand All @@ -42,7 +43,7 @@
},
{
"cell_type": "code",
"execution_count": 15,
"execution_count": 31,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -54,7 +55,7 @@
},
{
"cell_type": "code",
"execution_count": 16,
"execution_count": 32,
"metadata": {},
"outputs": [
{
Expand All @@ -63,7 +64,7 @@
"tensor(0.5415)"
]
},
"execution_count": 16,
"execution_count": 32,
"metadata": {},
"output_type": "execute_result"
}
Expand All @@ -74,191 +75,191 @@
},
{
"cell_type": "code",
"execution_count": 19,
"execution_count": 33,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor(0.0947)"
"tensor(0.1015)"
]
},
"execution_count": 19,
"execution_count": 33,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"FocalLoss(weight=torch.Tensor([2, 2, 2]), gamma=2)(torch.tensor(y_logits), torch.tensor(y,dtype=torch.int64))"
"FocalLoss(weight=torch.Tensor([0.25, 0.75, 0.75]), gamma=2)(torch.tensor(y_logits), torch.tensor(y,dtype=torch.int64))"
]
},
{
"cell_type": "code",
"execution_count": 20,
"execution_count": 34,
"metadata": {},
"outputs": [],
"source": [
"?nn.LSTM"
"class focal_loss(nn.Module):\n",
" \"\"\"https://github.com/yatengLG/Focal-Loss-Pytorch/blob/master/Focal_Loss.py\"\"\"\n",
" def __init__(self, alpha=0.25, gamma=2, num_classes=3, size_average=True):\n",
" \"\"\"\n",
" focal_loss损失函数, -α(1-yi)**γ *ce_loss(xi,yi)\n",
" 步骤详细的实现了 focal_loss损失函数.\n",
" :param alpha: 阿尔法α, 类别权重. 当α是列表时,为各类别权重,当α为常数时,类别权重为[α, 1-α, 1-α, ....],常用于 目标检测算法中抑制背景类 , retainnet中设置为0.25\n",
" :param gamma: 伽马γ, 难易分样本调节参数. retainnet中设置为2\n",
" :param num_classes: 类别数量\n",
" :param size_average: 损失计算方式,默认取均值\n",
" \"\"\"\n",
" super(focal_loss,self).__init__()\n",
" self.size_average = size_average\n",
" # α可以以list方式输入,size:[num_classes] 用于对不同类别精细地赋予权重\n",
" if isinstance(alpha, list):\n",
" assert len(alpha) == num_classes\n",
" self.alpha = torch.Tensor(alpha)\n",
" else:\n",
" assert alpha < 1 # 如果α为一个常数,则降低第一类的影响\n",
" self.alpha = torch.zeros(num_classes)\n",
" self.alpha[0] += alpha\n",
" # α 最终为 [ α, 1-α, 1-α, 1-α, 1-α, ...] size:[num_classes]\n",
" self.alpha[1:] += (1-alpha) \n",
" self.gamma = gamma\n",
"\n",
" def forward(self, preds, labels):\n",
" \"\"\"\n",
" focal_loss损失计算\n",
" :param preds: 预测类别. size:[B,N,C] or [B,C] 分别对应与检测与分类任务, B 批次, N检测框数, C类别数\n",
" :param labels: 实际类别. size:[B,N] or [B]\n",
" :return:\n",
" \"\"\"\n",
" # assert preds.dim()==2 and labels.dim()==1\n",
" preds = preds.view(-1, preds.size(-1))\n",
" self.alpha = self.alpha.to(preds.device)\n",
" preds_logsoft = F.log_softmax(preds, dim=1) # log_softmax\n",
" preds_softmax = torch.exp(preds_logsoft) # softmax\n",
" # 这部分实现nll_loss ( crossempty = log_softmax + nll )\n",
" preds_softmax = preds_softmax.gather(1, labels.view(-1,1)) \n",
" preds_logsoft = preds_logsoft.gather(1, labels.view(-1,1))\n",
" self.alpha = self.alpha.gather(0, labels.view(-1))\n",
" print(\"alpha\", self.alpha)\n",
" # torch.pow((1-preds_softmax), self.gamma) 为focal loss中 (1-pt)**γ\n",
" loss = -torch.mul(torch.pow((1 - preds_softmax), self.gamma), preds_logsoft) \n",
" print(\"loss:\", loss)\n",
" loss = torch.mul(self.alpha, loss.t())\n",
" if self.size_average:\n",
" loss = loss.mean()\n",
" else:\n",
" loss = loss.sum()\n",
" return loss"
]
},
{
"cell_type": "code",
"execution_count": 21,
"execution_count": 35,
"metadata": {},
"outputs": [],
"source": [
"?nn.GRU"
"fl = focal_loss()"
]
},
{
"cell_type": "code",
"execution_count": 22,
"execution_count": 36,
"metadata": {},
"outputs": [
{
"name": "stderr",
"name": "stdout",
"output_type": "stream",
"text": [
"/usr/local/anaconda2/envs/pt-tf-env/lib/python3.6/site-packages/ipykernel_launcher.py:1: UserWarning: Using a target size (torch.Size([4])) that is different to the input size (torch.Size([4, 3])) is deprecated. Please ensure they have the same size.\n",
" \"\"\"Entry point for launching an IPython kernel.\n"
"alpha tensor([0.7500, 0.2500, 0.7500, 0.7500])\n",
"loss: tensor([[2.5347e-02],\n",
" [6.4078e-02],\n",
" [7.6386e-01],\n",
" [3.8641e-08]])\n"
]
},
{
"ename": "ValueError",
"evalue": "Target and input must have the same number of elements. target nelement (4) != input nelement (12)",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m<ipython-input-22-0296e81c4ed6>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mF\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbinary_cross_entropy\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtensor\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0my_logits\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtensor\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0my\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mdtype\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mint64\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
"\u001b[0;32m/usr/local/anaconda2/envs/pt-tf-env/lib/python3.6/site-packages/torch/nn/functional.py\u001b[0m in \u001b[0;36mbinary_cross_entropy\u001b[0;34m(input, target, weight, size_average, reduce, reduction)\u001b[0m\n\u001b[1;32m 2104\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnumel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m!=\u001b[0m \u001b[0mtarget\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnumel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2105\u001b[0m raise ValueError(\"Target and input must have the same number of elements. target nelement ({}) \"\n\u001b[0;32m-> 2106\u001b[0;31m \"!= input nelement ({})\".format(target.numel(), input.numel()))\n\u001b[0m\u001b[1;32m 2107\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2108\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mweight\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mValueError\u001b[0m: Target and input must have the same number of elements. target nelement (4) != input nelement (12)"
]
}
],
"source": [
"F.binary_cross_entropy(torch.tensor(y_logits), torch.tensor(y,dtype=torch.int64))"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([0.5983, 1.0153, 0.4514], requires_grad=True)"
"tensor(0.1520)"
]
},
"execution_count": 23,
"execution_count": 36,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"m = nn.Sigmoid()\n",
"loss = nn.BCELoss()\n",
"input = torch.randn(3, requires_grad=True)\n",
"input"
"fl(torch.tensor(y_logits), torch.tensor(y,dtype=torch.int64))"
]
},
{
"cell_type": "code",
"execution_count": 24,
"execution_count": 37,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([1., 1., 1.])"
"tensor([2, 0, 1, 1])"
]
},
"execution_count": 24,
"execution_count": 37,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"target = torch.empty(3).random_(2)\n",
"target"
"labels = torch.tensor(y,dtype=torch.int64)\n",
"labels"
]
},
{
"cell_type": "code",
"execution_count": 25,
"execution_count": 38,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([0.6453, 0.7341, 0.6110], grad_fn=<SigmoidBackward>)"
"tensor([0.2500, 0.7500, 0.7500])"
]
},
"execution_count": 25,
"execution_count": 38,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"m(input)"
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {},
"outputs": [],
"source": [
"output = loss(m(input), target)"
"a = torch.tensor([0.2500, 0.7500, 0.75])\n",
"a"
]
},
{
"cell_type": "code",
"execution_count": 28,
"execution_count": 39,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor(0.4133, grad_fn=<BinaryCrossEntropyBackward>)"
"tensor([0.7500, 0.2500, 0.7500, 0.7500])"
]
},
"execution_count": 28,
"execution_count": 39,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"output"
]
},
{
"cell_type": "code",
"execution_count": 29,
"metadata": {},
"outputs": [],
"source": [
"output.backward()"
]
},
{
"cell_type": "code",
"execution_count": 32,
"metadata": {},
"outputs": [],
"source": [
"?nn.BCELoss"
"a.gather(0, labels.view(-1))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def "
]
"source": []
}
],
"metadata": {
Expand Down

0 comments on commit 730ffee

Please sign in to comment.