Skip to content

Commit df889c8

Browse files
committed
change DenseWithSparseWeights
1 parent 5e0eb59 commit df889c8

File tree

1 file changed

+72
-12
lines changed

1 file changed

+72
-12
lines changed

rasa/utils/tensorflow/layers.py

Lines changed: 72 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,18 @@
22
from typing import List, Optional, Text, Tuple, Callable, Union
33
import tensorflow as tf
44
import tensorflow_addons as tfa
5-
from tensorflow.python.keras.utils import tf_utils
5+
6+
from tensorflow.python.eager import context
7+
from tensorflow.python.framework import dtypes
8+
from tensorflow.python.framework import tensor_shape
69
from tensorflow.python.keras import backend as K
7-
from tensorflow.python.keras import initializers
10+
from tensorflow.python.keras.engine.input_spec import InputSpec
11+
from tensorflow.python.keras.utils import tf_utils
12+
from tensorflow.python.ops import gen_math_ops
13+
from tensorflow.python.ops import math_ops
14+
from tensorflow.python.ops import nn
15+
from tensorflow.python.ops import sparse_ops
16+
from tensorflow.python.ops import standard_ops
817

918
logger = logging.getLogger(__name__)
1019

@@ -71,20 +80,71 @@ def __init__(self, sparsity: int = 0.8, **kwargs) -> None:
7180
self.sparsity = sparsity
7281

7382
def build(self, input_shape: tf.TensorShape) -> None:
74-
super().build(input_shape)
83+
84+
dtype = dtypes.as_dtype(self.dtype or K.floatx())
85+
if not (dtype.is_floating or dtype.is_complex):
86+
raise TypeError('Unable to build `Dense` layer with non-floating point '
87+
'dtype %s' % (dtype,))
88+
input_shape = tensor_shape.TensorShape(input_shape)
89+
if tensor_shape.dimension_value(input_shape[-1]) is None:
90+
raise ValueError('The last dimension of the inputs to `Dense` '
91+
'should be defined. Found `None`.')
92+
last_dim = tensor_shape.dimension_value(input_shape[-1])
93+
self.input_spec = InputSpec(min_ndim=2,
94+
axes={-1: last_dim})
95+
96+
self.kernel_shape = tensor_shape.TensorShape([last_dim, self.units])
7597
# create random mask to set some weights to 0
76-
kernel_mask = tf.random.uniform(tf.shape(self.kernel), 0, 1)
77-
kernel_mask = tf.cast(
78-
tf.greater_equal(kernel_mask, self.sparsity), self.kernel.dtype
79-
)
80-
self.kernel_mask = tf.Variable(
81-
initial_value=kernel_mask, trainable=False, name="kernel_mask"
82-
)
98+
kernel_mask = tf.random.uniform(self.kernel_shape, 0, 1)
99+
kernel_mask = tf.greater_equal(kernel_mask, self.sparsity)
100+
self.kernel_indices = tf.where(kernel_mask)
101+
size = tf.math.count_nonzero(kernel_mask).numpy()
102+
self.kernel_values = self.add_weight(
103+
'kernel_values',
104+
shape=[size, ],
105+
initializer=self.kernel_initializer,
106+
regularizer=self.kernel_regularizer,
107+
constraint=self.kernel_constraint,
108+
dtype=self.dtype,
109+
trainable=True)
110+
111+
if self.use_bias:
112+
self.bias = self.add_weight(
113+
'bias',
114+
shape=[self.units, ],
115+
initializer=self.bias_initializer,
116+
regularizer=self.bias_regularizer,
117+
constraint=self.bias_constraint,
118+
dtype=self.dtype,
119+
trainable=True)
120+
else:
121+
self.bias = None
122+
self.built = True
83123

84124
def call(self, inputs: tf.Tensor) -> tf.Tensor:
85125
# set some weights to 0 according to precomputed mask
86-
self.kernel.assign(self.kernel * self.kernel_mask)
87-
return super().call(inputs)
126+
kernel = tf.scatter_nd(self.kernel_indices, self.kernel_values, self.kernel_shape)
127+
128+
rank = len(inputs.shape)
129+
if rank > 2:
130+
# Broadcasting is required for the inputs.
131+
outputs = standard_ops.tensordot(inputs, kernel, [[rank - 1], [0]])
132+
# Reshape the output back to the original ndim of the input.
133+
if not context.executing_eagerly():
134+
shape = inputs.shape.as_list()
135+
output_shape = shape[:-1] + [self.units]
136+
outputs.set_shape(output_shape)
137+
else:
138+
inputs = math_ops.cast(inputs, self._compute_dtype)
139+
if K.is_sparse(inputs):
140+
outputs = sparse_ops.sparse_tensor_dense_matmul(inputs, kernel)
141+
else:
142+
outputs = gen_math_ops.mat_mul(inputs, kernel)
143+
if self.use_bias:
144+
outputs = nn.bias_add(outputs, self.bias)
145+
if self.activation is not None:
146+
return self.activation(outputs) # pylint: disable=not-callable
147+
return outputs
88148

89149

90150
class Ffnn(tf.keras.layers.Layer):

0 commit comments

Comments
 (0)