Skip to content

Commit

Permalink
[NN] GNNExplainer (dmlc#3490)
Browse files Browse the repository at this point in the history
* Update

* Update

* Update

* Fix

* Update

* Update

* Update

* Update

* Update

* Update

* Fix

* Fix

* Update

* Update

* Update

* Update

* Update

* Update

* lint fix

* lint fix

* Fix lint

* Update

* Fix CI

* Fix CI

* Fix

* CI

* Fix

* Update

* Fix

* Fix

* Fix CI

* Fix CI
  • Loading branch information
mufeili authored Nov 10, 2021
1 parent 55f2e87 commit dfa32ae
Show file tree
Hide file tree
Showing 13 changed files with 590 additions and 51 deletions.
1 change: 1 addition & 0 deletions conda/dgl/meta.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ requirements:
- scipy
- networkx
- requests
- tqdm

build:
script_env:
Expand Down
16 changes: 14 additions & 2 deletions docs/source/api/python/nn.pytorch.rst
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ GATConv
.. autoclass:: dgl.nn.pytorch.conv.GATConv
:members: forward
:show-inheritance:

GATv2Conv
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Expand Down Expand Up @@ -163,7 +163,7 @@ TWIRLSUnfoldingAndAttention
.. autoclass:: dgl.nn.pytorch.conv.TWIRLSUnfoldingAndAttention
:members: forward
:show-inheritance:

GCN2Conv
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Expand Down Expand Up @@ -319,3 +319,15 @@ NodeEmbedding
.. autoclass:: dgl.nn.pytorch.sparse_emb.NodeEmbedding
:members:
:show-inheritance:

Explainability Models
----------------------------------------

.. automodule:: dgl.nn.pytorch.explain

GNNExplainer
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: dgl.nn.pytorch.explain.GNNExplainer
:members: explain_node, explain_graph
:show-inheritance:
7 changes: 6 additions & 1 deletion python/dgl/backend/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -1148,18 +1148,23 @@ def count_nonzero(input):
# DGL should contain all the operations on index, so this set of operators
# should be gradually removed.

def unique(input):
def unique(input, return_inverse=False):
"""Returns the unique scalar elements in a tensor.
Parameters
----------
input : Tensor
Must be a 1-D tensor.
return_inverse : bool, optional
Whether to also return the indices for where elements in the original
input ended up in the returned unique list.
Returns
-------
Tensor
A 1-D tensor containing unique elements.
Tensor
A 1-D tensor containing the new positions of the elements in the input.
"""
pass

Expand Down
12 changes: 9 additions & 3 deletions python/dgl/backend/mxnet/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,11 +347,17 @@ def count_nonzero(input):
tmp = input.asnumpy()
return np.count_nonzero(tmp)

def unique(input):
def unique(input, return_inverse=False):
# TODO: fallback to numpy is unfortunate
tmp = input.asnumpy()
tmp = np.unique(tmp)
return nd.array(tmp, ctx=input.context, dtype=input.dtype)
if return_inverse:
tmp, inv = np.unique(tmp, return_inverse=True)
tmp = nd.array(tmp, ctx=input.context, dtype=input.dtype)
inv = nd.array(inv, ctx=input.context)
return tmp, inv
else:
tmp = np.unique(tmp)
return nd.array(tmp, ctx=input.context, dtype=input.dtype)

def full_1d(length, fill_value, dtype, ctx):
return nd.full((length,), fill_value, dtype=dtype, ctx=ctx)
Expand Down
4 changes: 2 additions & 2 deletions python/dgl/backend/pytorch/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,10 +295,10 @@ def count_nonzero(input):
# TODO: fallback to numpy for backward compatibility
return np.count_nonzero(input)

def unique(input):
def unique(input, return_inverse=False):
if input.dtype == th.bool:
input = input.type(th.int8)
return th.unique(input)
return th.unique(input, return_inverse=return_inverse)

def full_1d(length, fill_value, dtype, ctx):
return th.full((length,), fill_value, dtype=dtype, device=ctx)
Expand Down
7 changes: 5 additions & 2 deletions python/dgl/backend/tensorflow/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,8 +413,11 @@ def count_nonzero(input):
return int(tf.math.count_nonzero(input))


def unique(input):
return tf.unique(input).y
def unique(input, return_inverse=False):
if return_inverse:
return tf.unique(input)
else:
return tf.unique(input).y


def full_1d(length, fill_value, dtype, ctx):
Expand Down
1 change: 1 addition & 0 deletions python/dgl/nn/pytorch/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Package for pytorch-specific NN modules."""
from .conv import *
from .explain import *
from .glob import *
from .softmax import *
from .factory import *
Expand Down
6 changes: 6 additions & 0 deletions python/dgl/nn/pytorch/explain/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
"""Torch modules for explanation models."""
# pylint: disable= no-member, arguments-differ, invalid-name

from .gnnexplainer import GNNExplainer

__all__ = ['GNNExplainer']
Loading

0 comments on commit dfa32ae

Please sign in to comment.