Skip to content

Commit

Permalink
Layer CAM
Browse files Browse the repository at this point in the history
  • Loading branch information
jacobgil committed Jul 9, 2021
1 parent c6e820d commit 5a5043c
Show file tree
Hide file tree
Showing 5 changed files with 42 additions and 8 deletions.
7 changes: 6 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
| ScoreCAM | Perbutate the image by the scaled activations and measure how the output drops |
| EigenCAM | Takes the first principle component of the 2D Activations (no class discrimination, but seems to give great results)|
| EigenGradCAM | Like EigenCAM but with class discrimination: First principle component of Activations*Grad. Looks like GradCAM, but cleaner|
| LayerCAM | Spatially weight the activations by positive gradients. Works better especially in lower layers |


### What makes the network think the image label is 'pug, pug-dog' and 'tabby, tabby cat':
Expand Down Expand Up @@ -142,7 +143,7 @@ To use with CUDA:

You can choose between:

`GradCAM` , `ScoreCAM`, `GradCAMPlusPlus`, `AblationCAM`, `XGradCAM` and `EigenCAM`.
`GradCAM` , `ScoreCAM`, `GradCAMPlusPlus`, `AblationCAM`, `XGradCAM` , `LayerCAM` and `EigenCAM`.

Some methods like ScoreCAM and AblationCAM require a large number of forward passes,
and have a batched implementation.
Expand Down Expand Up @@ -269,3 +270,7 @@ Ruigang Fu, Qingyong Hu, Xiaohu Dong, Yulan Guo, Yinghui Gao, Biao Li`
https://arxiv.org/abs/2008.00299 <br>
`Eigen-CAM: Class Activation Map using Principal Components
Mohammed Bany Muhammad, Mohammed Yeasin`

https://arxiv.org/abs/2008.00299 <br>
`LayerCAM: Exploring Hierarchical Class Activation Maps for Localization
Peng-Tao Jiang; Chang-Bin Zhang; Qibin Hou; Ming-Ming Cheng; Yunchao Wei`
16 changes: 10 additions & 6 deletions cam.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
AblationCAM, \
XGradCAM, \
EigenCAM, \
EigenGradCAM
EigenGradCAM, \
LayerCAM

from pytorch_grad_cam import GuidedBackpropReLUModel
from pytorch_grad_cam.utils.image import show_cam_on_image, \
Expand All @@ -30,10 +31,12 @@ def get_args():
help='Reduce noise by taking the first principle componenet'
'of cam_weights*activations')
parser.add_argument('--method', type=str, default='gradcam',
choices=['gradcam', 'gradcam++', 'scorecam', 'xgradcam',
'ablationcam', 'eigencam', 'eigengradcam'],
choices=['gradcam', 'gradcam++',
'scorecam', 'xgradcam',
'ablationcam', 'eigencam',
'eigengradcam', 'layercam'],
help='Can be gradcam/gradcam++/scorecam/xgradcam'
'/ablationcam/eigencam/eigengradcam')
'/ablationcam/eigencam/eigengradcam/layercam')

args = parser.parse_args()
args.use_cuda = args.use_cuda and torch.cuda.is_available()
Expand Down Expand Up @@ -61,7 +64,8 @@ def get_args():
"ablationcam": AblationCAM,
"xgradcam": XGradCAM,
"eigencam": EigenCAM,
"eigengradcam": EigenGradCAM}
"eigengradcam": EigenGradCAM,
"layercam": LayerCAM}

model = models.resnet50(pretrained=True)

Expand Down Expand Up @@ -113,4 +117,4 @@ def get_args():

cv2.imwrite(f'{args.method}_cam.jpg', cam_image)
cv2.imwrite(f'{args.method}_gb.jpg', gb)
cv2.imwrite(f'{args.method}_cam_gb.jpg', cam_gb)
cv2.imwrite(f'{args.method}_cam_gb.jpg', cam_gb)
1 change: 1 addition & 0 deletions pytorch_grad_cam/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from pytorch_grad_cam.xgrad_cam import XGradCAM
from pytorch_grad_cam.grad_cam_plusplus import GradCAMPlusPlus
from pytorch_grad_cam.score_cam import ScoreCAM
from pytorch_grad_cam.layer_cam import LayerCAM
from pytorch_grad_cam.eigen_cam import EigenCAM
from pytorch_grad_cam.eigen_grad_cam import EigenGradCAM
from pytorch_grad_cam.guided_backprop import GuidedBackpropReLUModel
24 changes: 24 additions & 0 deletions pytorch_grad_cam/layer_cam.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import cv2
import numpy as np
import torch
from pytorch_grad_cam.base_cam import BaseCAM
from pytorch_grad_cam.utils.svd_on_activations import get_2d_projection

# https://ieeexplore.ieee.org/document/9462463
class LayerCAM(BaseCAM):
def __init__(self, model, target_layer, use_cuda=False, reshape_transform=None):
super(LayerCAM, self).__init__(model, target_layer, use_cuda, reshape_transform)

def get_cam_image(self,
input_tensor,
target_category,
activations,
grads,
eigen_smooth):
spatial_weighted_activations = np.maximum(grads, 0) * activations

if eigen_smooth:
cam = get_2d_projection(spatial_weighted_activations)
else:
cam = spatial_weighted_activations.sum(axis=1)
return cam
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

setuptools.setup(
name='grad-cam',
version='1.2.8',
version='1.3.0',
author='Jacob Gildenblat',
author_email='[email protected]',
description='Many Class Activation Map methods implemented in Pytorch. Including Grad-CAM, Grad-CAM++, Score-CAM, Ablation-CAM and XGrad-CAM',
Expand Down

0 comments on commit 5a5043c

Please sign in to comment.