Skip to content

Commit

Permalink
Unify GCN/GAT FullBatchNodeGenerator and mv GCN_Aadj_feats_op and rel…
Browse files Browse the repository at this point in the history
…ated funcs to utils
  • Loading branch information
wangzhen263 committed Jan 24, 2019
1 parent 0e4f130 commit fd4551c
Show file tree
Hide file tree
Showing 8 changed files with 256 additions and 150 deletions.
5 changes: 3 additions & 2 deletions demos/node-classification-gcn/gcn-cora-example.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@

import stellargraph as sg
from stellargraph.layer import GCN, GraphConvolution
from stellargraph.mapper import FullBatchNodeGenerator, GCN_A_feats
from stellargraph.mapper import FullBatchNodeGenerator
from stellargraph.core.utils import GCN_Aadj_feats_op


def train(train_nodes,
Expand Down Expand Up @@ -95,7 +96,7 @@ def test(test_nodes, test_targets, generator, model_file, model):
test_nodes, test_targets, train_size=300, test_size=None, random_state=523214
)

generator = FullBatchNodeGenerator(G, func_opt=GCN_A_feats, filter='localpool')
generator = FullBatchNodeGenerator(G, func_A_feats=GCN_Aadj_feats_op, filter='localpool')

dropout=0.0
layer_sizes=[16, 7]
Expand Down
69 changes: 69 additions & 0 deletions stellargraph/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,72 @@ def is_real_iterable(x):
True if x is an iterable (but not a string) and False otherwise
"""
return isinstance(x, collections.Iterable) and not isinstance(x, (str, bytes))


def normalized_laplacian(adj, symmetric=True):
adj_normalized = normalize_adj(adj, symmetric)
laplacian = sp.eye(adj.shape[0]) - adj_normalized
return laplacian


def rescale_laplacian(laplacian):
try:
print("Calculating largest eigenvalue of normalized graph Laplacian...")
largest_eigval = eigsh(laplacian, 1, which="LM", return_eigenvectors=False)[0]
except ArpackNoConvergence:
print(
"Eigenvalue calculation did not converge! Using largest_eigval=2 instead."
)
largest_eigval = 2

scaled_laplacian = (2.0 / largest_eigval) * laplacian - sp.eye(laplacian.shape[0])
return scaled_laplacian


def chebyshev_polynomial(X, k):
"""Calculate Chebyshev polynomials up to order k. Return a list of sparse matrices."""
print("Calculating Chebyshev polynomials up to order {}...".format(k))

T_k = list()
T_k.append(sp.eye(X.shape[0]).tocsr())
T_k.append(X)

def chebyshev_recurrence(T_k_minus_one, T_k_minus_two, X):
X_ = sp.csr_matrix(X, copy=True)
return 2 * X_.dot(T_k_minus_one) - T_k_minus_two

for i in range(2, k + 1):
T_k.append(chebyshev_recurrence(T_k[-1], T_k[-2], X))

return T_k


def GCN_Aadj_feats_op(features, A, **kwargs):
def normalize_adj(adj, symmetric=True):
if symmetric:
d = sp.diags(np.power(np.array(adj.sum(1)), -0.5).flatten(), 0)
a_norm = adj.dot(d).transpose().dot(d).tocsr()
else:
d = sp.diags(np.power(np.array(adj.sum(1)), -1).flatten(), 0)
a_norm = d.dot(adj).tocsr()
return a_norm

def preprocess_adj(adj, symmetric=True):
adj = adj + sp.eye(adj.shape[0])
adj = normalize_adj(adj, symmetric)
return adj

# build symmetric adjacency matrix
A = A + A.T.multiply(A.T > A) - A.multiply(A.T > A)
filter = kwargs.get("filter", "localpool")

if filter == "localpool":
""" Local pooling filters (see 'renormalization trick' in Kipf & Welling, arXiv 2016) """
print("Using local pooling filters...")
A = preprocess_adj(A)
elif filter == "chebyshev":
""" Chebyshev polynomial basis filters (Defferard et al., NIPS 2016) """
print("Using Chebyshev polynomial basis filters...")
T_k = chebyshev_polynomial(rescale_laplacian(normalized_laplacian(A)), 2)
features = [features] + T_k
return features, A
8 changes: 3 additions & 5 deletions stellargraph/layer/gcn.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from keras import backend as K
from keras import Input
from keras.layers import Lambda, Dropout, Reshape
from ..mapper import gcn_mappers as gm
from ..mapper.node_mappers import FullBatchNodeGenerator

from typing import List, Tuple, Callable, AnyStr

Expand Down Expand Up @@ -125,7 +125,7 @@ def __init__(
normalize=regularizers.l2(5e-4),
):

if not isinstance(generator, gm.FullBatchNodeGenerator):
if not isinstance(generator, FullBatchNodeGenerator):
raise TypeError("Generator should be a instance of FullBatchNodeGenerator")

assert len(layer_sizes) == len(activations)
Expand Down Expand Up @@ -174,9 +174,7 @@ def node_model(self):
for _ in range(self.support)
]
else:
suppG = [
Input(shape=(None, None), batch_shape=(None, None), sparse=True)
]
suppG = [Input(shape=(None, None), batch_shape=(None, None), sparse=True)]

x_out = self([x_in] + suppG)
return [x_in] + suppG, x_out
Expand Down
1 change: 0 additions & 1 deletion stellargraph/mapper/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,3 @@
# Expose the mappers
from .node_mappers import *
from .link_mappers import *
from .gcn_mappers import *
130 changes: 0 additions & 130 deletions stellargraph/mapper/gcn_mappers.py

This file was deleted.

92 changes: 91 additions & 1 deletion stellargraph/mapper/node_mappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,20 @@
Mappers to provide input data for the graph models in layers.
"""
__all__ = ["NodeSequence", "GraphSAGENodeGenerator", "HinSAGENodeGenerator"]
__all__ = [
"NodeSequence",
"GraphSAGENodeGenerator",
"HinSAGENodeGenerator",
"FullBatchNodeGenerator",
]

