Skip to content

Commit

Permalink
update __call__ method in rbflayer
Browse files Browse the repository at this point in the history
  • Loading branch information
Shkev committed Dec 9, 2020
1 parent 2410dca commit be2a74d
Showing 1 changed file with 8 additions and 3 deletions.
11 changes: 8 additions & 3 deletions rbflayer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,14 @@ def __init__(self, X):
self.X = X

def __call__(self, shape, dtype=None):
assert shape[1] == self.X.shape[1]
idx = np.random.randint(self.X.shape[0], size=shape[0])
return self.X[idx, :]
assert shape[1] == self.X.shape[1]
idx = np.random.randint(self.X.shape[0], size=shape[0])

# type checking to access elements of data correctly
if type(self.X) == np.ndarray:
return self.X[idx, :]
elif type(self.X) == pd.core.frame.DataFrame:
return self.X.iloc[idx, :]


class RBFLayer(Layer):
Expand Down

0 comments on commit be2a74d

Please sign in to comment.