Skip to content

Commit

Permalink
Refine transfer_param in FT (PaddlePaddle#1405)
Browse files Browse the repository at this point in the history
* Optmize the cost for fp16 init in FT.

* Remove reserve_data argument in FT

* Remove unused import
  • Loading branch information
guoshengCS authored Dec 8, 2021
1 parent c3a467b commit 235cc50
Showing 1 changed file with 49 additions and 87 deletions.
136 changes: 49 additions & 87 deletions paddlenlp/ops/faster_transformer/transformer/decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def infer_force_decoding(
'EmbBias': linear_bias,
'PositionEncEmb': pos_emb,
# The input of custom op must be given.
# Dispensable() and Intermediate() are not supported.
# Dispensable() and Intermediate() are not supported.
'TrgWord': trg_word
}

Expand Down Expand Up @@ -460,16 +460,25 @@ def finalize(beam_size,
return ids


def transfer_param(p, is_bias=False, restore_data=False, reserve_var=False):
def transfer_param(p, is_bias=False, restore_data=False):
param_shape = p.shape
# Maybe we need allow users using `model.to('float16')` to use fp16 by this.
if (p.dtype == paddle.float16): return p
if restore_data:
if in_dygraph_mode():
param_data = p.numpy()
# Creating parameters with Assign initializer is too slow. Maybe we
# can cast to fp16 directly and get a tensor, while we do it more
# elaborately to get a ParamBase. Also note `VarBase.set_value`
# enforce the same dtype and can not be used directly.
new_p = type(p)(shape=param_shape, dtype="float16", is_bias=is_bias)
new_p.value().get_tensor().set(
param_data.astype("float16"),
paddle.fluid.framework._current_expected_place())
return new_p
else:
param_data = np.array(paddle.static.global_scope().find_var(p.name)
.get_tensor())
if not reserve_var:
del p
return paddle.create_parameter(
shape=param_shape,
dtype="float16",
Expand Down Expand Up @@ -986,15 +995,11 @@ def __init__(self,
paddle.concat(
[
transfer_param(
mod.self_attn.q_proj.weight,
restore_data=True,
reserve_var=True), transfer_param(
mod.self_attn.k_proj.weight,
restore_data=True,
reserve_var=True), transfer_param(
mod.self_attn.v_proj.weight,
restore_data=True,
reserve_var=True)
mod.self_attn.q_proj.weight, restore_data=True),
transfer_param(
mod.self_attn.k_proj.weight, restore_data=True),
transfer_param(
mod.self_attn.v_proj.weight, restore_data=True)
],
axis=-1))
self.sub_modules["slf_q_bias"].append(
Expand All @@ -1003,175 +1008,132 @@ def __init__(self,
transfer_param(
mod.self_attn.q_proj.bias,
is_bias=True,
restore_data=True,
reserve_var=True), transfer_param(
restore_data=True), transfer_param(
mod.self_attn.k_proj.bias,
is_bias=True,
restore_data=True,
reserve_var=True), transfer_param(
restore_data=True), transfer_param(
mod.self_attn.v_proj.bias,
is_bias=True,
restore_data=True,
reserve_var=True)
restore_data=True)
],
axis=-1))
self.sub_modules["slf_k_weight"].append(
transfer_param(
mod.self_attn.k_proj.weight,
restore_data=True,
reserve_var=True))
mod.self_attn.k_proj.weight, restore_data=True))
self.sub_modules["slf_k_bias"].append(
transfer_param(
mod.self_attn.k_proj.bias,
is_bias=True,
restore_data=True,
reserve_var=True))
restore_data=True))
self.sub_modules["slf_v_weight"].append(
transfer_param(
mod.self_attn.v_proj.weight,
restore_data=True,
reserve_var=True))
mod.self_attn.v_proj.weight, restore_data=True))
self.sub_modules["slf_v_bias"].append(
transfer_param(
mod.self_attn.v_proj.bias,
is_bias=True,
restore_data=True,
reserve_var=True))
restore_data=True))
self.sub_modules["slf_out_weight"].append(
transfer_param(
mod.self_attn.out_proj.weight,
restore_data=True,
reserve_var=True))
mod.self_attn.out_proj.weight, restore_data=True))
self.sub_modules["slf_out_bias"].append(
transfer_param(
mod.self_attn.out_proj.bias,
is_bias=True,
restore_data=True,
reserve_var=True))
restore_data=True))
self.sub_modules["ffn_inter_weight"].append(
transfer_param(
mod.linear1.weight, restore_data=True,
reserve_var=True))
mod.linear1.weight, restore_data=True))
self.sub_modules["ffn_inter_bias"].append(
transfer_param(
mod.linear1.bias,
is_bias=True,
restore_data=True,
reserve_var=True))
mod.linear1.bias, is_bias=True, restore_data=True))
self.sub_modules["ffn_out_weight"].append(
transfer_param(
mod.linear2.weight, restore_data=True,
reserve_var=True))
mod.linear2.weight, restore_data=True))
self.sub_modules["ffn_out_bias"].append(
transfer_param(
mod.linear2.bias,
is_bias=True,
restore_data=True,
reserve_var=True))
mod.linear2.bias, is_bias=True, restore_data=True))
self.sub_modules["slf_ln_weight"].append(
transfer_param(
mod.norm1.weight, restore_data=True, reserve_var=True))
mod.norm1.weight, restore_data=True))
self.sub_modules["slf_ln_bias"].append(
transfer_param(
mod.norm1.bias,
is_bias=True,
restore_data=True,
reserve_var=True))
mod.norm1.bias, is_bias=True, restore_data=True))
self.sub_modules["ffn_ln_weight"].append(
transfer_param(
mod.norm2.weight, restore_data=True, reserve_var=True))
mod.norm2.weight, restore_data=True))
self.sub_modules["ffn_ln_bias"].append(
transfer_param(
mod.norm2.bias,
is_bias=True,
restore_data=True,
reserve_var=True))
mod.norm2.bias, is_bias=True, restore_data=True))

