Skip to content

Commit

Permalink
Delete confusing parameter to make FT JIT easy to activate (PaddlePad…
Browse files Browse the repository at this point in the history
…dle#1495)

* fix need build of faster generation

* update

* update
  • Loading branch information
FrostML authored Dec 22, 2021
1 parent 8aa407e commit f631d3e
Show file tree
Hide file tree
Showing 6 changed files with 55 additions and 36 deletions.
12 changes: 10 additions & 2 deletions paddlenlp/ops/ext_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@
# Clear it for other non-CUDA situations.
CUDA_HOME = None

LOADED_EXT = {}


def _get_files(path):
"""
Expand Down Expand Up @@ -221,6 +223,8 @@ def load(name, build_dir=None, force=False, verbose=False, **kwargs):
logger.warning("%s is not available because CUDA can not be found." %
name)
raise NotImplementedError
if name in LOADED_EXT.keys():
return LOADED_EXT[name]
if build_dir is None:
# Maybe under package dir is better to avoid cmake source path conflict
# with different source path.
Expand All @@ -247,7 +251,9 @@ def load(name, build_dir=None, force=False, verbose=False, **kwargs):
ext_sources, ext_filepath, 'newer'):
logger.debug("skipping '%s' extension (up-to-date) build" %
name)
return load_op_meta_info_and_register_op(ext_filepath)
ops = load_op_meta_info_and_register_op(ext_filepath)
LOADED_EXT[name] = ops
return LOADED_EXT[name]

# write setup file and jit compile
file_path = os.path.join(build_dir, "{}_setup.py".format(name))
Expand All @@ -256,7 +262,9 @@ def load(name, build_dir=None, force=False, verbose=False, **kwargs):
if isinstance(extension, CMakeExtension):
# Load a shared library (if exists) only to register op.
if os.path.exists(ext_filepath):
load_op_meta_info_and_register_op(ext_filepath)
ops = load_op_meta_info_and_register_op(ext_filepath)
LOADED_EXT[name] = ops
return LOADED_EXT[name]
else:
# Import as callable python api
return _import_module_from_library(name, build_base_dir, verbose)
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def do_predict(args):

if args.enable_faster_encoder:
transformer = enable_faster_encoder(
transformer, need_build=False, use_fp16=args.use_fp16_encoder)
transformer, use_fp16=args.use_fp16_encoder)

src_word = generate_src_word(
batch_size=args.infer_batch_size,
Expand Down
8 changes: 5 additions & 3 deletions paddlenlp/ops/faster_transformer/transformer/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from paddle.fluid.layer_helper import LayerHelper
from paddlenlp.transformers import WordEmbedding, PositionalEmbedding, position_encoding_init
from paddlenlp.utils.log import logger
from paddlenlp.ops.ext_utils import load
from paddlenlp.ops.ext_utils import load, LOADED_EXT
from paddlenlp.ops import transfer_param


Expand Down Expand Up @@ -160,8 +160,10 @@ def __init__(self,

if decoder_lib is not None and os.path.isfile(decoder_lib):
# Maybe it has been loadad by `ext_utils.load`
paddle.utils.cpp_extension.load_op_meta_info_and_register_op(
decoder_lib)
if "FasterTransformer" not in LOADED_EXT.keys():
ops = paddle.utils.cpp_extension.load_op_meta_info_and_register_op(
decoder_lib)
LOADED_EXT["FasterTransformer"] = ops
else:
if decoder_lib is not None:
logger.warning(
Expand Down
32 changes: 21 additions & 11 deletions paddlenlp/ops/faster_transformer/transformer/decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from paddle.fluid.layer_helper import LayerHelper
import paddle

from paddlenlp.ops.ext_utils import load
from paddlenlp.ops.ext_utils import load, LOADED_EXT
from paddlenlp.utils.log import logger


Expand Down Expand Up @@ -610,8 +610,10 @@ def __init__(self,
# raise ValueError("The path to decoding lib is not exist.")
if decoding_lib is not None and os.path.isfile(decoding_lib):
# Maybe it has been loadad by `ext_utils.load`
paddle.utils.cpp_extension.load_op_meta_info_and_register_op(
decoding_lib)
if "FasterTransformer" not in LOADED_EXT.keys():
ops = paddle.utils.cpp_extension.load_op_meta_info_and_register_op(
decoding_lib)
LOADED_EXT["FasterTransformer"] = ops
else:
if decoding_lib is not None:
logger.warning(
Expand Down Expand Up @@ -870,8 +872,10 @@ def parse_function(func_name):
class InferGptDecoding(nn.Layer):
def __init__(self, model, decoding_lib=None, use_fp16_decoding=False):
if decoding_lib is not None and os.path.isfile(decoding_lib):
paddle.utils.cpp_extension.load_op_meta_info_and_register_op(
decoding_lib)
if "FasterTransformer" not in LOADED_EXT.keys():
ops = paddle.utils.cpp_extension.load_op_meta_info_and_register_op(
decoding_lib)
LOADED_EXT["FasterTransformer"] = ops
else:
if decoding_lib is not None:
logger.warning(
Expand Down Expand Up @@ -1078,8 +1082,10 @@ def __init__(self,
hidden_act="gelu"):
if decoding_lib is not None and os.path.isfile(decoding_lib):
# Maybe it has been loadad by `ext_utils.load`
paddle.utils.cpp_extension.load_op_meta_info_and_register_op(
decoding_lib)
if "FasterTransformer" not in LOADED_EXT.keys():
ops = paddle.utils.cpp_extension.load_op_meta_info_and_register_op(
decoding_lib)
LOADED_EXT["FasterTransformer"] = ops
else:
if decoding_lib is not None:
logger.warning(
Expand Down Expand Up @@ -1442,8 +1448,10 @@ class InferBartDecoding(nn.Layer):
def __init__(self, model, decoding_lib=None, use_fp16_decoding=False):
if decoding_lib is not None and os.path.isfile(decoding_lib):
# Maybe it has been loadad by `ext_utils.load`
paddle.utils.cpp_extension.load_op_meta_info_and_register_op(
decoding_lib)
if "FasterTransformer" not in LOADED_EXT.keys():
ops = paddle.utils.cpp_extension.load_op_meta_info_and_register_op(
decoding_lib)
LOADED_EXT["FasterTransformer"] = ops
else:
if decoding_lib is not None:
logger.warning(
Expand Down Expand Up @@ -1683,8 +1691,10 @@ def __init__(self,
hidden_act="gelu"):
if decoding_lib is not None and os.path.isfile(decoding_lib):
# Maybe it has been loadad by `ext_utils.load`
paddle.utils.cpp_extension.load_op_meta_info_and_register_op(
decoding_lib)
if "FasterTransformer" not in LOADED_EXT.keys():
ops = paddle.utils.cpp_extension.load_op_meta_info_and_register_op(
decoding_lib)
LOADED_EXT["FasterTransformer"] = ops
else:
if decoding_lib is not None:
logger.warning(
Expand Down
33 changes: 16 additions & 17 deletions paddlenlp/ops/faster_transformer/transformer/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,10 +238,7 @@ def encoder_forward(self, src, src_mask=None, cache=None):
return output


def enable_faster_encoder(self,
need_build=True,
use_fp16=False,
encoder_lib=None):
def enable_faster_encoder(self, use_fp16=False, encoder_lib=None):
"""
Compiles fusion encoder operator intergrated FasterTransformer using the
method of JIT(Just-In-Time) and replaces the `forward` function of
Expand Down Expand Up @@ -281,19 +278,21 @@ def init_func(layer):
convert_to_fp16(layer)

if not self.training:
if need_build:
try:
# Pass decoding lib to prevent re-building encoder.
# Todo: check weather decoding lib have contained encoder or not.
if encoder_lib is not None:
load_op_meta_info_and_register_op(encoder_lib)
else:
load("FasterTransformer", verbose=True)
except Exception:
logger.warning(
"Exception occurs when using FasterEncoder. " \
"The original forward will be involved. ")
return self
try:
# Pass decoding lib to prevent re-building encoder.
# Todo: check weather decoding lib have contained encoder or not.
if encoder_lib is not None:
if "FasterTransformer" not in LOADED_EXT.keys():
ops = paddle.utils.cpp_extension.load_op_meta_info_and_register_op(
decoding_lib)
LOADED_EXT["FasterTransformer"] = ops
else:
load("FasterTransformer", verbose=True)
except Exception:
logger.warning(
"Exception occurs when using FasterEncoder. " \
"The original forward will be involved. ")
return self
for layer in self.children():
layer.apply(init_func)
return self
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1161,7 +1161,7 @@ def forward(self,
**model_kwargs):

if encoder_output is None:
self.encoder = enable_faster_encoder(self.encoder, need_build=False)
self.encoder = enable_faster_encoder(self.encoder)
assert input_ids is not None, "You have to specify either input_ids or encoder_output."
encoder_output = self.prepare_encoder_decoder_kwargs_for_generation(
input_ids, model_kwargs)["encoder_output"]
Expand Down Expand Up @@ -1265,7 +1265,7 @@ def forward(self,

#(gongenlei) Not enable_faster_encoder temporarily
if encoder_output is None:
self.encoder = enable_faster_encoder(self.encoder, need_build=False)
self.encoder = enable_faster_encoder(self.encoder)
assert input_ids is not None, "You have to specify either input_ids or encoder_output."
encoder_output = self.prepare_encoder_decoder_kwargs_for_generation(
input_ids, model_kwargs)["encoder_output"]
Expand Down

0 comments on commit f631d3e

Please sign in to comment.