Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
JunnYu authored Jul 13, 2021
1 parent 83f0251 commit 4cdb7ee
Showing 2 changed files with 8 additions and 6 deletions.
7 changes: 4 additions & 3 deletions src/transformers/models/marian/modeling_tf_marian.py
Original file line number Diff line number Diff line change
@@ -151,11 +151,12 @@ def _init_weight(n_pos: int, dim: int):
position_enc = np.array(
[[pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)] for pos in range(n_pos)]
)
table = np.zeros_like(position_enc)
# index 0 is all zero
position_enc[:, 0 : dim // 2] = np.sin(position_enc[:, 0::2])
position_enc[:, dim // 2 :] = np.cos(position_enc[:, 1::2])
table[:, 0 : dim // 2] = np.sin(position_enc[:, 0::2])
table[:, dim // 2 :] = np.cos(position_enc[:, 1::2])
# convert to tensor
table = tf.convert_to_tensor(position_enc)
table = tf.convert_to_tensor(table)
tf.stop_gradient(table)
return table

7 changes: 4 additions & 3 deletions src/transformers/models/pegasus/modeling_tf_pegasus.py
Original file line number Diff line number Diff line change
@@ -152,11 +152,12 @@ def _init_weight(n_pos: int, dim: int):
position_enc = np.array(
[[pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)] for pos in range(n_pos)]
)
table = np.zeros_like(position_enc)
# index 0 is all zero
position_enc[:, 0 : dim // 2] = np.sin(position_enc[:, 0::2])
position_enc[:, dim // 2 :] = np.cos(position_enc[:, 1::2])
table[:, 0 : dim // 2] = np.sin(position_enc[:, 0::2])
table[:, dim // 2 :] = np.cos(position_enc[:, 1::2])
# convert to tensor
table = tf.convert_to_tensor(position_enc)
table = tf.convert_to_tensor(table)
tf.stop_gradient(table)
return table

0 comments on commit 4cdb7ee

Please sign in to comment.