Skip to content

Commit

Permalink
Merge pull request understandable-machine-intelligence-lab#210 from u…
Browse files Browse the repository at this point in the history
…nderstandable-machine-intelligence-lab/feature/relative-stability-metric

Feature/relative stability metric
  • Loading branch information
annahedstroem authored Dec 27, 2022
2 parents e60fd74 + a729e44 commit 5c6384a
Show file tree
Hide file tree
Showing 20 changed files with 2,414 additions and 380 deletions.
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,9 @@ measures to what extent explanations are stable when subject to slight perturbat
<li><b>Avg-Sensitivity </b><a href="https://arxiv.org/pdf/1901.09392.pdf">(Yeh et al., 2019)</a>: measures the average sensitivity of an explanation using a Monte Carlo sampling-based approximation
<li><b>Continuity </b><a href="https://arxiv.org/pdf/1706.07979.pdf">(Montavon et al., 2018)</a>: captures the strongest variation in explanation of an input and its perturbed version
<li><b>Consistency </b><a href="https://arxiv.org/abs/2202.00734">(Dasgupta et al., 2022)</a>: measures the probability that the inputs with the same explanation have the same prediction label
<li><b>Relative Input Stability (RIS)</b><a href="https://arxiv.org/pdf/2203.06877.pdf"> (Chirag Agarwal, et. al., 2022)</a>: measures the relative distance between explanations e_x and e_x' with respect to the distance between the two inputs x and x'
<li><b>Relative Representation Stability (RRS)</b><a href="https://arxiv.org/pdf/2203.06877.pdf"> (Chirag Agarwal, et. al., 2022)</a>: measures the relative distance between explanations e_x and e_x' with respect to the distance between internal models representations L_x and L_x' for x and x' respectively
<li><b>Relative Output Stability (ROS)</b><a href="https://arxiv.org/pdf/2203.06877.pdf"> (Chirag Agarwal, et. al., 2022)</a>: measures the relative distance between explanations e_x and e_x' with respect to the distance between output logits h(x) and h(x') for x and x' respectively
</ul>
</details>

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
quantus.metrics.robustness.relative\_input\_stability module
============================================================

.. automodule:: quantus.metrics.robustness.relative_input_stability
:members:
:undoc-members:
:show-inheritance:
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
quantus.metrics.robustness.relative\_ouput\_stability module
============================================================

.. automodule:: quantus.metrics.robustness.relative_ouput_stability
:members:
:undoc-members:
:show-inheritance:
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
quantus.metrics.robustness.relative\_output\_stability module
=============================================================

.. automodule:: quantus.metrics.robustness.relative_output_stability
:members:
:undoc-members:
:show-inheritance:
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
quantus.metrics.robustness.relative\_representation\_stability module
=====================================================================

