Skip to content

Commit

Permalink
replace node indices by node ids in the parameter list of saliency maps.
Browse files Browse the repository at this point in the history
  • Loading branch information
sktzwhj committed Aug 6, 2019
1 parent f2f1728 commit 90e82e8
Show file tree
Hide file tree
Showing 4 changed files with 102 additions and 119 deletions.
194 changes: 84 additions & 110 deletions demos/interpretability/gat/node-link-importance-demo-gat.ipynb

Large diffs are not rendered by default.

15 changes: 10 additions & 5 deletions stellargraph/utils/saliency_maps/integrated_gradients_gat.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ class IntegratedGradientsGAT(GradientSaliencyGAT):
A SaliencyMask class that implements the integrated gradients method.
"""

def __init__(self, model, generator):
def __init__(self, model, generator, node_list):
self.node_list = node_list
super().__init__(model, generator)

def get_integrated_node_masks(
Expand All @@ -62,6 +63,8 @@ def get_integrated_node_masks(
non_exist_feature (bool): Setting it to True allows to compute the importance of features that are 0.
return (Numpy array): Integrated gradients for the node features.
"""
node_idx = self.node_list.index(node_id)

X_val = self.X
if X_baseline is None:
if not non_exist_feature:
Expand All @@ -74,8 +77,8 @@ def get_integrated_node_masks(

for alpha in np.linspace(1.0 / steps, 1, steps):
X_step = X_baseline + alpha * X_diff
total_gradients += super(IntegratedGradientsGAT, self).get_node_masks(
node_id, class_of_interest, X_val=X_step
total_gradients += super().get_node_masks(
node_idx, class_of_interest, X_val=X_step
)
return np.squeeze(total_gradients * X_diff, 0)

Expand All @@ -96,6 +99,8 @@ def get_link_importance(
return (Numpy array): shape the same with A_val. Integrated gradients for the links.
"""
node_idx = self.node_list.index(node_id)

A_val = self.A
total_gradients = np.zeros(A_val.shape)
A_diff = (
Expand All @@ -106,8 +111,8 @@ def get_link_importance(
for alpha in np.linspace(1.0 / steps, 1.0, steps):
if self.is_sparse:
A_val = sp.lil_matrix(A_val)
tmp = super(IntegratedGradientsGAT, self).get_link_masks(
alpha, node_id, class_of_interest, int(non_exist_edge)
tmp = super().get_link_masks(
alpha, node_idx, class_of_interest, int(non_exist_edge)
)
if self.is_sparse:
tmp = sp.csr_matrix(
Expand Down
2 changes: 2 additions & 0 deletions stellargraph/utils/saliency_maps/saliency_gat.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ class GradientSaliencyGAT(object):
def __init__(self, model, generator):
"""
Args:
model (Keras model object): The Keras GAT model.
generator (FullBatchNodeSequence object): The generator from which we extract the feature and adjacency matirx.
"""
# The placeholders for features and adjacency matrix (model input):
Expand Down Expand Up @@ -193,6 +194,7 @@ def get_node_importance(self, node_id, class_of_interest, X_val=None, A_val=None
X_val (Numpy array): The feature matrix, we do not directly get it from generator to support the integrated gradients computation.
A_val (Numpy array): The adjacency matrix, we do not directly get it from generator to support the integrated gradients computation. Returns:
"""

if X_val is None:
X_val = self.X
if A_val is None:
Expand Down
10 changes: 6 additions & 4 deletions tests/utils/test_saliency_maps_gat.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,11 +106,13 @@ def test_ig_saliency_map():
if "ig_non_exist_edge" in var.name:
assert K.get_value(var) == 0.0

ig_saliency = IntegratedGradientsGAT(keras_model_gat, train_gen)
target_idx = 0
ig_saliency = IntegratedGradientsGAT(
keras_model_gat, train_gen, generator.node_list
)
target_id = 0
class_of_interest = 0
ig_link_importance = ig_saliency.get_link_importance(
target_idx, class_of_interest, steps=200
target_id, class_of_interest, steps=200
)
print(ig_link_importance)

Expand All @@ -131,7 +133,7 @@ def test_ig_saliency_map():
non_zero_edge_importance = np.sum(np.abs(ig_link_importance) > 1e-11)
assert 8 == non_zero_edge_importance
ig_node_importance = ig_saliency.get_node_importance(
target_idx, class_of_interest, steps=200
target_id, class_of_interest, steps=200
)
print(ig_node_importance)
assert pytest.approx(ig_node_importance, np.array([-13.06, -9.32, -7.46, -3.73, 0]))
Expand Down

0 comments on commit 90e82e8

Please sign in to comment.