|
2 | 2 | from typing import List, Optional, Text, Tuple, Callable, Union
|
3 | 3 | import tensorflow as tf
|
4 | 4 | import tensorflow_addons as tfa
|
5 |
| -from tensorflow.python.keras.utils import tf_utils |
| 5 | + |
| 6 | +from tensorflow.python.eager import context |
| 7 | +from tensorflow.python.framework import dtypes |
| 8 | +from tensorflow.python.framework import tensor_shape |
6 | 9 | from tensorflow.python.keras import backend as K
|
7 |
| -from tensorflow.python.keras import initializers |
| 10 | +from tensorflow.python.keras.engine.input_spec import InputSpec |
| 11 | +from tensorflow.python.keras.utils import tf_utils |
| 12 | +from tensorflow.python.ops import gen_math_ops |
| 13 | +from tensorflow.python.ops import math_ops |
| 14 | +from tensorflow.python.ops import nn |
| 15 | +from tensorflow.python.ops import sparse_ops |
| 16 | +from tensorflow.python.ops import standard_ops |
8 | 17 |
|
9 | 18 | logger = logging.getLogger(__name__)
|
10 | 19 |
|
@@ -71,20 +80,71 @@ def __init__(self, sparsity: int = 0.8, **kwargs) -> None:
|
71 | 80 | self.sparsity = sparsity
|
72 | 81 |
|
73 | 82 | def build(self, input_shape: tf.TensorShape) -> None:
|
74 |
| - super().build(input_shape) |
| 83 | + |
| 84 | + dtype = dtypes.as_dtype(self.dtype or K.floatx()) |
| 85 | + if not (dtype.is_floating or dtype.is_complex): |
| 86 | + raise TypeError('Unable to build `Dense` layer with non-floating point ' |
| 87 | + 'dtype %s' % (dtype,)) |
| 88 | + input_shape = tensor_shape.TensorShape(input_shape) |
| 89 | + if tensor_shape.dimension_value(input_shape[-1]) is None: |
| 90 | + raise ValueError('The last dimension of the inputs to `Dense` ' |
| 91 | + 'should be defined. Found `None`.') |
| 92 | + last_dim = tensor_shape.dimension_value(input_shape[-1]) |
| 93 | + self.input_spec = InputSpec(min_ndim=2, |
| 94 | + axes={-1: last_dim}) |
| 95 | + |
| 96 | + self.kernel_shape = tensor_shape.TensorShape([last_dim, self.units]) |
75 | 97 | # create random mask to set some weights to 0
|
76 |
| - kernel_mask = tf.random.uniform(tf.shape(self.kernel), 0, 1) |
77 |
| - kernel_mask = tf.cast( |
78 |
| - tf.greater_equal(kernel_mask, self.sparsity), self.kernel.dtype |
79 |
| - ) |
80 |
| - self.kernel_mask = tf.Variable( |
81 |
| - initial_value=kernel_mask, trainable=False, name="kernel_mask" |
82 |
| - ) |
| 98 | + kernel_mask = tf.random.uniform(self.kernel_shape, 0, 1) |
| 99 | + kernel_mask = tf.greater_equal(kernel_mask, self.sparsity) |
| 100 | + self.kernel_indices = tf.where(kernel_mask) |
| 101 | + size = tf.math.count_nonzero(kernel_mask).numpy() |
| 102 | + self.kernel_values = self.add_weight( |
| 103 | + 'kernel_values', |
| 104 | + shape=[size, ], |
| 105 | + initializer=self.kernel_initializer, |
| 106 | + regularizer=self.kernel_regularizer, |
| 107 | + constraint=self.kernel_constraint, |
| 108 | + dtype=self.dtype, |
| 109 | + trainable=True) |
| 110 | + |
| 111 | + if self.use_bias: |
| 112 | + self.bias = self.add_weight( |
| 113 | + 'bias', |
| 114 | + shape=[self.units, ], |
| 115 | + initializer=self.bias_initializer, |
| 116 | + regularizer=self.bias_regularizer, |
| 117 | + constraint=self.bias_constraint, |
| 118 | + dtype=self.dtype, |
| 119 | + trainable=True) |
| 120 | + else: |
| 121 | + self.bias = None |
| 122 | + self.built = True |
83 | 123 |
|
84 | 124 | def call(self, inputs: tf.Tensor) -> tf.Tensor:
|
85 | 125 | # set some weights to 0 according to precomputed mask
|
86 |
| - self.kernel.assign(self.kernel * self.kernel_mask) |
87 |
| - return super().call(inputs) |
| 126 | + kernel = tf.scatter_nd(self.kernel_indices, self.kernel_values, self.kernel_shape) |
| 127 | + |
| 128 | + rank = len(inputs.shape) |
| 129 | + if rank > 2: |
| 130 | + # Broadcasting is required for the inputs. |
| 131 | + outputs = standard_ops.tensordot(inputs, kernel, [[rank - 1], [0]]) |
| 132 | + # Reshape the output back to the original ndim of the input. |
| 133 | + if not context.executing_eagerly(): |
| 134 | + shape = inputs.shape.as_list() |
| 135 | + output_shape = shape[:-1] + [self.units] |
| 136 | + outputs.set_shape(output_shape) |
| 137 | + else: |
| 138 | + inputs = math_ops.cast(inputs, self._compute_dtype) |
| 139 | + if K.is_sparse(inputs): |
| 140 | + outputs = sparse_ops.sparse_tensor_dense_matmul(inputs, kernel) |
| 141 | + else: |
| 142 | + outputs = gen_math_ops.mat_mul(inputs, kernel) |
| 143 | + if self.use_bias: |
| 144 | + outputs = nn.bias_add(outputs, self.bias) |
| 145 | + if self.activation is not None: |
| 146 | + return self.activation(outputs) # pylint: disable=not-callable |
| 147 | + return outputs |
88 | 148 |
|
89 | 149 |
|
90 | 150 | class Ffnn(tf.keras.layers.Layer):
|
|
0 commit comments