import operator
from functools import reduce

import numpy as np
import itertools as it
from keras.utils import Sequence
import networkx as nx

from ..data.explorer import (
SampledBreadthFirstWalk,
Expand Down Expand Up @@ -448,3 +454,87 @@ def flow_from_dataframe(self, node_targets):
"""

return NodeSequence(self, node_targets.index, node_targets.values)


class FullBatchNodeSequence(Sequence):
def __init__(self, features, A, targets=None, sample_weight=None):
# Check targets is iterable & has the correct length
if not is_real_iterable(targets) and targets is not None:
raise TypeError("Targets must be None or an iterable or numpy array ")

self.features = features
self.A = A
self.targets = targets
self.sample_weight = sample_weight

def __len__(self):
return 1

def __getitem__(self, index):
return [self.features, self.A], self.targets, self.sample_weight


class FullBatchNodeGenerator:
def __init__(self, G, name=None, func_opt=None, **kwargs):
if not isinstance(G, StellarGraphBase):
raise TypeError("Graph must be a StellarGraph object.")

self.graph = G
self.name = name

# Check if the graph has features
G.check_graph_for_ml()

# Create sparse adjacency matrix
self.node_list = list(G.nodes())
self.Aadj = nx.adjacency_matrix(G, nodelist=self.node_list)

# We need a schema to check compatibility with GraphSAGE, GAT, GCN
self.schema = G.create_graph_schema(create_type_maps=True)

# Check that there is only a single node type for GraphSAGE, or GAT, or GCN
if len(self.schema.node_types) > 1:
raise TypeError(
"{}: node generator requires graph with single node type; "
"a graph with multiple node types is passed. Stopping.".format(
type(self).__name__
)
)

# Get the features for the nodes
self.features = G.get_feature_for_nodes(self.node_list)

if callable(func_opt):
self.features, self.Aadj = func_opt(
features=self.features, Aadj=self.Aadj, **kwargs
)

self.kwargs = kwargs

def flow(self, node_ids, targets=None):
# Check targets is an iterable
if not is_real_iterable(targets) and not targets is None:
raise TypeError("Targets must be an iterable or None")

# The list of indices of the target nodes in self.node_list
node_indices = np.array([self.node_list.index(n) for n in node_ids])
node_mask = np.zeros(len(self.node_list), dtype=int)
node_mask[node_indices] = 1
node_mask = np.ma.make_mask(node_mask)

# Reshape targets to (number of nodes in self.graph, number of classes), and store in y
if targets is not None:
targets = np.array(targets)
if len(targets.shape) == 1:
c = 1
else:
c = targets.shape[1]

n = self.Aadj.shape[0]
y = np.zeros((n, c))
for i, t in zip(node_indices, targets):
y[i] = t
else:
y = None

return FullBatchNodeSequence(self.features, self.Aadj, y, node_mask)
2 changes: 1 addition & 1 deletion tests/layer/test_gcn.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
"""

from stellargraph.layer.gcn import *
from stellargraph.mapper.gcn_mappers import *
from stellargraph.mapper.node_mappers import FullBatchNodeGenerator
from stellargraph.core.graph import StellarGraph

import networkx as nx
Expand Down
Loading

0 comments on commit fd4551c

Please sign in to comment.