Skip to content

Commit

Permalink
Added tests for checking if in an epoch the samples generated are cor…
Browse files Browse the repository at this point in the history
…rect.
  • Loading branch information
habiba-h committed Feb 20, 2019
1 parent af792d5 commit eeea3d1
Showing 1 changed file with 67 additions and 3 deletions.
70 changes: 67 additions & 3 deletions tests/mapper/test_link_mappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,20 @@ def test_edge_consistency(shuffle):
# feature_size=2 * self.n_feat,
# )

def test_GraphSAGELinkGenerator_not_Stellargraph(self):
G = nx.Graph()
elist = [(1, 2), (2, 3), (1, 4), (3, 2)]
G.add_edges_from(elist)

# Add example features
for v in G.nodes():
G.node[v]["feature"] = np.ones(1)

with pytest.raises(TypeError):
GraphSAGELinkGenerator(
G, batch_size=self.batch_size, num_samples=self.num_samples
)

def test_GraphSAGELinkGenerator_zero_samples(self):

G = example_Graph_1(self.n_feat)
Expand Down Expand Up @@ -411,9 +425,9 @@ def test_GraphSAGELinkGenerator_isolates(self):
ne, nl = gen[0]
assert pytest.approx([1, 1, 2, 2, 4, 4]) == [x.shape[1] for x in ne]

def test_GraphSAGELinkGenerator_unsupervisedSampler(self):
def test_GraphSAGELinkGenerator_unsupervisedSampler_flow(self):
"""
This tests link generator's iterator for real time link generation i.e. there is no pregenerated list of samples provided to it.
This tests link generator's initialization for on demand link generation i.e. there is no pregenerated list of samples provided to it.
"""
n_feat = 4
n_batch = 2
Expand All @@ -432,12 +446,62 @@ def test_GraphSAGELinkGenerator_unsupervisedSampler(self):
unsupervisedSamples
)

# UnsupervisedSampler object or a list of samples is not passed
# The flow method is not passed UnsupervisedSampler object or a list of samples is not passed
with pytest.raises(TypeError):
gen = GraphSAGELinkGenerator(
G, batch_size=n_batch, num_samples=n_samples
).flow("not_a_list_of_samples_or_a_sample_generator")

# The flow method is not passed nothing
with pytest.raises(TypeError):
gen = GraphSAGELinkGenerator(
G, batch_size=n_batch, num_samples=n_samples
).flow()

def test_GraphSAGELinkGenerator_unsupervisedSampler_sample_generation(self):

G = example_Graph_2(self.n_feat)

rw = UniformRandomWalk(G)

unsupervisedSamples = UnsupervisedSampler(G, walker=rw)

mapper = GraphSAGELinkGenerator(
G, batch_size=self.batch_size, num_samples=self.num_samples
).flow(unsupervisedSamples)

assert mapper.data_size == 8
assert self.batch_size == 2
assert len(mapper) == 4

for batch in range(len(mapper)):
nf, nl = mapper[batch]

assert len(nf) == 3 * 2
assert len(set(mapper.head_node_types)) == 1

for ii in range(2):
assert nf[ii].shape == (
min(self.batch_size, mapper.data_size),
1,
self.n_feat,
)
assert nf[ii + 2].shape == (
min(self.batch_size, mapper.data_size),
2,
self.n_feat,
)
assert nf[ii + 2 * 2].shape == (
min(self.batch_size, mapper.data_size),
2 * 2,
self.n_feat,
)
assert len(nl) == min(self.batch_size, mapper.data_size)
assert sorted(nl) == [0, 1]

with pytest.raises(IndexError):
nf, nl = mapper[4]


class Test_HinSAGELinkGenerator(object):
"""
Expand Down

0 comments on commit eeea3d1

Please sign in to comment.