Skip to content

Commit

Permalink
Merge branch 'main' into 200-fix-blur-at-indices
Browse files Browse the repository at this point in the history
  • Loading branch information
annahedstroem authored Jan 5, 2023
2 parents b840c73 + 455c5d9 commit b9c745c
Show file tree
Hide file tree
Showing 30 changed files with 2,599 additions and 517 deletions.
15 changes: 10 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,11 @@ _Quantus is currently under active development so carefully note the Quantus rel
- Offers more than **30+ metrics in 6 categories** for XAI evaluation
- Supports different data types (image, time-series, tabular, NLP next up!) and models (PyTorch and TensorFlow)
- Latest metrics additions:
- <b>Infidelity </b><a href="https://arxiv.org/abs/1901.09392">(Chih-Kuan, Yeh, et al., 2019)</a>
- <b>ROAD </b><a href="https://arxiv.org/abs/2202.00449">(Rong, Leemann, et al., 2022)</a>
- <b>Focus </b><a href="https://arxiv.org/abs/2109.15035">(Arias et al., 2022)</a>
- <b>Consistency </b><a href="https://arxiv.org/abs/2202.00734">(Dasgupta et al., 2022)</a>
- <b>Sufficiency </b><a href="https://arxiv.org/abs/2202.00734">(Dasgupta et al., 2022)</a>
- <b>Relative Input Stability</b><a href="https://arxiv.org/pdf/2203.06877.pdf"> (Chirag Agarwal, et. al., 2022)</a>
- <b>Relative Output Stability</b><a href="https://arxiv.org/pdf/2203.06877.pdf"> (Chirag Agarwal, et. al., 2022)</a>
- <b>Relative Representation Stability</b><a href="https://arxiv.org/pdf/2203.06877.pdf"> (Chirag Agarwal, et. al., 2022)</a>
- New optimisations to help speed up computation, see API reference [here](https://quantus.readthedocs.io/en/latest/docs_api/quantus.metrics.base_batched.html)!

## Citation
Expand Down Expand Up @@ -112,6 +112,9 @@ measures to what extent explanations are stable when subject to slight perturbat
<li><b>Avg-Sensitivity </b><a href="https://arxiv.org/pdf/1901.09392.pdf">(Yeh et al., 2019)</a>: measures the average sensitivity of an explanation using a Monte Carlo sampling-based approximation
<li><b>Continuity </b><a href="https://arxiv.org/pdf/1706.07979.pdf">(Montavon et al., 2018)</a>: captures the strongest variation in explanation of an input and its perturbed version
<li><b>Consistency </b><a href="https://arxiv.org/abs/2202.00734">(Dasgupta et al., 2022)</a>: measures the probability that the inputs with the same explanation have the same prediction label
<li><b>Relative Input Stability (RIS)</b><a href="https://arxiv.org/pdf/2203.06877.pdf"> (Chirag Agarwal, et. al., 2022)</a>: measures the relative distance between explanations e_x and e_x' with respect to the distance between the two inputs x and x'
<li><b>Relative Representation Stability (RRS)</b><a href="https://arxiv.org/pdf/2203.06877.pdf"> (Chirag Agarwal, et. al., 2022)</a>: measures the relative distance between explanations e_x and e_x' with respect to the distance between internal models representations L_x and L_x' for x and x' respectively
<li><b>Relative Output Stability (ROS)</b><a href="https://arxiv.org/pdf/2203.06877.pdf"> (Chirag Agarwal, et. al., 2022)</a>: measures the relative distance between explanations e_x and e_x' with respect to the distance between output logits h(x) and h(x') for x and x' respectively
</ul>
</details>

Expand Down Expand Up @@ -360,8 +363,10 @@ You can alternatively use your own customised explanation function
(assuming it returns an `np.ndarray` in a shape that matches the input `x_batch`). This is done as follows:

```python
def your_own_callable(model, x_batch, y_batch):
"""Logic goes here to compute the attributions in the same shape as x_batch."""
def your_own_callable(model, models, targets, **kwargs) -> np.ndarray
"""Logic goes here to compute the attributions and return an
explanation in the same shape as x_batch (np.array),
(flatten channels if necessary)."""
return explanation(model, x_batch, y_batch)

scores = metric(
Expand Down
2 changes: 1 addition & 1 deletion docs/developer_notes.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,6 @@ $ make html
```

#### Step 2. View edits and make changes accordingly.
http://localhost:63342/Projects/quantus/docs/build/html/index.html#
http://localhost:63342/Projects/quantus/docs/build/html/index.html

A copy is made of CONTRIBUTING.md to docs_dev/CONTRIBUTING.md. To avoid any inconsistencies, edit in CONTRIBUTING.md and overwrite in docs_dev/CONTRIBUTING.md.
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
quantus.metrics.robustness.relative\_input\_stability module
============================================================

