Skip to content

Commit

Permalink
Optimize import and format code
Browse files Browse the repository at this point in the history
  • Loading branch information
danielegrattarola committed Mar 11, 2022
1 parent 797868a commit 2b6ac04
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 7 deletions.
2 changes: 1 addition & 1 deletion spektral/layers/convolutional/xenet_conv.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from collections.abc import Iterable

import tensorflow as tf
from tensorflow.keras.layers import Concatenate, Dense, Multiply, PReLU, ReLU, Reshape
from tensorflow.keras.layers import Concatenate, Dense, Multiply, PReLU, ReLU
from tensorflow.python.ops import gen_sparse_ops

from spektral.layers.convolutional.conv import Conv
Expand Down
5 changes: 4 additions & 1 deletion spektral/layers/pooling/la_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ class LaPool(SRCPool):
the papaer (can be expensive);
- `return_selection`: boolean, whether to return the selection matrix;
"""

def __init__(self, shortest_path_reg=True, return_selection=False, **kwargs):
super().__init__(return_selection=return_selection, **kwargs)

Expand Down Expand Up @@ -99,7 +100,9 @@ def shortest_path(a_):
s = beta * tf.sparse.to_dense(s)

# Leaders end up entirely in their own cluster
kronecker_delta = tf.boolean_mask(tf.eye(self.n_nodes, dtype=s.dtype), leader_mask, axis=1)
kronecker_delta = tf.boolean_mask(
tf.eye(self.n_nodes, dtype=s.dtype), leader_mask, axis=1
)

# Create clustering
s = tf.where(leader_mask[:, None], kronecker_delta, s)
Expand Down
4 changes: 2 additions & 2 deletions spektral/layers/pooling/sag_pool.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import tensorflow as tf
from spektral.layers import ops
from tensorflow.keras import backend as K

from spektral.layers import ops
from spektral.layers.pooling.topk_pool import TopKPool
from tensorflow.keras import backend as K


class SAGPool(TopKPool):
Expand Down
6 changes: 3 additions & 3 deletions tests/test_layers/pooling/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,9 +135,9 @@ def _test_disjoint_mode(layer, sparse=False, **kwargs):

if "ratio" in kwargs.keys():
N_pool_expected = int(
np.ceil(kwargs["ratio"] * N1)
+ np.ceil(kwargs["ratio"] * N2)
+ np.ceil(kwargs["ratio"] * N3)
np.ceil(kwargs["ratio"] * N1)
+ np.ceil(kwargs["ratio"] * N2)
+ np.ceil(kwargs["ratio"] * N3)
)
elif "k" in kwargs.keys():
N_pool_expected = int(kwargs["k"])
Expand Down

0 comments on commit 2b6ac04

Please sign in to comment.