Skip to content

Commit 1adc859

Browse files
committed
fix number of kernel values
1 parent df889c8 commit 1adc859

File tree

1 file changed

+32
-17
lines changed

1 file changed

+32
-17
lines changed

rasa/utils/tensorflow/layers.py

Lines changed: 32 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -83,47 +83,62 @@ def build(self, input_shape: tf.TensorShape) -> None:
8383

8484
dtype = dtypes.as_dtype(self.dtype or K.floatx())
8585
if not (dtype.is_floating or dtype.is_complex):
86-
raise TypeError('Unable to build `Dense` layer with non-floating point '
87-
'dtype %s' % (dtype,))
86+
raise TypeError(
87+
"Unable to build `Dense` layer with non-floating point "
88+
"dtype %s" % (dtype,)
89+
)
8890
input_shape = tensor_shape.TensorShape(input_shape)
8991
if tensor_shape.dimension_value(input_shape[-1]) is None:
90-
raise ValueError('The last dimension of the inputs to `Dense` '
91-
'should be defined. Found `None`.')
92+
raise ValueError(
93+
"The last dimension of the inputs to `Dense` "
94+
"should be defined. Found `None`."
95+
)
9296
last_dim = tensor_shape.dimension_value(input_shape[-1])
93-
self.input_spec = InputSpec(min_ndim=2,
94-
axes={-1: last_dim})
97+
self.input_spec = InputSpec(min_ndim=2, axes={-1: last_dim})
9598

9699
self.kernel_shape = tensor_shape.TensorShape([last_dim, self.units])
97100
# create random mask to set some weights to 0
98101
kernel_mask = tf.random.uniform(self.kernel_shape, 0, 1)
99-
kernel_mask = tf.greater_equal(kernel_mask, self.sparsity)
100-
self.kernel_indices = tf.where(kernel_mask)
101-
size = tf.math.count_nonzero(kernel_mask).numpy()
102+
103+
size = int(last_dim * self.units * (1 - self.sparsity))
104+
# the probability that there are identical numbers is negligible
105+
threshold = tf.sort(tf.reshape(kernel_mask, (-1,)), direction="DESCENDING")[
106+
size - 1
107+
]
108+
kernel_mask = tf.greater_equal(kernel_mask, threshold)
109+
110+
self.kernel_indices = tf.Variable(
111+
initial_value=tf.where(kernel_mask), trainable=False, name="kernel_indices"
112+
)
102113
self.kernel_values = self.add_weight(
103-
'kernel_values',
104-
shape=[size, ],
114+
"kernel_values",
115+
shape=[size,],
105116
initializer=self.kernel_initializer,
106117
regularizer=self.kernel_regularizer,
107118
constraint=self.kernel_constraint,
108119
dtype=self.dtype,
109-
trainable=True)
120+
trainable=True,
121+
)
110122

111123
if self.use_bias:
112124
self.bias = self.add_weight(
113-
'bias',
114-
shape=[self.units, ],
125+
"bias",
126+
shape=[self.units,],
115127
initializer=self.bias_initializer,
116128
regularizer=self.bias_regularizer,
117129
constraint=self.bias_constraint,
118130
dtype=self.dtype,
119-
trainable=True)
131+
trainable=True,
132+
)
120133
else:
121134
self.bias = None
122135
self.built = True
123136

124137
def call(self, inputs: tf.Tensor) -> tf.Tensor:
125-
# set some weights to 0 according to precomputed mask
126-
kernel = tf.scatter_nd(self.kernel_indices, self.kernel_values, self.kernel_shape)
138+
# create dense kernel
139+
kernel = tf.scatter_nd(
140+
self.kernel_indices, self.kernel_values, self.kernel_shape
141+
)
127142

128143
rank = len(inputs.shape)
129144
if rank > 2:

0 commit comments

Comments
 (0)