.. automodule:: quantus.metrics.robustness.relative_input_stability
:members:
:undoc-members:
:show-inheritance:
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
quantus.metrics.robustness.relative\_ouput\_stability module
============================================================

.. automodule:: quantus.metrics.robustness.relative_ouput_stability
:members:
:undoc-members:
:show-inheritance:
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
quantus.metrics.robustness.relative\_output\_stability module
=============================================================

.. automodule:: quantus.metrics.robustness.relative_output_stability
:members:
:undoc-members:
:show-inheritance:
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
quantus.metrics.robustness.relative\_representation\_stability module
=====================================================================

.. automodule:: quantus.metrics.robustness.relative_representation_stability
:members:
:undoc-members:
:show-inheritance:
3 changes: 3 additions & 0 deletions docs/source/docs_api/quantus.metrics.robustness.rst
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,6 @@ Submodules
quantus.metrics.robustness.continuity
quantus.metrics.robustness.local_lipschitz_estimate
quantus.metrics.robustness.max_sensitivity
quantus.metrics.robustness.relative_input_stability
quantus.metrics.robustness.relative_output_stability
quantus.metrics.robustness.relative_representation_stability
8 changes: 5 additions & 3 deletions docs/source/getting_started/getting_started_example.md
Original file line number Diff line number Diff line change
Expand Up @@ -153,9 +153,11 @@ You can alternatively use your own customised explanation function
(assuming it returns an `np.ndarray` in a shape that matches the input `x_batch`). This is done as follows:

