Skip to content

Commit

Permalink
Add Score-CAM!
Browse files Browse the repository at this point in the history
  • Loading branch information
jacobgil committed Apr 3, 2021
1 parent 91191d1 commit fcd11c9
Show file tree
Hide file tree
Showing 6 changed files with 101 additions and 53 deletions.
42 changes: 24 additions & 18 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
## Grad-CAM and Grad-CAM++ implementation in Pytorch ##
## Grad-CAM, Grad-CAM++ and ScoreCAM implementation in Pytorch ##

### What makes the network think the image label is 'pug, pug-dog' and 'tabby, tabby cat':
![Dog](https://github.com/jacobgil/pytorch-grad-cam/blob/master/examples/dog.jpg?raw=true) ![Cat](https://github.com/jacobgil/pytorch-grad-cam/blob/master/examples/cat.jpg?raw=true)
Expand Down Expand Up @@ -31,43 +31,45 @@ Some common choices can be:
`pip install pytorch-grad-cam`

```python
from pytorch_grad_cam import GradCam
from pytorch_grad_cam import CAM
from pytorch_grad_cam.utils.image import show_cam_on_image
from torchvision.models import resnet50
model = resnet50(pretrained=True)
target_layer = model.layer4[-1]
input_tensor = # Create an input tensor image for you model..
grad_cam = GradCam(model=model,
target_layer=target_layer,
plusplus=False,
use_cuda=args.use_cuda)
grayscale_cam = grad_cam(input_tensor=input_tensor,
target_category=1)
cam = CAM(model=model,
target_layer=target_layer,
use_cuda=args.use_cuda)
grayscale_cam = cam(input_tensor=input_tensor,
target_category=1,
method="gradcam")
cam = show_cam_on_image(rgb_img, grayscale_cam)
```

----------

# Using GradCAM++
# Using GradCAM++ or Score-CAM:

To use GradCAM++, pass
`plusplus=True` to GradCam.
You can choose between:
- `method='gradcam'`
- `method='gradcam++`
- `method='scorecam`


It seems that it's almost the same as GradCAM, in
It seems that GradCAM++ is almost the same as GradCAM, in
most networks except VGG where the advantage is larger.

| Network | Image | GradCAM | GradCAM++ |
| ---------|-------|----------|------------|
| VGG16 | ![](examples/dogs.png) | ![](examples/dogs_gradcam_vgg16.jpg) | ![](examples/dogs_gradcam++_vgg16.jpg) |
| Resnet50 | ![](examples/dogs.png) | ![](examples/dogs_gradcam_resnet50.jpg) | ![](examples/dogs_gradcam++_resnet50.jpg)|
| Network | Image | GradCAM | GradCAM++ | Score-CAM |
| ---------|-------|----------|------------|------------|
| VGG16 | ![](examples/dogs.png) | ![](examples/dogs_gradcam_vgg16.jpg) | ![](examples/dogs_gradcam++_vgg16.jpg) |![](examples/dogs_scorecam_vgg16.jpg) |
| Resnet50 | ![](examples/dogs.png) | ![](examples/dogs_gradcam_resnet50.jpg) | ![](examples/dogs_gradcam++_resnet50.jpg)| ![](examples/dogs_scorecam_resnet50.jpg) |


----------

# Running the example script:

Usage: `python gradcam.py --image-path <path_to_image>`
Usage: `python gradcam.py --image-path <path_to_image> --method <method>`

To use with CUDA:
`python gradcam.py --image-path <path_to_image> --use-cuda`
Expand All @@ -82,4 +84,8 @@ Ramprasaath R. Selvaraju, Michael Cogswell, Abhishek Das, Ramakrishna Vedantam,

https://arxiv.org/abs/1710.11063
`Grad-CAM++: Improved Visual Explanations for Deep Convolutional Networks
Aditya Chattopadhyay, Anirban Sarkar, Prantik Howlader, Vineeth N Balasubramanian`
Aditya Chattopadhyay, Anirban Sarkar, Prantik Howlader, Vineeth N Balasubramanian`

https://arxiv.org/abs/1910.01279
`Score-CAM: Score-Weighted Visual Explanations for Convolutional Neural Networks
Haofan Wang, Zifan Wang, Mengnan Du, Fan Yang, Zijian Zhang, Sirui Ding, Piotr Mardziel, Xia Hu`
Binary file added examples/dogs_scorecam_cam_resnet50.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/dogs_scorecam_vgg16.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
27 changes: 15 additions & 12 deletions gradcam.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import torch
from torchvision import models

from pytorch_grad_cam import GradCam, GuidedBackpropReLUModel
from pytorch_grad_cam import CAM, GuidedBackpropReLUModel
from pytorch_grad_cam.utils.image import show_cam_on_image, \
deprocess_image, \
preprocess_image
Expand All @@ -15,6 +15,9 @@ def get_args():
help='Use NVIDIA GPU acceleration')
parser.add_argument('--image-path', type=str, default='./examples/both.png',
help='Input image path')
parser.add_argument('--method', type=str, default='scorecam',
help='Can be gradcam/gradcam++/scorecam')

args = parser.parse_args()
args.use_cuda = args.use_cuda and torch.cuda.is_available()
if args.use_cuda:
Expand Down Expand Up @@ -43,12 +46,10 @@ def get_args():
# mnasnet1_0: model.layers[-1]
# You can print the model to help chose the layer
target_layer = model.layer4[-1]


grad_cam = GradCam(model=model,
target_layer=target_layer,
use_cuda=args.use_cuda,
plusplus=False)
cam = CAM(model=model,
target_layer=target_layer,
use_cuda=args.use_cuda)

rgb_img = cv2.imread(args.image_path, 1)[:, :, ::-1]
rgb_img = np.float32(rgb_img) / 255
Expand All @@ -58,9 +59,11 @@ def get_args():
# If None, returns the map for the highest scoring category.
# Otherwise, targets the requested category.
target_category = None
grayscale_cam = grad_cam(input_tensor=input_tensor,
target_category=target_category)
cam = show_cam_on_image(rgb_img, grayscale_cam)
grayscale_cam = cam(input_tensor=input_tensor,
method=args.method,
target_category=target_category)

cam_image = show_cam_on_image(rgb_img, grayscale_cam)

gb_model = GuidedBackpropReLUModel(model=model, use_cuda=args.use_cuda)
gb = gb_model(input_tensor, target_category=target_category)
Expand All @@ -69,6 +72,6 @@ def get_args():
cam_gb = deprocess_image(cam_mask * gb)
gb = deprocess_image(gb)

cv2.imwrite('cam.jpg', cam)
cv2.imwrite('gb.jpg', gb)
cv2.imwrite('cam_gb.jpg', cam_gb)
cv2.imwrite(f'{args.method}_cam.jpg', cam_image)
cv2.imwrite(f'{args.method}_gb.jpg', gb)
cv2.imwrite(f'{args.method}_cam_gb.jpg', cam_gb)
2 changes: 1 addition & 1 deletion pytorch_grad_cam/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from pytorch_grad_cam.gradcam import GradCam
from pytorch_grad_cam.gradcam import CAM
from pytorch_grad_cam.guided_backprop import GuidedBackpropReLUModel
83 changes: 61 additions & 22 deletions pytorch_grad_cam/gradcam.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,10 @@
import torch
from pytorch_grad_cam.activations_and_gradients import ActivationsAndGradients

class GradCam:
def __init__(self, model, target_layer, plusplus=False, use_cuda=False):
self.model = model
self.model.eval()
class CAM:
def __init__(self, model, target_layer, use_cuda=False):
self.model = model.eval()
self.cuda = use_cuda
self.plusplus = plusplus
if self.cuda:
self.model = model.cuda()

Expand All @@ -17,7 +15,53 @@ def __init__(self, model, target_layer, plusplus=False, use_cuda=False):
def forward(self, input_img):
return self.model(input_img)

def __call__(self, input_tensor, target_category=None):
def gradcampp(self, activations, grads):
grads_power_2 = grads**2
grads_power_3 = grads_power_2*grads
# Equation 19 in https://arxiv.org/abs/1710.11063
sum_activations = np.sum(activations, axis=(1, 2))
eps = 0.00000001
aij = grads_power_2 / (2*grads_power_2 + sum_activations[:, None, None]*grads_power_3 + eps)

# Now bring back the ReLU from eq.7 in the paper,
# And zero out aijs where the activations are 0
aij = np.where(grads != 0, aij, 0)

weights = np.maximum(grads, 0)*aij
weights = np.sum(weights, axis=(1, 2))
return weights

def scorecam(self,
input_tensor,
activations,
target_category,
original_score):
with torch.no_grad():
upsample = torch.nn.UpsamplingBilinear2d(size=input_tensor.shape[2 : ])
activation_tensor = torch.from_numpy(activations).unsqueeze(0)
if self.cuda:
activation_tensor = activation_tensor.cuda()

upsampled = upsample(activation_tensor)
upsampled = upsampled[0, ]

maxs = upsampled.view(upsampled.size(0), -1).max(dim=-1)[0]
mins = upsampled.view(upsampled.size(0), -1).min(dim=-1)[0]
maxs, mins = maxs[:, None, None], mins[:, None, None]
upsampled = (upsampled - mins) / (maxs - mins)

input_tensors = input_tensor*upsampled[:, None, :, :]
batch_size = 16
scores = []
for i in range(0, input_tensors.size(0), batch_size):
batch = input_tensors[i : i + batch_size, :]
outputs = self.model(batch).cpu().numpy()[:, target_category]
scores.append(outputs)
scores = torch.from_numpy(np.concatenate(scores))
weights = torch.nn.Softmax(dim=-1)(scores - original_score).numpy()
return weights

def __call__(self, input_tensor, method="gradcam", target_category=None):
if self.cuda:
input_tensor = input_tensor.cuda()

Expand All @@ -40,23 +84,18 @@ def __call__(self, input_tensor, target_category=None):
grads = self.activations_and_grads.gradients[-1].cpu().data.numpy()[0, :]
cam = np.zeros(activations.shape[1:], dtype=np.float32)

if self.plusplus:
grads_power_2 = grads**2
grads_power_3 = grads_power_2*grads
# Equation 19 in https://arxiv.org/abs/1710.11063
sum_activations = np.sum(activations, axis=(1, 2))
eps = 0.00000001
aij = grads_power_2 / (2*grads_power_2 + sum_activations[:, None, None]*grads_power_3 + eps)

# Now bring back the ReLU from eq.7 in the paper,
# And zero out aijs where the activations are 0
aij = np.where(grads != 0, aij, 0)

weights = np.maximum(grads, 0)*aij
weights = np.sum(weights, axis=(1, 2))
else:
# Regular grad cam
if method == "gradcam++":
weights = self.gradcampp(activations, grads)
elif method == "gradcam":
weights = np.mean(grads, axis=(1, 2))
elif method == "scorecam":
original_score = original_score=output[0, target_category].cpu()
weights = self.scorecam(input_tensor,
activations,
target_category,
original_score=original_score)
else:
raise "Method not supported"

for i, w in enumerate(weights):
cam += w * activations[i, :, :]
Expand Down

0 comments on commit fcd11c9

Please sign in to comment.