.. automodule:: quantus.metrics.robustness.relative_representation_stability
:members:
:undoc-members:
:show-inheritance:
3 changes: 3 additions & 0 deletions docs/source/docs_api/quantus.metrics.robustness.rst
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,6 @@ Submodules
quantus.metrics.robustness.continuity
quantus.metrics.robustness.local_lipschitz_estimate
quantus.metrics.robustness.max_sensitivity
quantus.metrics.robustness.relative_input_stability
quantus.metrics.robustness.relative_output_stability
quantus.metrics.robustness.relative_representation_stability
3 changes: 3 additions & 0 deletions quantus/helpers/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@
"Max-Sensitivity": MaxSensitivity,
"Avg-Sensitivity": AvgSensitivity,
"Consistency": Consistency,
"Relative Input Stability": RelativeInputStability,
"Relative Output Stability": RelativeOutputStability,
"Relative Representation Stability": RelativeRepresentationStability,
},
"Localisation": {
"Pointing Game": PointingGame,
Expand Down
35 changes: 34 additions & 1 deletion quantus/helpers/model/model_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
# Quantus project URL: <https://github.com/understandable-machine-intelligence-lab/Quantus>.

from abc import ABC, abstractmethod
from typing import Any, Dict, Optional, Tuple
from typing import Any, Dict, Optional, Tuple, List

import numpy as np

Expand Down Expand Up @@ -105,3 +105,36 @@ def get_random_layer_generator(self):
set it to 'independent'. For bottom-up order, set it to 'bottom_up'.
"""
raise NotImplementedError


@abstractmethod
def get_hidden_representations(
self,
x: np.ndarray,
layer_names: Optional[List[str]] = None,
layer_indices: Optional[List[int]] = None,
) -> np.ndarray:
"""
Compute the model's internal representation of input x.
In practice, this means, executing a forward pass and then, capturing the output of layers (of interest).
As the exact definition of "internal model representation" is left out in the original paper (see: https://arxiv.org/pdf/2203.06877.pdf),
we make the implementation flexible.
It is up to the user whether all layers are used, or specific ones should be selected.
The user can therefore select a layer by providing 'layer_names' (exclusive) or 'layer_indices'.
Parameters
----------
x: np.ndarray
4D tensor, a batch of input datapoints
layer_names: List[str]
List with names of layers, from which output should be captured.
layer_indices: List[int]
List with indices of layers, from which output should be captured.
Intended to use in case, when layer names are not unique, or unknown.
Returns
-------
L: np.ndarray
2D tensor with shape (batch_size, None)
"""
raise NotImplementedError()
2 changes: 1 addition & 1 deletion quantus/helpers/model/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ def LeNetTF() -> tf.keras.Model:
),
tf.keras.layers.AveragePooling2D(),
tf.keras.layers.Conv2D(
filters=16, kernel_size=(3, 3), activation="relu"
filters=16, kernel_size=(3, 3), activation="relu", name="test_conv"
),
tf.keras.layers.AveragePooling2D(),
tf.keras.layers.Flatten(),
Expand Down
94 changes: 93 additions & 1 deletion quantus/helpers/model/pytorch_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from contextlib import suppress
from copy import deepcopy
from typing import Any, Dict, Optional, Tuple
from typing import Any, Dict, Optional, Tuple, List

import numpy as np
import torch
Expand Down Expand Up @@ -219,3 +219,95 @@ def sample(
"or 'additive' (string) when you sample the model."
)
return model_copy


def get_hidden_representations(
self,
x: np.ndarray,
layer_names: Optional[List[str]] = None,
layer_indices: Optional[List[int]] = None,
) -> np.ndarray:

"""
Compute the model's internal representation of input x.
In practice, this means, executing a forward pass and then, capturing the output of layers (of interest).
As the exact definition of "internal model representation" is left out in the original paper (see: https://arxiv.org/pdf/2203.06877.pdf),
we make the implementation flexible.
It is up to the user whether all layers are used, or specific ones should be selected.
The user can therefore select a layer by providing 'layer_names' (exclusive) or 'layer_indices'.
Parameters
----------
x: np.ndarray
4D tensor, a batch of input datapoints
layer_names: List[str]
List with names of layers, from which output should be captured.
layer_indices: List[int]
List with indices of layers, from which output should be captured.
Intended to use in case, when layer names are not unique, or unknown.
Returns
-------
L: np.ndarray
2D tensor with shape (batch_size, None)
"""

device = self.device if self.device is not None else "cpu"
all_layers = [*self.model.named_modules()]
num_layers = len(all_layers)

if layer_indices is None:
layer_indices = []

# E.g., user can provide index -1, in order to get only representations of the last layer.
# E.g., for 7 layers in total, this would correspond to positive index 6.
positive_layer_indices = [
i if i >= 0 else num_layers + i for i in layer_indices
]

if layer_names is None:
layer_names = []

def is_layer_of_interest(layer_index: int, layer_name: str):
if layer_names == [] and positive_layer_indices == []:
return True
return layer_index in positive_layer_indices or layer_name in layer_names

# skip modules defined by subclassing API.
hidden_layers = list( # type: ignore
filter(
lambda l: not isinstance(
l[1], (self.model.__class__, torch.nn.Sequential)
),
all_layers,
)
)

batch_size = x.shape[0]
hidden_outputs = []

# We register forward hook on layers of interest, which just saves the flattened layers' outputs to list.
# Then we execute forward pass and stack them in 2D tensor.
def hook(module, module_in, module_out):
arr = module_out.cpu().numpy()
arr = arr.reshape((batch_size, -1))
hidden_outputs.append(arr)

new_hooks = []
# Save handles of registered hooks, so we can clean them up later.
for index, (name, layer) in enumerate(hidden_layers):
if is_layer_of_interest(index, name):
handle = layer.register_forward_hook(hook)
new_hooks.append(handle)

if len(new_hooks) == 0:
raise ValueError("No hidden representations were selected.")

# Execute forward pass.
with torch.no_grad():
self.model(torch.Tensor(x).to(device))

# Cleanup.
[i.remove() for i in new_hooks]
return np.hstack(hidden_outputs)

103 changes: 102 additions & 1 deletion quantus/helpers/model/tf_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,14 @@

from __future__ import annotations

from typing import Dict, Optional, Tuple
from typing import Dict, Optional, Tuple, List
from keras.layers import Dense
from keras import activations
from keras import Model
from keras.models import clone_model
import numpy as np
import tensorflow as tf
from warnings import warn
from cachetools import cachedmethod, LRUCache
import operator

Expand Down Expand Up @@ -228,3 +229,103 @@ def get_random_layer_generator(self, order: str = "top_down", seed: int = 42):
np.random.seed(seed=seed + 1)
layer.set_weights([np.random.permutation(w) for w in weights])
yield layer.name, random_layer_model


@cachedmethod(operator.attrgetter("cache"))
def _build_hidden_representation_model(
self, layer_names: Tuple, layer_indices: Tuple
) -> Model:
"""
Build a keras model, which outputs the internal representation of layers,
specified in layer_names or layer_indices, default all.
This requires re-tracing the model, so we cache it to improve metric evaluation time.
"""
if layer_names == () and layer_indices == ():
warn(
"quantus.TensorFlowModel.get_hidden_layers_representations(...) received `layer_names`=None and "
"`layer_indices`=None. This will force creation of tensorflow.keras.Model with outputs of each layer"
" from original model. This can be very computationally expensive."
)

def is_layer_of_interest(index: int, name: str) -> bool:
if layer_names == () and layer_indices == ():
return True
return index in layer_indices or name in layer_names

outputs_of_interest = []
for i, layer in enumerate(self.model.layers):
if is_layer_of_interest(i, layer.name):
outputs_of_interest.append(layer.output)

if len(outputs_of_interest) == 0:
raise ValueError("No hidden representations were selected.")

hidden_representation_model = Model(self.model.input, outputs_of_interest)
return hidden_representation_model


def get_hidden_representations(
self,
x: np.ndarray,
layer_names: Optional[List[str]] = None,
layer_indices: Optional[List[int]] = None,
**kwargs,
) -> np.ndarray:

"""
Compute the model's internal representation of input x.
In practice, this means, executing a forward pass and then, capturing the output of layers (of interest).
As the exact definition of "internal model representation" is left out in the original paper (see: https://arxiv.org/pdf/2203.06877.pdf),
we make the implementation flexible.
It is up to the user whether all layers are used, or specific ones should be selected.
The user can therefore select a layer by providing 'layer_names' (exclusive) or 'layer_indices'.
Parameters
----------
x: np.ndarray
4D tensor, a batch of input datapoints
layer_names: List[str]
List with names of layers, from which output should be captured.
layer_indices: List[int]
List with indices of layers, from which output should be captured.
Intended to use in case, when layer names are not unique, or unknown.
Returns
-------
L: np.ndarray
2D tensor with shape (batch_size, None)
"""

num_layers = len(self.model.layers)

if layer_indices is None:
layer_indices = []

# E.g., user can provide index -1, in order to get only representations of the last layer.
# E.g., for 7 layers in total, this would correspond to positive index 6.
positive_layer_indices = [
i if i >= 0 else num_layers + i for i in layer_indices
]
if layer_names is None:
layer_names = []

# List is not hashable, so we pass names + indices as tuples.
hidden_representation_model = self._build_hidden_representation_model(
tuple(layer_names), tuple(positive_layer_indices)
)
predict_kwargs = self._get_predict_kwargs(**kwargs)
internal_representation = hidden_representation_model.predict(
x, **predict_kwargs
)
input_batch_size = x.shape[0]

# If we requested outputs only of 1 layer, keras will already return np.ndarray.
# Otherwise, keras returns a List of np.ndarray's.
if isinstance(internal_representation, np.ndarray):
return internal_representation.reshape((input_batch_size, -1))

internal_representation = [
i.reshape((input_batch_size, -1)) for i in internal_representation
]
return np.hstack(internal_representation)

5 changes: 5 additions & 0 deletions quantus/metrics/robustness/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,8 @@
from quantus.metrics.robustness.continuity import Continuity
from quantus.metrics.robustness.local_lipschitz_estimate import LocalLipschitzEstimate
from quantus.metrics.robustness.max_sensitivity import MaxSensitivity
from quantus.metrics.robustness.relative_input_stability import RelativeInputStability
from quantus.metrics.robustness.relative_output_stability import RelativeOutputStability
from quantus.metrics.robustness.relative_representation_stability import (
RelativeRepresentationStability,
)
Loading

0 comments on commit 5c6384a

Please sign in to comment.