self.sub_modules["word_emb"] = [
transfer_param(
self._model.embeddings.word_embeddings.weight,
restore_data=True,
reserve_var=True)
restore_data=True)
]
self.sub_modules["pos_emb"] = [
transfer_param(
self._model.embeddings.position_embeddings.weight,
restore_data=True,
reserve_var=True)
restore_data=True)
]
self.sub_modules["type_emb"] = [
transfer_param(
self._model.embeddings.token_type_embeddings.weight,
restore_data=True,
reserve_var=True)
restore_data=True)
]
if self._normalize_before:
self.sub_modules["decoder_ln_weight"] = [
transfer_param(
self._model.encoder.norm.weight,
restore_data=True,
reserve_var=True)
self._model.encoder.norm.weight, restore_data=True)
]
self.sub_modules["decoder_ln_bias"] = [
transfer_param(
self._model.encoder.norm.bias,
is_bias=True,
restore_data=True,
reserve_var=True)
restore_data=True)
]
else:
self.sub_modules["decoder_ln_weight"] = [
transfer_param(
self._model.encoder_norm.weight,
restore_data=True,
reserve_var=True)
self._model.encoder_norm.weight, restore_data=True)
]
self.sub_modules["decoder_ln_bias"] = [
transfer_param(
self._model.encoder_norm.bias,
is_bias=True,
restore_data=True,
reserve_var=True)
restore_data=True)
]
self.sub_modules["trans_weight"] = [
transfer_param(
self._model.lm_head.transform.weight,
restore_data=True,
reserve_var=True)
self._model.lm_head.transform.weight, restore_data=True)
]
self.sub_modules["trans_bias"] = [
transfer_param(
self._model.lm_head.transform.bias,
is_bias=True,
restore_data=True,
reserve_var=True)
restore_data=True)
]
self.sub_modules["lm_ln_weight"] = [
transfer_param(
self._model.lm_head.layer_norm.weight,
restore_data=True,
reserve_var=True)
self._model.lm_head.layer_norm.weight, restore_data=True)
]
self.sub_modules["lm_ln_bias"] = [
transfer_param(
self._model.lm_head.layer_norm.bias,
is_bias=True,
restore_data=True,
reserve_var=True)
restore_data=True)
]
self.sub_modules["linear_weight"] = [
paddle.transpose(
transfer_param(
self._model.lm_head.decoder_weight,
restore_data=True,
reserve_var=True), [1, 0])
self._model.lm_head.decoder_weight, restore_data=True),
[1, 0])
]
self.sub_modules["linear_bias"] = [
transfer_param(
self._model.lm_head.decoder_bias,
is_bias=True,
restore_data=True,
reserve_var=True)
restore_data=True)
]
else:
for mod in self._model.encoder.layers:
Expand Down

0 comments on commit 235cc50

Please sign in to comment.