Skip to content

Commit

Permalink
Updating RelationalFullBatchNodeGenerator's flow method (stellargraph…
Browse files Browse the repository at this point in the history
…#1871)

RelationalFullBatchNodeGenerator's flow method is extended to have a use_ilocs
parameter as FullBatchNodeGenerator's flow method has. This subject was
discussed in issues.

Fixes stellargraph#1870
  • Loading branch information
akinparkan authored Feb 18, 2021
1 parent 1e6120f commit 5ca1e59
Showing 1 changed file with 7 additions and 3 deletions.
10 changes: 7 additions & 3 deletions stellargraph/mapper/full_batch_generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -494,7 +494,7 @@ def __init__(self, G, name=None, sparse=True, transform=None, weighted=False):
def num_batch_dims(self):
return 2

def flow(self, node_ids, targets=None):
def flow(self, node_ids, targets=None, use_ilocs=False):
"""
Creates a generator/sequence object for training or evaluation
with the supplied node ids and numeric targets.
Expand All @@ -503,7 +503,8 @@ def flow(self, node_ids, targets=None):
node_ids: and iterable of node ids for the nodes of interest
(e.g., training, validation, or test set nodes)
targets: a 2D array of numeric node targets with shape ``(len(node_ids), target_size)``
use_ilocs (bool): if True, node_ids are represented by ilocs,
otherwise node_ids need to be transformed into ilocs
Returns:
A NodeSequence object to use with RGCN models
in Keras methods :meth:`fit`, :meth:`evaluate`,
Expand All @@ -519,7 +520,10 @@ def flow(self, node_ids, targets=None):
if len(targets) != len(node_ids):
raise TypeError("Targets must be the same length as node_ids")

node_indices = self.graph.node_ids_to_ilocs(node_ids)
if use_ilocs:
node_indices = np.asarray(node_ids)
else:
node_indices = self.graph.node_ids_to_ilocs(node_ids)

return RelationalFullBatchNodeSequence(
self.features, self.As, self.use_sparse, targets, node_indices
Expand Down

0 comments on commit 5ca1e59

Please sign in to comment.