Skip to content

Commit

Permalink
GuidedBackprop with DataParallel
Browse files Browse the repository at this point in the history
  • Loading branch information
jacobgil committed Sep 13, 2021
1 parent a37d8ac commit 3fdf7d4
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 4 deletions.
21 changes: 17 additions & 4 deletions pytorch_grad_cam/guided_backprop.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import numpy as np
import torch
from torch.autograd import Function
from pytorch_grad_cam.utils.find_layers import replace_all_layer_type_recursive


class GuidedBackpropReLU(Function):
Expand Down Expand Up @@ -34,6 +35,14 @@ def backward(self, grad_output):
return grad_input


class GuidedBackpropReLUasModule(torch.nn.Module):
def __init__(self):
super(GuidedBackpropReLUasModule, self).__init__()

def forward(self, input_img):
return GuidedBackpropReLU.apply(input_img)


class GuidedBackpropReLUModel:
def __init__(self, model, use_cuda):
self.model = model
Expand All @@ -46,10 +55,12 @@ def forward(self, input_img):
return self.model(input_img)

def recursive_replace_relu_with_guidedrelu(self, module_top):

for idx, module in module_top._modules.items():
self.recursive_replace_relu_with_guidedrelu(module)
if module.__class__.__name__ == 'ReLU':
module_top._modules[idx] = GuidedBackpropReLU.apply
print("b")

def recursive_replace_guidedrelu_with_relu(self, module_top):
try:
Expand All @@ -61,8 +72,9 @@ def recursive_replace_guidedrelu_with_relu(self, module_top):
pass

def __call__(self, input_img, target_category=None):
# replace ReLU with GuidedBackpropReLU
self.recursive_replace_relu_with_guidedrelu(self.model)
replace_all_layer_type_recursive(self.model,
torch.nn.ReLU,
GuidedBackpropReLUasModule())

if self.cuda:
input_img = input_img.cuda()
Expand All @@ -81,7 +93,8 @@ def __call__(self, input_img, target_category=None):
output = output[0, :, :, :]
output = output.transpose((1, 2, 0))

# replace GuidedBackpropReLU back with ReLU
self.recursive_replace_guidedrelu_with_relu(self.model)
replace_all_layer_type_recursive(self.model,
GuidedBackpropReLUasModule,
torch.nn.ReLU())

return output
7 changes: 7 additions & 0 deletions pytorch_grad_cam/utils/find_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,13 @@ def replace_layer_recursive(model, old_layer, new_layer):
return False


def replace_all_layer_type_recursive(model, old_layer_type, new_layer):
for name, layer in model._modules.items():
if isinstance(layer, old_layer_type):
model._modules[name] = new_layer
replace_all_layer_type_recursive(layer, old_layer_type, new_layer)


def find_layer_types_recursive(model, layer_types):
def predicate(layer):
return type(layer) in layer_types
Expand Down

0 comments on commit 3fdf7d4

Please sign in to comment.