@@ -83,47 +83,62 @@ def build(self, input_shape: tf.TensorShape) -> None:
83
83
84
84
dtype = dtypes .as_dtype (self .dtype or K .floatx ())
85
85
if not (dtype .is_floating or dtype .is_complex ):
86
- raise TypeError ('Unable to build `Dense` layer with non-floating point '
87
- 'dtype %s' % (dtype ,))
86
+ raise TypeError (
87
+ "Unable to build `Dense` layer with non-floating point "
88
+ "dtype %s" % (dtype ,)
89
+ )
88
90
input_shape = tensor_shape .TensorShape (input_shape )
89
91
if tensor_shape .dimension_value (input_shape [- 1 ]) is None :
90
- raise ValueError ('The last dimension of the inputs to `Dense` '
91
- 'should be defined. Found `None`.' )
92
+ raise ValueError (
93
+ "The last dimension of the inputs to `Dense` "
94
+ "should be defined. Found `None`."
95
+ )
92
96
last_dim = tensor_shape .dimension_value (input_shape [- 1 ])
93
- self .input_spec = InputSpec (min_ndim = 2 ,
94
- axes = {- 1 : last_dim })
97
+ self .input_spec = InputSpec (min_ndim = 2 , axes = {- 1 : last_dim })
95
98
96
99
self .kernel_shape = tensor_shape .TensorShape ([last_dim , self .units ])
97
100
# create random mask to set some weights to 0
98
101
kernel_mask = tf .random .uniform (self .kernel_shape , 0 , 1 )
99
- kernel_mask = tf .greater_equal (kernel_mask , self .sparsity )
100
- self .kernel_indices = tf .where (kernel_mask )
101
- size = tf .math .count_nonzero (kernel_mask ).numpy ()
102
+
103
+ size = int (last_dim * self .units * (1 - self .sparsity ))
104
+ # the probability that there are identical numbers is negligible
105
+ threshold = tf .sort (tf .reshape (kernel_mask , (- 1 ,)), direction = "DESCENDING" )[
106
+ size - 1
107
+ ]
108
+ kernel_mask = tf .greater_equal (kernel_mask , threshold )
109
+
110
+ self .kernel_indices = tf .Variable (
111
+ initial_value = tf .where (kernel_mask ), trainable = False , name = "kernel_indices"
112
+ )
102
113
self .kernel_values = self .add_weight (
103
- ' kernel_values' ,
104
- shape = [size , ],
114
+ " kernel_values" ,
115
+ shape = [size ,],
105
116
initializer = self .kernel_initializer ,
106
117
regularizer = self .kernel_regularizer ,
107
118
constraint = self .kernel_constraint ,
108
119
dtype = self .dtype ,
109
- trainable = True )
120
+ trainable = True ,
121
+ )
110
122
111
123
if self .use_bias :
112
124
self .bias = self .add_weight (
113
- ' bias' ,
114
- shape = [self .units , ],
125
+ " bias" ,
126
+ shape = [self .units ,],
115
127
initializer = self .bias_initializer ,
116
128
regularizer = self .bias_regularizer ,
117
129
constraint = self .bias_constraint ,
118
130
dtype = self .dtype ,
119
- trainable = True )
131
+ trainable = True ,
132
+ )
120
133
else :
121
134
self .bias = None
122
135
self .built = True
123
136
124
137
def call (self , inputs : tf .Tensor ) -> tf .Tensor :
125
- # set some weights to 0 according to precomputed mask
126
- kernel = tf .scatter_nd (self .kernel_indices , self .kernel_values , self .kernel_shape )
138
+ # create dense kernel
139
+ kernel = tf .scatter_nd (
140
+ self .kernel_indices , self .kernel_values , self .kernel_shape
141
+ )
127
142
128
143
rank = len (inputs .shape )
129
144
if rank > 2 :
0 commit comments