```python
def your_own_callable(model, x_batch, y_batch):
"""Logic goes here to compute the attributions in the same shape as x_batch."""
return explanation(model, x_batch, y_batch)
def your_own_callable(model, inputs, targets, **kwargs) -> np.ndarray:
"""Logic goes here to compute the attributions and return an
explanation in the same shape as x_batch (np.array),
(flatten channels if necessary)."""
return explanation(model, inputs, targets)

scores = metric(
model=model,
Expand Down
24 changes: 17 additions & 7 deletions quantus/functions/explanation_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,8 +339,8 @@ def generate_captum_explanation(

reduce_axes = {"axis": tuple(kwargs.get("reduce_axes", [1])), "keepdims": True}

# For data with no channel dimensions, like tabular data, we want to prevent attribution summation.
if len(tuple(kwargs.get("reduce_axes", [1]))) == 0:
# Prevent attribution summation for 2D-data. Recreate np.sum behavior when passing reduce_axes=(), i.e. no change.
if (len(tuple(kwargs.get("reduce_axes", [1]))) == 0) | (inputs.ndim < 3):

def f_reduce_axes(a):
return a
Expand Down Expand Up @@ -411,11 +411,9 @@ def f_reduce_axes(a):
if isinstance(kwargs["gc_layer"], str):
kwargs["gc_layer"] = eval(kwargs["gc_layer"])

explanation = f_reduce_axes(
LayerGradCam(model, layer=kwargs["gc_layer"]).attribute(
inputs=inputs, target=targets
)
)
explanation = LayerGradCam(model, layer=kwargs["gc_layer"]).attribute(
inputs=inputs, target=targets)

if "interpolate" in kwargs:
if isinstance(kwargs["interpolate"], tuple):
if "interpolate_mode" in kwargs:
Expand All @@ -428,6 +426,18 @@ def f_reduce_axes(a):
explanation = LayerGradCam.interpolate(
explanation, kwargs["interpolate"]
)
else:
if explanation.shape[-1] != inputs.shape[-1]:
warnings.warn(
"Quantus requires GradCam attribution and input to correspond in "
"last dimensions, but got shapes {} and {}\n "
"Pass 'interpolate' argument to explanation function get matching dimensions.".format(
explanation.shape, inputs.shape
),
category=UserWarning,
)

explanation = f_reduce_axes(explanation)

elif method == "Control Var. Sobel Filter".lower():
explanation = torch.zeros(size=inputs.shape)
Expand Down
3 changes: 3 additions & 0 deletions quantus/helpers/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@
"Max-Sensitivity": MaxSensitivity,
"Avg-Sensitivity": AvgSensitivity,
"Consistency": Consistency,
"Relative Input Stability": RelativeInputStability,
"Relative Output Stability": RelativeOutputStability,
"Relative Representation Stability": RelativeRepresentationStability,
},
"Localisation": {
"Pointing Game": PointingGame,
Expand Down
35 changes: 34 additions & 1 deletion quantus/helpers/model/model_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
# Quantus project URL: <https://github.com/understandable-machine-intelligence-lab/Quantus>.

from abc import ABC, abstractmethod
from typing import Any, Dict, Optional, Tuple
from typing import Any, Dict, Optional, Tuple, List

import numpy as np

Expand Down Expand Up @@ -105,3 +105,36 @@ def get_random_layer_generator(self):
set it to 'independent'. For bottom-up order, set it to 'bottom_up'.
"""
raise NotImplementedError


@abstractmethod
def get_hidden_representations(
self,
x: np.ndarray,
layer_names: Optional[List[str]] = None,
layer_indices: Optional[List[int]] = None,
) -> np.ndarray:
"""
Compute the model's internal representation of input x.
In practice, this means, executing a forward pass and then, capturing the output of layers (of interest).
As the exact definition of "internal model representation" is left out in the original paper (see: https://arxiv.org/pdf/2203.06877.pdf),
we make the implementation flexible.
It is up to the user whether all layers are used, or specific ones should be selected.
The user can therefore select a layer by providing 'layer_names' (exclusive) or 'layer_indices'.
Parameters
----------
x: np.ndarray
4D tensor, a batch of input datapoints
layer_names: List[str]
List with names of layers, from which output should be captured.
layer_indices: List[int]
List with indices of layers, from which output should be captured.
Intended to use in case, when layer names are not unique, or unknown.
Returns
-------
L: np.ndarray
2D tensor with shape (batch_size, None)
"""
raise NotImplementedError()
81 changes: 40 additions & 41 deletions quantus/helpers/model/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,53 +155,52 @@ def forward(self, x):


if util.find_spec("tensorflow"):

import tensorflow as tf
from tensorflow.keras.models import Sequential

class LeNetTF(Sequential):
def LeNetTF() -> tf.keras.Model:
"""
A Tensorflow implementation of LeNet architecture.
Adapted from: https://www.tensorflow.org/datasets/keras_example.
A Tensorflow implementation of LeNet5 architecture.
"""

def __init__(self):
super().__init__(
[
tf.keras.layers.Flatten(input_shape=(28, 28, 1)),
tf.keras.layers.Dense(128, activation="relu"),
tf.keras.layers.Dense(10),
]
)
self.compile(
optimizer=tf.keras.optimizers.Adam(0.001),
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=[tf.keras.metrics.SparseCategoricalAccuracy()],
)

class ConvNet1DTF(Sequential):
return tf.keras.Sequential(
[
tf.keras.layers.Conv2D(
filters=6,
kernel_size=(3, 3),
activation="relu",
input_shape=(28, 28, 1),
),
tf.keras.layers.AveragePooling2D(),
tf.keras.layers.Conv2D(
filters=16, kernel_size=(3, 3), activation="relu", name="test_conv"
),
tf.keras.layers.AveragePooling2D(),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(units=120, activation="relu"),
tf.keras.layers.Dense(units=84, activation="relu"),
tf.keras.layers.Dense(units=10),
],
name="LeNetTF",
)

def ConvNet1DTF(n_channels: int, seq_len: int, n_classes: int) -> tf.keras.Model:

"""
A Tensorflow implementation of 1D-convolutional architecture.
"""

def __init__(self, n_channels, seq_len, n_classes):
super().__init__(
[
tf.keras.layers.Input(shape=(seq_len, n_channels)),
tf.keras.layers.Conv1D(filters=6, kernel_size=5, strides=1),
tf.keras.layers.Activation("relu"),
tf.keras.layers.AveragePooling1D(pool_size=2, strides=2),
tf.keras.layers.Conv1D(filters=16, kernel_size=5, strides=1),
tf.keras.layers.Activation("relu"),
tf.keras.layers.AveragePooling1D(pool_size=2, strides=2),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(128, activation="relu"),
tf.keras.layers.Dense(84, activation="relu"),
tf.keras.layers.Dense(n_classes),
]
)
self.compile(
optimizer=tf.keras.optimizers.Adam(0.001),
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=[tf.keras.metrics.SparseCategoricalAccuracy()],
)
return tf.keras.Sequential(
[
tf.keras.layers.Input(shape=(seq_len, n_channels)),
tf.keras.layers.Conv1D(filters=6, kernel_size=5, strides=1),
tf.keras.layers.Activation("relu"),
tf.keras.layers.AveragePooling1D(pool_size=2, strides=2),
tf.keras.layers.Conv1D(filters=16, kernel_size=5, strides=1),
tf.keras.layers.Activation("relu"),
tf.keras.layers.AveragePooling1D(pool_size=2, strides=2),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(128, activation="relu"),
tf.keras.layers.Dense(84, activation="relu"),
tf.keras.layers.Dense(n_classes),
]
)
Loading

0 comments on commit b9c745c

Please sign in to comment.