Skip to content

Commit

Permalink
[ModelZoo] Support variable type BF16 in DCN model. (DeepRec-AI#480)
Browse files Browse the repository at this point in the history
  • Loading branch information
Duyi-Wang authored Oct 13, 2022
1 parent cd0c16a commit 28ac395
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions modelzoo/dcn/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,10 +166,10 @@ def _cross_net(self, cross_input, layer_num=2, layer_name=''):
reuse=tf.AUTO_REUSE) as cross_layer_scope:
w = tf.get_variable(
name=layer_name+'_w',
dtype=tf.float32,
dtype=cross_input.dtype,
shape=(last_dim),
)
b = tf.get_variable(name=layer_name+'_b', dtype=tf.float32, shape=(last_dim))
b = tf.get_variable(name=layer_name+'_b', dtype=cross_input.dtype, shape=(last_dim))
xw = tf.reduce_sum(x * w, axis=1, keepdims=True)
x = tf.math.add(tf.math.add(x0 * xw, b), x)

Expand Down

0 comments on commit 28ac395

Please sign in to comment.