Skip to content

Commit

Permalink
允许roformer自定义位置id
Browse files Browse the repository at this point in the history
修复roformer自定义位置id不起作用
  • Loading branch information
bojone authored Mar 30, 2022
2 parents 56407a5 + 1fae360 commit 9769d5e
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 3 deletions.
2 changes: 1 addition & 1 deletion bert4keras/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -837,8 +837,8 @@ def call(self, inputs):
"""如果custom_position_ids,那么第二个输入为自定义的位置id
"""
if self.custom_position_ids:
seq_len = K.shape(inputs)[1]
inputs, position_ids = inputs
seq_len = K.shape(inputs)[1]
if 'float' not in K.dtype(position_ids):
position_ids = K.cast(position_ids, K.floatx())
else:
Expand Down
7 changes: 5 additions & 2 deletions bert4keras/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1197,13 +1197,16 @@ def compute_position_bias(self, inputs=None):
"""Sinusoidal位置编码(直接返回)
"""
if self.position_bias is None:

x = inputs
if self.custom_position_ids:
x = [inputs, self.inputs[2]]
else:
x = inputs
self.position_bias = self.apply(
inputs=x,
layer=SinusoidalPositionEmbedding,
output_dim=self.attention_key_size,
merge_mode='zero',
custom_position_ids=self.custom_position_ids,
name='Embedding-Rotary-Position'
)

Expand Down

0 comments on commit 9769d5e

Please sign in to comment.