Skip to content

Commit

Permalink
fix AblationLayerVit (jacobgil#292)
Browse files Browse the repository at this point in the history
AblationLayerVit only works with 3-dimensional activations (e.g. (1,50,768)). This fix makes it work with 4-dimensional inputs such as that of the Swin-T model from Torchvision, as it returns activations of shape (1,7,7,768).
  • Loading branch information
lassiraa authored Jul 25, 2022
1 parent 7a5786c commit fe106ad
Showing 1 changed file with 4 additions and 3 deletions.
7 changes: 4 additions & 3 deletions pytorch_grad_cam/ablation_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def __init__(self):

def __call__(self, x):
output = self.activations
output = output.transpose(1, 2)
output = output.transpose(1, len(output.shape) - 1)
for i in range(output.size(0)):

# Commonly the minimum activation will be 0,
Expand All @@ -95,15 +95,16 @@ def __call__(self, x):
output[i, self.indices[i], :] = torch.min(
output) - ABLATION_VALUE

output = output.transpose(2, 1)
output = output.transpose(len(output.shape) - 1, 1)

return output

def set_next_batch(self, input_batch_index, activations, num_channels_to_ablate):
""" This creates the next batch of activations from the layer.
Just take corresponding batch member from activations, and repeat it num_channels_to_ablate times.
"""
self.activations = activations[input_batch_index, :, :].clone().unsqueeze(0).repeat(num_channels_to_ablate, 1, 1)
repeat_params = [num_channels_to_ablate] + len(activations.shape[:-1]) * [1]
self.activations = activations[input_batch_index, :, :].clone().unsqueeze(0).repeat(*repeat_params)



Expand Down

0 comments on commit fe106ad

Please sign in to comment.