diff --git a/.flake8 b/.flake8 index 16af86b67baf5..d5bfc9f5be2b3 100644 --- a/.flake8 +++ b/.flake8 @@ -22,3 +22,4 @@ exclude = ./orttraining, # ignore server code for now ./server, +ignore = W503, E203 diff --git a/cgmanifests/generate_cgmanifest.py b/cgmanifests/generate_cgmanifest.py index bb27bfd76893e..a6cdfb1065d15 100644 --- a/cgmanifests/generate_cgmanifest.py +++ b/cgmanifests/generate_cgmanifest.py @@ -15,8 +15,10 @@ registrations = [] -with open(os.path.join(REPO_DIR, 'tools', 'ci_build', 'github', 'linux', 'docker', 'Dockerfile.manylinux2014_cuda11'), - mode="r") as f: +with open( + os.path.join(REPO_DIR, "tools", "ci_build", "github", "linux", "docker", "Dockerfile.manylinux2014_cuda11"), + mode="r", +) as f: for line in f: if not line.strip(): package_name = None @@ -36,15 +38,12 @@ m = re.match(r"(.+?)_DOWNLOAD_URL=(\S+)", line) if m is not None: package_url = m.group(2) - if package_name == 'LIBXCRYPT': - package_url = m.group(2) + "/v" + \ - package_filename + ".tar.gz" - elif package_name == 'CMAKE': - package_url = m.group( - 2) + "/v" + package_filename + "/cmake-" + package_filename + ".tar.gz" + if package_name == "LIBXCRYPT": + package_url = m.group(2) + "/v" + package_filename + ".tar.gz" + elif package_name == "CMAKE": + package_url = m.group(2) + "/v" + package_filename + "/cmake-" + package_filename + ".tar.gz" else: - package_url = m.group(2) + "/" + \ - package_filename + ".tar.gz" + package_url = m.group(2) + "/" + package_filename + ".tar.gz" registration = { "Component": { "Type": "other", @@ -53,7 +52,7 @@ "Version": package_filename.split("-")[-1], "DownloadUrl": package_url, }, - "comments": "manylinux dependency" + "comments": "manylinux dependency", } } registrations.append(registration) @@ -67,14 +66,23 @@ def normalize_path_separators(path): proc = subprocess.run( - ["git", "submodule", "foreach", "--quiet", "--recursive", "{} {} $toplevel/$sm_path".format( - normalize_path_separators(sys.executable), - normalize_path_separators(os.path.join(SCRIPT_DIR, "print_submodule_info.py")))], + [ + "git", + "submodule", + "foreach", + "--quiet", + "--recursive", + "{} {} $toplevel/$sm_path".format( + normalize_path_separators(sys.executable), + normalize_path_separators(os.path.join(SCRIPT_DIR, "print_submodule_info.py")), + ), + ], check=True, cwd=REPO_DIR, stdout=subprocess.PIPE, stderr=subprocess.PIPE, - universal_newlines=True) + universal_newlines=True, +) submodule_lines = proc.stdout.splitlines() @@ -88,7 +96,8 @@ def normalize_path_separators(path): "repositoryUrl": url, }, "comments": "git submodule at {}".format( - normalize_path_separators(os.path.relpath(absolute_path, REPO_DIR))) + normalize_path_separators(os.path.relpath(absolute_path, REPO_DIR)) + ), } } registrations.append(registration) diff --git a/cgmanifests/print_submodule_info.py b/cgmanifests/print_submodule_info.py index e7ff463a442c3..362d168930451 100644 --- a/cgmanifests/print_submodule_info.py +++ b/cgmanifests/print_submodule_info.py @@ -10,19 +10,19 @@ path = sys.argv[1] -proc = subprocess.run(["git", "config", "--get", "remote.origin.url"], - check=True, - cwd=path, - stdout=subprocess.PIPE, - universal_newlines=True) +proc = subprocess.run( + ["git", "config", "--get", "remote.origin.url"], + check=True, + cwd=path, + stdout=subprocess.PIPE, + universal_newlines=True, +) url = proc.stdout.strip() -proc = subprocess.run(["git", "rev-parse", "HEAD"], - check=True, - cwd=path, - stdout=subprocess.PIPE, - universal_newlines=True) +proc = subprocess.run( + ["git", "rev-parse", "HEAD"], check=True, cwd=path, stdout=subprocess.PIPE, universal_newlines=True +) commit = proc.stdout.strip() diff --git a/csharp/testdata/test_input_BFLOAT16.py b/csharp/testdata/test_input_BFLOAT16.py index 862b57d10f1c4..414edd3e67628 100644 --- a/csharp/testdata/test_input_BFLOAT16.py +++ b/csharp/testdata/test_input_BFLOAT16.py @@ -2,28 +2,21 @@ # Licensed under the MIT License. import onnx -from onnx import helper +from onnx import TensorProto, helper from onnx.helper import make_opsetid -from onnx import TensorProto -input_info = helper.make_tensor_value_info('input', TensorProto.BFLOAT16, [1, 5]) -output_info = helper.make_tensor_value_info('output', TensorProto.BFLOAT16, [1, 5]) +input_info = helper.make_tensor_value_info("input", TensorProto.BFLOAT16, [1, 5]) +output_info = helper.make_tensor_value_info("output", TensorProto.BFLOAT16, [1, 5]) # Create a node (NodeProto) - This is based on Pad-11 -node_def = helper.make_node( - 'Identity', # node name - ['input'], # inputs - ['output'] # outputs -) +node_def = helper.make_node("Identity", ["input"], ["output"]) # node name # inputs # outputs -graph_def = helper.make_graph(nodes=[node_def], name='test_types_BLOAT16', - inputs=[input_info], outputs=[output_info]) +graph_def = helper.make_graph(nodes=[node_def], name="test_types_BLOAT16", inputs=[input_info], outputs=[output_info]) -model_def = helper.make_model(graph_def, producer_name='AIInfra', - opset_imports=[make_opsetid('', 13)]) +model_def = helper.make_model(graph_def, producer_name="AIInfra", opset_imports=[make_opsetid("", 13)]) onnx.checker.check_model(model_def) onnx.helper.strip_doc_string(model_def) final_model = onnx.shape_inference.infer_shapes(model_def) onnx.checker.check_model(final_model) -onnx.save(final_model, 'test_types_BFLOAT16.onnx') +onnx.save(final_model, "test_types_BFLOAT16.onnx") diff --git a/csharp/testdata/test_input_FLOAT16.py b/csharp/testdata/test_input_FLOAT16.py index c787cf1c06e6f..34aaa0ea04c8e 100644 --- a/csharp/testdata/test_input_FLOAT16.py +++ b/csharp/testdata/test_input_FLOAT16.py @@ -2,31 +2,28 @@ # Licensed under the MIT License. import onnx -from onnx import helper +from onnx import TensorProto, helper from onnx.helper import make_opsetid -from onnx import TensorProto -input_info = helper.make_tensor_value_info('input', TensorProto.FLOAT16, [1, 5]) -output_info = helper.make_tensor_value_info('output', TensorProto.FLOAT16, [1, 5]) +input_info = helper.make_tensor_value_info("input", TensorProto.FLOAT16, [1, 5]) +output_info = helper.make_tensor_value_info("output", TensorProto.FLOAT16, [1, 5]) # Create a node (NodeProto) - This is based on Pad-11 node_def = helper.make_node( - 'Slice', # node name - ['input'], # inputs - ['output'], # outputs + "Slice", # node name + ["input"], # inputs + ["output"], # outputs axes=[0, 1], # attributes ends=[1, 5], - starts=[0, 0] + starts=[0, 0], ) -graph_def = helper.make_graph(nodes=[node_def], name='test_input_FLOAT16', - inputs=[input_info], outputs=[output_info]) +graph_def = helper.make_graph(nodes=[node_def], name="test_input_FLOAT16", inputs=[input_info], outputs=[output_info]) -model_def = helper.make_model(graph_def, producer_name='AIInfra', - opset_imports=[make_opsetid('', 7)]) +model_def = helper.make_model(graph_def, producer_name="AIInfra", opset_imports=[make_opsetid("", 7)]) onnx.checker.check_model(model_def) onnx.helper.strip_doc_string(model_def) final_model = onnx.shape_inference.infer_shapes(model_def) onnx.checker.check_model(final_model) -onnx.save(final_model, 'test_types_FLOAT16.onnx') +onnx.save(final_model, "test_types_FLOAT16.onnx") diff --git a/docs/python/inference/conf.py b/docs/python/inference/conf.py index 18cbd532ef8e8..b299aa008100f 100644 --- a/docs/python/inference/conf.py +++ b/docs/python/inference/conf.py @@ -6,16 +6,18 @@ # Configuration file for the Sphinx documentation builder. import os -import sys import shutil +import sys + import onnxruntime + # import recommonmark # -- Project information ----------------------------------------------------- -project = 'ONNX Runtime' -copyright = '2018-2021, Microsoft' -author = 'Microsoft' +project = "ONNX Runtime" +copyright = "2018-2021, Microsoft" +author = "Microsoft" version = onnxruntime.__version__ release = version @@ -23,70 +25,72 @@ extensions = [ "alabaster", - 'sphinx.ext.intersphinx', - 'sphinx.ext.imgmath', - 'sphinx.ext.ifconfig', - 'sphinx.ext.viewcode', + "sphinx.ext.intersphinx", + "sphinx.ext.imgmath", + "sphinx.ext.ifconfig", + "sphinx.ext.viewcode", "sphinx.ext.autodoc", - 'sphinx.ext.githubpages', + "sphinx.ext.githubpages", "sphinx_gallery.gen_gallery", - 'sphinx.ext.graphviz', + "sphinx.ext.graphviz", "pyquickhelper.sphinxext.sphinx_runpython_extension", ] -templates_path = ['_templates'] +templates_path = ["_templates"] source_parsers = { - '.md': 'recommonmark.parser.CommonMarkParser', + ".md": "recommonmark.parser.CommonMarkParser", } -source_suffix = ['.rst'] # , '.md'] +source_suffix = [".rst"] # , '.md'] -master_doc = 'index' +master_doc = "index" language = "en" exclude_patterns = [] -pygments_style = 'default' -autoclass_content = 'both' +pygments_style = "default" +autoclass_content = "both" # -- Options for HTML output ------------------------------------------------- html_theme = "alabaster" html_logo = "ONNX_Runtime_icon.png" -html_static_path = ['_static'] +html_static_path = ["_static"] graphviz_output_format = "svg" # -- Options for intersphinx extension --------------------------------------- # Example configuration for intersphinx: refer to the Python standard library. -intersphinx_mapping = {'https://docs.python.org/': None} +intersphinx_mapping = {"https://docs.python.org/": None} # -- Options for Sphinx Gallery ---------------------------------------------- sphinx_gallery_conf = { - 'examples_dirs': 'examples', - 'gallery_dirs': 'auto_examples', + "examples_dirs": "examples", + "gallery_dirs": "auto_examples", } # -- markdown options ----------------------------------------------------------- md_image_dest = "media" md_link_replace = { - '#onnxruntimesessionoptionsenable-profiling)': '#class-onnxruntimesessionoptions)', + "#onnxruntimesessionoptionsenable-profiling)": "#class-onnxruntimesessionoptions)", } # -- Setup actions ----------------------------------------------------------- + def setup(app): # download examples for the documentation this = os.path.abspath(os.path.dirname(__file__)) dest = os.path.join(this, "model.onnx") if not os.path.exists(dest): import urllib.request - url = 'https://raw.githubusercontent.com/onnx/onnx/master/onnx/backend/test/data/node/test_sigmoid/model.onnx' + + url = "https://raw.githubusercontent.com/onnx/onnx/master/onnx/backend/test/data/node/test_sigmoid/model.onnx" urllib.request.urlretrieve(url, dest) loc = os.path.split(dest)[-1] if not os.path.exists(loc): import shutil + shutil.copy(dest, loc) return app - diff --git a/docs/python/inference/examples/plot_backend.py b/docs/python/inference/examples/plot_backend.py index 68096bb8a682e..b441012e637ae 100644 --- a/docs/python/inference/examples/plot_backend.py +++ b/docs/python/inference/examples/plot_backend.py @@ -15,15 +15,16 @@ of a simple logistic regression model. """ import numpy as np -from onnxruntime import datasets -from onnxruntime.capi.onnxruntime_pybind11_state import InvalidArgument -import onnxruntime.backend as backend from onnx import load +import onnxruntime.backend as backend + ######################################## # The device depends on how the package was compiled, # GPU or CPU. -from onnxruntime import get_device +from onnxruntime import datasets, get_device +from onnxruntime.capi.onnxruntime_pybind11_state import InvalidArgument + device = get_device() name = datasets.get_example("logreg_iris.onnx") diff --git a/docs/python/inference/examples/plot_common_errors.py b/docs/python/inference/examples/plot_common_errors.py index b474574c0fdf6..e4b2b1e05a3ec 100644 --- a/docs/python/inference/examples/plot_common_errors.py +++ b/docs/python/inference/examples/plot_common_errors.py @@ -15,9 +15,10 @@ trained on *Iris* datasets. The model takes a vector of dimension 2 and returns a class among three. """ +import numpy + import onnxruntime as rt from onnxruntime.capi.onnxruntime_pybind11_state import InvalidArgument -import numpy from onnxruntime.datasets import get_example example2 = get_example("logreg_iris.onnx") @@ -37,7 +38,7 @@ except Exception as e: print("Unexpected type") print("{0}: {1}".format(type(e), e)) - + ######################### # The model fails to return an output if the name # is misspelled. @@ -76,12 +77,12 @@ # dimension is a multiple of the expected input dimension. for x in [ - numpy.array([1.0, 2.0, 3.0, 4.0], dtype=numpy.float32), - numpy.array([[1.0, 2.0, 3.0, 4.0]], dtype=numpy.float32), - numpy.array([[1.0, 2.0], [3.0, 4.0]], dtype=numpy.float32), - numpy.array([1.0, 2.0, 3.0], dtype=numpy.float32), - numpy.array([[1.0, 2.0, 3.0]], dtype=numpy.float32), - ]: + numpy.array([1.0, 2.0, 3.0, 4.0], dtype=numpy.float32), + numpy.array([[1.0, 2.0, 3.0, 4.0]], dtype=numpy.float32), + numpy.array([[1.0, 2.0], [3.0, 4.0]], dtype=numpy.float32), + numpy.array([1.0, 2.0, 3.0], dtype=numpy.float32), + numpy.array([[1.0, 2.0, 3.0]], dtype=numpy.float32), +]: try: r = sess.run([output_name], {input_name: x}) print("Shape={0} and predicted labels={1}".format(x.shape, r)) @@ -89,12 +90,12 @@ print("ERROR with Shape={0} - {1}".format(x.shape, e)) for x in [ - numpy.array([1.0, 2.0, 3.0, 4.0], dtype=numpy.float32), - numpy.array([[1.0, 2.0, 3.0, 4.0]], dtype=numpy.float32), - numpy.array([[1.0, 2.0], [3.0, 4.0]], dtype=numpy.float32), - numpy.array([1.0, 2.0, 3.0], dtype=numpy.float32), - numpy.array([[1.0, 2.0, 3.0]], dtype=numpy.float32), - ]: + numpy.array([1.0, 2.0, 3.0, 4.0], dtype=numpy.float32), + numpy.array([[1.0, 2.0, 3.0, 4.0]], dtype=numpy.float32), + numpy.array([[1.0, 2.0], [3.0, 4.0]], dtype=numpy.float32), + numpy.array([1.0, 2.0, 3.0], dtype=numpy.float32), + numpy.array([[1.0, 2.0, 3.0]], dtype=numpy.float32), +]: try: r = sess.run(None, {input_name: x}) print("Shape={0} and predicted probabilities={1}".format(x.shape, r[1])) @@ -106,10 +107,10 @@ # is higher than expects but produces a warning. for x in [ - numpy.array([[[1.0, 2.0], [3.0, 4.0]]], dtype=numpy.float32), - numpy.array([[[1.0, 2.0, 3.0]]], dtype=numpy.float32), - numpy.array([[[1.0, 2.0]], [[3.0, 4.0]]], dtype=numpy.float32), - ]: + numpy.array([[[1.0, 2.0], [3.0, 4.0]]], dtype=numpy.float32), + numpy.array([[[1.0, 2.0, 3.0]]], dtype=numpy.float32), + numpy.array([[[1.0, 2.0]], [[3.0, 4.0]]], dtype=numpy.float32), +]: try: r = sess.run([output_name], {input_name: x}) print("Shape={0} and predicted labels={1}".format(x.shape, r)) diff --git a/docs/python/inference/examples/plot_convert_pipeline_vectorizer.py b/docs/python/inference/examples/plot_convert_pipeline_vectorizer.py index af1351d0c87ff..3df6d6dfea9bf 100644 --- a/docs/python/inference/examples/plot_convert_pipeline_vectorizer.py +++ b/docs/python/inference/examples/plot_convert_pipeline_vectorizer.py @@ -21,24 +21,25 @@ """ import pandas from sklearn.datasets import load_boston + boston = load_boston() X, y = boston.data, boston.target from sklearn.model_selection import train_test_split + X_train, X_test, y_train, y_test = train_test_split(X, y) -X_train_dict = pandas.DataFrame(X_train[:,1:]).T.to_dict().values() -X_test_dict = pandas.DataFrame(X_test[:,1:]).T.to_dict().values() +X_train_dict = pandas.DataFrame(X_train[:, 1:]).T.to_dict().values() +X_test_dict = pandas.DataFrame(X_test[:, 1:]).T.to_dict().values() #################################### # We create a pipeline. -from sklearn.pipeline import make_pipeline from sklearn.ensemble import GradientBoostingRegressor from sklearn.feature_extraction import DictVectorizer -pipe = make_pipeline( - DictVectorizer(sparse=False), - GradientBoostingRegressor()) - +from sklearn.pipeline import make_pipeline + +pipe = make_pipeline(DictVectorizer(sparse=False), GradientBoostingRegressor()) + pipe.fit(X_train_dict, y_train) #################################### @@ -53,15 +54,15 @@ # Conversion to ONNX format # +++++++++++++++++++++++++ # -# We use module +# We use module # `sklearn-onnx `_ # to convert the model into ONNX format. from skl2onnx import convert_sklearn -from skl2onnx.common.data_types import FloatTensorType, Int64TensorType, DictionaryType, SequenceType +from skl2onnx.common.data_types import DictionaryType, FloatTensorType, Int64TensorType, SequenceType # initial_type = [('float_input', DictionaryType(Int64TensorType([1]), FloatTensorType([])))] -initial_type = [('float_input', DictionaryType(Int64TensorType([1]), FloatTensorType([])))] +initial_type = [("float_input", DictionaryType(Int64TensorType([1]), FloatTensorType([])))] onx = convert_sklearn(pipe, initial_types=initial_type) with open("pipeline_vectorize.onnx", "wb") as f: f.write(onx.SerializeToString()) @@ -75,6 +76,7 @@ sess = rt.InferenceSession("pipeline_vectorize.onnx", providers=rt.get_available_providers()) import numpy + inp, out = sess.get_inputs()[0], sess.get_outputs()[0] print("input name='{}' and shape={} and type={}".format(inp.name, inp.shape, inp.type)) print("output name='{}' and shape={} and type={}".format(out.name, out.shape, out.type)) @@ -100,4 +102,3 @@ ######################### # Very similar. *ONNX Runtime* uses floats instead of doubles, # that explains the small discrepencies. - diff --git a/docs/python/inference/examples/plot_load_and_predict.py b/docs/python/inference/examples/plot_load_and_predict.py index 9bfdc5795758d..44e343163a420 100644 --- a/docs/python/inference/examples/plot_load_and_predict.py +++ b/docs/python/inference/examples/plot_load_and_predict.py @@ -12,8 +12,9 @@ retrieve the definition of its inputs and outputs. """ -import onnxruntime as rt import numpy + +import onnxruntime as rt from onnxruntime.datasets import get_example ######################### @@ -37,7 +38,7 @@ # Let's see the output name and shape. output_name = sess.get_outputs()[0].name -print("output name", output_name) +print("output name", output_name) output_shape = sess.get_outputs()[0].shape print("output shape", output_shape) output_type = sess.get_outputs()[0].type @@ -47,7 +48,8 @@ # Let's compute its outputs (or predictions if it is a machine learned model). import numpy.random -x = numpy.random.random((3,4,5)) + +x = numpy.random.random((3, 4, 5)) x = x.astype(numpy.float32) res = sess.run([output_name], {input_name: x}) print(res) diff --git a/docs/python/inference/examples/plot_metadata.py b/docs/python/inference/examples/plot_metadata.py index 94c45e688f27f..c76f2e8d9fa7f 100644 --- a/docs/python/inference/examples/plot_metadata.py +++ b/docs/python/inference/examples/plot_metadata.py @@ -15,9 +15,11 @@ """ from onnxruntime.datasets import get_example + example = get_example("logreg_iris.onnx") import onnx + model = onnx.load(example) print("doc_string={}".format(model.doc_string)) @@ -32,6 +34,7 @@ # With *ONNX Runtime*: import onnxruntime as rt + sess = rt.InferenceSession(example, providers=rt.get_available_providers()) meta = sess.get_modelmeta() diff --git a/docs/python/inference/examples/plot_pipeline.py b/docs/python/inference/examples/plot_pipeline.py index 0a002f6223e1b..366aadee1c3ae 100644 --- a/docs/python/inference/examples/plot_pipeline.py +++ b/docs/python/inference/examples/plot_pipeline.py @@ -21,12 +21,14 @@ """ from onnxruntime.datasets import get_example + example1 = get_example("mul_1.onnx") import onnx + model = onnx.load(example1) # model is a ModelProto protobuf message -print(model) +print(model) ################################# @@ -39,31 +41,30 @@ from onnx import ModelProto + model = ModelProto() -with open(example1, 'rb') as fid: +with open(example1, "rb") as fid: content = fid.read() model.ParseFromString(content) ################################### # We convert it into a graph. -from onnx.tools.net_drawer import GetPydotGraph, GetOpNodeProducer -pydot_graph = GetPydotGraph(model.graph, name=model.graph.name, rankdir="LR", - node_producer=GetOpNodeProducer("docstring")) +from onnx.tools.net_drawer import GetOpNodeProducer, GetPydotGraph + +pydot_graph = GetPydotGraph( + model.graph, name=model.graph.name, rankdir="LR", node_producer=GetOpNodeProducer("docstring") +) pydot_graph.write_dot("graph.dot") ####################################### # Then into an image import os -os.system('dot -O -Tpng graph.dot') + +os.system("dot -O -Tpng graph.dot") ################################ # Which we display... import matplotlib.pyplot as plt + image = plt.imread("graph.dot.png") plt.imshow(image) - - - - - - diff --git a/docs/python/inference/examples/plot_profiling.py b/docs/python/inference/examples/plot_profiling.py index 402e7b3baee10..3236f954cc052 100644 --- a/docs/python/inference/examples/plot_profiling.py +++ b/docs/python/inference/examples/plot_profiling.py @@ -11,9 +11,10 @@ *ONNX Runtime* can profile the execution of the model. This example shows how to interpret the results. """ +import numpy import onnx + import onnxruntime as rt -import numpy from onnxruntime.datasets import get_example @@ -27,8 +28,6 @@ def change_ir_version(filename, ir_version=6): return model - - ######################### # Let's load a very simple model and compute some prediction. @@ -61,10 +60,9 @@ def change_ir_version(filename, ir_version=6): # The results are stored un a file in JSON format. # Let's see what it contains. import json + with open(prof_file, "r") as f: sess_time = json.load(f) import pprint -pprint.pprint(sess_time) - - +pprint.pprint(sess_time) diff --git a/docs/python/inference/examples/plot_train_convert_predict.py b/docs/python/inference/examples/plot_train_convert_predict.py index 4aa36b3dce25c..b5033b503b3eb 100644 --- a/docs/python/inference/examples/plot_train_convert_predict.py +++ b/docs/python/inference/examples/plot_train_convert_predict.py @@ -22,16 +22,19 @@ """ from sklearn.datasets import load_iris + iris = load_iris() X, y = iris.data, iris.target from sklearn.model_selection import train_test_split + X_train, X_test, y_train, y_test = train_test_split(X, y) #################################### # Then we fit a model. from sklearn.linear_model import LogisticRegression + clr = LogisticRegression() clr.fit(X_train, y_train) @@ -47,14 +50,14 @@ # Conversion to ONNX format # +++++++++++++++++++++++++ # -# We use module +# We use module # `sklearn-onnx `_ # to convert the model into ONNX format. from skl2onnx import convert_sklearn from skl2onnx.common.data_types import FloatTensorType -initial_type = [('float_input', FloatTensorType([None, 4]))] +initial_type = [("float_input", FloatTensorType([None, 4]))] onx = convert_sklearn(clr, initial_types=initial_type) with open("logreg_iris.onnx", "wb") as f: f.write(onx.SerializeToString()) @@ -64,12 +67,11 @@ # its input and output. import onnxruntime as rt + sess = rt.InferenceSession("logreg_iris.onnx", providers=rt.get_available_providers()) -print("input name='{}' and shape={}".format( - sess.get_inputs()[0].name, sess.get_inputs()[0].shape)) -print("output name='{}' and shape={}".format( - sess.get_outputs()[0].name, sess.get_outputs()[0].shape)) +print("input name='{}' and shape={}".format(sess.get_inputs()[0].name, sess.get_inputs()[0].shape)) +print("output name='{}' and shape={}".format(sess.get_outputs()[0].name, sess.get_outputs()[0].shape)) ################################## # We compute the predictions. @@ -78,6 +80,7 @@ label_name = sess.get_outputs()[0].name import numpy + pred_onx = sess.run([label_name], {input_name: X_test.astype(numpy.float32)})[0] print(confusion_matrix(pred, pred_onx)) @@ -97,18 +100,20 @@ ############################# # And then with ONNX Runtime. -# The probabilies appear to be +# The probabilies appear to be prob_name = sess.get_outputs()[1].name prob_rt = sess.run([prob_name], {input_name: X_test.astype(numpy.float32)})[0] import pprint + pprint.pprint(prob_rt[0:3]) ############################### # Let's benchmark. from timeit import Timer + def speed(inst, number=10, repeat=20): timer = Timer(inst, globals=globals()) raw = numpy.array(timer.repeat(repeat, number=number)) @@ -117,6 +122,7 @@ def speed(inst, number=10, repeat=20): print("Average %1.3g min=%1.3g max=%1.3g" % (ave, mi, ma)) return ave + print("Execution time for clr.predict") speed("clr.predict(X_test)") @@ -128,20 +134,24 @@ def speed(inst, number=10, repeat=20): # experiences: the model has to do one prediction at a time # as opposed to a batch of prediction. + def loop(X_test, fct, n=None): nrow = X_test.shape[0] if n is None: n = nrow for i in range(0, n): im = i % nrow - fct(X_test[im: im+1]) + fct(X_test[im : im + 1]) + print("Execution time for clr.predict") speed("loop(X_test, clr.predict, 100)") + def sess_predict(x): return sess.run([label_name], {input_name: x.astype(numpy.float32)})[0] + print("Execution time for sess_predict") speed("loop(X_test, sess_predict, 100)") @@ -151,14 +161,16 @@ def sess_predict(x): print("Execution time for predict_proba") speed("loop(X_test, clr.predict_proba, 100)") + def sess_predict_proba(x): return sess.run([prob_name], {input_name: x.astype(numpy.float32)})[0] + print("Execution time for sess_predict_proba") speed("loop(X_test, sess_predict_proba, 100)") ##################################### -# This second comparison is better as +# This second comparison is better as # ONNX Runtime, in this experience, # computes the label and the probabilities # in every case. @@ -169,10 +181,11 @@ def sess_predict_proba(x): # # We first train and save a model in ONNX format. from sklearn.ensemble import RandomForestClassifier + rf = RandomForestClassifier() rf.fit(X_train, y_train) -initial_type = [('float_input', FloatTensorType([1, 4]))] +initial_type = [("float_input", FloatTensorType([1, 4]))] onx = convert_sklearn(rf, initial_types=initial_type) with open("rf_iris.onnx", "wb") as f: f.write(onx.SerializeToString()) @@ -182,9 +195,11 @@ def sess_predict_proba(x): sess = rt.InferenceSession("rf_iris.onnx", providers=rt.get_available_providers()) + def sess_predict_proba_rf(x): return sess.run([prob_name], {input_name: x.astype(numpy.float32)})[0] + print("Execution time for predict_proba") speed("loop(X_test, rf.predict_proba, 100)") @@ -196,26 +211,28 @@ def sess_predict_proba_rf(x): measures = [] -for n_trees in range(5, 51, 5): +for n_trees in range(5, 51, 5): print(n_trees) rf = RandomForestClassifier(n_estimators=n_trees) rf.fit(X_train, y_train) - initial_type = [('float_input', FloatTensorType([1, 4]))] + initial_type = [("float_input", FloatTensorType([1, 4]))] onx = convert_sklearn(rf, initial_types=initial_type) with open("rf_iris_%d.onnx" % n_trees, "wb") as f: f.write(onx.SerializeToString()) sess = rt.InferenceSession("rf_iris_%d.onnx" % n_trees, providers=rt.get_available_providers()) + def sess_predict_proba_loop(x): return sess.run([prob_name], {input_name: x.astype(numpy.float32)})[0] + tsk = speed("loop(X_test, rf.predict_proba, 100)", number=5, repeat=5) trt = speed("loop(X_test, sess_predict_proba_loop, 100)", number=5, repeat=5) - measures.append({'n_trees': n_trees, 'sklearn': tsk, 'rt': trt}) + measures.append({"n_trees": n_trees, "sklearn": tsk, "rt": trt}) from pandas import DataFrame + df = DataFrame(measures) ax = df.plot(x="n_trees", y="sklearn", label="scikit-learn", c="blue", logy=True) -df.plot(x="n_trees", y="rt", label="onnxruntime", - ax=ax, c="green", logy=True) +df.plot(x="n_trees", y="rt", label="onnxruntime", ax=ax, c="green", logy=True) ax.set_xlabel("Number of trees") ax.set_ylabel("Prediction time (s)") ax.set_title("Speed comparison between scikit-learn and ONNX Runtime\nFor a random forest on Iris dataset") diff --git a/docs/python/training/conf.py b/docs/python/training/conf.py index 2e3c084bb132c..5dd0f60506433 100644 --- a/docs/python/training/conf.py +++ b/docs/python/training/conf.py @@ -7,30 +7,28 @@ # -- Project information ----------------------------------------------------- -project = 'ORTModule' -copyright = '2018-2021, Microsoft' -author = 'Microsoft' -version = '0.1' # TODO: Should use `onnxruntime.__version__` instead? +project = "ORTModule" +copyright = "2018-2021, Microsoft" +author = "Microsoft" +version = "0.1" # TODO: Should use `onnxruntime.__version__` instead? release = version # -- General configuration --------------------------------------------------- -extensions = ['sphinx.ext.autodoc', - 'sphinx.ext.intersphinx' -] -templates_path = ['_templates'] +extensions = ["sphinx.ext.autodoc", "sphinx.ext.intersphinx"] +templates_path = ["_templates"] exclude_patterns = [] -autoclass_content = 'both' +autoclass_content = "both" # -- Options for HTML output ------------------------------------------------- -html_theme = 'sphinx_rtd_theme' -html_static_path = ['_static'] +html_theme = "sphinx_rtd_theme" +html_static_path = ["_static"] # -- Options for intersphinx extension --------------------------------------- intersphinx_mapping = { - 'python': ('https://docs.python.org/3', None), - 'numpy': ('https://numpy.org/doc/stable', None), - 'torch': ('https://pytorch.org/docs/stable/', None), + "python": ("https://docs.python.org/3", None), + "numpy": ("https://numpy.org/doc/stable", None), + "torch": ("https://pytorch.org/docs/stable/", None), } diff --git a/objectivec/test/testdata/single_add_gen.py b/objectivec/test/testdata/single_add_gen.py index af8722e099c54..15b43726e4a9e 100644 --- a/objectivec/test/testdata/single_add_gen.py +++ b/objectivec/test/testdata/single_add_gen.py @@ -1,6 +1,5 @@ import onnx -from onnx import helper -from onnx import TensorProto +from onnx import TensorProto, helper graph = helper.make_graph( [ # nodes @@ -8,12 +7,13 @@ ], "SingleAdd", # name [ # inputs - helper.make_tensor_value_info('A', TensorProto.FLOAT, [1]), - helper.make_tensor_value_info('B', TensorProto.FLOAT, [1]), + helper.make_tensor_value_info("A", TensorProto.FLOAT, [1]), + helper.make_tensor_value_info("B", TensorProto.FLOAT, [1]), ], [ # outputs - helper.make_tensor_value_info('C', TensorProto.FLOAT, [1]), - ]) + helper.make_tensor_value_info("C", TensorProto.FLOAT, [1]), + ], +) model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 12)]) -onnx.save(model, r'single_add.onnx') +onnx.save(model, r"single_add.onnx") diff --git a/onnxruntime/__init__.py b/onnxruntime/__init__.py index 1a5d0a7cf816e..7079abd94da48 100644 --- a/onnxruntime/__init__.py +++ b/onnxruntime/__init__.py @@ -20,11 +20,31 @@ # meaningful messages to the user. # the saved exception is raised after device version validation. try: - from onnxruntime.capi._pybind_state import get_all_providers, get_available_providers, get_device, set_seed, \ - RunOptions, SessionOptions, set_default_logger_severity, enable_telemetry_events, disable_telemetry_events, \ - NodeArg, ModelMetadata, GraphOptimizationLevel, ExecutionMode, ExecutionOrder, SessionIOBinding, \ - OrtAllocatorType, OrtMemType, OrtArenaCfg, OrtMemoryInfo, create_and_register_allocator, OrtSparseFormat, \ - set_default_logger_verbosity + from onnxruntime.capi._pybind_state import ( + ExecutionMode, + ExecutionOrder, + GraphOptimizationLevel, + ModelMetadata, + NodeArg, + OrtAllocatorType, + OrtArenaCfg, + OrtMemoryInfo, + OrtMemType, + OrtSparseFormat, + RunOptions, + SessionIOBinding, + SessionOptions, + create_and_register_allocator, + disable_telemetry_events, + enable_telemetry_events, + get_all_providers, + get_available_providers, + get_device, + set_default_logger_severity, + set_default_logger_verbosity, + set_seed, + ) + import_capi_exception = None except Exception as e: import_capi_exception = e @@ -34,9 +54,13 @@ if import_capi_exception: raise import_capi_exception -from onnxruntime.capi.onnxruntime_inference_collection import InferenceSession, IOBinding, OrtValue, SparseTensor, \ - OrtDevice - +from onnxruntime.capi.onnxruntime_inference_collection import ( + InferenceSession, + IOBinding, + OrtDevice, + OrtValue, + SparseTensor, +) from onnxruntime.capi.training import * # noqa: F403 # TODO: thiagofc: Temporary experimental namespace for new PyTorch front-end @@ -45,7 +69,8 @@ except ImportError: pass -from onnxruntime.capi.onnxruntime_validation import package_name, version, cuda_version +from onnxruntime.capi.onnxruntime_validation import cuda_version, package_name, version + if version: __version__ = version diff --git a/onnxruntime/core/providers/nuphar/scripts/__init__.py b/onnxruntime/core/providers/nuphar/scripts/__init__.py index 1a8a615c070e1..862c45ce31b25 100644 --- a/onnxruntime/core/providers/nuphar/scripts/__init__.py +++ b/onnxruntime/core/providers/nuphar/scripts/__init__.py @@ -1,4 +1,4 @@ -#------------------------------------------------------------------------- +# ------------------------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. -#-------------------------------------------------------------------------- +# -------------------------------------------------------------------------- diff --git a/onnxruntime/core/providers/nuphar/scripts/create_shared.py b/onnxruntime/core/providers/nuphar/scripts/create_shared.py index a68e1fdc7580b..28b451ca59995 100644 --- a/onnxruntime/core/providers/nuphar/scripts/create_shared.py +++ b/onnxruntime/core/providers/nuphar/scripts/create_shared.py @@ -8,14 +8,16 @@ import subprocess import sys + def is_windows(): return sys.platform.startswith("win") + def gen_md5(filename): if not os.path.exists(filename): return False hash_md5 = hashlib.md5() - BLOCKSIZE = 1024*64 + BLOCKSIZE = 1024 * 64 with open(filename, "rb") as f: buf = f.read(BLOCKSIZE) while len(buf) > 0: @@ -23,54 +25,61 @@ def gen_md5(filename): buf = f.read(BLOCKSIZE) return hash_md5.hexdigest() + def gen_checksum(file_checksum, input_dir): if not file_checksum: return - name = 'ORTInternal_checksum' - with open(os.path.join(input_dir, name + '.cc'), 'w') as checksum_cc: - print('#include ', file=checksum_cc) + name = "ORTInternal_checksum" + with open(os.path.join(input_dir, name + ".cc"), "w") as checksum_cc: + print("#include ", file=checksum_cc) print('static const char model_checksum[] = "' + file_checksum + '";', file=checksum_cc) print('extern "C"', file=checksum_cc) if is_windows(): - print('__declspec(dllexport)', file=checksum_cc) - print('void _ORTInternal_GetCheckSum(const char*& cs, size_t& len) {', file=checksum_cc) - print(' cs = model_checksum; len = sizeof(model_checksum)/sizeof(model_checksum[0]) - 1;', file=checksum_cc) - print('}', file=checksum_cc) + print("__declspec(dllexport)", file=checksum_cc) + print("void _ORTInternal_GetCheckSum(const char*& cs, size_t& len) {", file=checksum_cc) + print(" cs = model_checksum; len = sizeof(model_checksum)/sizeof(model_checksum[0]) - 1;", file=checksum_cc) + print("}", file=checksum_cc) + def gen_cache_version(input_dir): - name = 'ORTInternal_cache_version' - with open(os.path.join(input_dir, name + '.cc'), 'w') as cache_version_cc: - header_file = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'NUPHAR_CACHE_VERSION') + name = "ORTInternal_cache_version" + with open(os.path.join(input_dir, name + ".cc"), "w") as cache_version_cc: + header_file = os.path.join(os.path.dirname(os.path.realpath(__file__)), "NUPHAR_CACHE_VERSION") print('#include "{}"'.format(header_file), file=cache_version_cc) print('extern "C"', file=cache_version_cc) if is_windows(): - print('__declspec(dllexport)', file=cache_version_cc) - print('const char* _ORTInternal_GetCacheVersion() {', file=cache_version_cc) - print(' return __NUPHAR_CACHE_VERSION__;', file=cache_version_cc) - print('}', file=cache_version_cc) + print("__declspec(dllexport)", file=cache_version_cc) + print("const char* _ORTInternal_GetCacheVersion() {", file=cache_version_cc) + print(" return __NUPHAR_CACHE_VERSION__;", file=cache_version_cc) + print("}", file=cache_version_cc) + def compile_all_cc(path): for f in os.listdir(path): name, ext = os.path.splitext(f) - if ext != '.cc': + if ext != ".cc": continue if is_windows(): - subprocess.run(['cl', '/Fo' + name + '.o', '/c', f], cwd=path, check=True) + subprocess.run(["cl", "/Fo" + name + ".o", "/c", f], cwd=path, check=True) else: - subprocess.run(['g++', '-std=c++14', '-fPIC', '-o', name + '.o', '-c', f], cwd=path, check=True) + subprocess.run(["g++", "-std=c++14", "-fPIC", "-o", name + ".o", "-c", f], cwd=path, check=True) os.remove(os.path.join(path, f)) + def parse_arguments(): parser = argparse.ArgumentParser(description="Offline shared lib creation tool.") # Main arguments - parser.add_argument('--keep_input', action='store_true', help="Keep input files after created so.") - parser.add_argument('--input_dir', help="The input directory that contains obj files.", required=True) - parser.add_argument('--output_name', help="The output so file name.", default='jit.so') - parser.add_argument('--input_model', help="The input model file name to generate checksum into shared lib.", default=None) + parser.add_argument("--keep_input", action="store_true", help="Keep input files after created so.") + parser.add_argument("--input_dir", help="The input directory that contains obj files.", required=True) + parser.add_argument("--output_name", help="The output so file name.", default="jit.so") + parser.add_argument( + "--input_model", help="The input model file name to generate checksum into shared lib.", default=None + ) return parser.parse_args() -if __name__ == '__main__': + +if __name__ == "__main__": args = parse_arguments() if args.input_model: @@ -81,8 +90,8 @@ def parse_arguments(): if is_windows(): # create dllmain - name = 'ORTInternal_dllmain' - with open(os.path.join(args.input_dir, name + '.cc'), 'w') as dllmain_cc: + name = "ORTInternal_dllmain" + with open(os.path.join(args.input_dir, name + ".cc"), "w") as dllmain_cc: print("#include ", file=dllmain_cc) print("BOOL APIENTRY DllMain(HMODULE hModule,", file=dllmain_cc) print(" DWORD ul_reason_for_call,", file=dllmain_cc) @@ -90,12 +99,20 @@ def parse_arguments(): print(" {return TRUE;}", file=dllmain_cc) compile_all_cc(args.input_dir) - objs = [f for f in os.listdir(args.input_dir) if os.path.isfile(os.path.join(args.input_dir, f)) and '.o' == os.path.splitext(f)[1]] + objs = [ + f + for f in os.listdir(args.input_dir) + if os.path.isfile(os.path.join(args.input_dir, f)) and ".o" == os.path.splitext(f)[1] + ] if is_windows(): - subprocess.run(['link', '-dll', '-FORCE:MULTIPLE', '-EXPORT:__tvm_main__', '-out:' + args.output_name, '*.o'], cwd=args.input_dir, check=True) + subprocess.run( + ["link", "-dll", "-FORCE:MULTIPLE", "-EXPORT:__tvm_main__", "-out:" + args.output_name, "*.o"], + cwd=args.input_dir, + check=True, + ) else: - subprocess.run(['g++', '-shared', '-fPIC', '-o', args.output_name] + objs, cwd=args.input_dir, check=True) + subprocess.run(["g++", "-shared", "-fPIC", "-o", args.output_name] + objs, cwd=args.input_dir, check=True) if not args.keep_input: for f in objs: diff --git a/onnxruntime/core/providers/nuphar/scripts/model_editor.py b/onnxruntime/core/providers/nuphar/scripts/model_editor.py index 7ae5fbb9841e4..c4d7d7913ed6c 100644 --- a/onnxruntime/core/providers/nuphar/scripts/model_editor.py +++ b/onnxruntime/core/providers/nuphar/scripts/model_editor.py @@ -3,15 +3,18 @@ # -*- coding: UTF-8 -*- import argparse -from enum import Enum +import copy import warnings +from enum import Enum + import numpy as np -from numpy.testing import assert_array_equal import onnx +from numpy.testing import assert_array_equal from onnx import helper + from onnxruntime.nuphar.node_factory import NodeFactory, ensure_opset from onnxruntime.tools.symbolic_shape_infer import SymbolicShapeInference, get_shape_from_type_proto -import copy + # trim outputs of LSTM/GRU/RNN if not used or outputed def trim_unused_outputs(node, graph): @@ -22,9 +25,10 @@ def trim_unused_outputs(node, graph): o = node.output[o_idx] use = [n for n in graph.node if o in list(n.input) + graph_outputs] if not use: - trimmed.output[o_idx] = '' + trimmed.output[o_idx] = "" return trimmed + # squeeze init states, and split forward/reverse for bidirectional def handle_init_state(init_state, nf, num_directions, onnx_opset_ver): if not init_state: @@ -32,36 +36,38 @@ def handle_init_state(init_state, nf, num_directions, onnx_opset_ver): if not nf.get_initializer(init_state) is None: return nf.get_initializer(init_state) if num_directions == 2: - split_names = [init_state + '_split_0', init_state + '_split_1'] - nf.make_node('Split', init_state, {'axis':0}, split_names) # [1, batch, hidden] - return [nf.make_node_with_axes('Squeeze', s, [0], onnx_opset_ver) for s in split_names] + split_names = [init_state + "_split_0", init_state + "_split_1"] + nf.make_node("Split", init_state, {"axis": 0}, split_names) # [1, batch, hidden] + return [nf.make_node_with_axes("Squeeze", s, [0], onnx_opset_ver) for s in split_names] else: - return [nf.make_node_with_axes('Squeeze', init_state, [0], onnx_opset_ver)] + return [nf.make_node_with_axes("Squeeze", init_state, [0], onnx_opset_ver)] + # handle some common attributes between LSTM/GRU/RNN def handle_common_attributes(node, default_activations): - direction = NodeFactory.get_attribute(node, 'direction') + direction = NodeFactory.get_attribute(node, "direction") if direction: - direction = str(direction, 'utf-8') + direction = str(direction, "utf-8") else: - direction = 'forward' - num_directions = 2 if direction == 'bidirectional' else 1 + direction = "forward" + num_directions = 2 if direction == "bidirectional" else 1 - activations = NodeFactory.get_attribute(node, 'activations') + activations = NodeFactory.get_attribute(node, "activations") if activations: - activations = [str(x, 'utf-8').lower().capitalize() for x in activations] + activations = [str(x, "utf-8").lower().capitalize() for x in activations] else: activations = default_activations * num_directions - activation_alpha = NodeFactory.get_attribute(node, 'activation_alpha') - activation_beta = NodeFactory.get_attribute(node, 'activation_beta') - clip_threshold = NodeFactory.get_attribute(node, 'clip') + activation_alpha = NodeFactory.get_attribute(node, "activation_alpha") + activation_beta = NodeFactory.get_attribute(node, "activation_beta") + clip_threshold = NodeFactory.get_attribute(node, "clip") # TODO: support these activation attributes assert not activation_alpha assert not activation_beta assert not clip_threshold return direction, num_directions, activations + # get batch_size, and create batch_node if needed def handle_batch_size(X, nf, need_batch_node, onnx_opset_ver): X_vi = nf.get_value_info(X) @@ -70,79 +76,95 @@ def handle_batch_size(X, nf, need_batch_node, onnx_opset_ver): if type(dim) == str and need_batch_node: # only need to create batch_node for symbolic batch_size # otherwise, just use numpy.zeros - X_shape = nf.make_node('Shape', X) + X_shape = nf.make_node("Shape", X) if onnx_opset_ver < 11: - node = nf.make_node('Slice', X_shape, {'axes':[0],'starts':[1],'ends':[2]}) + node = nf.make_node("Slice", X_shape, {"axes": [0], "starts": [1], "ends": [2]}) else: - node = nf.make_node('Slice', [X_shape, - np.asarray([1]), #starts - np.asarray([2]), #ends - np.asarray([0]) #axes - ]) + node = nf.make_node( + "Slice", [X_shape, np.asarray([1]), np.asarray([2]), np.asarray([0])] # starts # ends # axes + ) else: node = None return dim, node + # create default init state with zeros -def default_init_state(X, batch_size, batch_node, hidden_size, nf, postfix=''): +def default_init_state(X, batch_size, batch_node, hidden_size, nf, postfix=""): if batch_node: - shape = nf.make_node('Concat', [batch_node, np.asarray([hidden_size]).astype(np.int64)], {'axis':0}) - return nf.make_node('ConstantOfShape', shape) + shape = nf.make_node("Concat", [batch_node, np.asarray([hidden_size]).astype(np.int64)], {"axis": 0}) + return nf.make_node("ConstantOfShape", shape) else: assert type(batch_size) == int # add default init state to graph input - initializer_name = X + '_zero_init_state' + postfix + initializer_name = X + "_zero_init_state" + postfix initializer_shape = (batch_size, hidden_size) nf.make_value_info(initializer_name, onnx.TensorProto.FLOAT, initializer_shape, NodeFactory.ValueInfoType.input) return nf.make_initializer(np.zeros(initializer_shape, dtype=np.float32), initializer_name) + # declare seq_len_subgraph if needed # note rank-1 for seq_len is to differentiate it from rank-2 states def declare_seq_len_in_subgraph(seq_len, nf_body, prefix, batch_size): if seq_len: - seq_len_subgraph = prefix + '_seq_len_subgraph' - nf_body.make_value_info(seq_len_subgraph, - data_type=onnx.TensorProto.INT32, - shape=(batch_size,), - usage=NodeFactory.ValueInfoType.input) + seq_len_subgraph = prefix + "_seq_len_subgraph" + nf_body.make_value_info( + seq_len_subgraph, + data_type=onnx.TensorProto.INT32, + shape=(batch_size,), + usage=NodeFactory.ValueInfoType.input, + ) else: seq_len_subgraph = None return seq_len_subgraph + # hook subgraph outputs, with condition from seq_len_subgraph -def handle_subgraph_outputs(nf_body, seq_len_subgraph, batch_size, hidden_size, subgraph_output_or_default, onnx_opset_ver): +def handle_subgraph_outputs( + nf_body, seq_len_subgraph, batch_size, hidden_size, subgraph_output_or_default, onnx_opset_ver +): final_subgraph_output = [] if seq_len_subgraph: - seq_len_output = nf_body.make_node('Sub', [seq_len_subgraph, np.asarray([1]).astype(np.int32)]) - nf_body.make_value_info(seq_len_output, - data_type=onnx.TensorProto.INT32, - shape=(batch_size,), - usage=NodeFactory.ValueInfoType.output) + seq_len_output = nf_body.make_node("Sub", [seq_len_subgraph, np.asarray([1]).astype(np.int32)]) + nf_body.make_value_info( + seq_len_output, + data_type=onnx.TensorProto.INT32, + shape=(batch_size,), + usage=NodeFactory.ValueInfoType.output, + ) final_subgraph_output.append(seq_len_output) # since seq_len is rank-1, need to unsqueeze for Where op on rank-2 states - condition = nf_body.make_node_with_axes('Unsqueeze', nf_body.make_node('Greater', [seq_len_subgraph, np.zeros(shape=(), dtype=np.int32)]), [1], onnx_opset_ver) + condition = nf_body.make_node_with_axes( + "Unsqueeze", + nf_body.make_node("Greater", [seq_len_subgraph, np.zeros(shape=(), dtype=np.int32)]), + [1], + onnx_opset_ver, + ) for valid, default in subgraph_output_or_default: - final_subgraph_output.append(nf_body.make_node('Where', [condition, valid, default])) + final_subgraph_output.append(nf_body.make_node("Where", [condition, valid, default])) else: final_subgraph_output.append(None) for valid, default in subgraph_output_or_default: - final_subgraph_output.append(nf_body.make_node('Identity', valid)) + final_subgraph_output.append(nf_body.make_node("Identity", valid)) for subgraph_o in final_subgraph_output[1:]: - nf_body.make_value_info(subgraph_o, - data_type=onnx.TensorProto.FLOAT, - shape=(batch_size, hidden_size), - usage=NodeFactory.ValueInfoType.output) + nf_body.make_value_info( + subgraph_o, + data_type=onnx.TensorProto.FLOAT, + shape=(batch_size, hidden_size), + usage=NodeFactory.ValueInfoType.output, + ) return final_subgraph_output + # unsqueeze/concat for the final outputs from scans, when the LSTM/GRU/RNN node is bidirectional def handle_final_scan_outputs(node, nf, scan_outputs, state_outputs, num_directions, onnx_opset_ver): if num_directions == 2: + def _bidirectional(outputs, axis, hook_output_name): - outputs = [nf.make_node_with_axes('Unsqueeze', x, [axis], onnx_opset_ver) for x in outputs] - nf.make_node('Concat', outputs, {'axis':axis}, output_names=hook_output_name) + outputs = [nf.make_node_with_axes("Unsqueeze", x, [axis], onnx_opset_ver) for x in outputs] + nf.make_node("Concat", outputs, {"axis": axis}, output_names=hook_output_name) if node.output[0]: _bidirectional(scan_outputs, 1, node.output[0]) @@ -150,15 +172,18 @@ def _bidirectional(outputs, axis, hook_output_name): _bidirectional(state_outputs[i_o - 1], 0, node.output[i_o]) else: if node.output[0]: - nf.make_node_with_axes('Unsqueeze', scan_outputs[0], [1], onnx_opset_ver, output_names=node.output[0]) + nf.make_node_with_axes("Unsqueeze", scan_outputs[0], [1], onnx_opset_ver, output_names=node.output[0]) for i_o in range(1, len(node.output)): - nf.make_node_with_axes('Unsqueeze', state_outputs[i_o - 1], [0], onnx_opset_ver, output_names=node.output[i_o]) + nf.make_node_with_axes( + "Unsqueeze", state_outputs[i_o - 1], [0], onnx_opset_ver, output_names=node.output[i_o] + ) + def convert_loop_to_scan(node, out_main_graph, keep_unconvertible_loop_ops): - assert node.op_type == 'Loop' + assert node.op_type == "Loop" # https://github.com/onnx/onnx/blob/master/docs/Operators.md#inputs-2--- - initial_state_names = node.input[2:] # exclude M and cond. + initial_state_names = node.input[2:] # exclude M and cond. loop_subgraph_input_i = node.attribute[0].g.input[0] subgraph_input_names = [] @@ -168,7 +193,7 @@ def convert_loop_to_scan(node, out_main_graph, keep_unconvertible_loop_ops): # Gather ops are to be removed from the subgraph gather_input_nodes = [] for n in node.attribute[0].g.node: - if n.op_type == 'Gather' and n.input[1] == loop_subgraph_input_i.name: + if n.op_type == "Gather" and n.input[1] == loop_subgraph_input_i.name: scan_input_names = [*scan_input_names, n.input[0]] subgraph_input_names = [*subgraph_input_names, n.output[0]] gather_input_nodes = [*gather_input_nodes, n] @@ -178,7 +203,7 @@ def convert_loop_to_scan(node, out_main_graph, keep_unconvertible_loop_ops): if keep_unconvertible_loop_ops: warnings.warn("Model contains a Loop op that cannot be converted to Scan. " + reason) return None - raise RuntimeError("To convert a Loop op to a Scan. " + reason) + raise RuntimeError("To convert a Loop op to a Scan. " + reason) scan_subgraph = copy.deepcopy(node.attribute[0].g) @@ -227,14 +252,15 @@ def convert_loop_to_scan(node, out_main_graph, keep_unconvertible_loop_ops): count = 0 for output_index2 in range(output_index + 1, len(scan_subgraph.output)): if scan_subgraph.output[output_index].name == scan_subgraph.output[output_index2].name: - new_output_name = scan_subgraph.output[output_index].name + '_extend_' + str(count) + new_output_name = scan_subgraph.output[output_index].name + "_extend_" + str(count) count = count + 1 identity_node = helper.make_node( - 'Identity', + "Identity", [scan_subgraph.output[output_index].name], - [new_output_name], - scan_subgraph.output[output_index].name + '_identity') - new_identity_node = scan_subgraph.node.add() + [new_output_name], + scan_subgraph.output[output_index].name + "_identity", + ) + new_identity_node = scan_subgraph.node.add() new_identity_node.CopyFrom(identity_node) scan_subgraph.output[output_index2].name = new_output_name @@ -242,17 +268,17 @@ def convert_loop_to_scan(node, out_main_graph, keep_unconvertible_loop_ops): new_input_names = [*initial_state_names, *scan_input_names] scan_output_names = [o for o in node.output] scan = nf.make_node( - 'Scan', + "Scan", new_input_names, - { - 'body': scan_subgraph, - 'num_scan_inputs': len(scan_input_names)}, - output_names=scan_output_names) + {"body": scan_subgraph, "num_scan_inputs": len(scan_input_names)}, + output_names=scan_output_names, + ) + + return scan - return scan def convert_lstm_to_scan(node, out_main_graph, onnx_opset_ver): - assert node.op_type == 'LSTM' + assert node.op_type == "LSTM" nf = NodeFactory(out_main_graph) with nf.scoped_prefix(node.output[0]) as scoped_prefix: X = node.input[0] @@ -268,10 +294,10 @@ def convert_lstm_to_scan(node, out_main_graph, onnx_opset_ver): # TODO: support peephole assert not PB - direction, num_directions, activations = handle_common_attributes(node, ['Sigmoid', 'Tanh', 'Tanh']) + direction, num_directions, activations = handle_common_attributes(node, ["Sigmoid", "Tanh", "Tanh"]) - hidden_size = NodeFactory.get_attribute(node, 'hidden_size') - input_forget = NodeFactory.get_attribute(node, 'input_forget') + hidden_size = NodeFactory.get_attribute(node, "hidden_size") + input_forget = NodeFactory.get_attribute(node, "input_forget") # TODO: implement input_forget = 1 assert not (input_forget != None and input_forget == 1) @@ -300,51 +326,55 @@ def convert_lstm_to_scan(node, out_main_graph, onnx_opset_ver): # init_c [batch_size, hidden_size] # PB [3*hidden_size] - name_prefix = node.output[0] + '_' + str(direction_index) + '_' + name_prefix = node.output[0] + "_" + str(direction_index) + "_" if InitHa is None: - init_h = default_init_state(X, batch_size, batch_node, hidden_size, nf, '_H') + init_h = default_init_state(X, batch_size, batch_node, hidden_size, nf, "_H") else: init_h = InitHa[direction_index] if InitCa is None: - init_c = default_init_state(X, batch_size, batch_node, hidden_size, nf, '_C') + init_c = default_init_state(X, batch_size, batch_node, hidden_size, nf, "_C") else: init_c = InitCa[direction_index] input_size = Wa.shape[len(Wa.shape) - 1] Wt = np.transpose(Wa[direction_index]) Rt = np.transpose(Ra[direction_index]) - B = Ba[direction_index].reshape(2, 4*hidden_size).sum(axis=0) # [4*hidden_size] - X_proj = nf.make_node('MatMul', [X, Wt]) #[seq_len, batch_size, 4*hidden_size] - X_proj = nf.make_node('Add', [X_proj, B]) + B = Ba[direction_index].reshape(2, 4 * hidden_size).sum(axis=0) # [4*hidden_size] + X_proj = nf.make_node("MatMul", [X, Wt]) # [seq_len, batch_size, 4*hidden_size] + X_proj = nf.make_node("Add", [X_proj, B]) if num_directions == 1: - is_backward = 0 if direction == 'forward' else 1 + is_backward = 0 if direction == "forward" else 1 else: is_backward = direction_index scan_body = onnx.GraphProto() - scan_body.name = name_prefix + '_subgraph' + scan_body.name = name_prefix + "_subgraph" nf_body = NodeFactory(out_main_graph, scan_body) with nf_body.scoped_prefix(name_prefix) as body_scoped_prefix: # subgraph inputs - X_proj_subgraph = X_proj.name + '_subgraph' - prev_h_subgraph = name_prefix + '_h_subgraph' - prev_c_subgraph = name_prefix + '_c_subgraph' + X_proj_subgraph = X_proj.name + "_subgraph" + prev_h_subgraph = name_prefix + "_h_subgraph" + prev_c_subgraph = name_prefix + "_c_subgraph" seq_len_subgraph = declare_seq_len_in_subgraph(seq_len, nf_body, X_proj.name, batch_size) for subgraph_i in [prev_h_subgraph, prev_c_subgraph]: - nf_body.make_value_info(subgraph_i, - data_type=onnx.TensorProto.FLOAT, - shape=(batch_size, hidden_size), - usage=NodeFactory.ValueInfoType.input) - - nf_body.make_value_info(X_proj_subgraph, - data_type=onnx.TensorProto.FLOAT, - shape=(batch_size, 4*hidden_size), - usage=NodeFactory.ValueInfoType.input) + nf_body.make_value_info( + subgraph_i, + data_type=onnx.TensorProto.FLOAT, + shape=(batch_size, hidden_size), + usage=NodeFactory.ValueInfoType.input, + ) + + nf_body.make_value_info( + X_proj_subgraph, + data_type=onnx.TensorProto.FLOAT, + shape=(batch_size, 4 * hidden_size), + usage=NodeFactory.ValueInfoType.input, + ) # subgraph nodes # it = f(Xt*(Wi^T) + Ht-1*(Ri^T) + Pi (.) Ct-1 + Wbi + Rbi) # ft = f(Xt*(Wf^T) + Ht-1*(Rf^T) + Pf (.) Ct-1 + Wbf + Rbf) @@ -352,49 +382,55 @@ def convert_lstm_to_scan(node, out_main_graph, onnx_opset_ver): # Ct = ft (.) Ct-1 + it (.) ct # ot = f(Xt*(Wo^T) + Ht-1*(Ro^T) + Po (.) Ct + Wbo + Rbo) # Ht = ot (.) h(Ct) - prev_h_proj = nf_body.make_node('MatMul', [prev_h_subgraph, Rt]) - sum_x_proj_h_proj_bias = nf_body.make_node('Add', [X_proj_subgraph, prev_h_proj]) - split_outputs = ['split_i', 'split_o', 'split_f', 'split_c'] - nf_body.make_split_node(sum_x_proj_h_proj_bias, [hidden_size]*4, onnx_opset_ver, {"axis":1}, output_names=split_outputs) + prev_h_proj = nf_body.make_node("MatMul", [prev_h_subgraph, Rt]) + sum_x_proj_h_proj_bias = nf_body.make_node("Add", [X_proj_subgraph, prev_h_proj]) + split_outputs = ["split_i", "split_o", "split_f", "split_c"] + nf_body.make_split_node( + sum_x_proj_h_proj_bias, [hidden_size] * 4, onnx_opset_ver, {"axis": 1}, output_names=split_outputs + ) # manually add shape inference to split outputs for split_o in split_outputs: - nf_body.make_value_info(split_o, - data_type=onnx.TensorProto.FLOAT, - shape=(batch_size, hidden_size)) - activation_f, activation_g, activation_h = activations[direction_index*3:(direction_index+1)*3] - it = nf_body.make_node(activation_f, 'split_i') - ft = nf_body.make_node(activation_f, 'split_f') - ct = nf_body.make_node(activation_g, 'split_c') - c_subgraph = nf_body.make_node('Add', - [nf_body.make_node('Mul', [ft, prev_c_subgraph]), - nf_body.make_node('Mul', [it, ct])]) - ot = nf_body.make_node(activation_f, 'split_o') - h_subgraph = nf_body.make_node('Mul', [ot, nf_body.make_node(activation_h, c_subgraph)]) - - subgraph_outputs = handle_subgraph_outputs(nf_body, - seq_len_subgraph, - batch_size, - hidden_size, - [(h_subgraph, prev_h_subgraph), - (c_subgraph, prev_c_subgraph)] + - ([(h_subgraph, np.zeros(shape=(), dtype=np.float32))] if node.output[0] else []), # skip scan output if node.output[0] is empty - onnx_opset_ver) - - scan_attribs = {'body':scan_body, - 'scan_input_directions':[is_backward], - 'num_scan_inputs':1} + nf_body.make_value_info(split_o, data_type=onnx.TensorProto.FLOAT, shape=(batch_size, hidden_size)) + activation_f, activation_g, activation_h = activations[direction_index * 3 : (direction_index + 1) * 3] + it = nf_body.make_node(activation_f, "split_i") + ft = nf_body.make_node(activation_f, "split_f") + ct = nf_body.make_node(activation_g, "split_c") + c_subgraph = nf_body.make_node( + "Add", [nf_body.make_node("Mul", [ft, prev_c_subgraph]), nf_body.make_node("Mul", [it, ct])] + ) + ot = nf_body.make_node(activation_f, "split_o") + h_subgraph = nf_body.make_node("Mul", [ot, nf_body.make_node(activation_h, c_subgraph)]) + + subgraph_outputs = handle_subgraph_outputs( + nf_body, + seq_len_subgraph, + batch_size, + hidden_size, + [(h_subgraph, prev_h_subgraph), (c_subgraph, prev_c_subgraph)] + + ( + [(h_subgraph, np.zeros(shape=(), dtype=np.float32))] if node.output[0] else [] + ), # skip scan output if node.output[0] is empty + onnx_opset_ver, + ) + + scan_attribs = {"body": scan_body, "scan_input_directions": [is_backward], "num_scan_inputs": 1} if node.output[0]: - scan_attribs.update({'scan_output_directions':[is_backward]}) - scan = nf.make_node('Scan', ([seq_len] if seq_len else []) + [init_h, init_c, X_proj], - scan_attribs, - output_names=[o.name for o in subgraph_outputs[(0 if seq_len else 1):]]) + scan_attribs.update({"scan_output_directions": [is_backward]}) + scan = nf.make_node( + "Scan", + ([seq_len] if seq_len else []) + [init_h, init_c, X_proj], + scan_attribs, + output_names=[o.name for o in subgraph_outputs[(0 if seq_len else 1) :]], + ) scan_h_outputs.append(subgraph_outputs[1]) scan_c_outputs.append(subgraph_outputs[2]) if node.output[0]: scan_outputs.append(subgraph_outputs[3]) - handle_final_scan_outputs(node, nf, scan_outputs, [scan_h_outputs, scan_c_outputs], num_directions, onnx_opset_ver) + handle_final_scan_outputs( + node, nf, scan_outputs, [scan_h_outputs, scan_c_outputs], num_directions, onnx_opset_ver + ) # remove old initializers nf.remove_initializer(node.input[1]) @@ -407,8 +443,9 @@ def convert_lstm_to_scan(node, out_main_graph, onnx_opset_ver): nf.remove_initializer(node.input[6], allow_empty=True) return True + def convert_gru_to_scan(node, out_main_graph, onnx_opset_ver): - assert node.op_type == 'GRU' + assert node.op_type == "GRU" nf = NodeFactory(out_main_graph) with nf.scoped_prefix(node.output[0]) as scoped_prefix: X = node.input[0] @@ -419,10 +456,10 @@ def convert_gru_to_scan(node, out_main_graph, onnx_opset_ver): seq_len = node.input[4] if num_inputs > 4 else None InitHa = node.input[5] if num_inputs > 5 else None - direction, num_directions, activations = handle_common_attributes(node, ['Sigmoid', 'Tanh']) + direction, num_directions, activations = handle_common_attributes(node, ["Sigmoid", "Tanh"]) - hidden_size = NodeFactory.get_attribute(node, 'hidden_size') - linear_before_reset = NodeFactory.get_attribute(node, 'linear_before_reset') + hidden_size = NodeFactory.get_attribute(node, "hidden_size") + linear_before_reset = NodeFactory.get_attribute(node, "linear_before_reset") InitHa = handle_init_state(InitHa, nf, num_directions, onnx_opset_ver) batch_size, batch_node = handle_batch_size(X, nf, InitHa is None, onnx_opset_ver) @@ -440,7 +477,7 @@ def convert_gru_to_scan(node, out_main_graph, onnx_opset_ver): # seq_len [batch_size] # init_h [batch_size, hidden_size] - name_prefix = node.output[0] + '_' + str(direction_index) + '_' + name_prefix = node.output[0] + "_" + str(direction_index) + "_" if InitHa is None: init_h = zero_init_state @@ -448,39 +485,47 @@ def convert_gru_to_scan(node, out_main_graph, onnx_opset_ver): init_h = InitHa[direction_index] input_size = Wa.shape[len(Wa.shape) - 1] - W_t = np.transpose(Wa[direction_index]) # [input_size, 3*hidden_size] - R_t = np.transpose(Ra[direction_index]) # [hidden_size, 3*hidden_size] - Rzr_t, Rh_t = np.hsplit(R_t, [2*hidden_size]) # [hidden_size, 2*hidden_size] and [hidden_size, hidden_size] - Bzr, Bh = np.hsplit(Ba[direction_index].reshape(2, 3*hidden_size), [2*hidden_size]) - Bzr = Bzr.sum(axis=0) # [2*hidden_size] + W_t = np.transpose(Wa[direction_index]) # [input_size, 3*hidden_size] + R_t = np.transpose(Ra[direction_index]) # [hidden_size, 3*hidden_size] + Rzr_t, Rh_t = np.hsplit( + R_t, [2 * hidden_size] + ) # [hidden_size, 2*hidden_size] and [hidden_size, hidden_size] + Bzr, Bh = np.hsplit(Ba[direction_index].reshape(2, 3 * hidden_size), [2 * hidden_size]) + Bzr = Bzr.sum(axis=0) # [2*hidden_size] Wbh = Bh[0] Rbh = Bh[1] - X_proj = nf.make_node('Add', [nf.make_node('MatMul', [X, W_t]), np.concatenate((Bzr, Wbh))]) #[seq_len, batch_size, 3*hidden_size] + X_proj = nf.make_node( + "Add", [nf.make_node("MatMul", [X, W_t]), np.concatenate((Bzr, Wbh))] + ) # [seq_len, batch_size, 3*hidden_size] if num_directions == 1: - is_backward = 0 if direction == 'forward' else 1 + is_backward = 0 if direction == "forward" else 1 else: is_backward = direction_index scan_body = onnx.GraphProto() - scan_body.name = name_prefix + '_subgraph' + scan_body.name = name_prefix + "_subgraph" nf_body = NodeFactory(out_main_graph, scan_body) with nf_body.scoped_prefix(name_prefix) as body_scoped_prefix: # subgraph inputs - X_proj_subgraph = X_proj.name + '_subgraph' - prev_h_subgraph = name_prefix + '_h_subgraph' + X_proj_subgraph = X_proj.name + "_subgraph" + prev_h_subgraph = name_prefix + "_h_subgraph" seq_len_subgraph = declare_seq_len_in_subgraph(seq_len, nf_body, X_proj.name, batch_size) - nf_body.make_value_info(prev_h_subgraph, - data_type=onnx.TensorProto.FLOAT, - shape=(batch_size, hidden_size), - usage=NodeFactory.ValueInfoType.input) + nf_body.make_value_info( + prev_h_subgraph, + data_type=onnx.TensorProto.FLOAT, + shape=(batch_size, hidden_size), + usage=NodeFactory.ValueInfoType.input, + ) - nf_body.make_value_info(X_proj_subgraph, - data_type=onnx.TensorProto.FLOAT, - shape=(batch_size, 3*hidden_size), - usage=NodeFactory.ValueInfoType.input) + nf_body.make_value_info( + X_proj_subgraph, + data_type=onnx.TensorProto.FLOAT, + shape=(batch_size, 3 * hidden_size), + usage=NodeFactory.ValueInfoType.input, + ) # subgraph nodes # zt = f(Xt*(Wz^T) + Ht-1*(Rz^T) + Wbz + Rbz) @@ -489,70 +534,114 @@ def convert_gru_to_scan(node, out_main_graph, onnx_opset_ver): # ht = g(Xt*(Wh^T) + (rt (.) (Ht-1*(Rh^T) + Rbh)) + Wbh) # when linear_before_reset != 0 # Ht = (1 - zt) (.) ht + zt (.) Ht-1 - split_X_outputs = ['split_Xzr', 'split_Xh'] - nf_body.make_split_node(X_proj_subgraph, [2*hidden_size, hidden_size], onnx_opset_ver, attributes={"axis":1}, output_names=split_X_outputs) - nf_body.make_value_info('split_Xzr', - data_type=onnx.TensorProto.FLOAT, - shape=(batch_size, 2*hidden_size)) - nf_body.make_value_info('split_Xh', - data_type=onnx.TensorProto.FLOAT, - shape=(batch_size, hidden_size)) - - activation_f, activation_g = activations[direction_index*2:(direction_index+1)*2] + split_X_outputs = ["split_Xzr", "split_Xh"] + nf_body.make_split_node( + X_proj_subgraph, + [2 * hidden_size, hidden_size], + onnx_opset_ver, + attributes={"axis": 1}, + output_names=split_X_outputs, + ) + nf_body.make_value_info( + "split_Xzr", data_type=onnx.TensorProto.FLOAT, shape=(batch_size, 2 * hidden_size) + ) + nf_body.make_value_info("split_Xh", data_type=onnx.TensorProto.FLOAT, shape=(batch_size, hidden_size)) + + activation_f, activation_g = activations[direction_index * 2 : (direction_index + 1) * 2] if linear_before_reset: - prev_h_proj = nf_body.make_node('Add', [nf_body.make_node('MatMul', [prev_h_subgraph, R_t]), np.concatenate((np.zeros(2*hidden_size).astype(np.float32), Rbh))]) - split_prev_h_outputs = ['split_Hzr', 'split_Hh'] - nf_body.make_split_node(prev_h_proj, [2*hidden_size, hidden_size], onnx_opset_ver, {"axis":1}, output_names=split_prev_h_outputs) - nf_body.make_value_info('split_Hzr', - data_type=onnx.TensorProto.FLOAT, - shape=(batch_size, 2*hidden_size)) - nf_body.make_value_info('split_Hh', - data_type=onnx.TensorProto.FLOAT, - shape=(batch_size, hidden_size)) - ztrt = nf_body.make_node(activation_f, nf_body.make_node('Add', ['split_Hzr', 'split_Xzr'])) - split_ztrt_outputs = ['split_zt', 'split_rt'] - nf_body.make_split_node(ztrt, [hidden_size, hidden_size], onnx_opset_ver, {"axis":1}, output_names=split_ztrt_outputs) - nf_body.make_value_info('split_zt', - data_type=onnx.TensorProto.FLOAT, - shape=(batch_size, hidden_size)) - nf_body.make_value_info('split_rt', - data_type=onnx.TensorProto.FLOAT, - shape=(batch_size, hidden_size)) - ht = nf_body.make_node(activation_g, nf_body.make_node('Add', [nf_body.make_node('Mul', ['split_rt', 'split_Hh']), 'split_Xh'])) + prev_h_proj = nf_body.make_node( + "Add", + [ + nf_body.make_node("MatMul", [prev_h_subgraph, R_t]), + np.concatenate((np.zeros(2 * hidden_size).astype(np.float32), Rbh)), + ], + ) + split_prev_h_outputs = ["split_Hzr", "split_Hh"] + nf_body.make_split_node( + prev_h_proj, + [2 * hidden_size, hidden_size], + onnx_opset_ver, + {"axis": 1}, + output_names=split_prev_h_outputs, + ) + nf_body.make_value_info( + "split_Hzr", data_type=onnx.TensorProto.FLOAT, shape=(batch_size, 2 * hidden_size) + ) + nf_body.make_value_info( + "split_Hh", data_type=onnx.TensorProto.FLOAT, shape=(batch_size, hidden_size) + ) + ztrt = nf_body.make_node(activation_f, nf_body.make_node("Add", ["split_Hzr", "split_Xzr"])) + split_ztrt_outputs = ["split_zt", "split_rt"] + nf_body.make_split_node( + ztrt, [hidden_size, hidden_size], onnx_opset_ver, {"axis": 1}, output_names=split_ztrt_outputs + ) + nf_body.make_value_info( + "split_zt", data_type=onnx.TensorProto.FLOAT, shape=(batch_size, hidden_size) + ) + nf_body.make_value_info( + "split_rt", data_type=onnx.TensorProto.FLOAT, shape=(batch_size, hidden_size) + ) + ht = nf_body.make_node( + activation_g, + nf_body.make_node("Add", [nf_body.make_node("Mul", ["split_rt", "split_Hh"]), "split_Xh"]), + ) else: - ztrt = nf_body.make_node(activation_f, nf_body.make_node('Add', [nf_body.make_node('MatMul', [prev_h_subgraph, Rzr_t]), 'split_Xzr'])) - split_ztrt_outputs = ['split_zt', 'split_rt'] - nf_body.make_split_node(ztrt, [hidden_size, hidden_size], onnx_opset_ver, {"axis":1}, output_names=split_ztrt_outputs) - nf_body.make_value_info('split_zt', - data_type=onnx.TensorProto.FLOAT, - shape=(batch_size, hidden_size)) - nf_body.make_value_info('split_rt', - data_type=onnx.TensorProto.FLOAT, - shape=(batch_size, hidden_size)) - ht = nf_body.make_node(activation_g, nf_body.make_node('Add', [nf_body.make_node('MatMul', [nf_body.make_node('Mul', [prev_h_subgraph, 'split_rt']), Rh_t]), 'split_Xh'])) - - Ht = nf_body.make_node('Add', [nf_body.make_node('Mul', [nf_body.make_node('Sub', [np.asarray([1]).astype(np.float32), - 'split_zt']), - ht]), - nf_body.make_node('Mul', ['split_zt', prev_h_subgraph])]) - - subgraph_outputs = handle_subgraph_outputs(nf_body, - seq_len_subgraph, - batch_size, - hidden_size, - [(Ht, prev_h_subgraph)] + - ([(Ht, np.zeros(shape=(), dtype=np.float32))] if node.output[0] else []), - onnx_opset_ver) - - scan_attribs = {'body':scan_body, - 'scan_input_directions':[is_backward], - 'num_scan_inputs':1} + ztrt = nf_body.make_node( + activation_f, + nf_body.make_node("Add", [nf_body.make_node("MatMul", [prev_h_subgraph, Rzr_t]), "split_Xzr"]), + ) + split_ztrt_outputs = ["split_zt", "split_rt"] + nf_body.make_split_node( + ztrt, [hidden_size, hidden_size], onnx_opset_ver, {"axis": 1}, output_names=split_ztrt_outputs + ) + nf_body.make_value_info( + "split_zt", data_type=onnx.TensorProto.FLOAT, shape=(batch_size, hidden_size) + ) + nf_body.make_value_info( + "split_rt", data_type=onnx.TensorProto.FLOAT, shape=(batch_size, hidden_size) + ) + ht = nf_body.make_node( + activation_g, + nf_body.make_node( + "Add", + [ + nf_body.make_node( + "MatMul", [nf_body.make_node("Mul", [prev_h_subgraph, "split_rt"]), Rh_t] + ), + "split_Xh", + ], + ), + ) + + Ht = nf_body.make_node( + "Add", + [ + nf_body.make_node( + "Mul", [nf_body.make_node("Sub", [np.asarray([1]).astype(np.float32), "split_zt"]), ht] + ), + nf_body.make_node("Mul", ["split_zt", prev_h_subgraph]), + ], + ) + + subgraph_outputs = handle_subgraph_outputs( + nf_body, + seq_len_subgraph, + batch_size, + hidden_size, + [(Ht, prev_h_subgraph)] + ([(Ht, np.zeros(shape=(), dtype=np.float32))] if node.output[0] else []), + onnx_opset_ver, + ) + + scan_attribs = {"body": scan_body, "scan_input_directions": [is_backward], "num_scan_inputs": 1} if node.output[0]: - scan_attribs.update({'scan_output_directions':[is_backward]}) - scan = nf.make_node('Scan', ([seq_len] if seq_len else []) + [init_h, X_proj], - scan_attribs, - output_names=[o.name for o in subgraph_outputs[(0 if seq_len else 1):]]) + scan_attribs.update({"scan_output_directions": [is_backward]}) + scan = nf.make_node( + "Scan", + ([seq_len] if seq_len else []) + [init_h, X_proj], + scan_attribs, + output_names=[o.name for o in subgraph_outputs[(0 if seq_len else 1) :]], + ) scan_h_outputs.append(subgraph_outputs[1]) if node.output[0]: @@ -569,8 +658,9 @@ def convert_gru_to_scan(node, out_main_graph, onnx_opset_ver): nf.remove_initializer(node.input[5], allow_empty=True) return True + def convert_rnn_to_scan(node, out_main_graph, onnx_opset_ver): - assert node.op_type == 'RNN' + assert node.op_type == "RNN" nf = NodeFactory(out_main_graph) with nf.scoped_prefix(node.output[0]) as scoped_prefix: X = node.input[0] @@ -581,9 +671,9 @@ def convert_rnn_to_scan(node, out_main_graph, onnx_opset_ver): seq_len = node.input[4] if num_inputs > 4 else None InitHa = node.input[5] if num_inputs > 5 else None - direction, num_directions, activations = handle_common_attributes(node, ['Tanh']) + direction, num_directions, activations = handle_common_attributes(node, ["Tanh"]) - hidden_size = NodeFactory.get_attribute(node, 'hidden_size') + hidden_size = NodeFactory.get_attribute(node, "hidden_size") InitHa = handle_init_state(InitHa, nf, num_directions, onnx_opset_ver) @@ -602,7 +692,7 @@ def convert_rnn_to_scan(node, out_main_graph, onnx_opset_ver): # seq_len [batch_size] # init_h [batch_size, hidden_size] - name_prefix = node.output[0] + '_' + str(direction_index) + '_' + name_prefix = node.output[0] + "_" + str(direction_index) + "_" if InitHa is None: init_h = zero_init_state @@ -610,57 +700,66 @@ def convert_rnn_to_scan(node, out_main_graph, onnx_opset_ver): init_h = InitHa[direction_index] input_size = Wa.shape[len(Wa.shape) - 1] - W_t = np.transpose(Wa[direction_index]) # [input_size, hidden_size] - R_t = np.transpose(Ra[direction_index]) # [hidden_size, hidden_size] - B = Ba[direction_index].reshape(2, hidden_size).sum(axis=0) # [hidden_size] - X_proj = nf.make_node('Add', [nf.make_node('MatMul', [X, W_t]), B]) #[seq_len, batch_size, hidden_size] + W_t = np.transpose(Wa[direction_index]) # [input_size, hidden_size] + R_t = np.transpose(Ra[direction_index]) # [hidden_size, hidden_size] + B = Ba[direction_index].reshape(2, hidden_size).sum(axis=0) # [hidden_size] + X_proj = nf.make_node("Add", [nf.make_node("MatMul", [X, W_t]), B]) # [seq_len, batch_size, hidden_size] if num_directions == 1: - is_backward = 0 if direction == 'forward' else 1 + is_backward = 0 if direction == "forward" else 1 else: is_backward = direction_index scan_body = onnx.GraphProto() - scan_body.name = name_prefix + '_subgraph' + scan_body.name = name_prefix + "_subgraph" nf_body = NodeFactory(out_main_graph, scan_body) with nf_body.scoped_prefix(name_prefix) as body_scoped_prefix: # subgraph inputs - X_proj_subgraph = X_proj.name + '_subgraph' - prev_h_subgraph = name_prefix + '_h_subgraph' + X_proj_subgraph = X_proj.name + "_subgraph" + prev_h_subgraph = name_prefix + "_h_subgraph" seq_len_subgraph = declare_seq_len_in_subgraph(seq_len, nf_body, X_proj.name, batch_size) - nf_body.make_value_info(prev_h_subgraph, - data_type=onnx.TensorProto.FLOAT, - shape=(batch_size, hidden_size), - usage=NodeFactory.ValueInfoType.input) - - nf_body.make_value_info(X_proj_subgraph, - data_type=onnx.TensorProto.FLOAT, - shape=(batch_size, hidden_size), - usage=NodeFactory.ValueInfoType.input) + nf_body.make_value_info( + prev_h_subgraph, + data_type=onnx.TensorProto.FLOAT, + shape=(batch_size, hidden_size), + usage=NodeFactory.ValueInfoType.input, + ) + + nf_body.make_value_info( + X_proj_subgraph, + data_type=onnx.TensorProto.FLOAT, + shape=(batch_size, hidden_size), + usage=NodeFactory.ValueInfoType.input, + ) # subgraph nodes # Ht = f(Xt*(W^T) + Ht-1*(R^T) + Wb + Rb) activation_f = activations[direction_index] - Ht = nf_body.make_node(activation_f, nf_body.make_node('Add', [nf_body.make_node('MatMul', [prev_h_subgraph, R_t]), X_proj_subgraph])) - - subgraph_outputs = handle_subgraph_outputs(nf_body, - seq_len_subgraph, - batch_size, - hidden_size, - [(Ht, prev_h_subgraph)] + - ([(Ht, np.zeros(shape=(), dtype=np.float32))] if node.output[0] else []), - onnx_opset_ver) - - scan_attribs = {'body':scan_body, - 'scan_input_directions':[is_backward], - 'num_scan_inputs':1} + Ht = nf_body.make_node( + activation_f, + nf_body.make_node("Add", [nf_body.make_node("MatMul", [prev_h_subgraph, R_t]), X_proj_subgraph]), + ) + + subgraph_outputs = handle_subgraph_outputs( + nf_body, + seq_len_subgraph, + batch_size, + hidden_size, + [(Ht, prev_h_subgraph)] + ([(Ht, np.zeros(shape=(), dtype=np.float32))] if node.output[0] else []), + onnx_opset_ver, + ) + + scan_attribs = {"body": scan_body, "scan_input_directions": [is_backward], "num_scan_inputs": 1} if node.output[0]: - scan_attribs.update({'scan_output_directions':[is_backward]}) - scan = nf.make_node('Scan', ([seq_len] if seq_len else []) + [init_h, X_proj], - scan_attribs, - output_names=[o.name for o in subgraph_outputs[(0 if seq_len else 1):]]) + scan_attribs.update({"scan_output_directions": [is_backward]}) + scan = nf.make_node( + "Scan", + ([seq_len] if seq_len else []) + [init_h, X_proj], + scan_attribs, + output_names=[o.name for o in subgraph_outputs[(0 if seq_len else 1) :]], + ) scan_h_outputs.append(subgraph_outputs[1]) if node.output[0]: @@ -677,25 +776,30 @@ def convert_rnn_to_scan(node, out_main_graph, onnx_opset_ver): nf.remove_initializer(node.input[5]) return True + def convert_loop_to_scan_model(input_model, output_model, keep_unconvertible_loop_ops=None): in_mp = onnx.load(input_model) out_mp = onnx.ModelProto() out_mp.CopyFrom(in_mp) - out_mp.ir_version = 5 # update ir version to avoid requirement of initializer in graph input - onnx_opset_ver = ensure_opset(out_mp, 9) # bump up to ONNX opset 9, which is required for Scan - out_mp.graph.ClearField('node') + out_mp.ir_version = 5 # update ir version to avoid requirement of initializer in graph input + onnx_opset_ver = ensure_opset(out_mp, 9) # bump up to ONNX opset 9, which is required for Scan + out_mp.graph.ClearField("node") cast_node_to_remove = [] loop_cond_initializer_to_remove = [] loop_cond_const_node_to_remove = [] for in_n in in_mp.graph.node: - if in_n.op_type == 'Loop': + if in_n.op_type == "Loop": cast_node = None cond_initializer = None cond_const_node = None for n in in_mp.graph.node: if n.op_type == "Cast" and n.output[0] == in_n.input[1]: - cond_initializers = [initializer for initializer in in_mp.graph.initializer if initializer.name == n.input[0]] - cond_const_nodes = [n_c for n_c in in_mp.graph.node if n_c.op_type == "Constant" and n_c.output[0] == n.input[0]] + cond_initializers = [ + initializer for initializer in in_mp.graph.initializer if initializer.name == n.input[0] + ] + cond_const_nodes = [ + n_c for n_c in in_mp.graph.node if n_c.op_type == "Constant" and n_c.output[0] == n.input[0] + ] if len(cond_initializers) == 1: # TODO: assert the the initializer raw data is not 0 (False) cast_node = n @@ -705,15 +809,15 @@ def convert_loop_to_scan_model(input_model, output_model, keep_unconvertible_loo cast_node = n cond_const_node = cond_const_nodes[0] break - + if cast_node: cast_node_to_remove = [*cast_node_to_remove, cast_node] if cond_initializer: loop_cond_initializer_to_remove = [*loop_cond_initializer_to_remove, cond_initializer] elif cond_const_node: loop_cond_const_node_to_remove = [*loop_cond_const_node_to_remove, cond_const_node] - - # at this point, it looks like that this Loop op can be converted to Scan. + + # at this point, it looks like that this Loop op can be converted to Scan. # however, convert_loop_to_scan may still fail when looking at the Loop's subgraph. scan_op = convert_loop_to_scan(in_n, out_mp.graph, keep_unconvertible_loop_ops) if scan_op: @@ -726,10 +830,9 @@ def convert_loop_to_scan_model(input_model, output_model, keep_unconvertible_loo else: raise RuntimeError("Cannot convert a Loop op to Scan: " + reason) - out_n = out_mp.graph.node.add() out_n.CopyFrom(in_n) - + for cast_n in cast_node_to_remove: out_mp.graph.node.remove(cast_n) for value_info in out_mp.graph.value_info: @@ -749,23 +852,24 @@ def convert_loop_to_scan_model(input_model, output_model, keep_unconvertible_loo onnx.save(out_mp, output_model) + def convert_to_scan_model(input_model, output_model): in_mp = onnx.load(input_model) out_mp = onnx.ModelProto() out_mp.CopyFrom(in_mp) - out_mp.ir_version = 5 # update ir version to avoid requirement of initializer in graph input - onnx_opset_ver = ensure_opset(out_mp, 9) # bump up to ONNX opset 9, which is required for Scan - out_mp.graph.ClearField('node') + out_mp.ir_version = 5 # update ir version to avoid requirement of initializer in graph input + onnx_opset_ver = ensure_opset(out_mp, 9) # bump up to ONNX opset 9, which is required for Scan + out_mp.graph.ClearField("node") for in_n in in_mp.graph.node: - if in_n.op_type in ['LSTM', 'GRU', 'RNN']: + if in_n.op_type in ["LSTM", "GRU", "RNN"]: in_n = trim_unused_outputs(in_n, in_mp.graph) - if in_n.op_type == 'LSTM': + if in_n.op_type == "LSTM": if convert_lstm_to_scan(in_n, out_mp.graph, onnx_opset_ver): continue - if in_n.op_type == 'GRU': + if in_n.op_type == "GRU": if convert_gru_to_scan(in_n, out_mp.graph, onnx_opset_ver): continue - if in_n.op_type == 'RNN': + if in_n.op_type == "RNN": if convert_rnn_to_scan(in_n, out_mp.graph, onnx_opset_ver): continue out_n = out_mp.graph.node.add() @@ -773,13 +877,14 @@ def convert_to_scan_model(input_model, output_model): onnx.save(out_mp, output_model) + def gemm_to_matmul(node, nf, converted_initializers): - assert node.op_type == 'Gemm' + assert node.op_type == "Gemm" - alpha = NodeFactory.get_attribute(node, 'alpha', 1.0) - beta = NodeFactory.get_attribute(node, 'beta', 1.0) - transA = NodeFactory.get_attribute(node, 'transA', 0) - transB = NodeFactory.get_attribute(node, 'transB', 0) + alpha = NodeFactory.get_attribute(node, "alpha", 1.0) + beta = NodeFactory.get_attribute(node, "beta", 1.0) + transA = NodeFactory.get_attribute(node, "transA", 0) + transB = NodeFactory.get_attribute(node, "transB", 0) A = node.input[0] B = node.input[1] @@ -787,9 +892,9 @@ def gemm_to_matmul(node, nf, converted_initializers): with nf.scoped_prefix(node.name) as scoped_prefix: if alpha != 1.0: - alpha_name = node.name + '_Const_alpha' + alpha_name = node.name + "_Const_alpha" nf.make_initializer(np.full((), alpha, dtype=np.float32), alpha_name) - alpha_A = nf.make_node('Mul', [alpha_name, A]) + alpha_A = nf.make_node("Mul", [alpha_name, A]) A = alpha_A.name if transA: @@ -799,13 +904,13 @@ def gemm_to_matmul(node, nf, converted_initializers): A_initializer = nf.get_initializer(A) # A is an initializer if A_initializer is not None: - new_A = A + '_trans' + new_A = A + "_trans" converted_initializers[A] = new_A nf.make_initializer(np.transpose(A_initializer), new_A, in_main_graph=True) nf.remove_initializer(A) A = new_A else: - A = nf.make_node('Transpose', A) + A = nf.make_node("Transpose", A) if transB: if B in converted_initializers: B = converted_initializers[B] @@ -813,31 +918,32 @@ def gemm_to_matmul(node, nf, converted_initializers): B_initializer = nf.get_initializer(B) # B is an initializer if B_initializer is not None: - new_B = B + '_trans' + new_B = B + "_trans" converted_initializers[B] = new_B nf.make_initializer(np.transpose(B_initializer), new_B, in_main_graph=True) nf.remove_initializer(B) B = new_B else: - B = nf.make_node('Transpose', B) + B = nf.make_node("Transpose", B) if len(node.input) != 3 or beta == 0.0: - nf.make_node('MatMul', [A, B], output_names=Y) + nf.make_node("MatMul", [A, B], output_names=Y) else: - AB = nf.make_node('MatMul', [A, B]) + AB = nf.make_node("MatMul", [A, B]) C = node.input[2] if beta != 1.0: - beta_name = node.name + '_Const_beta' + beta_name = node.name + "_Const_beta" nf.make_initializer(np.full((), beta, dtype=np.float32), beta_name) - C = nf.make_node('Mul', [beta_name, C]) - nf.make_node('Add', [AB, C], output_names=Y) + C = nf.make_node("Mul", [beta_name, C]) + nf.make_node("Add", [AB, C], output_names=Y) + def convert_gemm_to_matmul(input_model, output_model): in_mp = onnx.load(input_model) out_mp = onnx.ModelProto() out_mp.CopyFrom(in_mp) - out_mp.ir_version = 5 # update ir version to avoid requirement of initializer in graph input - out_mp.graph.ClearField('node') + out_mp.ir_version = 5 # update ir version to avoid requirement of initializer in graph input + out_mp.graph.ClearField("node") nf = NodeFactory(out_mp.graph) # gemm_to_matmul will generate transposed weights if the corresponding input # comes from initializer. We keep a map between the original and converted @@ -845,20 +951,20 @@ def convert_gemm_to_matmul(input_model, output_model): converted_initializers = {} for in_n in in_mp.graph.node: - if in_n.op_type == 'Gemm': + if in_n.op_type == "Gemm": gemm_to_matmul(in_n, nf, converted_initializers) continue out_n = out_mp.graph.node.add() out_n.CopyFrom(in_n) - if in_n.op_type == 'Scan' or in_n.op_type == 'Loop': - in_subgraph = NodeFactory.get_attribute(in_n, 'body') - out_subgraph = NodeFactory.get_attribute(out_n, 'body') - out_subgraph.ClearField('node') + if in_n.op_type == "Scan" or in_n.op_type == "Loop": + in_subgraph = NodeFactory.get_attribute(in_n, "body") + out_subgraph = NodeFactory.get_attribute(out_n, "body") + out_subgraph.ClearField("node") scan_nf = NodeFactory(out_mp.graph, out_subgraph) for in_sn in in_subgraph.node: - if in_sn.op_type == 'Gemm': + if in_sn.op_type == "Gemm": gemm_to_matmul(in_sn, scan_nf, converted_initializers) continue out_sn = out_subgraph.node.add() @@ -866,6 +972,7 @@ def convert_gemm_to_matmul(input_model, output_model): onnx.save(out_mp, output_model) + # Old models (ir_version < 4) is required to initializers in graph inputs # This is optional for ir_version >= 4 def remove_initializers_from_inputs(input_model, output_model, remain_inputs=[]): @@ -874,27 +981,28 @@ def remove_initializers_from_inputs(input_model, output_model, remain_inputs=[]) def _append_initializer_from_graph(graph): initializers = [i.name for i in graph.initializer] for node in graph.node: - if node.op_type == 'Scan': # currently only handle Scan - subgraph = NodeFactory.get_attribute(node, 'body') + if node.op_type == "Scan": # currently only handle Scan + subgraph = NodeFactory.get_attribute(node, "body") initializers += _append_initializer_from_graph(subgraph) return initializers all_initializer_names = [n for n in _append_initializer_from_graph(mp.graph) if n not in remain_inputs] new_inputs = [vi for vi in mp.graph.input if not vi.name in all_initializer_names] - mp.graph.ClearField('input') + mp.graph.ClearField("input") mp.graph.input.extend(new_inputs) onnx.save(mp, output_model) + def optimize_input_projection(input_model, output_model): in_mp = onnx.load(input_model) out_mp = onnx.ModelProto() out_mp.CopyFrom(in_mp) - out_mp.ir_version = 5 # update ir version to avoid requirement of initializer in graph input - out_mp.graph.ClearField('node') - nf = NodeFactory(out_mp.graph, prefix='opt_inproj_') + out_mp.ir_version = 5 # update ir version to avoid requirement of initializer in graph input + out_mp.graph.ClearField("node") + nf = NodeFactory(out_mp.graph, prefix="opt_inproj_") initializers = dict([(i.name, i) for i in in_mp.graph.initializer]) # first find possible fused SVD and do constant folding on MatMul of initializers - const_matmuls = [n for n in in_mp.graph.node if n.op_type == 'MatMul' and all([i in initializers for i in n.input])] + const_matmuls = [n for n in in_mp.graph.node if n.op_type == "MatMul" and all([i in initializers for i in n.input])] for mm in const_matmuls: lhs = numpy_helper.to_array(initializers[mm.input[0]]) rhs = numpy_helper.to_array(initializers[mm.input[1]]) @@ -906,11 +1014,11 @@ def optimize_input_projection(input_model, output_model): if not [n for n in in_mp.graph.node if n != mm and mm.input[1] in n.input]: nf.remove_initializer(mm.input[1]) - initializers = dict([(i.name,i) for i in out_mp.graph.initializer]) + initializers = dict([(i.name, i) for i in out_mp.graph.initializer]) # remove const_matmul output from graph outputs new_outputs = [i for i in out_mp.graph.output if not [m for m in const_matmuls if m.output[0] == i.name]] - out_mp.graph.ClearField('output') + out_mp.graph.ClearField("output") out_mp.graph.output.extend(new_outputs) for in_n in in_mp.graph.node: @@ -918,9 +1026,9 @@ def optimize_input_projection(input_model, output_model): continue optimize_scan = False - if in_n.op_type == 'Scan': - in_sg = NodeFactory.get_attribute(in_n, 'body') - num_scan_inputs = NodeFactory.get_attribute(in_n, 'num_scan_inputs') + if in_n.op_type == "Scan": + in_sg = NodeFactory.get_attribute(in_n, "body") + num_scan_inputs = NodeFactory.get_attribute(in_n, "num_scan_inputs") # only support 1 scan input if num_scan_inputs == 1: optimize_scan = True @@ -931,21 +1039,21 @@ def optimize_input_projection(input_model, output_model): out_n.CopyFrom(in_n) continue - scan_input_directions = NodeFactory.get_attribute(in_n, 'scan_input_directions') - scan_output_directions = NodeFactory.get_attribute(in_n, 'scan_output_directions') + scan_input_directions = NodeFactory.get_attribute(in_n, "scan_input_directions") + scan_output_directions = NodeFactory.get_attribute(in_n, "scan_output_directions") out_sg = onnx.GraphProto() out_sg.CopyFrom(in_sg) - out_sg.ClearField('node') - nf_subgraph = NodeFactory(out_mp.graph, out_sg, prefix='opt_inproj_sg_' + in_n.name + '_') + out_sg.ClearField("node") + nf_subgraph = NodeFactory(out_mp.graph, out_sg, prefix="opt_inproj_sg_" + in_n.name + "_") new_inputs = list(in_n.input) in_sg_inputs = [i.name for i in in_sg.input] replaced_matmul = None for in_sn in in_sg.node: - if in_sn.op_type == 'Concat' and len(in_sn.input) == 2 and all([i in in_sg_inputs for i in in_sn.input]): + if in_sn.op_type == "Concat" and len(in_sn.input) == 2 and all([i in in_sg_inputs for i in in_sn.input]): # make sure the concat's inputs are scan input and scan state - if NodeFactory.get_attribute(in_sn, 'axis') != len(in_sg.input[-1].type.tensor_type.shape.dim) - 1: - continue # must concat last dim - matmul_node = [nn for nn in in_sg.node if nn.op_type == 'MatMul' and in_sn.output[0] in nn.input] + if NodeFactory.get_attribute(in_sn, "axis") != len(in_sg.input[-1].type.tensor_type.shape.dim) - 1: + continue # must concat last dim + matmul_node = [nn for nn in in_sg.node if nn.op_type == "MatMul" and in_sn.output[0] in nn.input] if not matmul_node: continue replaced_matmul = matmul_node[0] @@ -959,70 +1067,82 @@ def optimize_input_projection(input_model, output_model): hidden_idx = 0 hidden_proj_weights, input_proj_weights = np.vsplit(aa, [aa.shape[-1] - input_size]) # add matmul for input_proj outside of Scan - input_proj = nf.make_node('MatMul', [new_inputs[-1], input_proj_weights]) + input_proj = nf.make_node("MatMul", [new_inputs[-1], input_proj_weights]) input_proj.doc_string = replaced_matmul.doc_string new_inputs[-1] = input_proj.name out_sg.input[-1].type.tensor_type.shape.dim[-1].dim_value = input_proj_weights.shape[-1] # add matmul for hidden_proj inside Scan - hidden_proj = nf_subgraph.make_node('MatMul', [in_sn.input[hidden_idx], hidden_proj_weights]) + hidden_proj = nf_subgraph.make_node("MatMul", [in_sn.input[hidden_idx], hidden_proj_weights]) hidden_proj.doc_string = replaced_matmul.doc_string - nf_subgraph.make_node('Add', [out_sg.input[-1].name, hidden_proj], output_names=replaced_matmul.output[0]) + nf_subgraph.make_node( + "Add", [out_sg.input[-1].name, hidden_proj], output_names=replaced_matmul.output[0] + ) # remove initializer of concat matmul if not [n for n in in_mp.graph.node if n != in_n and replaced_matmul.input[1] in n.input]: nf.remove_initializer(replaced_matmul.input[1]) elif in_sn != replaced_matmul: out_sg.node.add().CopyFrom(in_sn) - scan = nf.make_node('Scan', new_inputs, - {'body':out_sg, - 'scan_input_directions':scan_input_directions, - 'scan_output_directions':scan_output_directions, - 'num_scan_inputs':num_scan_inputs}, - output_names=list(in_n.output)) + scan = nf.make_node( + "Scan", + new_inputs, + { + "body": out_sg, + "scan_input_directions": scan_input_directions, + "scan_output_directions": scan_output_directions, + "num_scan_inputs": num_scan_inputs, + }, + output_names=list(in_n.output), + ) scan.name = in_n.name scan.doc_string = in_n.doc_string onnx.save(out_mp, output_model) + def parse_arguments(): parser = argparse.ArgumentParser() - parser.add_argument('--mode', help='The modification mode', - choices=['to_scan', - 'opt_inproj', - 'gemm_to_matmul', - 'remove_initializers_from_inputs', - 'loop_to_scan']) - parser.add_argument('--input', help='The input model file', default=None) - parser.add_argument('--output', help='The output model file', default=None) - parser.add_argument('--keep_unconvertible_loop_ops', help='Whether to keep unconvertible (to Scan) Loops. \ + parser.add_argument( + "--mode", + help="The modification mode", + choices=["to_scan", "opt_inproj", "gemm_to_matmul", "remove_initializers_from_inputs", "loop_to_scan"], + ) + parser.add_argument("--input", help="The input model file", default=None) + parser.add_argument("--output", help="The output model file", default=None) + parser.add_argument( + "--keep_unconvertible_loop_ops", + help="Whether to keep unconvertible (to Scan) Loops. \ If set, model editing will keep unconvertible (to Scan) Loops. \ - If not set, it will fail the editing when there is any Loop that is unconvertible to Scan op.', - default=None, action='store_true') + If not set, it will fail the editing when there is any Loop that is unconvertible to Scan op.", + default=None, + action="store_true", + ) return parser.parse_args() -if __name__ == '__main__': + +if __name__ == "__main__": args = parse_arguments() - print('input model: ' + args.input) - print('output model ' + args.output) - if args.mode == 'to_scan': - print('Convert LSTM/GRU/RNN to Scan...') + print("input model: " + args.input) + print("output model " + args.output) + if args.mode == "to_scan": + print("Convert LSTM/GRU/RNN to Scan...") convert_to_scan_model(args.input, args.output) - elif args.mode == 'gemm_to_matmul': - print('Convert Gemm to MatMul') + elif args.mode == "gemm_to_matmul": + print("Convert Gemm to MatMul") convert_gemm_to_matmul(args.input, args.output) - elif args.mode == 'opt_inproj': - print('Optimize input projection in Scan...') + elif args.mode == "opt_inproj": + print("Optimize input projection in Scan...") optimize_input_projection(args.input, args.output) - elif args.mode == 'remove_initializers_from_inputs': - print('Remove all initializers from input for model with IR version >= 4...') + elif args.mode == "remove_initializers_from_inputs": + print("Remove all initializers from input for model with IR version >= 4...") remove_initializers_from_inputs(args.input, args.output) - elif args.mode == 'loop_to_scan': - print('Convert Loop to Scan') + elif args.mode == "loop_to_scan": + print("Convert Loop to Scan") convert_loop_to_scan_model(args.input, args.output, args.keep_unconvertible_loop_ops) else: - raise NotImplementedError('Unknown mode') - print('Running symbolic shape inference on output model') + raise NotImplementedError("Unknown mode") + print("Running symbolic shape inference on output model") mp = onnx.load(args.output) mp = SymbolicShapeInference.infer_shapes(mp, auto_merge=True) onnx.save(mp, args.output) - print('Done!') + print("Done!") diff --git a/onnxruntime/core/providers/nuphar/scripts/model_quantizer.py b/onnxruntime/core/providers/nuphar/scripts/model_quantizer.py index 4c1f84d34ef62..e0caf3ef57209 100644 --- a/onnxruntime/core/providers/nuphar/scripts/model_quantizer.py +++ b/onnxruntime/core/providers/nuphar/scripts/model_quantizer.py @@ -3,13 +3,16 @@ # -*- coding: UTF-8 -*- import argparse -from enum import Enum import json +from enum import Enum + import numpy as np import onnx from onnx import helper, numpy_helper + from .node_factory import NodeFactory, ensure_opset + class QuantizeConfig: def __init__(self, signed, reserved_bits, type_bits): self.sign_bit_ = 1 if signed else 0 @@ -18,9 +21,9 @@ def __init__(self, signed, reserved_bits, type_bits): @staticmethod def from_dict(qcfg_dict): - return QuantizeConfig(1 if qcfg_dict['QuantizationType'] == 'Signed' else 0, - qcfg_dict['ReservedBit'], - qcfg_dict['QuantizeBit']) + return QuantizeConfig( + 1 if qcfg_dict["QuantizationType"] == "Signed" else 0, qcfg_dict["ReservedBit"], qcfg_dict["QuantizeBit"] + ) def signed(self): return self.sign_bit_ == 1 @@ -47,10 +50,15 @@ def q_type(self): def q_type_bits(self): return self.type_bits_ - def __iter__(self): # need this to make dict for json - return iter([('QuantizeBit', self.type_bits_), - ('QuantizationType', 'Signed' if self.sign_bit_ else 'Unsigned'), - ('ReservedBit', self.reserved_bits_)]) + def __iter__(self): # need this to make dict for json + return iter( + [ + ("QuantizeBit", self.type_bits_), + ("QuantizationType", "Signed" if self.sign_bit_ else "Unsigned"), + ("ReservedBit", self.reserved_bits_), + ] + ) + def parse_custom_attributes(in_node): if in_node.doc_string: @@ -67,40 +75,56 @@ def parse_custom_attributes(in_node): # "ReservedBitOfMatrix":0}} qcfg_str = in_node.doc_string # make sure it's the string we can parse - if 'custom_attributes' in qcfg_str: + if "custom_attributes" in qcfg_str: # some fixes to make it a valid JSON string, when model keys are not string - if qcfg_str[1] == 'c': - qcfg_str = qcfg_str.replace('{', '{"') - qcfg_str = qcfg_str.replace(',', ',"') - qcfg_str = qcfg_str.replace(':', '":') - qcfg_str = qcfg_str.replace('{"}', '{}') - qcfg = json.loads(qcfg_str)['custom_attributes'] + if qcfg_str[1] == "c": + qcfg_str = qcfg_str.replace("{", '{"') + qcfg_str = qcfg_str.replace(",", ',"') + qcfg_str = qcfg_str.replace(":", '":') + qcfg_str = qcfg_str.replace('{"}', "{}") + qcfg = json.loads(qcfg_str)["custom_attributes"] if qcfg: return qcfg return None + def parse_node_description(in_node): if not in_node.doc_string: return None custom_qcfg = parse_custom_attributes(in_node) if custom_qcfg: - assert custom_qcfg['IntermediateBit'] == 32 - assert custom_qcfg['PerRowQuantization'] - assert custom_qcfg['QuantizeBitOfVector'] == custom_qcfg['QuantizeBitOfMatrix'] - qbits = custom_qcfg['QuantizeBitOfVector'] - assert ("Asymmetric" in custom_qcfg['VectorQuantizationType']) == ("Asymmetric" in custom_qcfg['MatrixQuantizationType']) - symmetric = 0 if "Asymmetric" in custom_qcfg['VectorQuantizationType'] else 1 - x_signed = 0 if "Unsigned" in custom_qcfg['VectorQuantizationType'] else 1 - w_signed = 0 if "Unsigned" in custom_qcfg['MatrixQuantizationType'] else 1 - x_reserved_bits = custom_qcfg['ReservedBitOfVector'] - w_reserved_bits = custom_qcfg['ReservedBitOfMatrix'] - return {'W' : dict(QuantizeConfig(signed=w_signed, reserved_bits=w_reserved_bits, type_bits=qbits)), - 'X' : dict(QuantizeConfig(signed=x_signed, reserved_bits=x_reserved_bits, type_bits=qbits)), - 'Symmetric' : symmetric} + assert custom_qcfg["IntermediateBit"] == 32 + assert custom_qcfg["PerRowQuantization"] + assert custom_qcfg["QuantizeBitOfVector"] == custom_qcfg["QuantizeBitOfMatrix"] + qbits = custom_qcfg["QuantizeBitOfVector"] + assert ("Asymmetric" in custom_qcfg["VectorQuantizationType"]) == ( + "Asymmetric" in custom_qcfg["MatrixQuantizationType"] + ) + symmetric = 0 if "Asymmetric" in custom_qcfg["VectorQuantizationType"] else 1 + x_signed = 0 if "Unsigned" in custom_qcfg["VectorQuantizationType"] else 1 + w_signed = 0 if "Unsigned" in custom_qcfg["MatrixQuantizationType"] else 1 + x_reserved_bits = custom_qcfg["ReservedBitOfVector"] + w_reserved_bits = custom_qcfg["ReservedBitOfMatrix"] + return { + "W": dict(QuantizeConfig(signed=w_signed, reserved_bits=w_reserved_bits, type_bits=qbits)), + "X": dict(QuantizeConfig(signed=x_signed, reserved_bits=x_reserved_bits, type_bits=qbits)), + "Symmetric": symmetric, + } return None -def quantize_matmul_2d_with_weight(in_node, in_graph, nf, converted_weights, quantized_inputs, qcfg_dict, update_qcfg_dict, default_qcfg, onnx_opset_ver): - assert in_node.op_type == 'MatMul' + +def quantize_matmul_2d_with_weight( + in_node, + in_graph, + nf, + converted_weights, + quantized_inputs, + qcfg_dict, + update_qcfg_dict, + default_qcfg, + onnx_opset_ver, +): + assert in_node.op_type == "MatMul" # quantize weight # only handles weight being inputs[1] of MatMul/Gemm node @@ -108,7 +132,7 @@ def quantize_matmul_2d_with_weight(in_node, in_graph, nf, converted_weights, qua # skip if weights shared by other nodes that's not MatMul # TODO: support GEMM op if needed - other_nodes = [n for n in in_graph.node if n != in_node and fparam_name in n.input and n.op_type != 'MatMul'] + other_nodes = [n for n in in_graph.node if n != in_node and fparam_name in n.input and n.op_type != "MatMul"] if other_nodes: return False @@ -119,12 +143,16 @@ def quantize_matmul_2d_with_weight(in_node, in_graph, nf, converted_weights, qua if not node_qcfg: if not update_qcfg_dict and qcfg_dict: # when qcfg_dict is readonly, raise warning if qcfg is not found for this node - print("Warning: qcfg is not found for node with output: " + in_node.output[0] + ", fall back to default qcfg.") + print( + "Warning: qcfg is not found for node with output: " + + in_node.output[0] + + ", fall back to default qcfg." + ) node_qcfg = default_qcfg - w_qcfg = QuantizeConfig.from_dict(node_qcfg['W']) - x_qcfg = QuantizeConfig.from_dict(node_qcfg['X']) - symmetric = node_qcfg['Symmetric'] + w_qcfg = QuantizeConfig.from_dict(node_qcfg["W"]) + x_qcfg = QuantizeConfig.from_dict(node_qcfg["X"]) + symmetric = node_qcfg["Symmetric"] # for symmetric quantization, both weight and input should be quantized to signed assert not symmetric or (w_qcfg.signed() and x_qcfg.signed()) @@ -149,32 +177,34 @@ def quantize_matmul_2d_with_weight(in_node, in_graph, nf, converted_weights, qua else: fmin = np.amin(fparam, axis=0) fmax = np.amax(fparam, axis=0) - fscale = (fmax - fmin)/(2 if w_qcfg.signed() else 1) # signed would be normalized to [-1, 1], and unsigned to [0, 1] + fscale = (fmax - fmin) / ( + 2 if w_qcfg.signed() else 1 + ) # signed would be normalized to [-1, 1], and unsigned to [0, 1] step = fscale / q_range base = (fmax + fmin + step) * 0.5 if w_qcfg.signed() else fmin fparam_norm = np.zeros_like(fparam) - expand_fscale = np.expand_dims(fscale,0) - np.divide((fparam - np.expand_dims(base,0)), expand_fscale, out=fparam_norm, where=expand_fscale!=0) + expand_fscale = np.expand_dims(fscale, 0) + np.divide((fparam - np.expand_dims(base, 0)), expand_fscale, out=fparam_norm, where=expand_fscale != 0) qparam = np.round(fparam_norm * q_range) qparam = np.clip(qparam, w_qcfg.q_min(), w_qcfg.q_max()) qparam_rowsum = np.sum(qparam, axis=0) qparam = qparam.astype(w_qcfg.q_type()) # create new weights in main graph in case other Scans share via converted_weights - nf.make_initializer(step, fparam_name + '_step', in_main_graph=True) - nf.make_initializer(qparam, fparam_name + '_qparam', in_main_graph=True) - step = fparam_name + '_step' - qparam = fparam_name + '_qparam' + nf.make_initializer(step, fparam_name + "_step", in_main_graph=True) + nf.make_initializer(qparam, fparam_name + "_qparam", in_main_graph=True) + step = fparam_name + "_step" + qparam = fparam_name + "_qparam" if symmetric: # no need to compute qparam_rowsum and base for symmetric quantization base = None qparam_rowsum = None else: - nf.make_initializer(base, fparam_name + '_base', in_main_graph=True) - base = fparam_name + '_base' - nf.make_initializer(qparam_rowsum, fparam_name + '_qparam_rowsum', in_main_graph=True) - qparam_rowsum = fparam_name + '_qparam_rowsum' + nf.make_initializer(base, fparam_name + "_base", in_main_graph=True) + base = fparam_name + "_base" + nf.make_initializer(qparam_rowsum, fparam_name + "_qparam_rowsum", in_main_graph=True) + qparam_rowsum = fparam_name + "_qparam_rowsum" converted_weights[fparam_name] = (step, base, qparam_rowsum, qparam, w_qcfg, symmetric) nf.remove_initializer(fparam_name) @@ -183,136 +213,216 @@ def quantize_matmul_2d_with_weight(in_node, in_graph, nf, converted_weights, qua input_dim = nf.get_initializer(qparam).shape[0] X = in_node.input[0] if quantized_inputs is not None: - quantized_inputs_key = '{}_{}_{}'.format(X, symmetric, '|'.join(['{}:{}'.format(k,v) for (k, v) in x_qcfg])) + quantized_inputs_key = "{}_{}_{}".format( + X, symmetric, "|".join(["{}:{}".format(k, v) for (k, v) in x_qcfg]) + ) if quantized_inputs is not None and quantized_inputs_key in quantized_inputs: scale_X, bias_X, Q_X, Q_X_sum_int32 = quantized_inputs[quantized_inputs_key] else: if symmetric: - delta_X = nf.make_node('ReduceMax', nf.make_node('Abs', X), {'axes':[-1]}) # keepdims = 1 - inv_delta_X = nf.make_node('Reciprocal', delta_X) - norm_X = nf.make_node('Mul', [X, inv_delta_X]) + delta_X = nf.make_node("ReduceMax", nf.make_node("Abs", X), {"axes": [-1]}) # keepdims = 1 + inv_delta_X = nf.make_node("Reciprocal", delta_X) + norm_X = nf.make_node("Mul", [X, inv_delta_X]) bias_X = None assert x_qcfg.signed() else: - reduce_max_X = nf.make_node('ReduceMax', X, {'axes':[-1]}) # keepdims = 1 - bias_X = nf.make_node('ReduceMin', X, {'axes':[-1]}) - delta_X = nf.make_node('Sub', [reduce_max_X, bias_X]) - inv_delta_X = nf.make_node('Reciprocal', delta_X) - norm_X = nf.make_node('Mul', [nf.make_node('Sub', [X, bias_X]), inv_delta_X]) - - scale_X = nf.make_node('Mul', [delta_X, np.asarray(1.0 / x_qcfg.q_range()).astype(np.float32)]) - Q_Xf = nf.make_node('Mul', [norm_X, np.asarray(x_qcfg.q_range()).astype(np.float32)]) - Q_Xf = nf.make_node('Add', [Q_Xf, np.asarray(0.5).astype(np.float32)]) - Q_Xf = nf.make_node('Floor', Q_Xf) + reduce_max_X = nf.make_node("ReduceMax", X, {"axes": [-1]}) # keepdims = 1 + bias_X = nf.make_node("ReduceMin", X, {"axes": [-1]}) + delta_X = nf.make_node("Sub", [reduce_max_X, bias_X]) + inv_delta_X = nf.make_node("Reciprocal", delta_X) + norm_X = nf.make_node("Mul", [nf.make_node("Sub", [X, bias_X]), inv_delta_X]) + + scale_X = nf.make_node("Mul", [delta_X, np.asarray(1.0 / x_qcfg.q_range()).astype(np.float32)]) + Q_Xf = nf.make_node("Mul", [norm_X, np.asarray(x_qcfg.q_range()).astype(np.float32)]) + Q_Xf = nf.make_node("Add", [Q_Xf, np.asarray(0.5).astype(np.float32)]) + Q_Xf = nf.make_node("Floor", Q_Xf) if onnx_opset_ver < 11: - Q_Xf = nf.make_node('Clip', Q_Xf, {'max':x_qcfg.q_max(), 'min':x_qcfg.q_min()}) + Q_Xf = nf.make_node("Clip", Q_Xf, {"max": x_qcfg.q_max(), "min": x_qcfg.q_min()}) else: # Clip changed min max to inputs in opset 11 - Q_Xf = nf.make_node('Clip', [Q_Xf, np.asarray(x_qcfg.q_min()).astype(np.float32), np.asarray(x_qcfg.q_max()).astype(np.float32)]) - Q_X = nf.make_node('Cast', Q_Xf, {'to':int({np.uint8 : onnx.TensorProto.UINT8, - np.int8 : onnx.TensorProto.INT8, - np.uint16 : onnx.TensorProto.UINT16, - np.int16 : onnx.TensorProto.INT16}[x_qcfg.q_type()])}) + Q_Xf = nf.make_node( + "Clip", + [ + Q_Xf, + np.asarray(x_qcfg.q_min()).astype(np.float32), + np.asarray(x_qcfg.q_max()).astype(np.float32), + ], + ) + Q_X = nf.make_node( + "Cast", + Q_Xf, + { + "to": int( + { + np.uint8: onnx.TensorProto.UINT8, + np.int8: onnx.TensorProto.INT8, + np.uint16: onnx.TensorProto.UINT16, + np.int16: onnx.TensorProto.INT16, + }[x_qcfg.q_type()] + ) + }, + ) if symmetric: Q_X_sum_int32 = None else: - Q_X_sum_int32 = nf.make_node_with_axes('ReduceSum', nf.make_node('Cast', Q_X, {'to':int(onnx.TensorProto.INT32)}), [-1], onnx_opset_ver) + Q_X_sum_int32 = nf.make_node_with_axes( + "ReduceSum", nf.make_node("Cast", Q_X, {"to": int(onnx.TensorProto.INT32)}), [-1], onnx_opset_ver + ) if quantized_inputs is not None: quantized_inputs[quantized_inputs_key] = (scale_X, bias_X, Q_X, Q_X_sum_int32) # MatMulInteger if x_qcfg.q_type_bits() == 8: - Q_Y = nf.make_node('MatMulInteger', [Q_X, qparam]) + Q_Y = nf.make_node("MatMulInteger", [Q_X, qparam]) else: - Q_Y = nf.make_node('MatMulInteger16', [Q_X, qparam]) + Q_Y = nf.make_node("MatMulInteger16", [Q_X, qparam]) Q_Y.domain = "com.microsoft" # Dequantize Y = in_node.output[0] if symmetric: - nf.make_node('Mul', - [nf.make_node('Mul', [step, scale_X]), - nf.make_node('Cast', Q_Y, {'to': int(onnx.TensorProto.FLOAT)})], - output_names=Y) + nf.make_node( + "Mul", + [nf.make_node("Mul", [step, scale_X]), nf.make_node("Cast", Q_Y, {"to": int(onnx.TensorProto.FLOAT)})], + output_names=Y, + ) else: - o0 = nf.make_node('Mul', [nf.make_node('Mul', [step, scale_X]), - nf.make_node('Cast', Q_Y, {'to': int(onnx.TensorProto.FLOAT)})]) - o1 = nf.make_node('Mul', [nf.make_node('Mul', [step, bias_X]), qparam_rowsum]) - o2 = nf.make_node('Mul', [base, nf.make_node('Mul', [scale_X, nf.make_node('Cast', Q_X_sum_int32, {'to':int(onnx.TensorProto.FLOAT)})])]) - o3 = nf.make_node('Mul', [base, nf.make_node('Mul', [bias_X, np.asarray(float(input_dim)).astype(np.float32)])]) - nf.make_node('Sum', [o3, o2, o1, o0], output_names=Y) + o0 = nf.make_node( + "Mul", + [nf.make_node("Mul", [step, scale_X]), nf.make_node("Cast", Q_Y, {"to": int(onnx.TensorProto.FLOAT)})], + ) + o1 = nf.make_node("Mul", [nf.make_node("Mul", [step, bias_X]), qparam_rowsum]) + o2 = nf.make_node( + "Mul", + [ + base, + nf.make_node( + "Mul", [scale_X, nf.make_node("Cast", Q_X_sum_int32, {"to": int(onnx.TensorProto.FLOAT)})] + ), + ], + ) + o3 = nf.make_node( + "Mul", [base, nf.make_node("Mul", [bias_X, np.asarray(float(input_dim)).astype(np.float32)])] + ) + nf.make_node("Sum", [o3, o2, o1, o0], output_names=Y) if update_qcfg_dict: qcfg_dict[in_node.output[0]] = node_qcfg return True + def upgrade_op(nf, in_n): - if in_n.op_type == 'Slice' and len(in_n.input) == 1: + if in_n.op_type == "Slice" and len(in_n.input) == 1: # convert opset9 Slice to opset10 with nf.scoped_prefix(in_n.name) as scoped_prefix: - slice_inputs = [in_n.input[0], - np.asarray(NodeFactory.get_attribute(in_n,'starts')).astype(np.int64), - np.asarray(NodeFactory.get_attribute(in_n,'ends')).astype(np.int64), - np.asarray(NodeFactory.get_attribute(in_n,'axes')).astype(np.int64)] - nf.make_node('Slice', slice_inputs, output_names=list(in_n.output)) + slice_inputs = [ + in_n.input[0], + np.asarray(NodeFactory.get_attribute(in_n, "starts")).astype(np.int64), + np.asarray(NodeFactory.get_attribute(in_n, "ends")).astype(np.int64), + np.asarray(NodeFactory.get_attribute(in_n, "axes")).astype(np.int64), + ] + nf.make_node("Slice", slice_inputs, output_names=list(in_n.output)) return True - elif in_n.op_type == 'TopK' and len(in_n.input) == 1: + elif in_n.op_type == "TopK" and len(in_n.input) == 1: # convert opset1 TopK to opset10 with nf.scoped_prefix(in_n.name) as scoped_prefix: - topk_inputs = [in_n.input[0], - np.asarray([NodeFactory.get_attribute(in_n,'k')]).astype(np.int64)] - nf.make_node('TopK', topk_inputs, {'axis':NodeFactory.get_attribute(in_n,'axis',-1)}, output_names=list(in_n.output)) + topk_inputs = [in_n.input[0], np.asarray([NodeFactory.get_attribute(in_n, "k")]).astype(np.int64)] + nf.make_node( + "TopK", + topk_inputs, + {"axis": NodeFactory.get_attribute(in_n, "axis", -1)}, + output_names=list(in_n.output), + ) return True else: return False + # quantize matmul to MatMulInteger using asymm uint8 -def convert_matmul_model(input_model, output_model, only_for_scan=False, share_input_quantization=False, preset_str='asymm8_param0_input1', qcfg_json=None, export_qcfg_json=None): - preset_qcfgs = {'asymm8_param0_input1' : {'W' : dict(QuantizeConfig(signed=1, reserved_bits=0, type_bits=8)), - 'X' : dict(QuantizeConfig(signed=0, reserved_bits=1, type_bits=8)), - 'Symmetric' : 0}, - 'symm16_param3_input3' : {'W' : dict(QuantizeConfig(signed=1, reserved_bits=3, type_bits=16)), - 'X' : dict(QuantizeConfig(signed=1, reserved_bits=3, type_bits=16)), - 'Symmetric' : 1}} +def convert_matmul_model( + input_model, + output_model, + only_for_scan=False, + share_input_quantization=False, + preset_str="asymm8_param0_input1", + qcfg_json=None, + export_qcfg_json=None, +): + preset_qcfgs = { + "asymm8_param0_input1": { + "W": dict(QuantizeConfig(signed=1, reserved_bits=0, type_bits=8)), + "X": dict(QuantizeConfig(signed=0, reserved_bits=1, type_bits=8)), + "Symmetric": 0, + }, + "symm16_param3_input3": { + "W": dict(QuantizeConfig(signed=1, reserved_bits=3, type_bits=16)), + "X": dict(QuantizeConfig(signed=1, reserved_bits=3, type_bits=16)), + "Symmetric": 1, + }, + } default_qcfg = preset_qcfgs[preset_str] in_mp = onnx.load(input_model) qcfg_dict = {} if qcfg_json and not export_qcfg_json: - with open(qcfg_json, 'r') as f: + with open(qcfg_json, "r") as f: qcfg_dict = json.load(f) out_mp = onnx.ModelProto() out_mp.CopyFrom(in_mp) - out_mp.ir_version = 5 # update ir version to avoid requirement of initializer in graph input - onnx_opset_ver = ensure_opset(out_mp, 10) # bump up to ONNX opset 10, which is required for MatMulInteger - ensure_opset(out_mp, 1, 'com.microsoft') # add MS domain for MatMulInteger16 - out_mp.graph.ClearField('node') + out_mp.ir_version = 5 # update ir version to avoid requirement of initializer in graph input + onnx_opset_ver = ensure_opset(out_mp, 10) # bump up to ONNX opset 10, which is required for MatMulInteger + ensure_opset(out_mp, 1, "com.microsoft") # add MS domain for MatMulInteger16 + out_mp.graph.ClearField("node") nf = NodeFactory(out_mp.graph) - converted_weights = {} # remember MatMul weights that have been converted, in case of sharing - quantized_inputs = {} if share_input_quantization else None # remember quantized inputs that might be able to share between MatMuls + converted_weights = {} # remember MatMul weights that have been converted, in case of sharing + quantized_inputs = ( + {} if share_input_quantization else None + ) # remember quantized inputs that might be able to share between MatMuls for in_n in in_mp.graph.node: if upgrade_op(nf, in_n): continue - if in_n.op_type == 'MatMul' and not only_for_scan: - if quantize_matmul_2d_with_weight(in_n, in_mp.graph, nf, converted_weights, quantized_inputs, qcfg_dict, export_qcfg_json, default_qcfg, onnx_opset_ver): + if in_n.op_type == "MatMul" and not only_for_scan: + if quantize_matmul_2d_with_weight( + in_n, + in_mp.graph, + nf, + converted_weights, + quantized_inputs, + qcfg_dict, + export_qcfg_json, + default_qcfg, + onnx_opset_ver, + ): continue out_n = out_mp.graph.node.add() out_n.CopyFrom(in_n) - if in_n.op_type == 'Scan' or in_n.op_type == 'Loop': - in_subgraph = NodeFactory.get_attribute(in_n, 'body') - out_subgraph = NodeFactory.get_attribute(out_n, 'body') - out_subgraph.ClearField('node') + if in_n.op_type == "Scan" or in_n.op_type == "Loop": + in_subgraph = NodeFactory.get_attribute(in_n, "body") + out_subgraph = NodeFactory.get_attribute(out_n, "body") + out_subgraph.ClearField("node") scan_nf = NodeFactory(out_mp.graph, out_subgraph) - subgraph_quantized_inputs = {} if share_input_quantization else None # remember quantized inputs that might be able to share between MatMuls + subgraph_quantized_inputs = ( + {} if share_input_quantization else None + ) # remember quantized inputs that might be able to share between MatMuls for in_sn in in_subgraph.node: - if in_sn.op_type == 'MatMul': - if quantize_matmul_2d_with_weight(in_sn, in_subgraph, scan_nf, converted_weights, subgraph_quantized_inputs, qcfg_dict, export_qcfg_json, default_qcfg, onnx_opset_ver): + if in_sn.op_type == "MatMul": + if quantize_matmul_2d_with_weight( + in_sn, + in_subgraph, + scan_nf, + converted_weights, + subgraph_quantized_inputs, + qcfg_dict, + export_qcfg_json, + default_qcfg, + onnx_opset_ver, + ): continue if upgrade_op(scan_nf, in_sn): @@ -323,25 +433,55 @@ def convert_matmul_model(input_model, output_model, only_for_scan=False, share_i onnx.save(out_mp, output_model) if export_qcfg_json: - with open(qcfg_json, 'w') as f: + with open(qcfg_json, "w") as f: f.write(json.dumps(qcfg_dict, indent=2)) + def parse_arguments(): - parser = argparse.ArgumentParser() - parser.add_argument('--input', required=True, help='The input model file') - parser.add_argument('--output', required=True, help='The output model file') - parser.add_argument('--default_qcfg', help='The preset of quantization of _param_input', choices=['asymm8_param0_input1', 'symm16_param3_input3'], default='asymm8_param0_input1') - parser.add_argument('--qcfg_json', help='The quantization config json file for read or write.', default=None) - parser.add_argument('--export_qcfg_json', help='If set, write default quantization config to qcfg_json file.', action='store_true', default=False) - parser.add_argument('--only_for_scan', help='If set, apply quantization of MatMul only inside scan', action='store_true', default=False) - parser.add_argument('--share_input_quantization', help='If set, allow input quantization to be shared if the same input is used in multiple MatMul', action='store_true', default=False) - return parser.parse_args() - -if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument("--input", required=True, help="The input model file") + parser.add_argument("--output", required=True, help="The output model file") + parser.add_argument( + "--default_qcfg", + help="The preset of quantization of _param_input", + choices=["asymm8_param0_input1", "symm16_param3_input3"], + default="asymm8_param0_input1", + ) + parser.add_argument("--qcfg_json", help="The quantization config json file for read or write.", default=None) + parser.add_argument( + "--export_qcfg_json", + help="If set, write default quantization config to qcfg_json file.", + action="store_true", + default=False, + ) + parser.add_argument( + "--only_for_scan", + help="If set, apply quantization of MatMul only inside scan", + action="store_true", + default=False, + ) + parser.add_argument( + "--share_input_quantization", + help="If set, allow input quantization to be shared if the same input is used in multiple MatMul", + action="store_true", + default=False, + ) + return parser.parse_args() + + +if __name__ == "__main__": args = parse_arguments() - print('input model: ' + args.input) - print('output model ' + args.output) - print('Quantize MatMul to MatMulInteger...') + print("input model: " + args.input) + print("output model " + args.output) + print("Quantize MatMul to MatMulInteger...") assert not args.export_qcfg_json or args.qcfg_json, "--qcfg_json must be specified when --export_qcfg_json is used" - convert_matmul_model(args.input, args.output, args.only_for_scan, args.share_input_quantization, args.default_qcfg, args.qcfg_json, args.export_qcfg_json) - print('Done!') + convert_matmul_model( + args.input, + args.output, + args.only_for_scan, + args.share_input_quantization, + args.default_qcfg, + args.qcfg_json, + args.export_qcfg_json, + ) + print("Done!") diff --git a/onnxruntime/core/providers/nuphar/scripts/model_tools.py b/onnxruntime/core/providers/nuphar/scripts/model_tools.py index 013a7e971fec8..de2bfe7a6b361 100644 --- a/onnxruntime/core/providers/nuphar/scripts/model_tools.py +++ b/onnxruntime/core/providers/nuphar/scripts/model_tools.py @@ -1,21 +1,25 @@ +import argparse +import copy +import os + import numpy as np -from numpy.testing import assert_array_equal -import onnxruntime as ort import onnx +from numpy.testing import assert_array_equal from onnx import helper -from onnxruntime.nuphar.node_factory import ensure_opset + +import onnxruntime as ort +import onnxruntime.tools.onnxruntime_test as ort_test from onnxruntime.nuphar.model_editor import convert_loop_to_scan_model +from onnxruntime.nuphar.node_factory import ensure_opset from onnxruntime.tools.symbolic_shape_infer import SymbolicShapeInference, get_shape_from_type_proto -import onnxruntime.tools.onnxruntime_test as ort_test -import argparse -import copy -import os + def run_shape_inference(input_model, output_model): in_mp = onnx.load(input_model) in_mp = SymbolicShapeInference.infer_shapes(in_mp, auto_merge=True) onnx.save(in_mp, output_model) + # use this function to make a loop op's output as model output. # it helps to debug data issues when edited model outputs do not match the original model. def extract_loop_outputs_as_model_outputs(model): @@ -29,13 +33,15 @@ def set_op_output_as_model_output(node, graph): break for node in model.graph.node: - if node.op_type == 'Loop': + if node.op_type == "Loop": # for debugging to make scan output as model graph output set_op_output_as_model_output(node, model.graph) + def run_with_ort(model_path, symbolic_dims={}, feeds=None, ort_test_case_dir=None): - _, feeds, outputs = ort_test.run_model(model_path, symbolic_dims=symbolic_dims, - feeds=feeds, override_initializers=False) + _, feeds, outputs = ort_test.run_model( + model_path, symbolic_dims=symbolic_dims, feeds=feeds, override_initializers=False + ) if ort_test_case_dir: model = onnx.load(model_path) @@ -44,61 +50,73 @@ def save_ort_test_case(ort_test_case_dir): if not os.path.exists(ort_test_case_dir): os.makedirs(ort_test_case_dir) - test_data_set_dir = os.path.join(ort_test_case_dir, 'test_data_set_0') + test_data_set_dir = os.path.join(ort_test_case_dir, "test_data_set_0") if not os.path.exists(test_data_set_dir): os.makedirs(test_data_set_dir) - onnx.save(model, os.path.join(ort_test_case_dir, 'model.onnx')) + onnx.save(model, os.path.join(ort_test_case_dir, "model.onnx")) for i, (input_name, input) in enumerate(feeds.items()): - onnx.save_tensor(onnx.numpy_helper.from_array(input, input_name), - os.path.join(test_data_set_dir, 'input_{0}.pb'.format(i))) + onnx.save_tensor( + onnx.numpy_helper.from_array(input, input_name), + os.path.join(test_data_set_dir, "input_{0}.pb".format(i)), + ) output_names = [output.name for output in model.graph.output] output_dict = dict(zip(output_names, outputs)) for i, (output_name, output) in enumerate(output_dict.items()): - onnx.save_tensor(onnx.numpy_helper.from_array(output, output_name), - os.path.join(test_data_set_dir, 'output_{0}.pb'.format(i))) + onnx.save_tensor( + onnx.numpy_helper.from_array(output, output_name), + os.path.join(test_data_set_dir, "output_{0}.pb".format(i)), + ) save_ort_test_case(ort_test_case_dir) return feeds, outputs + def validate_with_ort(input_filename, output_filename, symbolic_dims={}): feeds, loop_output = run_with_ort(input_filename, symbolic_dims=symbolic_dims) _, scan_output = run_with_ort(output_filename, symbolic_dims=symbolic_dims, feeds=feeds) - assert(len(loop_output) == len(scan_output)) + assert len(loop_output) == len(scan_output) for index in range(0, len(loop_output)): assert_array_equal(loop_output[index], scan_output[index]) + def convert_loop_to_scan_and_validate(input_filename, output_filename, symbolic_dims={}): convert_loop_to_scan_model(args.input, args.output) validate_with_ort(args.input, args.output, symbolic_dims=symbolic_dims) + def parse_arguments(): parser = argparse.ArgumentParser() - parser.add_argument('--tool', help='what to do', - choices=['run_shape_inference', - 'run_with_ort', - 'validate_with_ort', - 'convert_loop_to_scan_and_validate']) - - parser.add_argument('--input', help='The input model file', default=None) - parser.add_argument('--output', help='The output model file', default=None) - parser.add_argument('--symbolic_dims', default={}, type=lambda s: dict(x.split("=") for x in s.split(",")), - help='Comma separated name=value pairs for any symbolic dimensions in the model input. ' - 'e.g. --symbolic_dims batch=1,seqlen=5. ' - 'If not provided, the value of 1 will be used for all symbolic dimensions.') - parser.add_argument('--ort_test_case_dir', help='ort test case dir', default=None) + parser.add_argument( + "--tool", + help="what to do", + choices=["run_shape_inference", "run_with_ort", "validate_with_ort", "convert_loop_to_scan_and_validate"], + ) + + parser.add_argument("--input", help="The input model file", default=None) + parser.add_argument("--output", help="The output model file", default=None) + parser.add_argument( + "--symbolic_dims", + default={}, + type=lambda s: dict(x.split("=") for x in s.split(",")), + help="Comma separated name=value pairs for any symbolic dimensions in the model input. " + "e.g. --symbolic_dims batch=1,seqlen=5. " + "If not provided, the value of 1 will be used for all symbolic dimensions.", + ) + parser.add_argument("--ort_test_case_dir", help="ort test case dir", default=None) return parser.parse_args() -if __name__ == '__main__': + +if __name__ == "__main__": args = parse_arguments() - if args.tool == 'run_shape_inference': + if args.tool == "run_shape_inference": run_shape_inference(args.input, args.output) - elif args.tool == 'run_with_ort': + elif args.tool == "run_with_ort": run_with_ort(args.input, symbolic_dims=args.symbolic_dims, ort_test_case_dir=args.ort_test_case_dir) - elif args.tool == 'validate_with_ort': + elif args.tool == "validate_with_ort": validate_with_ort(args.input, args.output, symbolic_dims=args.symbolic_dims) - elif args.tool == 'convert_loop_to_scan_and_validate': + elif args.tool == "convert_loop_to_scan_and_validate": convert_loop_to_scan_and_validate(args.input, args.output, symbolic_dims=args.symbolic_dims) diff --git a/onnxruntime/core/providers/nuphar/scripts/node_factory.py b/onnxruntime/core/providers/nuphar/scripts/node_factory.py index eb6160c79974d..4a4fe13e186ed 100644 --- a/onnxruntime/core/providers/nuphar/scripts/node_factory.py +++ b/onnxruntime/core/providers/nuphar/scripts/node_factory.py @@ -1,19 +1,22 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. +import json +import re + # -*- coding: UTF-8 -*- from enum import Enum -import json + import numpy as np import onnx from onnx import helper, numpy_helper -import re + class NodeFactory: node_count_ = 0 const_count_ = 0 - def __init__(self, main_graph, sub_graph=None, prefix=''): + def __init__(self, main_graph, sub_graph=None, prefix=""): self.graph_ = sub_graph if sub_graph else main_graph self.main_graph_ = main_graph self.name_prefix_ = prefix @@ -91,15 +94,17 @@ def make_value_info(self, node_or_name, data_type, shape=None, usage=None): value_info.CopyFrom(helper.make_tensor_value_info(name, data_type, shape)) - def make_initializer(self, ndarray, name='', in_main_graph=False): + def make_initializer(self, ndarray, name="", in_main_graph=False): new_initializer = (self.main_graph_ if in_main_graph else self.graph_).initializer.add() new_name = name if len(new_name) == 0: already_existed = True while already_existed: - new_name = self.name_prefix_ + '_Const_' + str(NodeFactory.const_count_) + new_name = self.name_prefix_ + "_Const_" + str(NodeFactory.const_count_) NodeFactory.const_count_ = NodeFactory.const_count_ + 1 - already_existed = new_name in [i.name for i in list(self.main_graph_.initializer) + list(self.graph_.initializer)] + already_existed = new_name in [ + i.name for i in list(self.main_graph_.initializer) + list(self.graph_.initializer) + ] new_initializer.CopyFrom(numpy_helper.from_array(ndarray, new_name)) return new_initializer @@ -118,12 +123,12 @@ def make_node(self, op_type, inputs, attributes={}, output_names=None, node=None new_initializer = self.make_initializer(i) input_names.append(new_initializer.name) else: - assert False # unexpected type in input + assert False # unexpected type in input if not node: node = self.graph_.node.add() - name = self.name_prefix_ + op_type + '_' + str(NodeFactory.node_count_) + name = self.name_prefix_ + op_type + "_" + str(NodeFactory.node_count_) NodeFactory.node_count_ = NodeFactory.node_count_ + 1 if not output_names: @@ -134,9 +139,9 @@ def make_node(self, op_type, inputs, attributes={}, output_names=None, node=None # Squeeze/Unsqueeze/ReduceSum changed axes to input[1] in opset 13 def make_node_with_axes(self, op_type, input, axes, onnx_opset_ver, attributes={}, output_names=None): - assert op_type in ['Squeeze', 'Unsqueeze', 'ReduceSum'] + assert op_type in ["Squeeze", "Unsqueeze", "ReduceSum"] if onnx_opset_ver < 13: - attributes.update({'axes':axes}) + attributes.update({"axes": axes}) return self.make_node(op_type, input, attributes=attributes, output_names=output_names) else: axes = np.asarray(axes).astype(np.int64) @@ -149,13 +154,14 @@ def make_node_with_axes(self, op_type, input, axes, onnx_opset_ver, attributes={ # Split changed split to input[1] in opset 13 def make_split_node(self, input, split, onnx_opset_ver, attributes, output_names=None): if onnx_opset_ver < 13: - attributes.update({'split':split}) - return self.make_node('Split', input, attributes=attributes, output_names=output_names) + attributes.update({"split": split}) + return self.make_node("Split", input, attributes=attributes, output_names=output_names) else: split = np.asarray(split).astype(np.int64) - return self.make_node('Split', [input, split], attributes=attributes, output_names=output_names) + return self.make_node("Split", [input, split], attributes=attributes, output_names=output_names) + -def ensure_opset(mp, ver, domains=['onnx', '']): +def ensure_opset(mp, ver, domains=["onnx", ""]): if type(domains) == str: domains = [domains] assert type(domains) == list diff --git a/onnxruntime/core/providers/nuphar/scripts/rnn_benchmark.py b/onnxruntime/core/providers/nuphar/scripts/rnn_benchmark.py index 821a02cbb3dc2..5868ca7ad346a 100644 --- a/onnxruntime/core/providers/nuphar/scripts/rnn_benchmark.py +++ b/onnxruntime/core/providers/nuphar/scripts/rnn_benchmark.py @@ -4,54 +4,107 @@ # -*- coding: UTF-8 -*- import argparse import multiprocessing +import os +from timeit import default_timer as timer + import numpy as np import onnx +from onnx import IR_VERSION, helper, numpy_helper, shape_inference + # use lines below when building ONNX Runtime from source with --enable_pybind -#import sys -#sys.path.append(r'X:\Repos\Lotus\build\Windows\Release\Release') -#sys.path.append('/repos/Lotus/build/Linux/Release') +# import sys +# sys.path.append(r'X:\Repos\Lotus\build\Windows\Release\Release') +# sys.path.append('/repos/Lotus/build/Linux/Release') import onnxruntime -from onnx import helper, numpy_helper -from onnx import shape_inference -from onnx import IR_VERSION -import os -from timeit import default_timer as timer -def generate_model(rnn_type, input_dim, hidden_dim, bidirectional, layers, model_name, batch_one=True, has_seq_len=False, onnx_opset_ver=7): + +def generate_model( + rnn_type, + input_dim, + hidden_dim, + bidirectional, + layers, + model_name, + batch_one=True, + has_seq_len=False, + onnx_opset_ver=7, +): model = onnx.ModelProto() model.ir_version = IR_VERSION - + opset = model.opset_import.add() - opset.domain == 'onnx' + opset.domain == "onnx" opset.version = onnx_opset_ver num_directions = 2 if bidirectional else 1 - X = 'input' - model.graph.input.add().CopyFrom(helper.make_tensor_value_info(X, onnx.TensorProto.FLOAT, ['s', 1 if batch_one else 'b', input_dim])) - model.graph.initializer.add().CopyFrom(numpy_helper.from_array(np.asarray([0, 0, -1], dtype=np.int64), 'shape')) + X = "input" + model.graph.input.add().CopyFrom( + helper.make_tensor_value_info(X, onnx.TensorProto.FLOAT, ["s", 1 if batch_one else "b", input_dim]) + ) + model.graph.initializer.add().CopyFrom(numpy_helper.from_array(np.asarray([0, 0, -1], dtype=np.int64), "shape")) if has_seq_len: - seq_len = 'seq_len' - model.graph.input.add().CopyFrom(helper.make_tensor_value_info(seq_len, onnx.TensorProto.INT32, [1 if batch_one else 'b',])) + seq_len = "seq_len" + model.graph.input.add().CopyFrom( + helper.make_tensor_value_info( + seq_len, + onnx.TensorProto.INT32, + [ + 1 if batch_one else "b", + ], + ) + ) - gates = {'lstm':4, 'gru':3, 'rnn':1}[rnn_type] + gates = {"lstm": 4, "gru": 3, "rnn": 1}[rnn_type] for i in range(layers): - layer_input_dim = (input_dim if i == 0 else hidden_dim * num_directions) - model.graph.initializer.add().CopyFrom(numpy_helper.from_array(np.random.rand(num_directions, gates*hidden_dim, layer_input_dim).astype(np.float32), 'W'+str(i))) - model.graph.initializer.add().CopyFrom(numpy_helper.from_array(np.random.rand(num_directions, gates*hidden_dim, hidden_dim).astype(np.float32), 'R'+str(i))) - model.graph.initializer.add().CopyFrom(numpy_helper.from_array(np.random.rand(num_directions, 2*gates*hidden_dim).astype(np.float32), 'B'+str(i))) - layer_inputs = [X, 'W'+str(i), 'R'+str(i), 'B'+str(i)] + layer_input_dim = input_dim if i == 0 else hidden_dim * num_directions + model.graph.initializer.add().CopyFrom( + numpy_helper.from_array( + np.random.rand(num_directions, gates * hidden_dim, layer_input_dim).astype(np.float32), "W" + str(i) + ) + ) + model.graph.initializer.add().CopyFrom( + numpy_helper.from_array( + np.random.rand(num_directions, gates * hidden_dim, hidden_dim).astype(np.float32), "R" + str(i) + ) + ) + model.graph.initializer.add().CopyFrom( + numpy_helper.from_array( + np.random.rand(num_directions, 2 * gates * hidden_dim).astype(np.float32), "B" + str(i) + ) + ) + layer_inputs = [X, "W" + str(i), "R" + str(i), "B" + str(i)] if has_seq_len: layer_inputs += [seq_len] - layer_outputs = ['layer_output_'+str(i)] - model.graph.node.add().CopyFrom(helper.make_node(rnn_type.upper(), layer_inputs, layer_outputs, rnn_type+str(i), hidden_size=hidden_dim, direction='bidirectional' if bidirectional else 'forward')) - model.graph.node.add().CopyFrom(helper.make_node('Transpose', layer_outputs, ['transposed_output_'+str(i)], 'transpose'+str(i), perm=[0,2,1,3])) - model.graph.node.add().CopyFrom(helper.make_node('Reshape', ['transposed_output_'+str(i), 'shape'], ['reshaped_output_'+str(i)], 'reshape'+str(i))) - X = 'reshaped_output_'+str(i) - model.graph.output.add().CopyFrom(helper.make_tensor_value_info(X, onnx.TensorProto.FLOAT, ['s', 'b', hidden_dim * num_directions])) + layer_outputs = ["layer_output_" + str(i)] + model.graph.node.add().CopyFrom( + helper.make_node( + rnn_type.upper(), + layer_inputs, + layer_outputs, + rnn_type + str(i), + hidden_size=hidden_dim, + direction="bidirectional" if bidirectional else "forward", + ) + ) + model.graph.node.add().CopyFrom( + helper.make_node( + "Transpose", layer_outputs, ["transposed_output_" + str(i)], "transpose" + str(i), perm=[0, 2, 1, 3] + ) + ) + model.graph.node.add().CopyFrom( + helper.make_node( + "Reshape", ["transposed_output_" + str(i), "shape"], ["reshaped_output_" + str(i)], "reshape" + str(i) + ) + ) + X = "reshaped_output_" + str(i) + model.graph.output.add().CopyFrom( + helper.make_tensor_value_info(X, onnx.TensorProto.FLOAT, ["s", "b", hidden_dim * num_directions]) + ) model = shape_inference.infer_shapes(model) onnx.save(model, model_name) + def perf_run(sess, feeds, min_counts=5, min_duration_seconds=10): # warm up sess.run([], feeds) @@ -70,19 +123,23 @@ def perf_run(sess, feeds, min_counts=5, min_duration_seconds=10): run = False return count, (end - start), per_iter_cost + def top_n_avg(per_iter_cost, n): # following the perf test methodology in [timeit](https://docs.python.org/3/library/timeit.html#timeit.Timer.repeat) per_iter_cost.sort() return sum(per_iter_cost[:n]) * 1000 / n + def get_num_threads(): - return os.environ['OMP_NUM_THREADS'] if 'OMP_NUM_THREADS' in os.environ else None + return os.environ["OMP_NUM_THREADS"] if "OMP_NUM_THREADS" in os.environ else None + def set_num_threads(num_threads): if num_threads: - os.environ['OMP_NUM_THREADS'] = str(num_threads) + os.environ["OMP_NUM_THREADS"] = str(num_threads) else: - del os.environ['OMP_NUM_THREADS'] + del os.environ["OMP_NUM_THREADS"] + class ScopedSetNumThreads: def __init__(self, num_threads): @@ -95,117 +152,222 @@ def __enter__(self): def __exit__(self, type, value, tb): set_num_threads(self.saved_num_threads_) -def perf_test(rnn_type, num_threads, input_dim, hidden_dim, bidirectional, layers, seq_len, batch_size, top_n=5, min_duration_seconds=10): - model_name = '{}_i{}_h{}_{}_l{}_{}.onnx'.format(rnn_type, input_dim, hidden_dim, - 'bi' if bidirectional else '', - layers, - 'batched' if batch_size > 1 else 'no_batch') + +def perf_test( + rnn_type, + num_threads, + input_dim, + hidden_dim, + bidirectional, + layers, + seq_len, + batch_size, + top_n=5, + min_duration_seconds=10, +): + model_name = "{}_i{}_h{}_{}_l{}_{}.onnx".format( + rnn_type, + input_dim, + hidden_dim, + "bi" if bidirectional else "", + layers, + "batched" if batch_size > 1 else "no_batch", + ) generate_model(rnn_type, input_dim, hidden_dim, bidirectional, layers, model_name, batch_size == 1) - feeds = {'input':np.random.rand(seq_len, batch_size, input_dim).astype(np.float32)} + feeds = {"input": np.random.rand(seq_len, batch_size, input_dim).astype(np.float32)} # run original model in CPU provider, using all threads # there are some local thread pool inside LSTM/GRU CPU kernel # that cannot be controlled by OMP or intra_op_num_threads - sess = onnxruntime.InferenceSession(model_name, providers=['CPUExecutionProvider']) + sess = onnxruntime.InferenceSession(model_name, providers=["CPUExecutionProvider"]) count, duration, per_iter_cost = perf_run(sess, feeds, min_counts=top_n, min_duration_seconds=min_duration_seconds) avg_rnn = top_n_avg(per_iter_cost, top_n) - print('perf_rnn (with default threads) {}: run for {} iterations, top {} avg {:.3f} ms'.format(model_name, count, top_n, avg_rnn)) + print( + "perf_rnn (with default threads) {}: run for {} iterations, top {} avg {:.3f} ms".format( + model_name, count, top_n, avg_rnn + ) + ) # run converted model in Nuphar, using specified threads with ScopedSetNumThreads(num_threads) as scoped_set_num_threads: # run Scan model converted from original in Nuphar - from .model_editor import convert_to_scan_model from ..tools.symbolic_shape_infer import SymbolicShapeInference - scan_model_name = os.path.splitext(model_name)[0] + '_scan.onnx' + from .model_editor import convert_to_scan_model + + scan_model_name = os.path.splitext(model_name)[0] + "_scan.onnx" convert_to_scan_model(model_name, scan_model_name) # note that symbolic shape inference is needed because model has symbolic batch dim, thus init_state is ConstantOfShape onnx.save(SymbolicShapeInference.infer_shapes(onnx.load(scan_model_name)), scan_model_name) sess = onnxruntime.InferenceSession(scan_model_name, providers=onnxruntime.get_available_providers()) - count, duration, per_iter_cost = perf_run(sess, feeds, min_counts=top_n, min_duration_seconds=min_duration_seconds) + count, duration, per_iter_cost = perf_run( + sess, feeds, min_counts=top_n, min_duration_seconds=min_duration_seconds + ) avg_scan = top_n_avg(per_iter_cost, top_n) - print('perf_scan (with {} threads) {}: run for {} iterations, top {} avg {:.3f} ms'.format(num_threads, scan_model_name, count, top_n, avg_scan)) + print( + "perf_scan (with {} threads) {}: run for {} iterations, top {} avg {:.3f} ms".format( + num_threads, scan_model_name, count, top_n, avg_scan + ) + ) # quantize Scan model to int8 and run in Nuphar from .model_quantizer import convert_matmul_model - int8_model_name = os.path.splitext(model_name)[0] + '_int8.onnx' + + int8_model_name = os.path.splitext(model_name)[0] + "_int8.onnx" convert_matmul_model(scan_model_name, int8_model_name) onnx.save(SymbolicShapeInference.infer_shapes(onnx.load(int8_model_name)), int8_model_name) sess = onnxruntime.InferenceSession(int8_model_name, providers=onnxruntime.get_available_providers()) - count, duration, per_iter_cost = perf_run(sess, feeds, min_counts=top_n, min_duration_seconds=min_duration_seconds) + count, duration, per_iter_cost = perf_run( + sess, feeds, min_counts=top_n, min_duration_seconds=min_duration_seconds + ) avg_int8 = top_n_avg(per_iter_cost, top_n) - print('perf_int8 (with {} threads) {}: run for {} iterations, top {} avg {:.3f} ms'.format(num_threads, int8_model_name, count, top_n, avg_int8)) + print( + "perf_int8 (with {} threads) {}: run for {} iterations, top {} avg {:.3f} ms".format( + num_threads, int8_model_name, count, top_n, avg_int8 + ) + ) return avg_rnn, avg_scan, avg_int8 + def perf_test_auto(auto_file): # generate reports in csv format - with open('single_thread_' + auto_file + '.csv', 'w') as f: - print('single thread test: unidirection 4-layer lstm/gru/rnn with input_dim=128 batch_size=1', file=f) - print('rnn_type,hidden,seq_len,avg_rnn,avg_nuphar_fp,avg_nuphar_int8,speedup_fp,speedup_int8', file=f) - for rnn_type in ['lstm', 'gru', 'rnn']: + with open("single_thread_" + auto_file + ".csv", "w") as f: + print("single thread test: unidirection 4-layer lstm/gru/rnn with input_dim=128 batch_size=1", file=f) + print("rnn_type,hidden,seq_len,avg_rnn,avg_nuphar_fp,avg_nuphar_int8,speedup_fp,speedup_int8", file=f) + for rnn_type in ["lstm", "gru", "rnn"]: for hidden_dim in [32, 128, 1024, 2048]: for seq_len in [1, 16, 32, 64]: avg_rnn, avg_scan, avg_int8 = perf_test(rnn_type, 1, 128, hidden_dim, False, 4, seq_len, 1) - print('{},{},{},{},{},{},{},{}'.format(rnn_type,hidden_dim, seq_len, avg_rnn, avg_scan, avg_int8, avg_rnn/avg_scan, avg_rnn/avg_int8), file=f) + print( + "{},{},{},{},{},{},{},{}".format( + rnn_type, + hidden_dim, + seq_len, + avg_rnn, + avg_scan, + avg_int8, + avg_rnn / avg_scan, + avg_rnn / avg_int8, + ), + file=f, + ) - with open('multi_thread_' + auto_file + '.csv', 'w') as f: - print('multi-thread test: unidirection 4-layer lstm/gru/rnn with input_dim=128 seq_len=32 batch_size=1', file=f) - print('rnn_type,threads,hidden,avg_rnn,avg_nuphar_fp,avg_nuphar_int8,speedup_fp,speedup_int8', file=f) - for rnn_type in ['lstm', 'gru', 'rnn']: + with open("multi_thread_" + auto_file + ".csv", "w") as f: + print("multi-thread test: unidirection 4-layer lstm/gru/rnn with input_dim=128 seq_len=32 batch_size=1", file=f) + print("rnn_type,threads,hidden,avg_rnn,avg_nuphar_fp,avg_nuphar_int8,speedup_fp,speedup_int8", file=f) + for rnn_type in ["lstm", "gru", "rnn"]: for num_threads in [1, 2, 4]: for hidden_dim in [32, 128, 1024, 2048]: - avg_rnn, avg_scan, avg_int8 = perf_test(rnn_type, num_threads, 128, hidden_dim, False, 4, seq_len, 1) - print('{},{},{},{},{},{},{},{}'.format(rnn_type,num_threads, hidden_dim, avg_rnn, avg_scan, avg_int8, avg_rnn/avg_scan, avg_rnn/avg_int8), file=f) + avg_rnn, avg_scan, avg_int8 = perf_test( + rnn_type, num_threads, 128, hidden_dim, False, 4, seq_len, 1 + ) + print( + "{},{},{},{},{},{},{},{}".format( + rnn_type, + num_threads, + hidden_dim, + avg_rnn, + avg_scan, + avg_int8, + avg_rnn / avg_scan, + avg_rnn / avg_int8, + ), + file=f, + ) - with open('batch_single_thread_' + auto_file + '.csv', 'w') as f: - print('single thread test: unidirection 4-layer lstm/gru/rnn with input_dim=128 hidden_dim=1024', file=f) - print('rnn_type,seq_len,batch_size,avg_rnn,avg_nuphar_fp,avg_nuphar_int8,speedup_fp,speedup_int8', file=f) - for rnn_type in ['lstm', 'gru', 'rnn']: + with open("batch_single_thread_" + auto_file + ".csv", "w") as f: + print("single thread test: unidirection 4-layer lstm/gru/rnn with input_dim=128 hidden_dim=1024", file=f) + print("rnn_type,seq_len,batch_size,avg_rnn,avg_nuphar_fp,avg_nuphar_int8,speedup_fp,speedup_int8", file=f) + for rnn_type in ["lstm", "gru", "rnn"]: for seq_len in [1, 16, 32, 64]: for batch_size in [1, 4, 16, 64]: avg_rnn, avg_scan, avg_int8 = perf_test(rnn_type, 1, 128, 1024, False, 4, seq_len, batch_size) - print('{},{},{},{},{},{},{},{}'.format(rnn_type,seq_len, batch_size, avg_rnn, avg_scan, avg_int8, avg_rnn/avg_scan, avg_rnn/avg_int8), file=f) + print( + "{},{},{},{},{},{},{},{}".format( + rnn_type, + seq_len, + batch_size, + avg_rnn, + avg_scan, + avg_int8, + avg_rnn / avg_scan, + avg_rnn / avg_int8, + ), + file=f, + ) - with open('batch_multi_thread_' + auto_file + '.csv', 'w') as f: - print('batch thread test: unidirection 4-layer lstm/gru/rnn with input_dim=128 hidden_dim=1024 seq_len=32', file=f) - print('rnn_type,threads,batch_size,avg_rnn,avg_nuphar_fp,avg_nuphar_int8,speedup_fp,speedup_int8', file=f) - for rnn_type in ['lstm', 'gru', 'rnn']: + with open("batch_multi_thread_" + auto_file + ".csv", "w") as f: + print( + "batch thread test: unidirection 4-layer lstm/gru/rnn with input_dim=128 hidden_dim=1024 seq_len=32", file=f + ) + print("rnn_type,threads,batch_size,avg_rnn,avg_nuphar_fp,avg_nuphar_int8,speedup_fp,speedup_int8", file=f) + for rnn_type in ["lstm", "gru", "rnn"]: for num_threads in [1, 2, 4]: for batch_size in [1, 4, 16, 64]: avg_rnn, avg_scan, avg_int8 = perf_test(rnn_type, num_threads, 128, 1024, False, 4, 32, batch_size) - print('{},{},{},{},{},{},{},{}'.format(rnn_type,num_threads, batch_size, avg_rnn, avg_scan, avg_int8, avg_rnn/avg_scan, avg_rnn/avg_int8), file=f) + print( + "{},{},{},{},{},{},{},{}".format( + rnn_type, + num_threads, + batch_size, + avg_rnn, + avg_scan, + avg_int8, + avg_rnn / avg_scan, + avg_rnn / avg_int8, + ), + file=f, + ) + def parse_arguments(): - parser = argparse.ArgumentParser() - parser.add_argument('--rnn_type', help='Type of rnn, one of lstm/gru/rnn', choices=['lstm', 'gru', 'rnn'], default='lstm') - parser.add_argument('--input_dim', help='Input size of lstm/gru/rnn', type=int, default=128) - parser.add_argument('--hidden_dim', help='Hidden size of lstm/gru/rnn', type=int, default=1024) - parser.add_argument('--bidirectional', help='Use bidirectional', action='store_true', default=False) - parser.add_argument('--layers', help='Number of layers', type=int, default=4) - parser.add_argument('--seq_len', help='Sequence length', type=int, default=32) - parser.add_argument('--batch_size', help='Batch size', type=int, default=1) - parser.add_argument('--num_threads', help='Number of MKL threads', type=int, default=multiprocessing.cpu_count()) - parser.add_argument('--top_n', help='Fastest N samples to compute average time', type=int, default=5) - parser.add_argument('--auto', help='Auto_name (usually CPU type) for auto test to generate (batch_)single|multithread_.csv files', default=None) - return parser.parse_args() - -if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument( + "--rnn_type", help="Type of rnn, one of lstm/gru/rnn", choices=["lstm", "gru", "rnn"], default="lstm" + ) + parser.add_argument("--input_dim", help="Input size of lstm/gru/rnn", type=int, default=128) + parser.add_argument("--hidden_dim", help="Hidden size of lstm/gru/rnn", type=int, default=1024) + parser.add_argument("--bidirectional", help="Use bidirectional", action="store_true", default=False) + parser.add_argument("--layers", help="Number of layers", type=int, default=4) + parser.add_argument("--seq_len", help="Sequence length", type=int, default=32) + parser.add_argument("--batch_size", help="Batch size", type=int, default=1) + parser.add_argument("--num_threads", help="Number of MKL threads", type=int, default=multiprocessing.cpu_count()) + parser.add_argument("--top_n", help="Fastest N samples to compute average time", type=int, default=5) + parser.add_argument( + "--auto", + help="Auto_name (usually CPU type) for auto test to generate (batch_)single|multithread_.csv files", + default=None, + ) + return parser.parse_args() + + +if __name__ == "__main__": args = parse_arguments() if args.auto: perf_test_auto(args.auto) else: - print('Testing model: ', args.rnn_type.upper()) - print(' input_dim: ', args.input_dim) - print(' hidden_dim: ', args.hidden_dim) + print("Testing model: ", args.rnn_type.upper()) + print(" input_dim: ", args.input_dim) + print(" hidden_dim: ", args.hidden_dim) if args.bidirectional: - print(' bidirectional') - print(' layers: ', args.layers) + print(" bidirectional") + print(" layers: ", args.layers) cpu_count = multiprocessing.cpu_count() num_threads = max(min(args.num_threads, cpu_count), 1) - print('Test setup') - print(' cpu_count: ', cpu_count) - print(' num_threads: ', num_threads) - print(' seq_len: ', args.seq_len) - print(' batch_size: ', args.batch_size) - perf_test(args.rnn_type, num_threads, args.input_dim, args.hidden_dim, args.bidirectional, args.layers, args.seq_len, args.batch_size, args.top_n) + print("Test setup") + print(" cpu_count: ", cpu_count) + print(" num_threads: ", num_threads) + print(" seq_len: ", args.seq_len) + print(" batch_size: ", args.batch_size) + perf_test( + args.rnn_type, + num_threads, + args.input_dim, + args.hidden_dim, + args.bidirectional, + args.layers, + args.seq_len, + args.batch_size, + args.top_n, + ) diff --git a/onnxruntime/python/backend/backend.py b/onnxruntime/python/backend/backend.py index fffbd51f0f144..cbe01a630f7ec 100644 --- a/onnxruntime/python/backend/backend.py +++ b/onnxruntime/python/backend/backend.py @@ -5,15 +5,15 @@ """ Implements ONNX's backend API. """ -from onnx import ModelProto -from onnx import helper -from onnx import version -from onnx.checker import check_model +import os +import unittest + +from onnx import ModelProto, helper, version from onnx.backend.base import Backend -from onnxruntime import InferenceSession, SessionOptions, get_device, get_available_providers +from onnx.checker import check_model + +from onnxruntime import InferenceSession, SessionOptions, get_available_providers, get_device from onnxruntime.backend.backend_rep import OnnxRuntimeBackendRep -import unittest -import os class OnnxRuntimeBackend(Backend): @@ -28,7 +28,7 @@ class OnnxRuntimeBackend(Backend): Note: This is not the official Python API. """ # noqa: E501 - allowReleasedOpsetsOnly = bool(os.getenv('ALLOW_RELEASED_ONNX_OPSET_ONLY', '1') == '1') + allowReleasedOpsetsOnly = bool(os.getenv("ALLOW_RELEASED_ONNX_OPSET_ONLY", "1") == "1") @classmethod def is_compatible(cls, model, device=None, **kwargs): @@ -55,22 +55,26 @@ def is_opset_supported(cls, model): """ if cls.allowReleasedOpsetsOnly: for opset in model.opset_import: - domain = opset.domain if opset.domain else 'ai.onnx' + domain = opset.domain if opset.domain else "ai.onnx" try: key = (domain, opset.version) if not (key in helper.OP_SET_ID_VERSION_MAP): - error_message = ("Skipping this test as only released onnx opsets are supported." - "To run this test set env variable ALLOW_RELEASED_ONNX_OPSET_ONLY to 0." - " Got Domain '{0}' version '{1}'.".format(domain, opset.version)) + error_message = ( + "Skipping this test as only released onnx opsets are supported." + "To run this test set env variable ALLOW_RELEASED_ONNX_OPSET_ONLY to 0." + " Got Domain '{0}' version '{1}'.".format(domain, opset.version) + ) return False, error_message except AttributeError: # for some CI pipelines accessing helper.OP_SET_ID_VERSION_MAP # is generating attribute error. TODO investigate the pipelines to # fix this error. Falling back to a simple version check when this error is encountered - if (domain == 'ai.onnx' and opset.version > 12) or (domain == 'ai.ommx.ml' and opset.version > 2): - error_message = ("Skipping this test as only released onnx opsets are supported." - "To run this test set env variable ALLOW_RELEASED_ONNX_OPSET_ONLY to 0." - " Got Domain '{0}' version '{1}'.".format(domain, opset.version)) + if (domain == "ai.onnx" and opset.version > 12) or (domain == "ai.ommx.ml" and opset.version > 2): + error_message = ( + "Skipping this test as only released onnx opsets are supported." + "To run this test set env variable ALLOW_RELEASED_ONNX_OPSET_ONLY to 0." + " Got Domain '{0}' version '{1}'.".format(domain, opset.version) + ) return False, error_message return True, "" @@ -80,8 +84,8 @@ def supports_device(cls, device): Check whether the backend is compiled with particular device support. In particular it's used in the testing suite. """ - if device == 'CUDA': - device = 'GPU' + if device == "CUDA": + device = "GPU" return device in get_device() @classmethod @@ -108,7 +112,7 @@ def prepare(cls, model, device=None, **kwargs): if hasattr(options, k): setattr(options, k, v) - excluded_providers = os.getenv('ORT_ONNX_BACKEND_EXCLUDE_PROVIDERS', default="").split(',') + excluded_providers = os.getenv("ORT_ONNX_BACKEND_EXCLUDE_PROVIDERS", default="").split(",") providers = [x for x in get_available_providers() if (x not in excluded_providers)] inf = InferenceSession(model, sess_options=options, providers=providers) @@ -156,10 +160,10 @@ def run_model(cls, model, inputs, device=None, **kwargs): @classmethod def run_node(cls, node, inputs, device=None, outputs_info=None, **kwargs): - ''' + """ This method is not implemented as it is much more efficient to run a whole model than every node independently. - ''' + """ raise NotImplementedError("It is much more efficient to run a whole model than every node independently.") diff --git a/onnxruntime/python/backend/backend_rep.py b/onnxruntime/python/backend/backend_rep.py index 0041838f9aa26..6dced3aba7f80 100644 --- a/onnxruntime/python/backend/backend_rep.py +++ b/onnxruntime/python/backend/backend_rep.py @@ -5,10 +5,12 @@ """ Implements ONNX's backend API. """ -from onnxruntime import RunOptions -from onnx.backend.base import BackendRep from typing import Any, Tuple +from onnx.backend.base import BackendRep + +from onnxruntime import RunOptions + class OnnxRuntimeBackendRep(BackendRep): """ diff --git a/onnxruntime/python/onnxruntime_collect_build_info.py b/onnxruntime/python/onnxruntime_collect_build_info.py index 4445fb03593af..6cd67938dd0ba 100644 --- a/onnxruntime/python/onnxruntime_collect_build_info.py +++ b/onnxruntime/python/onnxruntime_collect_build_info.py @@ -2,9 +2,9 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. # -------------------------------------------------------------------------- -import warnings import ctypes import sys +import warnings def find_cudart_versions(build_env=False, build_cuda_version=None): @@ -16,16 +16,16 @@ def find_cudart_versions(build_env=False, build_cuda_version=None): # for the above reason, we need find all versions in the environment and # only give warnings if the expected cuda version is not found. # in onnxruntime build environment, we expected only one Cuda version. - if not sys.platform.startswith('linux'): - warnings.warn('find_cudart_versions only works on Linux') + if not sys.platform.startswith("linux"): + warnings.warn("find_cudart_versions only works on Linux") return None cudart_possible_versions = {None, build_cuda_version} def get_cudart_version(find_cudart_version=None): - cudart_lib_filename = 'libcudart.so' + cudart_lib_filename = "libcudart.so" if find_cudart_version: - cudart_lib_filename = cudart_lib_filename + '.' + find_cudart_version + cudart_lib_filename = cudart_lib_filename + "." + find_cudart_version try: cudart = ctypes.CDLL(cudart_lib_filename) @@ -35,14 +35,13 @@ def get_cudart_version(find_cudart_version=None): status = cudart.cudaRuntimeGetVersion(ctypes.byref(version)) if status != 0: return None - except: # noqa + except: # noqa return None return version.value # use set to avoid duplications - cudart_found_versions = { - get_cudart_version(cudart_version) for cudart_version in cudart_possible_versions} + cudart_found_versions = {get_cudart_version(cudart_version) for cudart_version in cudart_possible_versions} # convert to list and remove None return [ver for ver in cudart_found_versions if ver] @@ -50,27 +49,42 @@ def get_cudart_version(find_cudart_version=None): def find_cudnn_supported_cuda_versions(build_env=False): # comments in get_cudart_version apply here - if not sys.platform.startswith('linux'): - warnings.warn('find_cudnn_versions only works on Linux') + if not sys.platform.startswith("linux"): + warnings.warn("find_cudnn_versions only works on Linux") cudnn_possible_versions = {None} if not build_env: # if not in a build environment, there may be more than one installed cudnn. # https://developer.nvidia.com/rdp/cudnn-archive to include all that may support Cuda 10+. - cudnn_possible_versions.update({ - '8.2', - '8.1.1', '8.1.0', - '8.0.5', '8.0.4', '8.0.3', '8.0.2', '8.0.1', - '7.6.5', '7.6.4', '7.6.3', '7.6.2', '7.6.1', '7.6.0', - '7.5.1', '7.5.0', - '7.4.2', '7.4.1', - '7.3.1', '7.3.0', - }) + cudnn_possible_versions.update( + { + "8.2", + "8.1.1", + "8.1.0", + "8.0.5", + "8.0.4", + "8.0.3", + "8.0.2", + "8.0.1", + "7.6.5", + "7.6.4", + "7.6.3", + "7.6.2", + "7.6.1", + "7.6.0", + "7.5.1", + "7.5.0", + "7.4.2", + "7.4.1", + "7.3.1", + "7.3.0", + } + ) def get_cudnn_supported_cuda_version(find_cudnn_version=None): - cudnn_lib_filename = 'libcudnn.so' + cudnn_lib_filename = "libcudnn.so" if find_cudnn_version: - cudnn_lib_filename = cudnn_lib_filename + '.' + find_cudnn_version + cudnn_lib_filename = cudnn_lib_filename + "." + find_cudnn_version # in cudnn.h cudnn version are calculated as: # #define CUDNN_VERSION (CUDNN_MAJOR * 1000 + CUDNN_MINOR * 100 + CUDNN_PATCHLEVEL) @@ -79,7 +93,7 @@ def get_cudnn_supported_cuda_version(find_cudnn_version=None): # cudnn_ver = cudnn.cudnnGetVersion() cuda_ver = cudnn.cudnnGetCudartVersion() return cuda_ver - except: # noqa + except: # noqa return None # use set to avoid duplications diff --git a/onnxruntime/python/onnxruntime_inference_collection.py b/onnxruntime/python/onnxruntime_inference_collection.py index dcdaa4300c776..eff95f95b0aad 100644 --- a/onnxruntime/python/onnxruntime_inference_collection.py +++ b/onnxruntime/python/onnxruntime_inference_collection.py @@ -12,14 +12,14 @@ def get_ort_device_type(device): device_type = device if type(device) is str else device.type.lower() - if device_type == 'cuda': + if device_type == "cuda": return C.OrtDevice.cuda() - elif device_type == 'cpu': + elif device_type == "cpu": return C.OrtDevice.cpu() - elif device_type == 'ort': + elif device_type == "ort": return C.get_ort_device(device.index).device_type() else: - raise Exception('Unsupported device type: ' + device_type) + raise Exception("Unsupported device type: " + device_type) def check_and_normalize_provider_args(providers, provider_options, available_provider_names): @@ -52,8 +52,10 @@ def check_and_normalize_provider_args(providers, provider_options, available_pro def set_provider_options(name, options): if name not in available_provider_names: - warnings.warn("Specified provider '{}' is not in available provider names." - "Available providers: '{}'".format(name, ", ".join(available_provider_names))) + warnings.warn( + "Specified provider '{}' is not in available provider names." + "Available providers: '{}'".format(name, ", ".join(available_provider_names)) + ) if name in provider_name_to_options: warnings.warn("Duplicate provider '{}' encountered, ignoring.".format(name)) @@ -85,8 +87,12 @@ def set_provider_options(name, options): for provider in providers: if isinstance(provider, str): set_provider_options(provider, dict()) - elif isinstance(provider, tuple) and len(provider) == 2 and \ - isinstance(provider[0], str) and isinstance(provider[1], dict): + elif ( + isinstance(provider, tuple) + and len(provider) == 2 + and isinstance(provider[0], str) + and isinstance(provider[1], dict) + ): set_provider_options(provider[0], provider[1]) else: raise ValueError("'providers' values must be either strings or (string, dict) tuples.") @@ -98,6 +104,7 @@ class Session: """ This is the main class used to run a model. """ + def __init__(self): # self._sess is managed by the derived class and relies on bindings from C.InferenceSession @@ -216,6 +223,7 @@ def run_with_ort_values(self, output_names, input_dict_ort_values, run_options=N sess.run([output_name], {input_name: x}) """ + def invoke(sess, output_names, input_dict_ort_values, run_options): input_dict = {} for n, v in input_dict_ort_values.items(): @@ -268,10 +276,10 @@ def io_binding(self): def run_with_iobinding(self, iobinding, run_options=None): """ - Compute the predictions. + Compute the predictions. - :param iobinding: the iobinding object that has graph inputs/outputs bind. - :param run_options: See :class:`onnxruntime.RunOptions`. + :param iobinding: the iobinding object that has graph inputs/outputs bind. + :param run_options: See :class:`onnxruntime.RunOptions`. """ self._sess.run_with_iobinding(iobinding._iobinding, run_options) @@ -280,6 +288,7 @@ class InferenceSession(Session): """ This is the main class used to run a model. """ + def __init__(self, path_or_bytes, sess_options=None, providers=None, provider_options=None, **kwargs): """ :param path_or_bytes: filename or serialized ONNX or ORT format model in a byte string @@ -326,10 +335,10 @@ def __init__(self, path_or_bytes, sess_options=None, providers=None, provider_op self._sess_options = sess_options self._sess_options_initial = sess_options self._enable_fallback = True - self._read_config_from_model = os.environ.get('ORT_LOAD_CONFIG_FROM_MODEL') == '1' + self._read_config_from_model = os.environ.get("ORT_LOAD_CONFIG_FROM_MODEL") == "1" # internal parameters that we don't expect to be used in general so aren't documented - disabled_optimizers = kwargs['disabled_optimizers'] if 'disabled_optimizers' in kwargs else None + disabled_optimizers = kwargs["disabled_optimizers"] if "disabled_optimizers" in kwargs else None try: self._create_inference_session(providers, provider_options, disabled_optimizers) @@ -347,23 +356,25 @@ def _create_inference_session(self, providers, provider_options, disabled_optimi available_providers = C.get_available_providers() # Tensorrt can fall back to CUDA. All others fall back to CPU. - if 'TensorrtExecutionProvider' in available_providers: - self._fallback_providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] - elif 'MIGraphXExecutionProvider' in available_providers: - self._fallback_providers = ['ROCMExecutionProvider', 'CPUExecutionProvider'] + if "TensorrtExecutionProvider" in available_providers: + self._fallback_providers = ["CUDAExecutionProvider", "CPUExecutionProvider"] + elif "MIGraphXExecutionProvider" in available_providers: + self._fallback_providers = ["ROCMExecutionProvider", "CPUExecutionProvider"] else: - self._fallback_providers = ['CPUExecutionProvider'] + self._fallback_providers = ["CPUExecutionProvider"] # validate providers and provider_options before other initialization - providers, provider_options = check_and_normalize_provider_args(providers, - provider_options, - available_providers) + providers, provider_options = check_and_normalize_provider_args( + providers, provider_options, available_providers + ) if providers == [] and len(available_providers) > 1: self.disable_fallback() - raise ValueError("This ORT build has {} enabled. ".format(available_providers) + - "Since ORT 1.9, you are required to explicitly set " + - "the providers parameter when instantiating InferenceSession. For example, " - "onnxruntime.InferenceSession(..., providers={}, ...)".format(available_providers)) + raise ValueError( + "This ORT build has {} enabled. ".format(available_providers) + + "Since ORT 1.9, you are required to explicitly set " + + "the providers parameter when instantiating InferenceSession. For example, " + "onnxruntime.InferenceSession(..., providers={}, ...)".format(available_providers) + ) session_options = self._sess_options if self._sess_options else C.get_default_session_options() if self._model_path: @@ -410,19 +421,20 @@ def _reset_session(self, providers, provider_options): class IOBinding: - ''' + """ This class provides API to bind input/output to a specified device, e.g. GPU. - ''' + """ + def __init__(self, session): self._iobinding = C.SessionIOBinding(session._sess) self._numpy_obj_references = {} def bind_cpu_input(self, name, arr_on_cpu): - ''' + """ bind an input to array on CPU :param name: input name :param arr_on_cpu: input values as a python array on CPU - ''' + """ # Hold a reference to the numpy object as the bound OrtValue is backed # directly by the data buffer of the numpy object and so the numpy object # must be around until this IOBinding instance is around @@ -430,38 +442,53 @@ def bind_cpu_input(self, name, arr_on_cpu): self._iobinding.bind_input(name, arr_on_cpu) def bind_input(self, name, device_type, device_id, element_type, shape, buffer_ptr): - ''' + """ :param name: input name :param device_type: e.g. cpu, cuda :param device_id: device id, e.g. 0 :param element_type: input element type :param shape: input shape :param buffer_ptr: memory pointer to input data - ''' - self._iobinding.bind_input(name, - C.OrtDevice(get_ort_device_type(device_type), C.OrtDevice.default_memory(), - device_id), - element_type, shape, buffer_ptr) + """ + self._iobinding.bind_input( + name, + C.OrtDevice( + get_ort_device_type(device_type), + C.OrtDevice.default_memory(), + device_id, + ), + element_type, + shape, + buffer_ptr, + ) def bind_ortvalue_input(self, name, ortvalue): - ''' + """ :param name: input name :param ortvalue: OrtValue instance to bind - ''' + """ self._iobinding.bind_ortvalue_input(name, ortvalue._ortvalue) def synchronize_inputs(self): self._iobinding.synchronize_inputs() - def bind_output(self, name, device_type='cpu', device_id=0, element_type=None, shape=None, buffer_ptr=None): - ''' + def bind_output( + self, + name, + device_type="cpu", + device_id=0, + element_type=None, + shape=None, + buffer_ptr=None, + ): + """ :param name: output name :param device_type: e.g. cpu, cuda, cpu by default :param device_id: device id, e.g. 0 :param element_type: output element type :param shape: output shape :param buffer_ptr: memory pointer to output data - ''' + """ # Follow the `if` path when the user has not provided any pre-allocated buffer but still # would like to bind an output to a specific device (e.g. cuda). @@ -470,32 +497,44 @@ def bind_output(self, name, device_type='cpu', device_id=0, element_type=None, s # in which case ORT will allocate the memory for the user # (2) The output has a dynamic shape and hence the size of the buffer may not be fixed across runs if buffer_ptr is None: - self._iobinding.bind_output(name, - C.OrtDevice(get_ort_device_type(device_type), C.OrtDevice.default_memory(), - device_id)) + self._iobinding.bind_output( + name, + C.OrtDevice( + get_ort_device_type(device_type), + C.OrtDevice.default_memory(), + device_id, + ), + ) else: if element_type is None or shape is None: raise ValueError("`element_type` and `shape` are to be provided if pre-allocated memory is provided") - self._iobinding.bind_output(name, - C.OrtDevice(get_ort_device_type(device_type), C.OrtDevice.default_memory(), - device_id), - element_type, shape, buffer_ptr) + self._iobinding.bind_output( + name, + C.OrtDevice( + get_ort_device_type(device_type), + C.OrtDevice.default_memory(), + device_id, + ), + element_type, + shape, + buffer_ptr, + ) def bind_ortvalue_output(self, name, ortvalue): - ''' + """ :param name: output name :param ortvalue: OrtValue instance to bind - ''' + """ self._iobinding.bind_ortvalue_output(name, ortvalue._ortvalue) def synchronize_outputs(self): self._iobinding.synchronize_outputs() def get_outputs(self): - ''' + """ Returns the output OrtValues from the Run() that preceded the call. The data buffer of the obtained OrtValues may not reside on CPU memory - ''' + """ returned_ortvalues = [] for ortvalue in self._iobinding.get_outputs(): @@ -504,7 +543,7 @@ def get_outputs(self): return returned_ortvalues def copy_outputs_to_cpu(self): - '''Copy output contents to CPU (if on another device). No-op if already on the CPU.''' + """Copy output contents to CPU (if on another device). No-op if already on the CPU.""" return self._iobinding.copy_outputs_to_cpu() def clear_binding_inputs(self): @@ -515,11 +554,12 @@ def clear_binding_outputs(self): class OrtValue: - ''' + """ A data structure that supports all ONNX data formats (tensors and non-tensors) that allows users to place the data backing these on a device, for example, on a CUDA supported device. This class provides APIs to construct and deal with OrtValues. - ''' + """ + def __init__(self, ortvalue, numpy_obj=None): if isinstance(ortvalue, C.OrtValue): self._ortvalue = ortvalue @@ -528,157 +568,183 @@ def __init__(self, ortvalue, numpy_obj=None): self._numpy_obj = numpy_obj else: # An end user won't hit this error - raise ValueError("`Provided ortvalue` needs to be of type " + - "`onnxruntime.capi.onnxruntime_pybind11_state.OrtValue`") + raise ValueError( + "`Provided ortvalue` needs to be of type " + "`onnxruntime.capi.onnxruntime_pybind11_state.OrtValue`" + ) def _get_c_value(self): return self._ortvalue @staticmethod - def ortvalue_from_numpy(numpy_obj, device_type='cpu', device_id=0): - ''' + def ortvalue_from_numpy(numpy_obj, device_type="cpu", device_id=0): + """ Factory method to construct an OrtValue (which holds a Tensor) from a given Numpy object A copy of the data in the Numpy object is held by the OrtValue only if the device is NOT cpu :param numpy_obj: The Numpy object to construct the OrtValue from :param device_type: e.g. cpu, cuda, cpu by default :param device_id: device id, e.g. 0 - ''' + """ # Hold a reference to the numpy object (if device_type is 'cpu') as the OrtValue # is backed directly by the data buffer of the numpy object and so the numpy object # must be around until this OrtValue instance is around - return OrtValue(C.OrtValue.ortvalue_from_numpy(numpy_obj, C.OrtDevice(get_ort_device_type(device_type), - C.OrtDevice.default_memory(), device_id)), numpy_obj if device_type.lower() == 'cpu' else None) + return OrtValue( + C.OrtValue.ortvalue_from_numpy( + numpy_obj, + C.OrtDevice( + get_ort_device_type(device_type), + C.OrtDevice.default_memory(), + device_id, + ), + ), + numpy_obj if device_type.lower() == "cpu" else None, + ) @staticmethod - def ortvalue_from_shape_and_type(shape=None, element_type=None, device_type='cpu', device_id=0): - ''' + def ortvalue_from_shape_and_type(shape=None, element_type=None, device_type="cpu", device_id=0): + """ Factory method to construct an OrtValue (which holds a Tensor) from given shape and element_type :param shape: List of integers indicating the shape of the OrtValue :param element_type: The data type of the elements in the OrtValue (numpy type) :param device_type: e.g. cpu, cuda, cpu by default :param device_id: device id, e.g. 0 - ''' + """ if shape is None or element_type is None: raise ValueError("`element_type` and `shape` are to be provided if pre-allocated memory is provided") - return OrtValue(C.OrtValue.ortvalue_from_shape_and_type(shape, element_type, - C.OrtDevice(get_ort_device_type(device_type), C.OrtDevice.default_memory(), device_id))) + return OrtValue( + C.OrtValue.ortvalue_from_shape_and_type( + shape, + element_type, + C.OrtDevice( + get_ort_device_type(device_type), + C.OrtDevice.default_memory(), + device_id, + ), + ) + ) @staticmethod def ort_value_from_sparse_tensor(sparse_tensor): - ''' + """ The function will construct an OrtValue instance from a valid SparseTensor The new instance of OrtValue will assume the ownership of sparse_tensor - ''' + """ return OrtValue(C.OrtValue.ort_value_from_sparse_tensor(sparse_tensor._get_c_tensor())) def as_sparse_tensor(self): - ''' + """ The function will return SparseTensor contained in this OrtValue - ''' + """ return SparseTensor(self._ortvalue.as_sparse_tensor()) def data_ptr(self): - ''' + """ Returns the address of the first element in the OrtValue's data buffer - ''' + """ return self._ortvalue.data_ptr() def device_name(self): - ''' + """ Returns the name of the device where the OrtValue's data buffer resides e.g. cpu, cuda - ''' + """ return self._ortvalue.device_name().lower() def shape(self): - ''' + """ Returns the shape of the data in the OrtValue - ''' + """ return self._ortvalue.shape() def data_type(self): - ''' + """ Returns the data type of the data in the OrtValue - ''' + """ return self._ortvalue.data_type() def element_type(self): - ''' + """ Returns the proto type of the data in the OrtValue if the OrtValue is a tensor. - ''' + """ return self._ortvalue.element_type() def has_value(self): - ''' + """ Returns True if the OrtValue corresponding to an optional type contains data, else returns False - ''' + """ return self._ortvalue.has_value() def is_tensor(self): - ''' + """ Returns True if the OrtValue contains a Tensor, else returns False - ''' + """ return self._ortvalue.is_tensor() def is_sparse_tensor(self): - ''' + """ Returns True if the OrtValue contains a SparseTensor, else returns False - ''' + """ return self._ortvalue.is_sparse_tensor() def is_tensor_sequence(self): - ''' + """ Returns True if the OrtValue contains a Tensor Sequence, else returns False - ''' + """ return self._ortvalue.is_tensor_sequence() def numpy(self): - ''' + """ Returns a Numpy object from the OrtValue. Valid only for OrtValues holding Tensors. Throws for OrtValues holding non-Tensors. Use accessors to gain a reference to non-Tensor objects such as SparseTensor - ''' + """ return self._ortvalue.numpy() def update_inplace(self, np_arr): - ''' + """ Update the OrtValue in place with a new Numpy array. The numpy contents are copied over to the device memory backing the OrtValue. It can be used to update the input valuess for an InferenceSession with CUDA graph enabled or other scenarios where the OrtValue needs to be updated while the memory address can not be changed. - ''' + """ self._ortvalue.update_inplace(np_arr) class OrtDevice: - ''' + """ A data structure that exposes the underlying C++ OrtDevice - ''' + """ + def __init__(self, c_ort_device): - ''' + """ Internal constructor - ''' + """ if isinstance(c_ort_device, C.OrtDevice): self._ort_device = c_ort_device else: - raise ValueError("`Provided object` needs to be of type " + - "`onnxruntime.capi.onnxruntime_pybind11_state.OrtDevice`") + raise ValueError( + "`Provided object` needs to be of type " + "`onnxruntime.capi.onnxruntime_pybind11_state.OrtDevice`" + ) def _get_c_device(self): - ''' + """ Internal accessor to underlying object - ''' + """ return self._ort_device @staticmethod def make(ort_device_name, device_id): - return OrtDevice(C.OrtDevice(get_ort_device_type(ort_device_name), - C.OrtDevice.default_memory(), device_id)) + return OrtDevice( + C.OrtDevice( + get_ort_device_type(ort_device_name), + C.OrtDevice.default_memory(), + device_id, + ) + ) def device_id(self): return self._ort_device.device_id() @@ -688,29 +754,31 @@ def device_type(self): class SparseTensor: - ''' + """ A data structure that project the C++ SparseTensor object The class provides API to work with the object. Depending on the format, the class will hold more than one buffer depending on the format - ''' + """ + def __init__(self, sparse_tensor): - ''' + """ Internal constructor - ''' + """ if isinstance(sparse_tensor, C.SparseTensor): self._tensor = sparse_tensor else: # An end user won't hit this error - raise ValueError("`Provided object` needs to be of type " + - "`onnxruntime.capi.onnxruntime_pybind11_state.SparseTensor`") + raise ValueError( + "`Provided object` needs to be of type " + "`onnxruntime.capi.onnxruntime_pybind11_state.SparseTensor`" + ) def _get_c_tensor(self): return self._tensor @staticmethod def sparse_coo_from_numpy(dense_shape, values, coo_indices, ort_device): - ''' + """ Factory method to construct a SparseTensor in COO format from given arguments :param dense_shape: 1-D numpy array(int64) or a python list that contains a dense_shape of the sparse tensor @@ -729,13 +797,14 @@ def sparse_coo_from_numpy(dense_shape, values, coo_indices, ort_device): on GC. The buffers may reside in any storage either CPU or GPU. For strings and objects, it will create a copy of the arrays in CPU memory as ORT does not support those on other devices and their memory can not be mapped. - ''' - return SparseTensor(C.SparseTensor.sparse_coo_from_numpy(dense_shape, values, coo_indices, - ort_device._get_c_device())) + """ + return SparseTensor( + C.SparseTensor.sparse_coo_from_numpy(dense_shape, values, coo_indices, ort_device._get_c_device()) + ) @staticmethod def sparse_csr_from_numpy(dense_shape, values, inner_indices, outer_indices, ort_device): - ''' + """ Factory method to construct a SparseTensor in CSR format from given arguments :param dense_shape: 1-D numpy array(int64) or a python list that contains a dense_shape of the @@ -754,20 +823,27 @@ def sparse_csr_from_numpy(dense_shape, values, inner_indices, outer_indices, ort The buffers may reside in any storage either CPU or GPU. For strings and objects, it will create a copy of the arrays in CPU memory as ORT does not support those on other devices and their memory can not be mapped. - ''' - return SparseTensor(C.SparseTensor.sparse_csr_from_numpy(dense_shape, values, inner_indices, outer_indices, - ort_device._get_c_device())) + """ + return SparseTensor( + C.SparseTensor.sparse_csr_from_numpy( + dense_shape, + values, + inner_indices, + outer_indices, + ort_device._get_c_device(), + ) + ) def values(self): - ''' + """ The method returns a numpy array that is backed by the native memory if the data type is numeric. Otherwise, the returned numpy array that contains copies of the strings. - ''' + """ return self._tensor.values() def as_coo_view(self): - ''' + """ The method will return coo representation of the sparse tensor which will enable querying COO indices. If the instance did not contain COO format, it would throw. You can query coo indices as: @@ -777,11 +853,11 @@ def as_coo_view(self): coo_indices = sparse_tensor.as_coo_view().indices() which will return a numpy array that is backed by the native memory. - ''' + """ return self._tensor.get_coo_data() def as_csrc_view(self): - ''' + """ The method will return CSR(C) representation of the sparse tensor which will enable querying CRS(C) indices. If the instance dit not contain CSR(C) format, it would throw. You can query indices as: @@ -792,11 +868,11 @@ def as_csrc_view(self): outer_ndices = sparse_tensor.as_csrc_view().outer() returning numpy arrays backed by the native memory. - ''' + """ return self._tensor.get_csrc_data() def as_blocksparse_view(self): - ''' + """ The method will return coo representation of the sparse tensor which will enable querying BlockSparse indices. If the instance did not contain BlockSparse format, it would throw. You can query coo indices as: @@ -806,11 +882,11 @@ def as_blocksparse_view(self): block_sparse_indices = sparse_tensor.as_blocksparse_view().indices() which will return a numpy array that is backed by the native memory - ''' + """ return self._tensor.get_blocksparse_data() def to_cuda(self, ort_device): - ''' + """ Returns a copy of this instance on the specified cuda device :param ort_device: with name 'cuda' and valid gpu device id @@ -821,29 +897,29 @@ def to_cuda(self, ort_device): - this instance is already on GPU. Cross GPU copy is not supported - CUDA is not present in this build - if the specified device is not valid - ''' + """ return SparseTensor(self._tensor.to_cuda(ort_device._get_c_device())) def format(self): - ''' + """ Returns a OrtSparseFormat enumeration - ''' + """ return self._tensor.format def dense_shape(self): - ''' + """ Returns a numpy array(int64) containing a dense shape of a sparse tensor - ''' + """ return self._tensor.dense_shape() def data_type(self): - ''' + """ Returns a string data type of the data in the OrtValue - ''' + """ return self._tensor.data_type() def device_name(self): - ''' + """ Returns the name of the device where the SparseTensor data buffers reside e.g. cpu, cuda - ''' + """ return self._tensor.device_name().lower() diff --git a/onnxruntime/python/onnxruntime_validation.py b/onnxruntime/python/onnxruntime_validation.py index f22d20a1b3796..8b313635527ac 100644 --- a/onnxruntime/python/onnxruntime_validation.py +++ b/onnxruntime/python/onnxruntime_validation.py @@ -5,34 +5,36 @@ """ Check OS requirements for ONNX Runtime Python Bindings. """ -import platform import linecache +import platform import warnings def check_distro_info(): - __my_distro__ = '' - __my_distro_ver__ = '' + __my_distro__ = "" + __my_distro_ver__ = "" __my_system__ = platform.system().lower() - __OS_RELEASE_FILE__ = '/etc/os-release' - __LSB_RELEASE_FILE__ = '/etc/lsb-release' + __OS_RELEASE_FILE__ = "/etc/os-release" + __LSB_RELEASE_FILE__ = "/etc/lsb-release" - if __my_system__ == 'windows': + if __my_system__ == "windows": __my_distro__ = __my_system__ __my_distro_ver__ = platform.release().lower() - if __my_distro_ver__ != '10': - warnings.warn('Unsupported Windows version (%s). ONNX Runtime supports Windows 10 and above, only.' % - __my_distro_ver__) - elif __my_system__ == 'linux': - ''' Although the 'platform' python module for getting Distro information works well on standard OS images + if __my_distro_ver__ != "10": + warnings.warn( + "Unsupported Windows version (%s). ONNX Runtime supports Windows 10 and above, only." + % __my_distro_ver__ + ) + elif __my_system__ == "linux": + """Although the 'platform' python module for getting Distro information works well on standard OS images running on real hardware, it is not accurate when running on Azure VMs, Git Bash, Cygwin, etc. The returned values for release and version are unpredictable for virtualized or emulated environments. /etc/os-release and /etc/lsb_release files, on the other hand, are guaranteed to exist and have standard values in all OSes supported by onnxruntime. The former is the current standard file to check OS info and the latter is its predecessor. - ''' + """ # Newer systems have /etc/os-release with relevant distro info __my_distro__ = linecache.getline(__OS_RELEASE_FILE__, 3)[3:-1] __my_distro_ver__ = linecache.getline(__OS_RELEASE_FILE__, 6)[12:-2] @@ -46,16 +48,18 @@ def check_distro_info(): # warn the user ONNX Runtime may not work out of the box __my_distro__ = __my_distro__.lower() __my_distro_ver__ = __my_distro_ver__.lower() - elif __my_system__ == 'darwin': + elif __my_system__ == "darwin": __my_distro__ = __my_system__ __my_distro_ver__ = platform.release().lower() - if int(__my_distro_ver__.split('.')[0]) < 11: - warnings.warn('Unsupported macOS version (%s). ONNX Runtime supports macOS 11.0 or later.' % - (__my_distro_ver__)) + if int(__my_distro_ver__.split(".")[0]) < 11: + warnings.warn( + "Unsupported macOS version (%s). ONNX Runtime supports macOS 11.0 or later." % (__my_distro_ver__) + ) else: - warnings.warn('Unsupported platform (%s). ONNX Runtime supports Linux, macOS and Windows platforms, only.' % - __my_system__) + warnings.warn( + "Unsupported platform (%s). ONNX Runtime supports Linux, macOS and Windows platforms, only." % __my_system__ + ) def validate_build_package_info(): @@ -63,7 +67,8 @@ def validate_build_package_info(): has_ortmodule = False try: - from onnxruntime.training.ortmodule import ORTModule # noqa + from onnxruntime.training.ortmodule import ORTModule # noqa + has_ortmodule = True except ImportError: # ORTModule not present @@ -74,6 +79,7 @@ def validate_build_package_info(): # device version validation and raise the exception after. try: from onnxruntime.training.ortmodule._fallback import ORTModuleInitException + if isinstance(e, ORTModuleInitException): # ORTModule is present but not ready to run yet has_ortmodule = True @@ -84,19 +90,19 @@ def validate_build_package_info(): if not has_ortmodule: import_ortmodule_exception = e - package_name = '' - version = '' - cuda_version = '' + package_name = "" + version = "" + cuda_version = "" if has_ortmodule: try: # collect onnxruntime package name, version, and cuda version - from .build_and_package_info import package_name from .build_and_package_info import __version__ as version + from .build_and_package_info import package_name try: from .build_and_package_info import cuda_version - except: # noqa + except: # noqa pass if cuda_version: @@ -104,29 +110,30 @@ def validate_build_package_info(): # when the build environment has none or multiple libraries installed try: from .build_and_package_info import cudart_version - except: # noqa - warnings.warn('WARNING: failed to get cudart_version from onnxruntime build info.') + except: # noqa + warnings.warn("WARNING: failed to get cudart_version from onnxruntime build info.") cudart_version = None def print_build_package_info(): - warnings.warn('onnxruntime training package info: package_name: %s' % package_name) - warnings.warn('onnxruntime training package info: __version__: %s' % version) - warnings.warn('onnxruntime training package info: cuda_version: %s' % cuda_version) - warnings.warn('onnxruntime build info: cudart_version: %s' % cudart_version) + warnings.warn("onnxruntime training package info: package_name: %s" % package_name) + warnings.warn("onnxruntime training package info: __version__: %s" % version) + warnings.warn("onnxruntime training package info: cuda_version: %s" % cuda_version) + warnings.warn("onnxruntime build info: cudart_version: %s" % cudart_version) # collection cuda library info from current environment. from onnxruntime.capi.onnxruntime_collect_build_info import find_cudart_versions + local_cudart_versions = find_cudart_versions(build_env=False, build_cuda_version=cuda_version) if cudart_version and local_cudart_versions and cudart_version not in local_cudart_versions: print_build_package_info() - warnings.warn('WARNING: failed to find cudart version that matches onnxruntime build info') - warnings.warn('WARNING: found cudart versions: %s' % local_cudart_versions) + warnings.warn("WARNING: failed to find cudart version that matches onnxruntime build info") + warnings.warn("WARNING: found cudart versions: %s" % local_cudart_versions) else: # TODO: rcom pass - except Exception as e: # noqa - warnings.warn('WARNING: failed to collect onnxruntime version and build info') + except Exception as e: # noqa + warnings.warn("WARNING: failed to collect onnxruntime version and build info") print(e) if import_ortmodule_exception: diff --git a/onnxruntime/python/providers/tvm/extend_python_file.py b/onnxruntime/python/providers/tvm/extend_python_file.py index 96beb113d69ec..65902619f8150 100644 --- a/onnxruntime/python/providers/tvm/extend_python_file.py +++ b/onnxruntime/python/providers/tvm/extend_python_file.py @@ -9,9 +9,10 @@ def rewrite_target_file(target): - with open(target, 'a') as f: - f.write(textwrap.dedent( - """ + with open(target, "a") as f: + f.write( + textwrap.dedent( + """ import warnings try: @@ -33,15 +34,21 @@ def rewrite_target_file(target): f"WARNING: Failed to register python functions to work with TVM EP. More details: {e}" ) """ - )) + ) + ) def main(): parser = argparse.ArgumentParser() - parser.add_argument("--target_file", type=str, required=True, help="Path to the file to be expanded.") + parser.add_argument( + "--target_file", + type=str, + required=True, + help="Path to the file to be expanded.", + ) args = parser.parse_args() rewrite_target_file(args.target_file) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/onnxruntime/python/providers/tvm/ort.py b/onnxruntime/python/providers/tvm/ort.py index 25d0b6cd07d0a..d2b690ddc6c35 100644 --- a/onnxruntime/python/providers/tvm/ort.py +++ b/onnxruntime/python/providers/tvm/ort.py @@ -4,17 +4,16 @@ # license information. # -------------------------------------------------------------------------- -import os import collections import copy import logging +import os import onnx import tvm -from tvm import relay, auto_scheduler -from tvm.relay import vm +from tvm import auto_scheduler, autotvm, relay from tvm.contrib import graph_executor -from tvm import autotvm +from tvm.relay import vm log = logging.getLogger("tvm_ep") @@ -23,18 +22,20 @@ @tvm.register_func("tvm_onnx_import_and_compile") -def onnx_compile(model_string, - model_path, - executor, - target, - target_host, - opt_level, - opset, - freeze_params, - input_shapes, - nhwc=False, - tuning_logfile="", - tuning_type=AUTO_TVM_TYPE): +def onnx_compile( + model_string, + model_path, + executor, + target, + target_host, + opt_level, + opset, + freeze_params, + input_shapes, + nhwc=False, + tuning_logfile="", + tuning_type=AUTO_TVM_TYPE, +): def get_tvm_executor(irmod, executor, target, params): if executor == "vm": log.info("Build TVM virtual machine") @@ -47,8 +48,9 @@ def get_tvm_executor(irmod, executor, target, params): log.info("Build TVM graph executor") lib = relay.build(irmod, target=target, params=params) else: - log.error("Executor type {} is unsupported. ".format(executor) + - "Only \"vm\" and \"graph\" types are supported") + log.error( + "Executor type {} is unsupported. ".format(executor) + 'Only "vm" and "graph" types are supported' + ) return None return lib @@ -94,7 +96,7 @@ def get_tvm_executor(irmod, executor, target, params): config={ "relay.backend.use_auto_scheduler": True, "relay.FuseOps.max_depth": 30, - } + }, ): if nhwc: seq = tvm.transform.Sequential( @@ -113,8 +115,10 @@ def get_tvm_executor(irmod, executor, target, params): with autotvm.apply_history_best(tuning_logfile): lib = get_tvm_executor(irmod, executor, tvm_target, params) else: - log.error("Tuning log type {} is unsupported. ".format(tuning_type) + - "Only {} and {} types are supported".format(ANSOR_TYPE, AUTO_TVM_TYPE)) + log.error( + "Tuning log type {} is unsupported. ".format(tuning_type) + + "Only {} and {} types are supported".format(ANSOR_TYPE, AUTO_TVM_TYPE) + ) return None else: with tvm.transform.PassContext(opt_level=opt_level): @@ -129,8 +133,10 @@ def get_tvm_executor(irmod, executor, target, params): elif executor == "graph": m = graph_executor.GraphModule(lib["default"](ctx)) else: - print("ERROR: Executor type {} is unsupported. ".format(executor), - "Only \"vm\" and \"graph\" types are supported") + print( + "ERROR: Executor type {} is unsupported. ".format(executor), + 'Only "vm" and "graph" types are supported', + ) return None return m.module diff --git a/onnxruntime/python/tools/microbench/attention.py b/onnxruntime/python/tools/microbench/attention.py index bc9daae4455c5..dc8291309fc72 100644 --- a/onnxruntime/python/tools/microbench/attention.py +++ b/onnxruntime/python/tools/microbench/attention.py @@ -1,10 +1,11 @@ -#------------------------------------------------------------------------- +# ------------------------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. -#-------------------------------------------------------------------------- +# -------------------------------------------------------------------------- import argparse from dataclasses import dataclass + import numpy as np from benchmark import BenchmarkOp, add_arguments @@ -24,12 +25,21 @@ def __init__(self, args): def create_inputs_outputs(cls, op_param): np.random.seed(0) - input_data = np.random.rand(op_param.batch_size, op_param.seq_len, op_param.hidden_size).astype(op_param.data_type) + input_data = np.random.rand(op_param.batch_size, op_param.seq_len, op_param.hidden_size).astype( + op_param.data_type + ) weight = np.random.rand(op_param.hidden_size, op_param.length).astype(op_param.data_type) bias = np.random.rand(op_param.length).astype(op_param.data_type) mask_index = np.random.rand(op_param.batch_size, op_param.seq_len).astype(np.int32) - output_data = np.random.rand(op_param.batch_size, op_param.seq_len, op_param.hidden_size).astype(op_param.data_type) - inputs = {"INPUT": input_data, "WEIGHT": weight, "BIAS": bias, "MASK_INDEX": mask_index} + output_data = np.random.rand(op_param.batch_size, op_param.seq_len, op_param.hidden_size).astype( + op_param.data_type + ) + inputs = { + "INPUT": input_data, + "WEIGHT": weight, + "BIAS": bias, + "MASK_INDEX": mask_index, + } outputs = {"return_val": output_data} return inputs, outputs diff --git a/onnxruntime/python/tools/microbench/benchmark.py b/onnxruntime/python/tools/microbench/benchmark.py index 86fa98e153146..16ec187136bbb 100644 --- a/onnxruntime/python/tools/microbench/benchmark.py +++ b/onnxruntime/python/tools/microbench/benchmark.py @@ -1,44 +1,67 @@ -#------------------------------------------------------------------------- +# ------------------------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. -#-------------------------------------------------------------------------- +# -------------------------------------------------------------------------- +import logging +import time from abc import ABC, abstractmethod from argparse import ArgumentParser -import logging + import numpy -import onnxruntime as ort -import time import torch +import onnxruntime as ort + logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) def numpy_type(torch_type): - type_map = {torch.float32: numpy.float32, - torch.float16: numpy.float16, - torch.int32: numpy.int32} + type_map = { + torch.float32: numpy.float32, + torch.float16: numpy.float16, + torch.int32: numpy.int32, + } return type_map[torch_type] def add_arguments(parser: ArgumentParser): - parser.add_argument("--provider", required=False, type=str, - choices=["cuda", "rocm", "cpu", None], default=None, - help=("Execution provider to use. By default, a " - "provider is selected in the priority order " - "(cuda|rocm, cpu) depending on availability.")) - parser.add_argument("--precision", required=False, type=str, - choices=["fp16", "fp32"], default="fp16", - help="Number format to use") - parser.add_argument('--profiling', required=False, type=bool, - default=False, help='If enable profiling') + parser.add_argument( + "--provider", + required=False, + type=str, + choices=["cuda", "rocm", "cpu", None], + default=None, + help=( + "Execution provider to use. By default, a " + "provider is selected in the priority order " + "(cuda|rocm, cpu) depending on availability." + ), + ) + parser.add_argument( + "--precision", + required=False, + type=str, + choices=["fp16", "fp32"], + default="fp16", + help="Number format to use", + ) + parser.add_argument( + "--profiling", + required=False, + type=bool, + default=False, + help="If enable profiling", + ) def provider_name(name): - provider_map = {"cuda": "CUDAExecutionProvider", - "rocm": "ROCMExecutionProvider", - "cpu": "CPUExecutionProvider"} + provider_map = { + "cuda": "CUDAExecutionProvider", + "rocm": "ROCMExecutionProvider", + "cpu": "CPUExecutionProvider", + } return provider_map[name] @@ -52,8 +75,7 @@ def get_default_provider(): class Benchmark: def __init__(self, model, inputs, outputs, args): - self.provider = (get_default_provider() if args.provider == None - else provider_name(args.provider)) + self.provider = get_default_provider() if args.provider == None else provider_name(args.provider) logger.info(f"Execution provider: {self.provider}") self.profiling = args.profiling self.model = model @@ -62,43 +84,49 @@ def __init__(self, model, inputs, outputs, args): self.outputs = outputs def create_input_output_tensors(self): - on_gpu = (self.provider == "CUDAExecutionProvider" - or self.provider == "ROCMExecutionProvider") + on_gpu = self.provider == "CUDAExecutionProvider" or self.provider == "ROCMExecutionProvider" device = "cuda" if on_gpu else "cpu" - input_tensors = {name: torch.from_numpy(array).to(device) - for name, array in self.inputs.items()} - output_tensors = {name: torch.from_numpy(array).to(device) - for name, array in self.outputs.items()} + input_tensors = {name: torch.from_numpy(array).to(device) for name, array in self.inputs.items()} + output_tensors = {name: torch.from_numpy(array).to(device) for name, array in self.outputs.items()} return input_tensors, output_tensors @classmethod def create_io_binding(cls, sess, input_tensors, output_tensors): io_binding = sess.io_binding() for name, tensor in input_tensors.items(): - io_binding.bind_input(name, tensor.device.type, 0, - numpy_type(tensor.dtype), tensor.shape, - tensor.data_ptr()) + io_binding.bind_input( + name, + tensor.device.type, + 0, + numpy_type(tensor.dtype), + tensor.shape, + tensor.data_ptr(), + ) for name, tensor in output_tensors.items(): - io_binding.bind_output(name, tensor.device.type, 0, - numpy_type(tensor.dtype), tensor.shape, - tensor.data_ptr()) + io_binding.bind_output( + name, + tensor.device.type, + 0, + numpy_type(tensor.dtype), + tensor.shape, + tensor.data_ptr(), + ) return io_binding def create_session(self): sess_opt = ort.SessionOptions() sess_opt.enable_profiling = self.profiling - sess = ort.InferenceSession(self.model, sess_options=sess_opt, - providers=[self.provider]) + sess = ort.InferenceSession(self.model, sess_options=sess_opt, providers=[self.provider]) return sess def benchmark(self): sess = self.create_session() input_tensors, output_tensors = self.create_input_output_tensors() io_binding = self.create_io_binding(sess, input_tensors, output_tensors) - + # warm up for iter in range(10): - sess.run_with_iobinding(io_binding) + sess.run_with_iobinding(io_binding) # measure max_iters = 100 diff --git a/onnxruntime/python/tools/microbench/cast.py b/onnxruntime/python/tools/microbench/cast.py index d6ae83a236c85..86219a99ac5df 100644 --- a/onnxruntime/python/tools/microbench/cast.py +++ b/onnxruntime/python/tools/microbench/cast.py @@ -1,28 +1,29 @@ -#------------------------------------------------------------------------- +# ------------------------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. -#-------------------------------------------------------------------------- +# -------------------------------------------------------------------------- import argparse from dataclasses import dataclass + import numpy as np from benchmark import BenchmarkOp, add_arguments @dataclass class OpParam: - x : int - y : int - m : int - n : int - input_data_type : type - output_data_type : type + x: int + y: int + m: int + n: int + input_data_type: type + output_data_type: type @dataclass class ModelParam: - token_type_ids_dim0 : int - input_ids_dim1 : int + token_type_ids_dim0: int + input_ids_dim1: int class BenchmarkCast(BenchmarkOp): @@ -38,9 +39,39 @@ def create_inputs_outputs(cls, op_param): return inputs, outputs def add_model_cases(self, mp, model, input_data_type, output_data_type): - self.add_case(OpParam(1, mp.token_type_ids_dim0, mp.input_ids_dim1, 1024, input_data_type, output_data_type), model) - self.add_case(OpParam(1, mp.token_type_ids_dim0, mp.input_ids_dim1, 1, input_data_type, output_data_type), model) - self.add_case(OpParam(16, mp.token_type_ids_dim0, mp.input_ids_dim1, mp.input_ids_dim1, input_data_type, output_data_type), model) + self.add_case( + OpParam( + 1, + mp.token_type_ids_dim0, + mp.input_ids_dim1, + 1024, + input_data_type, + output_data_type, + ), + model, + ) + self.add_case( + OpParam( + 1, + mp.token_type_ids_dim0, + mp.input_ids_dim1, + 1, + input_data_type, + output_data_type, + ), + model, + ) + self.add_case( + OpParam( + 16, + mp.token_type_ids_dim0, + mp.input_ids_dim1, + mp.input_ids_dim1, + input_data_type, + output_data_type, + ), + model, + ) def create_cases(self): model = "models/cast_fp16tofp32.onnx" if self.args.precision == "fp16" else "models/cast_fp32tofp16.onnx" @@ -61,7 +92,7 @@ def create_cases(self): def case_profile(cls, op_param, time): profile = f"(x y m n input_data_type) = ({op_param.x} {op_param.y} {op_param.m} {op_param.n} {op_param.input_data_type}), {time:7.4f} ms" return profile - + def main(): parser = argparse.ArgumentParser() diff --git a/onnxruntime/python/tools/microbench/fast_gelu.py b/onnxruntime/python/tools/microbench/fast_gelu.py index 2d50e256a0642..82f86020a6afd 100644 --- a/onnxruntime/python/tools/microbench/fast_gelu.py +++ b/onnxruntime/python/tools/microbench/fast_gelu.py @@ -1,10 +1,11 @@ -#------------------------------------------------------------------------- +# ------------------------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. -#-------------------------------------------------------------------------- +# -------------------------------------------------------------------------- import argparse from dataclasses import dataclass + import numpy as np from benchmark import BenchmarkOp, add_arguments @@ -43,7 +44,12 @@ def create_cases(self): data_type = np.float16 if self.args.precision == "fp16" else np.float32 # bert-large model_param = ModelParam(1, 384, 1024 * 4, data_type) - op_param = OpParam(model_param.batch_size, model_param.seq_len, model_param.inter_dim, model_param.data_type) + op_param = OpParam( + model_param.batch_size, + model_param.seq_len, + model_param.inter_dim, + model_param.data_type, + ) self.add_case(op_param, model) def case_profile(cls, op_param, time): diff --git a/onnxruntime/python/tools/microbench/matmul.py b/onnxruntime/python/tools/microbench/matmul.py index 1de45ee5c75b3..cdac59cbbf7e2 100644 --- a/onnxruntime/python/tools/microbench/matmul.py +++ b/onnxruntime/python/tools/microbench/matmul.py @@ -1,10 +1,11 @@ -#------------------------------------------------------------------------- +# ------------------------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. -#-------------------------------------------------------------------------- +# -------------------------------------------------------------------------- import argparse from dataclasses import dataclass + import numpy as np from benchmark import BenchmarkOp, add_arguments @@ -43,10 +44,36 @@ def create_inputs_outputs(cls, op_param): return inputs, outputs def add_model_cases(self, mp, model): - self.add_case(OpParam(1, mp.batch_size, mp.seq_len, mp.hidden_size, mp.hidden_size, mp.data_type), model) - self.add_case(OpParam(1, mp.batch_size, mp.seq_len, mp.inter_dim, mp.hidden_size, mp.data_type), model) - self.add_case(OpParam(1, mp.batch_size, mp.seq_len, mp.hidden_size, mp.inter_dim, mp.data_type), model) - self.add_case(OpParam(mp.batch_size, mp.num_heads, mp.seq_len, mp.seq_len, int(mp.hidden_size / mp.num_heads), mp.data_type), model) + self.add_case( + OpParam( + 1, + mp.batch_size, + mp.seq_len, + mp.hidden_size, + mp.hidden_size, + mp.data_type, + ), + model, + ) + self.add_case( + OpParam(1, mp.batch_size, mp.seq_len, mp.inter_dim, mp.hidden_size, mp.data_type), + model, + ) + self.add_case( + OpParam(1, mp.batch_size, mp.seq_len, mp.hidden_size, mp.inter_dim, mp.data_type), + model, + ) + self.add_case( + OpParam( + mp.batch_size, + mp.num_heads, + mp.seq_len, + mp.seq_len, + int(mp.hidden_size / mp.num_heads), + mp.data_type, + ), + model, + ) def create_cases(self): model = "models/matmul_fp16.onnx" if self.args.precision == "fp16" else "models/matmul_fp32.onnx" diff --git a/onnxruntime/python/tools/microbench/skip_layer_norm.py b/onnxruntime/python/tools/microbench/skip_layer_norm.py index b6f8c5f9e15e0..dbfda7ef301af 100644 --- a/onnxruntime/python/tools/microbench/skip_layer_norm.py +++ b/onnxruntime/python/tools/microbench/skip_layer_norm.py @@ -1,10 +1,11 @@ -#------------------------------------------------------------------------- +# ------------------------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. -#-------------------------------------------------------------------------- +# -------------------------------------------------------------------------- import argparse from dataclasses import dataclass + import numpy as np from benchmark import BenchmarkOp, add_arguments @@ -23,20 +24,32 @@ def __init__(self, args): def create_inputs_outputs(cls, op_param): np.random.seed(0) - input_data = np.random.rand(op_param.batch_size, op_param.seq_len, op_param.hidden_size).astype(op_param.data_type) + input_data = np.random.rand(op_param.batch_size, op_param.seq_len, op_param.hidden_size).astype( + op_param.data_type + ) skip = np.random.rand(op_param.batch_size, op_param.seq_len, op_param.hidden_size).astype(op_param.data_type) gamma = np.random.rand(op_param.hidden_size).astype(op_param.data_type) beta = np.random.rand(op_param.hidden_size).astype(op_param.data_type) bias = np.random.rand(op_param.hidden_size).astype(op_param.data_type) - output_data = np.random.rand(op_param.batch_size, op_param.seq_len, op_param.hidden_size).astype(op_param.data_type) - - inputs = {"INPUT": input_data, "SKIP": skip, "GAMMA": gamma, "BETA": beta, "BIAS": bias} + output_data = np.random.rand(op_param.batch_size, op_param.seq_len, op_param.hidden_size).astype( + op_param.data_type + ) + + inputs = { + "INPUT": input_data, + "SKIP": skip, + "GAMMA": gamma, + "BETA": beta, + "BIAS": bias, + } outputs = {"return_val": output_data} - + return inputs, outputs def create_cases(self): - model = "models/skip_layer_norm_fp16.onnx" if self.args.precision == "fp16" else "models/skip_layer_norm_fp32.onnx" + model = ( + "models/skip_layer_norm_fp16.onnx" if self.args.precision == "fp16" else "models/skip_layer_norm_fp32.onnx" + ) data_type = np.float16 if self.args.precision == "fp16" else np.float32 # bert-large op_param = OpParam(1, 384, 1024, data_type) diff --git a/onnxruntime/python/tools/onnx_randomizer.py b/onnxruntime/python/tools/onnx_randomizer.py index f4b736fa0a5be..08c39b9a72976 100644 --- a/onnxruntime/python/tools/onnx_randomizer.py +++ b/onnxruntime/python/tools/onnx_randomizer.py @@ -1,21 +1,23 @@ -#------------------------------------------------------------------------- +# ------------------------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. -#-------------------------------------------------------------------------- +# -------------------------------------------------------------------------- # An offline standalone script to declassify an ONNX model by randomizing the tensor data in initializers. # The ORT Performance may change especially on generative models. import argparse -import numpy as np -from onnx import onnx_pb, numpy_helper, save_model, load_model from pathlib import Path +import numpy as np +from onnx import load_model, numpy_helper, onnx_pb, save_model + # An experimental small value for differentiating shape data and weights. # The tensor data with larger size can't be shape data. # User may adjust this value as needed. SIZE_THRESHOLD = 10 + def graph_iterator(model, func): graph_queue = [model.graph] while graph_queue: @@ -24,11 +26,11 @@ def graph_iterator(model, func): for node in graph.node: for attr in node.attribute: if attr.type == onnx_pb.AttributeProto.AttributeType.GRAPH: - assert (isinstance(attr.g, onnx_pb.GraphProto)) + assert isinstance(attr.g, onnx_pb.GraphProto) graph_queue.append(attr.g) if attr.type == onnx_pb.AttributeProto.AttributeType.GRAPHS: for g in attr.graphs: - assert (isinstance(g, onnx_pb.GraphProto)) + assert isinstance(g, onnx_pb.GraphProto) graph_queue.append(g) @@ -37,54 +39,47 @@ def randomize_graph_initializer(graph): array = numpy_helper.to_array(i_tensor) # TODO: need to find a better way to differentiate shape data and weights. if array.size > SIZE_THRESHOLD: - random_array = np.random.uniform(array.min(), - array.max(), - size=array.shape).astype( - array.dtype) + random_array = np.random.uniform(array.min(), array.max(), size=array.shape).astype(array.dtype) o_tensor = numpy_helper.from_array(random_array, i_tensor.name) i_tensor.CopyFrom(o_tensor) def main(): - parser = argparse.ArgumentParser( - description='Randomize the weights of an ONNX model') - parser.add_argument('-m', - type=str, - required=True, - help='input onnx model path') - parser.add_argument('-o', - type=str, - required=True, - help='output onnx model path') - parser.add_argument("--use_external_data_format", - required=False, - action="store_true", - help="Store or Save in external data format") - parser.add_argument("--all_tensors_to_one_file", - required=False, - action="store_true", - help="Save all tensors to one file") + parser = argparse.ArgumentParser(description="Randomize the weights of an ONNX model") + parser.add_argument("-m", type=str, required=True, help="input onnx model path") + parser.add_argument("-o", type=str, required=True, help="output onnx model path") + parser.add_argument( + "--use_external_data_format", + required=False, + action="store_true", + help="Store or Save in external data format", + ) + parser.add_argument( + "--all_tensors_to_one_file", + required=False, + action="store_true", + help="Save all tensors to one file", + ) args = parser.parse_args() data_path = None if args.use_external_data_format: if Path(args.m).parent == Path(args.o).parent: - raise RuntimeError( - "Please specify output directory with different parent path to input directory." - ) + raise RuntimeError("Please specify output directory with different parent path to input directory.") if args.all_tensors_to_one_file: data_path = Path(args.o).name + ".data" Path(args.o).parent.mkdir(parents=True, exist_ok=True) - onnx_model = load_model(args.m, - load_external_data=args.use_external_data_format) + onnx_model = load_model(args.m, load_external_data=args.use_external_data_format) graph_iterator(onnx_model, randomize_graph_initializer) - save_model(onnx_model, - args.o, - save_as_external_data=args.use_external_data_format, - all_tensors_to_one_file=args.all_tensors_to_one_file, - location=data_path) + save_model( + onnx_model, + args.o, + save_as_external_data=args.use_external_data_format, + all_tensors_to_one_file=args.all_tensors_to_one_file, + location=data_path, + ) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/onnxruntime/python/tools/onnxruntime_test.py b/onnxruntime/python/tools/onnxruntime_test.py index 0d4cc22be3d96..11759f3ad17d5 100644 --- a/onnxruntime/python/tools/onnxruntime_test.py +++ b/onnxruntime/python/tools/onnxruntime_test.py @@ -1,27 +1,34 @@ -#------------------------------------------------------------------------- +# ------------------------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. -#-------------------------------------------------------------------------- +# -------------------------------------------------------------------------- import argparse -import onnxruntime as onnxrt -import numpy as np import os import sys from timeit import default_timer as timer -float_dict = {'tensor(float16)': 'float16', 'tensor(float)': 'float32', 'tensor(double)': 'float64'} +import numpy as np + +import onnxruntime as onnxrt + +float_dict = { + "tensor(float16)": "float16", + "tensor(float)": "float32", + "tensor(double)": "float64", +} integer_dict = { - 'tensor(int32)': 'int32', - 'tensor(int8)': 'int8', - 'tensor(uint8)': 'uint8', - 'tensor(int16)': 'int16', - 'tensor(uint16)': 'uint16', - 'tensor(int64)': 'int64', - 'tensor(uint64)': 'uint64' + "tensor(int32)": "int32", + "tensor(int8)": "int8", + "tensor(uint8)": "uint8", + "tensor(int16)": "int16", + "tensor(uint16)": "uint16", + "tensor(int64)": "int64", + "tensor(uint64)": "uint64", } + def generate_feeds(sess, symbolic_dims={}): feeds = {} for input_meta in sess.get_inputs(): @@ -43,23 +50,27 @@ def generate_feeds(sess, symbolic_dims={}): if input_meta.type in float_dict: feeds[input_meta.name] = np.random.rand(*shape).astype(float_dict[input_meta.type]) elif input_meta.type in integer_dict: - feeds[input_meta.name] = np.random.uniform(high=1000, - size=tuple(shape)).astype(integer_dict[input_meta.type]) - elif input_meta.type == 'tensor(bool)': - feeds[input_meta.name] = np.random.randint(2, size=tuple(shape)).astype('bool') + feeds[input_meta.name] = np.random.uniform(high=1000, size=tuple(shape)).astype( + integer_dict[input_meta.type] + ) + elif input_meta.type == "tensor(bool)": + feeds[input_meta.name] = np.random.randint(2, size=tuple(shape)).astype("bool") else: print("unsupported input type {} for input {}".format(input_meta.type, input_meta.name)) sys.exit(-1) return feeds + # simple test program for loading onnx model, feeding all inputs and running the model num_iters times. -def run_model(model_path, - num_iters=1, - debug=None, - profile=None, - symbolic_dims={}, - feeds=None, - override_initializers=True): +def run_model( + model_path, + num_iters=1, + debug=None, + profile=None, + symbolic_dims={}, + feeds=None, + override_initializers=True, +): if debug: print("Pausing execution ready for debugger to attach to pid: {}".format(os.getpid())) print("Press key to continue.") @@ -71,7 +82,11 @@ def run_model(model_path, sess_options.enable_profiling = True sess_options.profile_file_prefix = os.path.basename(model_path) - sess = onnxrt.InferenceSession(model_path, sess_options=sess_options, providers=onnxrt.get_available_providers()) + sess = onnxrt.InferenceSession( + model_path, + sess_options=sess_options, + providers=onnxrt.get_available_providers(), + ) meta = sess.get_modelmeta() if not feeds: @@ -86,10 +101,11 @@ def run_model(model_path, if initializer.type in float_dict: feeds[initializer.name] = np.random.rand(*shape).astype(float_dict[initializer.type]) elif initializer.type in integer_dict: - feeds[initializer.name] = np.random.uniform(high=1000, - size=tuple(shape)).astype(integer_dict[initializer.type]) - elif initializer.type == 'tensor(bool)': - feeds[initializer.name] = np.random.randint(2, size=tuple(shape)).astype('bool') + feeds[initializer.name] = np.random.uniform(high=1000, size=tuple(shape)).astype( + integer_dict[initializer.type] + ) + elif initializer.type == "tensor(bool)": + feeds[initializer.name] = np.random.randint(2, size=tuple(shape)).astype("bool") else: print("unsupported initializer type {} for initializer {}".format(initializer.type, initializer.name)) sys.exit(-1) @@ -112,15 +128,29 @@ def run_model(model_path, if __name__ == "__main__": - parser = argparse.ArgumentParser(description='Simple ONNX Runtime Test Tool.') - parser.add_argument('model_path', help='model path') - parser.add_argument('num_iters', nargs='?', type=int, default=1000, help='model run iterations. default=1000') - parser.add_argument('--debug', action='store_true', help='pause execution to allow attaching a debugger.') - parser.add_argument('--profile', action='store_true', help='enable chrome timeline trace profiling.') - parser.add_argument('--symbolic_dims', default={}, type=lambda s: dict(x.split("=") for x in s.split(",")), - help='Comma separated name=value pairs for any symbolic dimensions in the model input. ' - 'e.g. --symbolic_dims batch=1,seqlen=5. ' - 'If not provided, the value of 1 will be used for all symbolic dimensions.') + parser = argparse.ArgumentParser(description="Simple ONNX Runtime Test Tool.") + parser.add_argument("model_path", help="model path") + parser.add_argument( + "num_iters", + nargs="?", + type=int, + default=1000, + help="model run iterations. default=1000", + ) + parser.add_argument( + "--debug", + action="store_true", + help="pause execution to allow attaching a debugger.", + ) + parser.add_argument("--profile", action="store_true", help="enable chrome timeline trace profiling.") + parser.add_argument( + "--symbolic_dims", + default={}, + type=lambda s: dict(x.split("=") for x in s.split(",")), + help="Comma separated name=value pairs for any symbolic dimensions in the model input. " + "e.g. --symbolic_dims batch=1,seqlen=5. " + "If not provided, the value of 1 will be used for all symbolic dimensions.", + ) args = parser.parse_args() exit_code, _, _ = run_model(args.model_path, args.num_iters, args.debug, args.profile, args.symbolic_dims) diff --git a/onnxruntime/python/tools/pytorch_export_contrib_ops.py b/onnxruntime/python/tools/pytorch_export_contrib_ops.py index e3ac06c47d23a..744e90c032eb1 100644 --- a/onnxruntime/python/tools/pytorch_export_contrib_ops.py +++ b/onnxruntime/python/tools/pytorch_export_contrib_ops.py @@ -12,8 +12,8 @@ from torch.onnx import register_custom_op_symbolic except ModuleNotFoundError: raise ModuleNotFoundError( - "This module is only useful in combination with PyTorch. " - "To install PyTorch see https://pytorch.org/.") + "This module is only useful in combination with PyTorch. To install PyTorch see https://pytorch.org/." + ) import torch.onnx.symbolic_helper as sym_help import torch.onnx.symbolic_registry as sym_registry @@ -44,8 +44,8 @@ def grid_sampler(g, input, grid, mode, padding_mode, align_corners): # 'reflection' : onnx::Constant[value={2}] mode = sym_help._maybe_get_const(mode, "i") padding_mode = sym_help._maybe_get_const(padding_mode, "i") - mode_str = ['bilinear', 'nearest', 'bicubic'][mode] - padding_mode_str = ['zeros', 'border', 'reflection'][padding_mode] + mode_str = ["bilinear", "nearest", "bicubic"][mode] + padding_mode_str = ["zeros", "border", "reflection"][padding_mode] align_corners = int(sym_help._maybe_get_const(align_corners, "b")) # From opset v13 onward, the output shape can be specified with @@ -55,28 +55,36 @@ def grid_sampler(g, input, grid, mode, padding_mode, align_corners): # output_shape = input_shape[:2] + gird_shape[1:3] # g.op(...).setType(input.type().with_sizes(output_shape)) - return g.op("com.microsoft::GridSample", input, grid, - mode_s=mode_str, - padding_mode_s=padding_mode_str, - align_corners_i=align_corners) + return g.op( + "com.microsoft::GridSample", + input, + grid, + mode_s=mode_str, + padding_mode_s=padding_mode_str, + align_corners_i=align_corners, + ) + _reg(grid_sampler) def inverse(g, self): return g.op("com.microsoft::Inverse", self).setType(self.type()) + _reg(inverse) def gelu(g, self): return g.op("com.microsoft::Gelu", self).setType(self.type()) + _reg(gelu) def triu(g, self, diagonal): return g.op("com.microsoft::Trilu", self, diagonal, upper_i=1).setType(self.type()) + _reg(triu) def tril(g, self, diagonal): return g.op("com.microsoft::Trilu", self, diagonal, upper_i=0).setType(self.type()) - _reg(tril) + _reg(tril) def unregister(): @@ -86,6 +94,5 @@ def unregister(): for name in _registered_ops: ns, kind = name.split("::") for version in sym_help._onnx_stable_opsets: - if (version >= _OPSET_VERSION and - sym_registry.is_registered_op(kind, ns, version)): + if version >= _OPSET_VERSION and sym_registry.is_registered_op(kind, ns, version): del sym_registry._registry[(ns, version)][kind] diff --git a/onnxruntime/python/tools/quantization/CalTableFlatBuffers/KeyValue.py b/onnxruntime/python/tools/quantization/CalTableFlatBuffers/KeyValue.py index fc5d725db0fdd..ba846b17eecdc 100644 --- a/onnxruntime/python/tools/quantization/CalTableFlatBuffers/KeyValue.py +++ b/onnxruntime/python/tools/quantization/CalTableFlatBuffers/KeyValue.py @@ -4,10 +4,12 @@ import flatbuffers from flatbuffers.compat import import_numpy + np = import_numpy() + class KeyValue(object): - __slots__ = ['_tab'] + __slots__ = ["_tab"] @classmethod def GetRootAs(cls, buf, offset=0): @@ -20,6 +22,7 @@ def GetRootAs(cls, buf, offset=0): def GetRootAsKeyValue(cls, buf, offset=0): """This method is deprecated. Please switch to GetRootAs.""" return cls.GetRootAs(buf, offset) + # KeyValue def Init(self, buf, pos): self._tab = flatbuffers.table.Table(buf, pos) @@ -38,19 +41,38 @@ def Value(self): return self._tab.String(o + self._tab.Pos) return None -def Start(builder): builder.StartObject(2) + +def Start(builder): + builder.StartObject(2) + + def KeyValueStart(builder): """This method is deprecated. Please switch to Start.""" return Start(builder) -def AddKey(builder, key): builder.PrependUOffsetTRelativeSlot(0, flatbuffers.number_types.UOffsetTFlags.py_type(key), 0) + + +def AddKey(builder, key): + builder.PrependUOffsetTRelativeSlot(0, flatbuffers.number_types.UOffsetTFlags.py_type(key), 0) + + def KeyValueAddKey(builder, key): """This method is deprecated. Please switch to AddKey.""" return AddKey(builder, key) -def AddValue(builder, value): builder.PrependUOffsetTRelativeSlot(1, flatbuffers.number_types.UOffsetTFlags.py_type(value), 0) + + +def AddValue(builder, value): + builder.PrependUOffsetTRelativeSlot(1, flatbuffers.number_types.UOffsetTFlags.py_type(value), 0) + + def KeyValueAddValue(builder, value): """This method is deprecated. Please switch to AddValue.""" return AddValue(builder, value) -def End(builder): return builder.EndObject() + + +def End(builder): + return builder.EndObject() + + def KeyValueEnd(builder): """This method is deprecated. Please switch to End.""" - return End(builder) \ No newline at end of file + return End(builder) diff --git a/onnxruntime/python/tools/quantization/CalTableFlatBuffers/TrtTable.py b/onnxruntime/python/tools/quantization/CalTableFlatBuffers/TrtTable.py index b31f4cac1d4c4..cf5202ee3b359 100644 --- a/onnxruntime/python/tools/quantization/CalTableFlatBuffers/TrtTable.py +++ b/onnxruntime/python/tools/quantization/CalTableFlatBuffers/TrtTable.py @@ -4,10 +4,12 @@ import flatbuffers from flatbuffers.compat import import_numpy + np = import_numpy() + class TrtTable(object): - __slots__ = ['_tab'] + __slots__ = ["_tab"] @classmethod def GetRootAs(cls, buf, offset=0): @@ -20,6 +22,7 @@ def GetRootAs(cls, buf, offset=0): def GetRootAsTrtTable(cls, buf, offset=0): """This method is deprecated. Please switch to GetRootAs.""" return cls.GetRootAs(buf, offset) + # TrtTable def Init(self, buf, pos): self._tab = flatbuffers.table.Table(buf, pos) @@ -32,6 +35,7 @@ def Dict(self, j): x += flatbuffers.number_types.UOffsetTFlags.py_type(j) * 4 x = self._tab.Indirect(x) from onnxruntime.quantization.CalTableFlatBuffers.KeyValue import KeyValue + obj = KeyValue() obj.Init(self._tab.Bytes, x) return obj @@ -49,19 +53,38 @@ def DictIsNone(self): o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4)) return o == 0 -def Start(builder): builder.StartObject(1) + +def Start(builder): + builder.StartObject(1) + + def TrtTableStart(builder): """This method is deprecated. Please switch to Start.""" return Start(builder) -def AddDict(builder, dict): builder.PrependUOffsetTRelativeSlot(0, flatbuffers.number_types.UOffsetTFlags.py_type(dict), 0) + + +def AddDict(builder, dict): + builder.PrependUOffsetTRelativeSlot(0, flatbuffers.number_types.UOffsetTFlags.py_type(dict), 0) + + def TrtTableAddDict(builder, dict): """This method is deprecated. Please switch to AddDict.""" return AddDict(builder, dict) -def StartDictVector(builder, numElems): return builder.StartVector(4, numElems, 4) + + +def StartDictVector(builder, numElems): + return builder.StartVector(4, numElems, 4) + + def TrtTableStartDictVector(builder, numElems): """This method is deprecated. Please switch to Start.""" return StartDictVector(builder, numElems) -def End(builder): return builder.EndObject() + + +def End(builder): + return builder.EndObject() + + def TrtTableEnd(builder): """This method is deprecated. Please switch to End.""" - return End(builder) \ No newline at end of file + return End(builder) diff --git a/onnxruntime/python/tools/quantization/__init__.py b/onnxruntime/python/tools/quantization/__init__.py index 0559b6c5f9e1e..c3a7b7e9abfa6 100644 --- a/onnxruntime/python/tools/quantization/__init__.py +++ b/onnxruntime/python/tools/quantization/__init__.py @@ -1,5 +1,4 @@ -from .quantize import quantize_static, quantize_dynamic -from .quantize import QuantizationMode -from .calibrate import CalibrationDataReader, CalibraterBase, MinMaxCalibrater, create_calibrator, CalibrationMethod -from .quant_utils import QuantType, QuantFormat, write_calibration_table +from .calibrate import CalibraterBase, CalibrationDataReader, CalibrationMethod, MinMaxCalibrater, create_calibrator from .qdq_quantizer import QDQQuantizer +from .quant_utils import QuantFormat, QuantType, write_calibration_table +from .quantize import QuantizationMode, quantize_dynamic, quantize_static diff --git a/onnxruntime/python/tools/quantization/calibrate.py b/onnxruntime/python/tools/quantization/calibrate.py index 9d2f1ec3e5c3d..894fbaa97a9da 100644 --- a/onnxruntime/python/tools/quantization/calibrate.py +++ b/onnxruntime/python/tools/quantization/calibrate.py @@ -7,27 +7,37 @@ # -------------------------------------------------------------------------- import abc import itertools -import numpy as np -import onnxruntime +from enum import Enum +from pathlib import Path +import numpy as np import onnx -from onnx import helper, TensorProto, ModelProto +from onnx import ModelProto, TensorProto, helper from onnx import onnx_pb as onnx_proto -from enum import Enum -from pathlib import Path -from .quant_utils import QuantType, model_has_infer_metadata, smooth_distribution, apply_plot, load_model, clone_model_with_shape_infer +import onnxruntime + +from .quant_utils import ( + QuantType, + apply_plot, + clone_model_with_shape_infer, + load_model, + model_has_infer_metadata, + smooth_distribution, +) from .registry import QLinearOpsRegistry + class CalibrationMethod(Enum): MinMax = 0 Entropy = 1 Percentile = 2 + class CalibrationDataReader(metaclass=abc.ABCMeta): @classmethod def __subclasshook__(cls, subclass): - return (hasattr(subclass, 'get_next') and callable(subclass.get_next) or NotImplemented) + return hasattr(subclass, "get_next") and callable(subclass.get_next) or NotImplemented @abc.abstractmethod def get_next(self) -> dict: @@ -36,14 +46,21 @@ def get_next(self) -> dict: class CalibraterBase: - def __init__(self, model, op_types_to_calibrate=[], augmented_model_path='augmented_model.onnx', symmetric=False, use_external_data_format=False): - ''' + def __init__( + self, + model, + op_types_to_calibrate=[], + augmented_model_path="augmented_model.onnx", + symmetric=False, + use_external_data_format=False, + ): + """ :param model: ONNX model to calibrate. It can be a ModelProto or a model path :param op_types_to_calibrate: operator types to calibrate. By default, calibrate all the float32/float16 tensors. :param augmented_model_path: save augmented model to this path. :param symmetric: make range of tensor symmetric (central point is 0). :param use_external_data_format: use external data format to store model which size is >= 2Gb - ''' + """ if isinstance(model, str): self.model = load_model(Path(model), False) elif isinstance(model, Path): @@ -51,7 +68,7 @@ def __init__(self, model, op_types_to_calibrate=[], augmented_model_path='augmen elif isinstance(model, ModelProto): self.model = model else: - raise ValueError('model should be either model path or onnx.ModelProto.') + raise ValueError("model should be either model path or onnx.ModelProto.") self.op_types_to_calibrate = op_types_to_calibrate self.augmented_model_path = augmented_model_path @@ -64,33 +81,35 @@ def __init__(self, model, op_types_to_calibrate=[], augmented_model_path='augmen # Create InferenceSession self.infer_session = None - self.execution_providers = ['CPUExecutionProvider'] + self.execution_providers = ["CPUExecutionProvider"] self._create_inference_session() - def set_execution_providers(self, execution_providers=['CPUExecutionProvider']): - ''' + def set_execution_providers(self, execution_providers=["CPUExecutionProvider"]): + """ reset the execution providers to execute the collect_data. It triggers to re-creating inference session. - ''' + """ self.execution_providers = execution_providers self._create_inference_session() def _create_inference_session(self): - ''' + """ create an OnnxRuntime InferenceSession. - ''' + """ sess_options = onnxruntime.SessionOptions() sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_DISABLE_ALL - self.infer_session = onnxruntime.InferenceSession(self.augmented_model_path, - sess_options=sess_options, - providers=self.execution_providers) + self.infer_session = onnxruntime.InferenceSession( + self.augmented_model_path, + sess_options=sess_options, + providers=self.execution_providers, + ) def select_tensors_to_calibrate(self, model): - ''' - select all quantization_candidates op type nodes' input/output tensors. + """ + select all quantization_candidates op type nodes' input/output tensors. returns: tensors (set): set of tensor name. value_infos (dict): tensor name to value info. - ''' + """ value_infos = {vi.name: vi for vi in model.graph.value_info} value_infos.update({ot.name: ot for ot in model.graph.output}) value_infos.update({it.name: it for it in model.graph.input}) @@ -104,50 +123,54 @@ def select_tensors_to_calibrate(self, model): for tensor_name in itertools.chain(node.input, node.output): if tensor_name in value_infos.keys(): vi = value_infos[tensor_name] - if vi.type.HasField('tensor_type') and ( - vi.type.tensor_type.elem_type in tensor_type_to_calibrate) and ( - tensor_name not in initializer): + if ( + vi.type.HasField("tensor_type") + and (vi.type.tensor_type.elem_type in tensor_type_to_calibrate) + and (tensor_name not in initializer) + ): tensors_to_calibrate.add(tensor_name) return tensors_to_calibrate, value_infos def get_augment_model(self): - ''' + """ return: augmented onnx model - ''' + """ return self.augment_model def augment_graph(self): - ''' + """ abstract method: augment the input model to prepare for collecting data. It will: 1. save augmented model to augmented_model_path. 2. set the self.augment_model - ''' + """ raise NotImplementedError def collect_data(self, data_reader: CalibrationDataReader): - ''' + """ abstract method: collect the tensors that will be used for range computation. It can be called multiple times. - ''' + """ raise NotImplementedError def compute_range(self, data_reader: CalibrationDataReader): - ''' + """ abstract method: compute the [min, max] range for the tensors to calibrate based on the collected data. - ''' + """ raise NotImplementedError class MinMaxCalibrater(CalibraterBase): - def __init__(self, - model, - op_types_to_calibrate=[], - augmented_model_path='augmented_model.onnx', - symmetric=False, - use_external_data_format=False, - moving_average=False, - averaging_constant=0.01): - ''' + def __init__( + self, + model, + op_types_to_calibrate=[], + augmented_model_path="augmented_model.onnx", + symmetric=False, + use_external_data_format=False, + moving_average=False, + averaging_constant=0.01, + ): + """ :param model: ONNX model to calibrate. It can be a ModelProto or a model path :param op_types_to_calibrate: operator types to calibrate. By default, calibrate all the float32/float16 tensors. :param augmented_model_path: save augmented model to this path. @@ -155,8 +178,14 @@ def __init__(self, :param use_external_data_format: use external data format to store model which size is >= 2Gb :param moving_average: compute the moving average of the minimum and maximum values instead of the global minimum and maximum. :param averaging_constant: constant smoothing factor to use when computing the moving average. - ''' - super(MinMaxCalibrater, self).__init__(model, op_types_to_calibrate, augmented_model_path, symmetric, use_external_data_format) + """ + super(MinMaxCalibrater, self).__init__( + model, + op_types_to_calibrate, + augmented_model_path, + symmetric, + use_external_data_format, + ) self.intermediate_outputs = [] self.calibrate_tensors_range = None self.num_model_outputs = len(self.model.graph.output) @@ -167,16 +196,16 @@ def __init__(self, self.averaging_constant = averaging_constant def augment_graph(self): - ''' + """ Adds ReduceMin and ReduceMax nodes to all quantization_candidates op type nodes in model and ensures their outputs are stored as part of the graph output :return: augmented ONNX model - ''' + """ model = clone_model_with_shape_infer(self.model) added_nodes = [] added_outputs = [] - tensors, value_infos = self.select_tensors_to_calibrate(model) + tensors, value_infos = self.select_tensors_to_calibrate(model) for tensor in tensors: @@ -193,22 +222,38 @@ def augment_graph(self): shape = (1,) if len(dim) == 1 else tuple(1 for i in range(len(dim))) # Adding ReduceMin nodes - reduce_min_name = tensor + '_ReduceMin' - reduce_min_node = onnx.helper.make_node('ReduceMin', [tensor], [tensor + '_ReduceMin'], reduce_min_name, keepdims=keepdims) + reduce_min_name = tensor + "_ReduceMin" + reduce_min_node = onnx.helper.make_node( + "ReduceMin", + [tensor], + [tensor + "_ReduceMin"], + reduce_min_name, + keepdims=keepdims, + ) added_nodes.append(reduce_min_node) added_outputs.append(helper.make_tensor_value_info(reduce_min_node.output[0], TensorProto.FLOAT, shape)) # Adding ReduceMax nodes - reduce_max_name = tensor + '_ReduceMax' - reduce_max_node = onnx.helper.make_node('ReduceMax', [tensor], [tensor + '_ReduceMax'], reduce_max_name, keepdims=keepdims) + reduce_max_name = tensor + "_ReduceMax" + reduce_max_node = onnx.helper.make_node( + "ReduceMax", + [tensor], + [tensor + "_ReduceMax"], + reduce_max_name, + keepdims=keepdims, + ) added_nodes.append(reduce_max_node) added_outputs.append(helper.make_tensor_value_info(reduce_max_node.output[0], TensorProto.FLOAT, shape)) model.graph.node.extend(added_nodes) model.graph.output.extend(added_outputs) - onnx.save(model, self.augmented_model_path, save_as_external_data=self.use_external_data_format) + onnx.save( + model, + self.augmented_model_path, + save_as_external_data=self.use_external_data_format, + ) self.augment_model = model def clear_collected_data(self): @@ -231,7 +276,7 @@ def merge_range(self, old_range, new_range): if not old_range: return new_range - for key, value in old_range.items(): + for key, value in old_range.items(): if self.moving_average: min_value = value[0] + self.averaging_constant * (new_range[key][0] - value[0]) max_value = value[1] + self.averaging_constant * (new_range[key][1] - value[1]) @@ -243,10 +288,10 @@ def merge_range(self, old_range, new_range): return new_range def compute_range(self): - ''' + """ Compute the min-max range of tensor :return: dictionary mapping: {added node names: (ReduceMin, ReduceMax) pairs } - ''' + """ if len(self.intermediate_outputs) == 0: return self.calibrate_tensors_range @@ -260,21 +305,22 @@ def compute_range(self): for d in output_dicts_list: for k, v in d.items(): merged_output_dict.setdefault(k, []).append(v) - added_output_names = output_names[self.num_model_outputs:] + added_output_names = output_names[self.num_model_outputs :] calibrate_tensor_names = [ - added_output_names[i].rpartition('_')[0] for i in range(0, len(added_output_names), 2) - ] #output names + added_output_names[i].rpartition("_")[0] for i in range(0, len(added_output_names), 2) + ] # output names merged_added_output_dict = dict( - (i, merged_output_dict[i]) for i in merged_output_dict if i not in self.model_original_outputs) + (i, merged_output_dict[i]) for i in merged_output_dict if i not in self.model_original_outputs + ) pairs = [] for i in range(0, len(added_output_names), 2): min_value = 0 max_value = 0 if self.moving_average: - min_value_array = np.mean(merged_added_output_dict[added_output_names[i]], axis = 0) - max_value_array = np.mean(merged_added_output_dict[added_output_names[i + 1]], axis = 0) + min_value_array = np.mean(merged_added_output_dict[added_output_names[i]], axis=0) + max_value_array = np.mean(merged_added_output_dict[added_output_names[i + 1]], axis=0) else: min_value_array = min(merged_added_output_dict[added_output_names[i]]) max_value_array = max(merged_added_output_dict[added_output_names[i + 1]]) @@ -293,22 +339,25 @@ def compute_range(self): if self.calibrate_tensors_range: self.calibrate_tensors_range = self.merge_range(self.calibrate_tensors_range, new_calibrate_tensors_range) else: - self.calibrate_tensors_range = new_calibrate_tensors_range + self.calibrate_tensors_range = new_calibrate_tensors_range return self.calibrate_tensors_range + class HistogramCalibrater(CalibraterBase): - def __init__(self, - model, - op_types_to_calibrate=[], - augmented_model_path='augmented_model.onnx', - use_external_data_format=False, - method='percentile', - symmetric=False, - num_bins=128, - num_quantized_bins=2048, - percentile=99.999): - ''' + def __init__( + self, + model, + op_types_to_calibrate=[], + augmented_model_path="augmented_model.onnx", + use_external_data_format=False, + method="percentile", + symmetric=False, + num_bins=128, + num_quantized_bins=2048, + percentile=99.999, + ): + """ :param model: ONNX model to calibrate. It can be a ModelProto or a model path :param op_types_to_calibrate: operator types to calibrate. By default, calibrate all the float32/float16 tensors. :param augmented_model_path: save augmented model to this path. @@ -318,8 +367,10 @@ def __init__(self, :param num_bins: number of bins to create a new histogram for collecting tensor values. :param num_quantized_bins: number of quantized bins. Default 128. :param percentile: A float number between [0, 100]. Default 99.99. - ''' - super(HistogramCalibrater, self).__init__(model, op_types_to_calibrate, augmented_model_path, use_external_data_format) + """ + super(HistogramCalibrater, self).__init__( + model, op_types_to_calibrate, augmented_model_path, use_external_data_format + ) self.intermediate_outputs = [] self.calibrate_tensors_range = None self.num_model_outputs = len(self.model.graph.output) @@ -332,31 +383,35 @@ def __init__(self, self.percentile = percentile def augment_graph(self): - ''' + """ make all quantization_candidates op type nodes as part of the graph output. :return: augmented ONNX model - ''' + """ model = clone_model_with_shape_infer(self.model) added_nodes = [] added_outputs = [] - tensors, value_infos = self.select_tensors_to_calibrate(model) + tensors, value_infos = self.select_tensors_to_calibrate(model) for tensor in tensors: added_outputs.append(value_infos[tensor]) model.graph.node.extend(added_nodes) model.graph.output.extend(added_outputs) - onnx.save(model, self.augmented_model_path, save_as_external_data=self.use_external_data_format) + onnx.save( + model, + self.augmented_model_path, + save_as_external_data=self.use_external_data_format, + ) self.augment_model = model def clear_collected_data(self): self.intermediate_outputs = [] def collect_data(self, data_reader: CalibrationDataReader): - ''' - Entropy Calibrator collects operators' tensors as well as generates tensor histogram for each operator. - ''' + """ + Entropy Calibrator collects operators' tensors as well as generates tensor histogram for each operator. + """ while True: inputs = data_reader.get_next() if not inputs: @@ -379,36 +434,41 @@ def collect_data(self, data_reader: CalibrationDataReader): clean_merged_dict = dict((i, merged_dict[i]) for i in merged_dict if i not in self.model_original_outputs) if not self.collector: - self.collector = HistogramCollector(method=self.method, - symmetric=self.symmetric, - num_bins=self.num_bins, - num_quantized_bins=self.num_quantized_bins, - percentile=self.percentile) + self.collector = HistogramCollector( + method=self.method, + symmetric=self.symmetric, + num_bins=self.num_bins, + num_quantized_bins=self.num_quantized_bins, + percentile=self.percentile, + ) self.collector.collect(clean_merged_dict) self.clear_collected_data() def compute_range(self): - ''' + """ Compute the min-max range of tensor :return: dictionary mapping: {tensor name: (min value, max value)} - ''' + """ if not self.collector: raise ValueError("No collector created and can't generate calibration data.") return self.collector.compute_collection_result() + class EntropyCalibrater(HistogramCalibrater): - def __init__(self, - model, - op_types_to_calibrate=[], - augmented_model_path='augmented_model.onnx', - use_external_data_format=False, - method='entropy', - symmetric=False, - num_bins=128, - num_quantized_bins=128): - ''' + def __init__( + self, + model, + op_types_to_calibrate=[], + augmented_model_path="augmented_model.onnx", + use_external_data_format=False, + method="entropy", + symmetric=False, + num_bins=128, + num_quantized_bins=128, + ): + """ :param model: ONNX model to calibrate. It can be a ModelProto or a model path :param op_types_to_calibrate: operator types to calibrate. By default, calibrate all the float32/float16 tensors. :param augmented_model_path: save augmented model to this path. @@ -417,21 +477,32 @@ def __init__(self, :param symmetric: make range of tensor symmetric (central point is 0). :param num_bins: number of bins to create a new histogram for collecting tensor values. :param num_quantized_bins: number of quantized bins. Default 128. - ''' - super(EntropyCalibrater, self).__init__(model, op_types_to_calibrate, augmented_model_path, use_external_data_format, - method=method, symmetric=symmetric, num_bins=num_bins, num_quantized_bins=num_quantized_bins) + """ + super(EntropyCalibrater, self).__init__( + model, + op_types_to_calibrate, + augmented_model_path, + use_external_data_format, + method=method, + symmetric=symmetric, + num_bins=num_bins, + num_quantized_bins=num_quantized_bins, + ) + class PercentileCalibrater(HistogramCalibrater): - def __init__(self, - model, - op_types_to_calibrate=[], - augmented_model_path='augmented_model.onnx', - use_external_data_format=False, - method='percentile', - symmetric=False, - num_bins=2048, - percentile=99.999): - ''' + def __init__( + self, + model, + op_types_to_calibrate=[], + augmented_model_path="augmented_model.onnx", + use_external_data_format=False, + method="percentile", + symmetric=False, + num_bins=2048, + percentile=99.999, + ): + """ :param model: ONNX model to calibrate. It can be a ModelProto or a model path :param op_types_to_calibrate: operator types to calibrate. By default, calibrate all the float32/float16 tensors. :param augmented_model_path: save augmented model to this path. @@ -440,9 +511,18 @@ def __init__(self, :param symmetric: make range of tensor symmetric (central point is 0). :param num_quantized_bins: number of quantized bins. Default 128. :param percentile: A float number between [0, 100]. Default 99.99. - ''' - super(PercentileCalibrater, self).__init__(model, op_types_to_calibrate, augmented_model_path, use_external_data_format, - method=method, symmetric=symmetric, num_bins=num_bins, percentile=percentile) + """ + super(PercentileCalibrater, self).__init__( + model, + op_types_to_calibrate, + augmented_model_path, + use_external_data_format, + method=method, + symmetric=symmetric, + num_bins=num_bins, + percentile=percentile, + ) + class CalibrationDataCollector(metaclass=abc.ABCMeta): """ @@ -453,18 +533,19 @@ class CalibrationDataCollector(metaclass=abc.ABCMeta): def collect(self, name_to_arr): """ Generate informative data based on given data. - name_to_arr : dict - tensor name to NDArray data + name_to_arr : dict + tensor name to NDArray data """ raise NotImplementedError @abc.abstractmethod def compute_collection_result(self): """ - Get the optimal result among collection data. + Get the optimal result among collection data. """ raise NotImplementedError + class HistogramCollector(CalibrationDataCollector): """ Collecting histogram for each tensor. Percentile and Entropy method are supported. @@ -473,12 +554,13 @@ class HistogramCollector(CalibrationDataCollector): ref: https://docs.nvidia.com/deeplearning/tensorrt/pytorch-quantization-toolkit/docs/_modules/ pytorch_quantization/calib/histogram.html """ + def __init__(self, method, symmetric, num_bins, num_quantized_bins, percentile): self.histogram_dict = {} self.method = method self.symmetric = symmetric self.num_bins = num_bins - self.num_quantized_bins= num_quantized_bins + self.num_quantized_bins = num_quantized_bins self.percentile = percentile def get_histogram_dict(self): @@ -489,24 +571,24 @@ def collect(self, name_to_arr): # TODO: Currently we have different collect() for entropy and percentile method respectively. # Need unified collect in the future. - if self.method == 'entropy': + if self.method == "entropy": return self.collect_value(name_to_arr) - elif self.method == 'percentile': + elif self.method == "percentile": if self.symmetric: return self.collect_absolute_value(name_to_arr) else: return self.collect_value(name_to_arr) else: - raise ValueError('Only \'entropy\' or \'percentile\' method are supported') + raise ValueError("Only 'entropy' or 'percentile' method are supported") def collect_absolute_value(self, name_to_arr): - ''' + """ Collect histogram on absolute value - ''' + """ for tensor, data_arr in name_to_arr.items(): data_arr = np.asarray(data_arr) data_arr = data_arr.flatten() - data_arr = np.absolute(data_arr) # only consider absolute value + data_arr = np.absolute(data_arr) # only consider absolute value if tensor not in self.histogram_dict: # first time it uses num_bins to compute histogram. @@ -524,13 +606,13 @@ def collect_absolute_value(self, name_to_arr): new_bin_edges = np.arange(old_hist_edges[-1] + width, temp_amax + width, width) old_hist_edges = np.hstack((old_hist_edges, new_bin_edges)) hist, hist_edges = np.histogram(data_arr, bins=old_hist_edges) - hist[:len(old_hist)] += old_hist + hist[: len(old_hist)] += old_hist self.histogram_dict[tensor] = (hist, hist_edges) def collect_value(self, name_to_arr): - ''' + """ Collect histogram on real value - ''' + """ for tensor, data_arr in name_to_arr.items(): data_arr = np.asarray(data_arr) data_arr = data_arr.flatten() @@ -546,10 +628,18 @@ def collect_value(self, name_to_arr): if tensor in self.histogram_dict: old_histogram = self.histogram_dict[tensor] - self.histogram_dict[tensor] = self.merge_histogram(old_histogram, data_arr, min_value, max_value, threshold) + self.histogram_dict[tensor] = self.merge_histogram( + old_histogram, data_arr, min_value, max_value, threshold + ) else: hist, hist_edges = np.histogram(data_arr, self.num_bins, range=(-threshold, threshold)) - self.histogram_dict[tensor] = (hist, hist_edges, min_value, max_value, threshold) + self.histogram_dict[tensor] = ( + hist, + hist_edges, + min_value, + max_value, + threshold, + ) def merge_histogram(self, old_histogram, data_arr, new_min, new_max, new_threshold): @@ -557,7 +647,13 @@ def merge_histogram(self, old_histogram, data_arr, new_min, new_max, new_thresho if new_threshold <= old_threshold: new_hist, _ = np.histogram(data_arr, len(old_hist), range=(-old_threshold, old_threshold)) - return (new_hist + old_hist, old_hist_edges, min(old_min, new_min), max(old_max, new_max), old_threshold) + return ( + new_hist + old_hist, + old_hist_edges, + min(old_min, new_min), + max(old_max, new_max), + old_threshold, + ) else: if old_threshold == 0: hist, hist_edges = np.histogram(data_arr, len(old_hist), range=(-new_threshold, new_threshold)) @@ -565,24 +661,30 @@ def merge_histogram(self, old_histogram, data_arr, new_min, new_max, new_thresho else: old_num_bins = len(old_hist) old_stride = 2 * old_threshold / old_num_bins - half_increased_bins = int((new_threshold - old_threshold) // old_stride + 1) + half_increased_bins = int((new_threshold - old_threshold) // old_stride + 1) new_num_bins = old_num_bins + 2 * half_increased_bins new_threshold = half_increased_bins * old_stride + old_threshold hist, hist_edges = np.histogram(data_arr, new_num_bins, range=(-new_threshold, new_threshold)) - hist[half_increased_bins:new_num_bins-half_increased_bins] += old_hist - return (hist, hist_edges, min(old_min, new_min), max(old_max, new_max), new_threshold) + hist[half_increased_bins : new_num_bins - half_increased_bins] += old_hist + return ( + hist, + hist_edges, + min(old_min, new_min), + max(old_max, new_max), + new_threshold, + ) def compute_collection_result(self): if not self.histogram_dict or len(self.histogram_dict) == 0: raise ValueError("Histogram has not been collected. Please run collect() first.") print("Finding optimal threshold for each tensor using {} algorithm ...".format(self.method)) - if self.method == 'entropy': + if self.method == "entropy": return self.compute_entropy() - elif self.method == 'percentile': + elif self.method == "percentile": return self.compute_percentile() else: - raise ValueError('Only \'entropy\' or \'percentile\' method are supported') + raise ValueError("Only 'entropy' or 'percentile' method are supported") def compute_percentile(self): if self.percentile < 0 or self.percentile > 100: @@ -591,7 +693,7 @@ def compute_percentile(self): histogram_dict = self.histogram_dict percentile = self.percentile - thresholds_dict = {} # per tensor thresholds + thresholds_dict = {} # per tensor thresholds print("Number of tensors : {}".format(len(histogram_dict))) print("Number of histogram bins : {}".format(self.num_bins)) @@ -601,15 +703,21 @@ def compute_percentile(self): hist = histogram[0] hist_edges = histogram[1] total = hist.sum() - cdf = np.cumsum(hist/total) + cdf = np.cumsum(hist / total) if self.symmetric: idx_right = np.searchsorted(cdf, percentile / 100.0) - thresholds_dict[tensor] = (-float(hist_edges[idx_right]), float(hist_edges[idx_right])) + thresholds_dict[tensor] = ( + -float(hist_edges[idx_right]), + float(hist_edges[idx_right]), + ) else: percent_to_cut_one_side = (100.0 - percentile) / 200.0 idx_right = np.searchsorted(cdf, 1.0 - percent_to_cut_one_side) idx_left = np.searchsorted(cdf, percent_to_cut_one_side) - thresholds_dict[tensor] = (float(hist_edges[idx_left]), float(hist_edges[idx_right])) + thresholds_dict[tensor] = ( + float(hist_edges[idx_left]), + float(hist_edges[idx_right]), + ) # Plot histogram for debug only if False: @@ -621,10 +729,14 @@ def compute_entropy(self): histogram_dict = self.histogram_dict num_quantized_bins = self.num_quantized_bins - thresholds_dict = {} # per tensor thresholds + thresholds_dict = {} # per tensor thresholds print("Number of tensors : {}".format(len(histogram_dict))) - print("Number of histogram bins : {} (The number may increase depends on the data it collects)".format(self.num_bins)) + print( + "Number of histogram bins : {} (The number may increase depends on the data it collects)".format( + self.num_bins + ) + ) print("Number of quantized bins : {}".format(self.num_quantized_bins)) for tensor, histogram in histogram_dict.items(): @@ -643,17 +755,18 @@ def get_entropy_threshold(self, histogram, num_quantized_bins): `q` is a truncated version of the original distribution. Ref: http://on-demand.gputechconf.com/gtc/2017/presentation/s7310-8-bit-inference-with-tensorrt.pdf """ - from scipy.stats import entropy import copy + from scipy.stats import entropy + hist = histogram[0] hist_edges = histogram[1] num_bins = hist.size zero_bin_index = num_bins // 2 num_half_quantized_bin = num_quantized_bins // 2 - + kl_divergence = np.zeros(zero_bin_index - num_half_quantized_bin + 1) - thresholds = [(0, 0) for i in range(kl_divergence.size)] + thresholds = [(0, 0) for i in range(kl_divergence.size)] # <------------ num bins ----------------> # <--- quantized bins ----> @@ -670,33 +783,36 @@ def get_entropy_threshold(self, histogram, num_quantized_bins): # start index end index (end of iteration) for i in range(num_half_quantized_bin, zero_bin_index + 1, 1): - start_index = zero_bin_index - i + start_index = zero_bin_index - i end_index = zero_bin_index + i + 1 if (zero_bin_index + i + 1) <= num_bins else num_bins - thresholds[i - num_half_quantized_bin] = (float(hist_edges[start_index]), float(hist_edges[end_index])) + thresholds[i - num_half_quantized_bin] = ( + float(hist_edges[start_index]), + float(hist_edges[end_index]), + ) sliced_distribution = copy.deepcopy(hist[start_index:end_index]) # reference distribution p - p = sliced_distribution.copy() # a copy of np array - left_outliers_count = sum(hist[:start_index]) + p = sliced_distribution.copy() # a copy of np array + left_outliers_count = sum(hist[:start_index]) right_outliers_count = sum(hist[end_index:]) p[0] += left_outliers_count p[-1] += right_outliers_count # nonzeros[i] incidates whether p[i] is non-zero nonzeros = (p != 0).astype(np.int64) - - # quantize p.size bins into quantized bins (default 128 bins) + + # quantize p.size bins into quantized bins (default 128 bins) quantized_bins = np.zeros(num_quantized_bins, dtype=np.int64) num_merged_bins = sliced_distribution.size // num_quantized_bins # merge bins into quantized bins for index in range(num_quantized_bins): - start = index * num_merged_bins + start = index * num_merged_bins end = start + num_merged_bins - quantized_bins[index] = sum(sliced_distribution[start:end]) - quantized_bins[-1] += sum(sliced_distribution[num_quantized_bins * num_merged_bins:]) + quantized_bins[index] = sum(sliced_distribution[start:end]) + quantized_bins[-1] += sum(sliced_distribution[num_quantized_bins * num_merged_bins :]) # in order to compare p and q, we need to make length of q equals to length of p # expand quantized bins into p.size bins @@ -708,63 +824,71 @@ def get_entropy_threshold(self, histogram, num_quantized_bins): norm = sum(nonzeros[start:end]) if norm != 0: q[start:end] = float(quantized_bins[index]) / float(norm) - + p = smooth_distribution(p) q = smooth_distribution(q) if isinstance(q, np.ndarray): kl_divergence[i - num_half_quantized_bin] = entropy(p, q) else: - kl_divergence[i - num_half_quantized_bin] = float('inf') + kl_divergence[i - num_half_quantized_bin] = float("inf") min_kl_divergence_idx = np.argmin(kl_divergence) - optimal_threshold = thresholds[min_kl_divergence_idx] + optimal_threshold = thresholds[min_kl_divergence_idx] return optimal_threshold -def create_calibrator(model, - op_types_to_calibrate=[], - augmented_model_path='augmented_model.onnx', - calibrate_method=CalibrationMethod.MinMax, - use_external_data_format=False, - extra_options={}): +def create_calibrator( + model, + op_types_to_calibrate=[], + augmented_model_path="augmented_model.onnx", + calibrate_method=CalibrationMethod.MinMax, + use_external_data_format=False, + extra_options={}, +): if calibrate_method == CalibrationMethod.MinMax: # default settings for min-max algorithm - symmetric = False if 'symmetric' not in extra_options else extra_options['symmetric'] - moving_average = False if 'moving_average' not in extra_options else extra_options['moving_average'] - averaging_constant = 0.01 if 'averaging_constant' not in extra_options else extra_options['averaging_constant'] + symmetric = False if "symmetric" not in extra_options else extra_options["symmetric"] + moving_average = False if "moving_average" not in extra_options else extra_options["moving_average"] + averaging_constant = 0.01 if "averaging_constant" not in extra_options else extra_options["averaging_constant"] return MinMaxCalibrater( - model, op_types_to_calibrate, augmented_model_path, + model, + op_types_to_calibrate, + augmented_model_path, use_external_data_format=use_external_data_format, symmetric=symmetric, moving_average=moving_average, - averaging_constant=averaging_constant + averaging_constant=averaging_constant, ) elif calibrate_method == CalibrationMethod.Entropy: # default settings for entropy algorithm - num_bins = 128 if 'num_bins' not in extra_options else extra_options['num_bins'] - num_quantized_bins = 128 if 'num_quantized_bins' not in extra_options else extra_options['num_quantized_bins'] - symmetric = False if 'symmetric' not in extra_options else extra_options['symmetric'] + num_bins = 128 if "num_bins" not in extra_options else extra_options["num_bins"] + num_quantized_bins = 128 if "num_quantized_bins" not in extra_options else extra_options["num_quantized_bins"] + symmetric = False if "symmetric" not in extra_options else extra_options["symmetric"] return EntropyCalibrater( - model, op_types_to_calibrate, augmented_model_path, + model, + op_types_to_calibrate, + augmented_model_path, use_external_data_format=use_external_data_format, symmetric=symmetric, num_bins=num_bins, - num_quantized_bins=num_quantized_bins + num_quantized_bins=num_quantized_bins, ) elif calibrate_method == CalibrationMethod.Percentile: # default settings for percentile algorithm - num_bins = 2048 if 'num_bins' not in extra_options else extra_options['num_bins'] - percentile = 99.999 if 'percentile' not in extra_options else extra_options['percentile'] - symmetric = True if 'symmetric' not in extra_options else extra_options['symmetric'] + num_bins = 2048 if "num_bins" not in extra_options else extra_options["num_bins"] + percentile = 99.999 if "percentile" not in extra_options else extra_options["percentile"] + symmetric = True if "symmetric" not in extra_options else extra_options["symmetric"] return PercentileCalibrater( - model, op_types_to_calibrate, augmented_model_path, + model, + op_types_to_calibrate, + augmented_model_path, use_external_data_format=use_external_data_format, symmetric=symmetric, num_bins=num_bins, - percentile=percentile + percentile=percentile, ) - raise ValueError('Unsupported calibration method {}'.format(calibrate_method)) + raise ValueError("Unsupported calibration method {}".format(calibrate_method)) diff --git a/onnxruntime/python/tools/quantization/onnx_model.py b/onnxruntime/python/tools/quantization/onnx_model.py index dfa2e274d3d41..566a16db3fe67 100644 --- a/onnxruntime/python/tools/quantization/onnx_model.py +++ b/onnxruntime/python/tools/quantization/onnx_model.py @@ -1,8 +1,11 @@ -import onnx import itertools -from .quant_utils import find_by_name, attribute_to_kwarg from pathlib import Path +import onnx + +from .quant_utils import attribute_to_kwarg, find_by_name + + class ONNXModel: def __init__(self, model): self.model = model @@ -121,19 +124,19 @@ def get_parent(self, node, idx, output_name_to_node=None): return output_name_to_node[input] def find_node_by_name(self, node_name, new_nodes_list, graph): - ''' + """ Find out if a node exists in a graph or a node is in the new set of nodes created during quantization. Return the node found. - ''' - graph_nodes_list = list(graph.node) #deep copy + """ + graph_nodes_list = list(graph.node) # deep copy graph_nodes_list.extend(new_nodes_list) node = find_by_name(node_name, graph_nodes_list) return node def find_nodes_by_initializer(self, graph, initializer): - ''' + """ Find all nodes with given initializer as an input. - ''' + """ nodes = [] for node in graph.node: for node_input in node.input: @@ -174,19 +177,19 @@ def __replace_gemm_with_matmul(graph_path): kwargs.update(kv) node = onnx.helper.make_node(node.op_type, node.input, node.output, name=node.name, **kwargs) - if node.op_type == 'Gemm': + if node.op_type == "Gemm": alpha = 1.0 beta = 1.0 transA = 0 transB = 0 for attr in node.attribute: - if attr.name == 'alpha': + if attr.name == "alpha": alpha = onnx.helper.get_attribute_value(attr) - elif attr.name == 'beta': + elif attr.name == "beta": beta = onnx.helper.get_attribute_value(attr) - elif attr.name == 'transA': + elif attr.name == "transA": transA = onnx.helper.get_attribute_value(attr) - elif attr.name == 'transB': + elif attr.name == "transB": transB = onnx.helper.get_attribute_value(attr) if alpha == 1.0 and beta == 1.0 and transA == 0: inputB = node.input[1] @@ -204,25 +207,30 @@ def __replace_gemm_with_matmul(graph_path): break Bs_graph.initializer.extend([B_trans]) else: - inputB += '_Transposed' - transpose_node = onnx.helper.make_node('Transpose', - inputs=[node.input[1]], - outputs=[inputB], - name=node.name + '_Transpose' if node.name != "" else "") + inputB += "_Transposed" + transpose_node = onnx.helper.make_node( + "Transpose", + inputs=[node.input[1]], + outputs=[inputB], + name=node.name + "_Transpose" if node.name != "" else "", + ) new_nodes.append(transpose_node) matmul_node = onnx.helper.make_node( - 'MatMul', + "MatMul", inputs=[node.input[0], inputB], - outputs=[node.output[0] + ('_MatMul' if len(node.input) > 2 else '')], - name=node.name + '_MatMul' if node.name != "" else "") + outputs=[node.output[0] + ("_MatMul" if len(node.input) > 2 else "")], + name=node.name + "_MatMul" if node.name != "" else "", + ) new_nodes.append(matmul_node) if len(node.input) > 2: - add_node = onnx.helper.make_node('Add', - inputs=[node.output[0] + '_MatMul', node.input[2]], - outputs=node.output, - name=node.name + '_Add' if node.name != "" else "") + add_node = onnx.helper.make_node( + "Add", + inputs=[node.output[0] + "_MatMul", node.input[2]], + outputs=node.output, + name=node.name + "_Add" if node.name != "" else "", + ) new_nodes.append(add_node) # unsupported @@ -233,7 +241,7 @@ def __replace_gemm_with_matmul(graph_path): else: new_nodes.append(node) - graph.ClearField('node') + graph.ClearField("node") graph.node.extend(new_nodes) graph_path.pop() return graph @@ -243,14 +251,16 @@ def replace_gemm_with_matmul(self): ONNXModel.__replace_gemm_with_matmul(graph_path) def save_model_to_file(self, output_path, use_external_data_format=False): - ''' + """ Save model to external data, which is needed for model size > 2GB - ''' + """ self.topological_sort() if use_external_data_format: - onnx.external_data_helper.convert_model_to_external_data(self.model, - all_tensors_to_one_file=True, - location=Path(output_path).name + ".data") + onnx.external_data_helper.convert_model_to_external_data( + self.model, + all_tensors_to_one_file=True, + location=Path(output_path).name + ".data", + ) onnx.save_model(self.model, output_path) @staticmethod @@ -278,12 +288,15 @@ def replace_output_of_all_nodes(self, old_output_name, new_output_name): def remove_unused_constant(self): input_name_to_nodes = self.input_name_to_nodes() - #remove unused constant + # remove unused constant unused_nodes = [] nodes = self.nodes() for node in nodes: - if node.op_type == "Constant" and not self.is_graph_output( - node.output[0]) and node.output[0] not in input_name_to_nodes: + if ( + node.op_type == "Constant" + and not self.is_graph_output(node.output[0]) + and node.output[0] not in input_name_to_nodes + ): unused_nodes.append(node) self.remove_nodes(unused_nodes) @@ -308,13 +321,13 @@ def is_graph_output(self, output_name): # TODO:use OnnxModel.graph_topological_sort(self.model.graph) from transformers.onnx_model # Currently it breaks Openvino/Linux training gpu pipeline so hold off for 1.8 release def topological_sort(self): - deps_count = [0]*len(self.nodes()) # dependency count of each node - deps_to_nodes = {} # input to node indice + deps_count = [0] * len(self.nodes()) # dependency count of each node + deps_to_nodes = {} # input to node indice sorted_nodes = [] # initialize sorted_nodes for node_idx, node in enumerate(self.nodes()): # CANNOT use len(node.input) directly because input can be optional - deps_count[node_idx] = sum(1 for _ in node.input if _ ) - if deps_count[node_idx] == 0: # Constant doesn't depend on any inputs + deps_count[node_idx] = sum(1 for _ in node.input if _) + if deps_count[node_idx] == 0: # Constant doesn't depend on any inputs sorted_nodes.append(self.nodes()[node_idx]) continue @@ -353,6 +366,6 @@ def topological_sort(self): end = end + 1 start = start + 1 - assert(end == len(self.graph().node)), "Graph is not a DAG" - self.graph().ClearField('node') - self.graph().node.extend(sorted_nodes) \ No newline at end of file + assert end == len(self.graph().node), "Graph is not a DAG" + self.graph().ClearField("node") + self.graph().node.extend(sorted_nodes) diff --git a/onnxruntime/python/tools/quantization/onnx_quantizer.py b/onnxruntime/python/tools/quantization/onnx_quantizer.py index 000d6e71123ac..39d53620f62db 100644 --- a/onnxruntime/python/tools/quantization/onnx_quantizer.py +++ b/onnxruntime/python/tools/quantization/onnx_quantizer.py @@ -3,30 +3,62 @@ # Licensed under the MIT License. See License.txt in the project root for # license information. # -------------------------------------------------------------------------- +import logging import os import struct from pathlib import Path -import numpy as np -import logging +import numpy as np import onnx import onnx.numpy_helper from onnx import onnx_pb as onnx_proto -from onnxruntime import SessionOptions, InferenceSession, GraphOptimizationLevel -from .quant_utils import QuantizationMode, QuantizedValueType, QuantizedInitializer, QuantizedValue -from .quant_utils import find_by_name, get_elem_index, get_mul_node, generate_identified_filename, attribute_to_kwarg, type_to_name -from .quant_utils import quantize_nparray, quantize_data, compute_scale_zp, get_qrange_for_qType, get_qmin_qmax_for_qType -from .quant_utils import save_and_reload_model, model_has_infer_metadata, add_infer_metadata -from .quant_utils import QuantType, onnx_domain, __producer__, __version__ - -from .registry import CreateOpQuantizer, CreateDefaultOpQuantizer +from onnxruntime import GraphOptimizationLevel, InferenceSession, SessionOptions from .onnx_model import ONNXModel +from .quant_utils import ( + QuantizationMode, + QuantizedInitializer, + QuantizedValue, + QuantizedValueType, + QuantType, + __producer__, + __version__, + add_infer_metadata, + attribute_to_kwarg, + compute_scale_zp, + find_by_name, + generate_identified_filename, + get_elem_index, + get_mul_node, + get_qmin_qmax_for_qType, + get_qrange_for_qType, + model_has_infer_metadata, + onnx_domain, + quantize_data, + quantize_nparray, + save_and_reload_model, + type_to_name, +) +from .registry import CreateDefaultOpQuantizer, CreateOpQuantizer + class ONNXQuantizer: - def __init__(self, model, per_channel, reduce_range, mode, static, weight_qType, input_qType, tensors_range, - nodes_to_quantize, nodes_to_exclude, op_types_to_quantize, extra_options={}): + def __init__( + self, + model, + per_channel, + reduce_range, + mode, + static, + weight_qType, + input_qType, + tensors_range, + nodes_to_quantize, + nodes_to_exclude, + op_types_to_quantize, + extra_options={}, + ): if not model_has_infer_metadata(model): model = save_and_reload_model(model) @@ -35,7 +67,8 @@ def __init__(self, model, per_channel, reduce_range, mode, static, weight_qType, self.value_infos.update({it.name: it for it in model.graph.input}) self.model = ONNXModel(model) - if not static: self.model.replace_gemm_with_matmul() + if not static: + self.model.replace_gemm_with_matmul() self.per_channel = per_channel # weight-pack per channel self.reduce_range = reduce_range @@ -44,16 +77,28 @@ def __init__(self, model, per_channel, reduce_range, mode, static, weight_qType, self.fuse_dynamic_quant = False self.extra_options = extra_options if extra_options is not None else {} - self.enable_subgraph_quantization = 'EnableSubgraph' in self.extra_options and self.extra_options['EnableSubgraph'] - self.force_quantize_no_input_check = 'ForceQuantizeNoInputCheck' in self.extra_options and self.extra_options['ForceQuantizeNoInputCheck'] - self.q_matmul_const_b_only = 'MatMulConstBOnly' in self.extra_options and self.extra_options['MatMulConstBOnly'] + self.enable_subgraph_quantization = ( + "EnableSubgraph" in self.extra_options and self.extra_options["EnableSubgraph"] + ) + self.force_quantize_no_input_check = ( + "ForceQuantizeNoInputCheck" in self.extra_options and self.extra_options["ForceQuantizeNoInputCheck"] + ) + self.q_matmul_const_b_only = "MatMulConstBOnly" in self.extra_options and self.extra_options["MatMulConstBOnly"] is_weight_int8 = weight_qType == QuantType.QInt8 - self.is_weight_symmetric = is_weight_int8 if 'WeightSymmetric' not in self.extra_options else self.extra_options['WeightSymmetric'] - self.is_activation_symmetric = False if 'ActivationSymmetric' not in self.extra_options else self.extra_options['ActivationSymmetric'] - - self.input_qType = onnx_proto.TensorProto.INT8 if input_qType == QuantType.QInt8 else onnx_proto.TensorProto.UINT8 - self.weight_qType = onnx_proto.TensorProto.INT8 if weight_qType == QuantType.QInt8 else onnx_proto.TensorProto.UINT8 - ''' + self.is_weight_symmetric = ( + is_weight_int8 if "WeightSymmetric" not in self.extra_options else self.extra_options["WeightSymmetric"] + ) + self.is_activation_symmetric = ( + False if "ActivationSymmetric" not in self.extra_options else self.extra_options["ActivationSymmetric"] + ) + + self.input_qType = ( + onnx_proto.TensorProto.INT8 if input_qType == QuantType.QInt8 else onnx_proto.TensorProto.UINT8 + ) + self.weight_qType = ( + onnx_proto.TensorProto.INT8 if weight_qType == QuantType.QInt8 else onnx_proto.TensorProto.UINT8 + ) + """ Dictionary specifying the min and max values for tensors. It has following format: { "param_name": [min, max] @@ -63,15 +108,15 @@ def __init__(self, model, per_channel, reduce_range, mode, static, weight_qType, 'Conv_3:0': [np.float32(0), np.float32(0.5)], 'Conv_4:0': [np.float32(1), np.float32(3.5)] } - ''' + """ self.tensors_range = tensors_range self.nodes_to_quantize = nodes_to_quantize # specific nodes to quantize self.nodes_to_exclude = nodes_to_exclude # specific nodes to exclude self.op_types_to_quantize = op_types_to_quantize self.new_nodes = [] self.parent = None - self.graph_scope = "/" # for human readable debug information - self.tensor_names = { } # in case the shape inference not totally working + self.graph_scope = "/" # for human readable debug information + self.tensor_names = {} # in case the shape inference not totally working self.tensor_names.update({ot.name: 1 for ot in model.graph.output}) self.tensor_names.update({it.name: 1 for it in model.graph.input}) for node in self.model.model.graph.node: @@ -80,7 +125,7 @@ def __init__(self, model, per_channel, reduce_range, mode, static, weight_qType, self.opset_version = self.check_opset_version() if not self.mode in QuantizationMode: - raise ValueError('unsupported quantization mode {}'.format(self.mode)) + raise ValueError("unsupported quantization mode {}".format(self.mode)) self.quantization_params = self.calculate_quantization_params() @@ -101,37 +146,46 @@ def __init__(self, model, per_channel, reduce_range, mode, static, weight_qType, # routines for subgraph support def quantize_subgraph(self, subgraph, graph_key): - ''' - generate submodel for the subgraph, so that we re-utilize current quantization implementation. - quantize the submodel - update subgraph and set it back to node - ''' - warped_model = onnx.helper.make_model(subgraph, producer_name='onnx-quantizer', - opset_imports=self.model.model.opset_import) + """ + generate submodel for the subgraph, so that we re-utilize current quantization implementation. + quantize the submodel + update subgraph and set it back to node + """ + warped_model = onnx.helper.make_model( + subgraph, + producer_name="onnx-quantizer", + opset_imports=self.model.model.opset_import, + ) add_infer_metadata(warped_model) - sub_quanitzer = ONNXQuantizer(warped_model, - self.per_channel, - self.reduce_range, - self.mode, - self.static, - self.weight_qType, - self.input_qType, - self.tensors_range, - self.nodes_to_quantize, - self.nodes_to_exclude, - self.op_types_to_quantize, - self.extra_options) + sub_quanitzer = ONNXQuantizer( + warped_model, + self.per_channel, + self.reduce_range, + self.mode, + self.static, + self.weight_qType, + self.input_qType, + self.tensors_range, + self.nodes_to_quantize, + self.nodes_to_exclude, + self.op_types_to_quantize, + self.extra_options, + ) sub_quanitzer.parent = self sub_quanitzer.graph_scope = "{}{}/".format(self.graph_scope, graph_key) sub_quanitzer.quantize_model() return sub_quanitzer.model.model.graph def quantize_node_with_sub_graph(self, node): - ''' + """ Check subgraph, if any, quantize it and replace it. return new_nodes added for quantizing subgraph - ''' - graph_attrs = [attr for attr in node.attribute if attr.type == onnx.AttributeProto.GRAPH or attr.type == onnx.AttributeProto.GRAPHS] + """ + graph_attrs = [ + attr + for attr in node.attribute + if attr.type == onnx.AttributeProto.GRAPH or attr.type == onnx.AttributeProto.GRAPHS + ] if len(graph_attrs) == 0: return node node_name = node.name if node.name != "" else "{}_node_count_{}".format(node.op_type, len(self.new_nodes)) @@ -142,7 +196,14 @@ def quantize_node_with_sub_graph(self, node): elif attr.type == onnx.AttributeProto.GRAPHS: value = [] for subgraph in attr.graphs: - value.extend([self.quantize_subgraph(subgraph, "{}:{}:{}".format(node_name, attr.name, len(value)))]) + value.extend( + [ + self.quantize_subgraph( + subgraph, + "{}:{}:{}".format(node_name, attr.name, len(value)), + ) + ] + ) kv = {attr.name: value} else: kv = attribute_to_kwarg(attr) @@ -154,19 +215,23 @@ def check_opset_version(self): opset for opset in self.model.model.opset_import if not opset.domain or opset.domain == "ai.onnx" ] if 1 != len(ai_onnx_domain): - raise ValueError('Failed to find proper ai.onnx domain') + raise ValueError("Failed to find proper ai.onnx domain") opset_version = ai_onnx_domain[0].version if opset_version == 10: logging.warning( - "The original model opset version is {}, which does not support node fusions. Please update the model to opset >= 11 for better performance." - .format(opset_version)) + "The original model opset version is {}, which does not support node fusions. Please update the model to opset >= 11 for better performance.".format( + opset_version + ) + ) return 10 if opset_version < 10: logging.warning( - "The original model opset version is {}, which does not support quantization. Please update the model to opset >= 11. Updating the model automatically to opset 11. Please verify the quantized model." - .format(opset_version)) + "The original model opset version is {}, which does not support quantization. Please update the model to opset >= 11. Updating the model automatically to opset 11. Please verify the quantized model.".format( + opset_version + ) + ) self.model.model.opset_import.remove(ai_onnx_domain[0]) self.model.model.opset_import.extend([onnx.helper.make_opsetid("", 11)]) opset_version = 11 @@ -175,39 +240,47 @@ def check_opset_version(self): return opset_version def has_QDQ_nodes(self): - ''' - Detect if model already has QuantizeLinear or DequantizeLinear. - ''' - return any(node.op_type == 'QuantizeLinear' or node.op_type == 'DequantizeLinear' for node in self.model.nodes()) + """ + Detect if model already has QuantizeLinear or DequantizeLinear. + """ + return any( + node.op_type == "QuantizeLinear" or node.op_type == "DequantizeLinear" for node in self.model.nodes() + ) def remove_fake_quantized_nodes(self): - ''' - Detect and remove the quantize/dequantizelinear node pairs(fake quantized nodes in Quantization-Aware training) - and reconnect and update the nodes. - ''' + """ + Detect and remove the quantize/dequantizelinear node pairs(fake quantized nodes in Quantization-Aware training) + and reconnect and update the nodes. + """ nodes_to_remove = [] initializers_to_remove = [] for curr_node in self.model.nodes(): - if curr_node.op_type == 'QuantizeLinear': + if curr_node.op_type == "QuantizeLinear": next_node, prev_node, succ_node = None, None, None for child_node in self.model.get_children(curr_node): - if child_node.op_type == 'DequantizeLinear': + if child_node.op_type == "DequantizeLinear": next_node = child_node if next_node is None: raise ValueError( "Remove fake-quantized node pair Error: DequantizeLinear node is not found for {}.".format( - curr_node.name)) + curr_node.name + ) + ) prev_node = self.model.get_parent(curr_node, 0) if prev_node is None: - raise ValueError("Remove fake-quantized node pair Error: Parent node is not found for {}.".format( - curr_node.name)) + raise ValueError( + "Remove fake-quantized node pair Error: Parent node is not found for {}.".format(curr_node.name) + ) succ_nodes = self.model.get_children(next_node) if len(succ_nodes) == 0: - raise ValueError("Remove fake-quantized node pair Error: No successive nodes found for {}.".format( - next_node.name)) + raise ValueError( + "Remove fake-quantized node pair Error: No successive nodes found for {}.".format( + next_node.name + ) + ) # TODO: convert it to the specified input_type scale_tensor_name = curr_node.input[1] @@ -216,7 +289,7 @@ def remove_fake_quantized_nodes(self): initializer_zp = find_by_name(zp_tensor_name, self.model.initializer()) zp_and_scale = [ onnx.numpy_helper.to_array(initializer_zp), - onnx.numpy_helper.to_array(initializer_scale) + onnx.numpy_helper.to_array(initializer_scale), ] # connect the previous and successive node input and output @@ -226,8 +299,10 @@ def remove_fake_quantized_nodes(self): succ_node.input[succ_idx] = curr_node.input[0] else: raise ValueError( - "Remove fake-quantized node pair Error: Connection failed. No matched successive node input found for {}." - .format(next_node.name)) + "Remove fake-quantized node pair Error: Connection failed. No matched successive node input found for {}.".format( + next_node.name + ) + ) param_name = curr_node.input[0] if self.quantization_params is None: @@ -255,11 +330,14 @@ def find_initializer_in_path(self, initializer_name): return False def should_quantize(self, node): - if self.nodes_to_quantize is not None and len( - self.nodes_to_quantize) != 0 and node.name not in self.nodes_to_quantize: + if ( + self.nodes_to_quantize is not None + and len(self.nodes_to_quantize) != 0 + and node.name not in self.nodes_to_quantize + ): return False - if (node.op_type not in self.op_types_to_quantize): + if node.op_type not in self.op_types_to_quantize: return False if self.nodes_to_exclude is not None and node.name in self.nodes_to_exclude: @@ -283,7 +361,8 @@ def quantize_model(self): if self.has_QDQ_nodes(): logging.warning( "Please check if the model is already quantized." - "Note you don't need to quantize a QAT model. OnnxRuntime support to run QAT model directly.") + "Note you don't need to quantize a QAT model. OnnxRuntime support to run QAT model directly." + ) for node in self.model.nodes(): # quantize subgraphes if have @@ -305,7 +384,7 @@ def quantize_model(self): # extend is used to append to the list for a protobuf fields # https://developers.google.com/protocol-buffers/docs/reference/python-generated?csw=1#fields - self.model.graph().ClearField('node') + self.model.graph().ClearField("node") self.model.graph().node.extend(self.new_nodes) # Remove ununsed initializers from graph, starting from the top level graph. @@ -324,8 +403,11 @@ def tensor_proto_to_array(initializer): if initializer.data_type == onnx_proto.TensorProto.FLOAT: weights = onnx.numpy_helper.to_array(initializer) else: - raise ValueError('Only float type quantization is supported. Weights {} is {}. '.format( - initializer.name, type_to_name[initializer.data_type])) + raise ValueError( + "Only float type quantization is supported. Weights {} is {}. ".format( + initializer.name, type_to_name[initializer.data_type] + ) + ) return weights def is_input_a_weight(self, input_name): @@ -344,65 +426,93 @@ def is_valid_quantize_weight(self, weight_name): return self.parent.is_valid_quantize_weight(weight_name) def _get_dynamic_input_quantization_params(self, input_name, nodes_list, qType): - ''' + """ Create nodes for dynamic quantization of input and add them to nodes_list. parameter input_name: Name of the input. parameter nodes_list: new nodes are appended to this list. parameter qType: type to quantize to. return: scale_name, zero_point_name, scale_shape, zero_point_shape. - ''' + """ if qType == onnx_proto.TensorProto.INT8: return self._get_dynamic_input_quantization_params_int8(input_name, nodes_list) return self._get_dynamic_input_quantization_params_uint8(input_name, nodes_list) def _get_dynamic_input_quantization_params_int8(self, input_name, nodes_list): - ''' + """ Create nodes for dynamic quantization of input to int8 and add them to nodes_list parameter input_name: Name of the input. parameter nodes_list: new nodes are appended to this list. return: scale_name, zero_point_name, scale_shape, zero_point_shape. - ''' + """ qType = onnx_proto.TensorProto.INT8 # Reduce min and Reduce max input_scale_name = input_name + "_scale" reduce_min_name = input_name + "_ReduceMin" - reduce_min_node = onnx.helper.make_node("ReduceMin", [input_name], [reduce_min_name + ":0"], - reduce_min_name, - keepdims=0) + reduce_min_node = onnx.helper.make_node( + "ReduceMin", + [input_name], + [reduce_min_name + ":0"], + reduce_min_name, + keepdims=0, + ) nodes_list.append(reduce_min_node) reduce_max_name = input_name + "_ReduceMax" - reduce_max_node = onnx.helper.make_node("ReduceMax", [input_name], [reduce_max_name + ":0"], - reduce_max_name, - keepdims=0) + reduce_max_node = onnx.helper.make_node( + "ReduceMax", + [input_name], + [reduce_max_name + ":0"], + reduce_max_name, + keepdims=0, + ) nodes_list.append(reduce_max_node) # Compute scale # Find abs(rmin) reduce_min_abs_name = reduce_min_name + "_Abs" - reduce_min_abs_node = onnx.helper.make_node("Abs", [reduce_min_node.output[0]], [reduce_min_abs_name + ":0"], - reduce_min_abs_name) + reduce_min_abs_node = onnx.helper.make_node( + "Abs", + [reduce_min_node.output[0]], + [reduce_min_abs_name + ":0"], + reduce_min_abs_name, + ) nodes_list.append(reduce_min_abs_node) # Find abs(rmax) reduce_max_abs_name = reduce_max_name + "_Abs" - reduce_max_abs_node = onnx.helper.make_node("Abs", [reduce_max_node.output[0]], [reduce_max_abs_name + ":0"], - reduce_max_abs_name) + reduce_max_abs_node = onnx.helper.make_node( + "Abs", + [reduce_max_node.output[0]], + [reduce_max_abs_name + ":0"], + reduce_max_abs_name, + ) nodes_list.append(reduce_max_abs_node) # Compute max of abs(rmin) and abs(rmax) abs_max_name = input_name + "_Abs_Max" - abs_max_node = onnx.helper.make_node("Max", [reduce_min_abs_node.output[0], reduce_max_abs_node.output[0]], - [abs_max_name + ":0"], abs_max_name) + abs_max_node = onnx.helper.make_node( + "Max", + [reduce_min_abs_node.output[0], reduce_max_abs_node.output[0]], + [abs_max_name + ":0"], + abs_max_name, + ) nodes_list.append(abs_max_node) # and divide by (quantize_range/2.0) which will be equal to max(...)*2.0/quantize_range - initializer_div = onnx.helper.make_tensor(self.fixed_qrange_int8_name, onnx_proto.TensorProto.FLOAT, [], - [get_qrange_for_qType(qType) / 2.0]) + initializer_div = onnx.helper.make_tensor( + self.fixed_qrange_int8_name, + onnx_proto.TensorProto.FLOAT, + [], + [get_qrange_for_qType(qType) / 2.0], + ) self.model.add_initializer(initializer_div) scale_div_name = input_name + "scale_Div" - scale_div_node = onnx.helper.make_node("Div", [abs_max_node.output[0], self.fixed_qrange_int8_name], - [input_scale_name], scale_div_name) + scale_div_node = onnx.helper.make_node( + "Div", + [abs_max_node.output[0], self.fixed_qrange_int8_name], + [input_scale_name], + scale_div_name, + ) nodes_list.append(scale_div_node) # Zero point @@ -412,32 +522,44 @@ def _get_dynamic_input_quantization_params_int8(self, input_name, nodes_list): return input_scale_name, self.fixed_zero_zp_name, [], [] def _get_dynamic_input_quantization_params_uint8(self, input_name, nodes_list): - ''' + """ Create nodes for dynamic quantization of input to uint8 and add them to nodes_list parameter input_name: Name of the input. parameter nodes_list: new nodes are appended to this list. return: scale_name, zero_point_name, scale_shape, zero_point_shape. - ''' + """ qType = onnx_proto.TensorProto.UINT8 # Reduce min and Reduce max input_scale_name = input_name + "_scale" input_zp_name = input_name + "_zero_point" reduce_min_name = input_name + "_ReduceMin" - reduce_min_node = onnx.helper.make_node("ReduceMin", [input_name], [reduce_min_name + ":0"], - reduce_min_name, - keepdims=0) + reduce_min_node = onnx.helper.make_node( + "ReduceMin", + [input_name], + [reduce_min_name + ":0"], + reduce_min_name, + keepdims=0, + ) nodes_list.append(reduce_min_node) reduce_max_name = input_name + "_ReduceMax" - reduce_max_node = onnx.helper.make_node("ReduceMax", [input_name], [reduce_max_name + ":0"], - reduce_max_name, - keepdims=0) + reduce_max_node = onnx.helper.make_node( + "ReduceMax", + [input_name], + [reduce_max_name + ":0"], + reduce_max_name, + keepdims=0, + ) nodes_list.append(reduce_max_node) # Add tensors for quantize range and zero value. - initializer_qrange = onnx.helper.make_tensor(self.fixed_qrange_uint8_name, onnx_proto.TensorProto.FLOAT, [], - [get_qrange_for_qType(qType)]) + initializer_qrange = onnx.helper.make_tensor( + self.fixed_qrange_uint8_name, + onnx_proto.TensorProto.FLOAT, + [], + [get_qrange_for_qType(qType)], + ) self.model.add_initializer(initializer_qrange) initializer_qvalue = onnx.helper.make_tensor(self.fixed_zero_name, onnx_proto.TensorProto.FLOAT, [], [0.0]) self.model.add_initializer(initializer_qvalue) @@ -445,25 +567,41 @@ def _get_dynamic_input_quantization_params_uint8(self, input_name, nodes_list): # Compute Scale # Subtract rmax and rmin scale_sub_name = input_name + "_scale_Sub" - scale_sub_node = onnx.helper.make_node("Sub", [reduce_max_node.output[0], reduce_min_node.output[0]], - [scale_sub_name + ":0"], scale_sub_name) + scale_sub_node = onnx.helper.make_node( + "Sub", + [reduce_max_node.output[0], reduce_min_node.output[0]], + [scale_sub_name + ":0"], + scale_sub_name, + ) nodes_list.append(scale_sub_node) # and divide by quantize range scale_div_name = input_name + "_scale_Div" - scale_div_node = onnx.helper.make_node("Div", [scale_sub_node.output[0], self.fixed_qrange_uint8_name], - [input_scale_name], scale_div_name) + scale_div_node = onnx.helper.make_node( + "Div", + [scale_sub_node.output[0], self.fixed_qrange_uint8_name], + [input_scale_name], + scale_div_name, + ) nodes_list.append(scale_div_node) # Compute zero point # Subtract zero and rmin zp_sub_name = input_name + "_zero_point_Sub" - zp_sub_node = onnx.helper.make_node("Sub", [self.fixed_zero_name, reduce_min_node.output[0]], - [zp_sub_name + ":0"], zp_sub_name) + zp_sub_node = onnx.helper.make_node( + "Sub", + [self.fixed_zero_name, reduce_min_node.output[0]], + [zp_sub_name + ":0"], + zp_sub_name, + ) nodes_list.append(zp_sub_node) # Divide by scale zp_div_name = input_name + "_zero_point_Div" - zp_div_node = onnx.helper.make_node("Div", [zp_sub_node.output[0], input_scale_name], [zp_div_name + ":0"], - zp_div_name) + zp_div_node = onnx.helper.make_node( + "Div", + [zp_sub_node.output[0], input_scale_name], + [zp_div_name + ":0"], + zp_div_name, + ) nodes_list.append(zp_div_node) # Compute floor zp_floor_name = input_name + "_zero_point_Floor" @@ -477,21 +615,23 @@ def _get_dynamic_input_quantization_params_uint8(self, input_name, nodes_list): return input_scale_name, input_zp_name, [], [] def _get_quantization_params(self, param_name, use_scale=None, use_zeropoint=None): - ''' + """ Create initializers and inputs in the graph for zero point and scale of output. Zero point and scale values are obtained from self.quantization_params if specified. parameter param_name: Name of the quantization parameter. return: result, scale_name, zero_point_name, scale_shape, zero_point_shape. - ''' + """ if use_scale is None or use_zeropoint is None: if self.quantization_params is None or param_name not in self.quantization_params: - logging.info("Quantization parameters for tensor:\"{}\" not specified".format(param_name)) + logging.info('Quantization parameters for tensor:"{}" not specified'.format(param_name)) return False, "", "", "", "" params = self.quantization_params[param_name] if params is None or len(params) != 2: - raise ValueError("Quantization parameters should contain zero point and scale. " - "Specified values for output {}: {}".format(param_name, params)) + raise ValueError( + "Quantization parameters should contain zero point and scale. " + "Specified values for output {}: {}".format(param_name, params) + ) zero_point_values = [params[0]] scale_values = [params[1]] @@ -514,7 +654,7 @@ def _get_quantization_params(self, param_name, use_scale=None, use_zeropoint=Non return True, scale_name, zero_point_name, scale_shape, zero_point_shape def _get_quantize_input_nodes(self, node, input_index, qType, given_scale_name=None, given_zp_name=None): - ''' + """ Given an input for a node (which is not a initializer), this function - add nodes to compute zero point and scale for this input if they don't exist. @@ -526,7 +666,7 @@ def _get_quantize_input_nodes(self, node, input_index, qType, given_scale_name=N :param given_scale_name: if those inputs need to be quanitzed using this scale tensor. :param given_zp_name: if those inputs to be quantized using this zeropoint tensor. :return: List of newly created nodes in NodeProto format. - ''' + """ input_name = node.input[input_index] output_name = input_name + "_quantized" ql_node_name = input_name + "_QuantizeLinear" @@ -538,8 +678,12 @@ def _get_quantize_input_nodes(self, node, input_index, qType, given_scale_name=N nodes = [] if data_found == True: - qlinear_node = onnx.helper.make_node("QuantizeLinear", [input_name, scale_name, zp_name], - [output_name], ql_node_name) + qlinear_node = onnx.helper.make_node( + "QuantizeLinear", + [input_name, scale_name, zp_name], + [output_name], + ql_node_name, + ) else: if self.static: return None @@ -548,13 +692,25 @@ def _get_quantize_input_nodes(self, node, input_index, qType, given_scale_name=N if self.fuse_dynamic_quant and qType == onnx_proto.TensorProto.UINT8: scale_name = input_name + "_scale" zp_name = input_name + "_zero_point" - qlinear_node = onnx.helper.make_node("DynamicQuantizeLinear", [input_name], - [output_name, scale_name, zp_name], ql_node_name) + qlinear_node = onnx.helper.make_node( + "DynamicQuantizeLinear", + [input_name], + [output_name, scale_name, zp_name], + ql_node_name, + ) else: - scale_name, zp_name, scale_shape, zp_shape = \ - self._get_dynamic_input_quantization_params(input_name, nodes, qType) - qlinear_node = onnx.helper.make_node("QuantizeLinear", [input_name, scale_name, zp_name], - [output_name], ql_node_name) + ( + scale_name, + zp_name, + scale_shape, + zp_shape, + ) = self._get_dynamic_input_quantization_params(input_name, nodes, qType) + qlinear_node = onnx.helper.make_node( + "QuantizeLinear", + [input_name, scale_name, zp_name], + [output_name], + ql_node_name, + ) self.quantized_value_map[input_name] = QuantizedValue(input_name, output_name, scale_name, zp_name, qType) return nodes + [qlinear_node] @@ -566,10 +722,10 @@ def find_quantized_value(self, input_name): return self.parent.find_quantized_value(input_name) return None - def quantize_bias_static(self, bias_name, input_name, weight_name, beta = 1.0): - ''' + def quantize_bias_static(self, bias_name, input_name, weight_name, beta=1.0): + """ Quantized the bias. Zero Point == 0 and Scale == Input_Scale * Weight_Scale - ''' + """ # Handle case where bias already in quantizatio map if bias_name in self.quantized_value_map: @@ -619,22 +775,40 @@ def quantize_bias_static(self, bias_name, input_name, weight_name, beta = 1.0): packed_bias_zp_initializer = onnx.numpy_helper.from_array(bias_zp_data, quantized_bias_zp_name) self.model.initializer().extend([packed_bias_zp_initializer]) - assert (bias_name not in self.quantized_value_map) - quantized_value = QuantizedValue(bias_name, quantized_bias_name, quantized_bias_scale_name, - quantized_bias_zp_name, QuantizedValueType.Initializer, - 0 if bias_scale_data.size > 1 else None) + assert bias_name not in self.quantized_value_map + quantized_value = QuantizedValue( + bias_name, + quantized_bias_name, + quantized_bias_scale_name, + quantized_bias_zp_name, + QuantizedValueType.Initializer, + 0 if bias_scale_data.size > 1 else None, + ) self.quantized_value_map[bias_name] = quantized_value return quantized_bias_name def contains_tensor(self, tensor_name): - ''' + """ only check for value info and newly generated tensor names, initializers are checked seperately - ''' - return (tensor_name in self.value_infos) or (tensor_name in self.tensor_names) or (tensor_name in self.generated_value_names) - - def quantize_inputs(self, node, indices, initializer_use_weight_qType=True, reduce_range=False, op_level_per_channel=False, axis=-1, from_subgraph=False): - ''' + """ + return ( + (tensor_name in self.value_infos) + or (tensor_name in self.tensor_names) + or (tensor_name in self.generated_value_names) + ) + + def quantize_inputs( + self, + node, + indices, + initializer_use_weight_qType=True, + reduce_range=False, + op_level_per_channel=False, + axis=-1, + from_subgraph=False, + ): + """ Given a node, this function quantizes the inputs as follows: - If input is an initializer, quantize the initializer data, replace old initializer with new initializer @@ -645,7 +819,7 @@ def quantize_inputs(self, node, indices, initializer_use_weight_qType=True, redu List of zero point names used for input quantization, List of scale names used for input quantization, List of new QuantizeLinear nodes created) - ''' + """ scale_names = [] zero_point_names = [] @@ -667,21 +841,27 @@ def quantize_inputs(self, node, indices, initializer_use_weight_qType=True, redu initializer = find_by_name(node_input, self.model.initializer()) if initializer is not None: if self.per_channel and op_level_per_channel: - q_weight_name, zp_name, scale_name = self.quantize_weight_per_channel( - initializer.name, self.weight_qType if initializer_use_weight_qType else self.input_qType, - axis, reduce_range) + (q_weight_name, zp_name, scale_name,) = self.quantize_weight_per_channel( + initializer.name, + self.weight_qType if initializer_use_weight_qType else self.input_qType, + axis, + reduce_range, + ) else: q_weight_name, zp_name, scale_name = self.quantize_weight( - initializer, self.weight_qType if initializer_use_weight_qType else self.input_qType, - reduce_range) + initializer, + self.weight_qType if initializer_use_weight_qType else self.input_qType, + reduce_range, + ) quantized_input_names.append(q_weight_name) zero_point_names.append(zp_name) scale_names.append(scale_name) elif self.contains_tensor(node_input): # Add QuantizeLinear node. - qlinear_node = self.model.find_node_by_name(node_input + "_QuantizeLinear", self.new_nodes, - self.model.graph()) + qlinear_node = self.model.find_node_by_name( + node_input + "_QuantizeLinear", self.new_nodes, self.model.graph() + ) if qlinear_node is None: quantize_input_nodes = self._get_quantize_input_nodes(node, input_index, self.input_qType) if quantize_input_nodes is None: @@ -701,35 +881,47 @@ def quantize_inputs(self, node, indices, initializer_use_weight_qType=True, redu scale_names.append(qlinear_node.output[1]) zero_point_names.append(qlinear_node.output[2]) elif self.parent is not None: - (parent_quantized_input_names, parent_zero_point_names, parent_scale_names, _) = self.parent.quantize_inputs( + ( + parent_quantized_input_names, + parent_zero_point_names, + parent_scale_names, + _, + ) = self.parent.quantize_inputs( node, [input_index], initializer_use_weight_qType=initializer_use_weight_qType, reduce_range=reduce_range, op_level_per_channel=op_level_per_channel, axis=axis, - from_subgraph=True) + from_subgraph=True, + ) quantized_input_names.append(parent_quantized_input_names[0]) scale_names.append(parent_scale_names[0]) zero_point_names.append(parent_zero_point_names[0]) # node should not be add this child level here else: - raise ValueError('Invalid tensor name to quantize: {} @graph scope{}'.format(node_input, self.graph_scope)) + raise ValueError( + "Invalid tensor name to quantize: {} @graph scope{}".format(node_input, self.graph_scope) + ) return (quantized_input_names, zero_point_names, scale_names, nodes) def quantize_weight(self, weight, qType, reduce_range=False, keep_float_weight=False): - ''' - :param weight: TensorProto initializer - :param qType: type to quantize to - :param keep_float_weight: Whether to quantize the weight. In some cases, we only want to qunatize scale and zero point. - If keep_float_weight is False, quantize the weight, or don't quantize the weight. - :return: quantized weight name, zero point name, scale name - ''' + """ + :param weight: TensorProto initializer + :param qType: type to quantize to + :param keep_float_weight: Whether to quantize the weight. In some cases, we only want to qunatize scale and zero point. + If keep_float_weight is False, quantize the weight, or don't quantize the weight. + :return: quantized weight name, zero point name, scale name + """ # Find if this input is already quantized if weight.name in self.quantized_value_map: quantized_value = self.quantized_value_map[weight.name] - return (quantized_value.q_name, quantized_value.zp_name, quantized_value.scale_name) + return ( + quantized_value.q_name, + quantized_value.zp_name, + quantized_value.scale_name, + ) q_weight_name = weight.name + "_quantized" zp_name = weight.name + "_zero_point" @@ -737,32 +929,52 @@ def quantize_weight(self, weight, qType, reduce_range=False, keep_float_weight=F # Update packed weight, zero point, and scale initializers weight_data = self.tensor_proto_to_array(weight) - _, _, zero_point, scale, q_weight_data = quantize_data(weight_data.flatten().tolist(), - qType, self.is_weight_symmetric, - self.reduce_range and reduce_range) + _, _, zero_point, scale, q_weight_data = quantize_data( + weight_data.flatten().tolist(), + qType, + self.is_weight_symmetric, + self.reduce_range and reduce_range, + ) scale_initializer = onnx.helper.make_tensor(scale_name, onnx_proto.TensorProto.FLOAT, [], [scale]) zero_initializer = onnx.helper.make_tensor(zp_name, qType, [], [zero_point]) self.model.initializer().extend([scale_initializer, zero_initializer]) if not keep_float_weight: - q_weight_data = np.asarray(q_weight_data, - dtype=onnx.mapping.TENSOR_TYPE_TO_NP_TYPE[qType]).reshape(weight.dims) + q_weight_data = np.asarray(q_weight_data, dtype=onnx.mapping.TENSOR_TYPE_TO_NP_TYPE[qType]).reshape( + weight.dims + ) q_weight_initializer = onnx.numpy_helper.from_array(q_weight_data, q_weight_name) self.model.initializer().extend([q_weight_initializer]) # Log entry for this quantized weight - quantized_value = QuantizedValue(weight.name, q_weight_name, scale_name, zp_name, - QuantizedValueType.Initializer, None) + quantized_value = QuantizedValue( + weight.name, + q_weight_name, + scale_name, + zp_name, + QuantizedValueType.Initializer, + None, + ) self.quantized_value_map[weight.name] = quantized_value return q_weight_name, zp_name, scale_name - def quantize_weight_per_channel(self, weight_name, weight_qType, channel_axis, reduce_range=True, - keep_float_weight=False): + def quantize_weight_per_channel( + self, + weight_name, + weight_qType, + channel_axis, + reduce_range=True, + keep_float_weight=False, + ): # Find if this input is already quantized if weight_name in self.quantized_value_map: quantized_value = self.quantized_value_map[weight_name] - return (quantized_value.q_name, quantized_value.zp_name, quantized_value.scale_name) + return ( + quantized_value.q_name, + quantized_value.zp_name, + quantized_value.scale_name, + ) initializer = find_by_name(weight_name, self.model.initializer()) if initializer is None: @@ -778,8 +990,11 @@ def quantize_weight_per_channel(self, weight_name, weight_qType, channel_axis, r for i in range(channel_count): per_channel_data = weights.take(i, channel_axis) rmin, rmax, zero_point, scale, quantized_per_channel_data = quantize_data( - per_channel_data.flatten().tolist(), weight_qType, - self.is_weight_symmetric or weight_qType == onnx_proto.TensorProto.INT8, self.reduce_range and reduce_range) + per_channel_data.flatten().tolist(), + weight_qType, + self.is_weight_symmetric or weight_qType == onnx_proto.TensorProto.INT8, + self.reduce_range and reduce_range, + ) rmin_list.append(rmin) rmax_list.append(rmax) zero_point_list.append(zero_point) @@ -798,56 +1013,70 @@ def quantize_weight_per_channel(self, weight_name, weight_qType, channel_axis, r zp_name = weight_name + "_zero_point" scale_name = weight_name + "_scale" - quantized_value = QuantizedValue(weight_name, q_weight_name, scale_name, zp_name, - QuantizedValueType.Initializer, None) + quantized_value = QuantizedValue( + weight_name, + q_weight_name, + scale_name, + zp_name, + QuantizedValueType.Initializer, + None, + ) self.quantized_value_map[weight_name] = quantized_value # Update packed weight, zero point, and scale initializers zero_scale_shape = [initializer.dims[channel_axis]] - scale_initializer = onnx.helper.make_tensor(scale_name, onnx_proto.TensorProto.FLOAT, zero_scale_shape, - scale_list) + scale_initializer = onnx.helper.make_tensor( + scale_name, onnx_proto.TensorProto.FLOAT, zero_scale_shape, scale_list + ) zero_initializer = onnx.helper.make_tensor(zp_name, weight_qType, zero_scale_shape, zero_point_list) self.model.initializer().extend([scale_initializer, zero_initializer]) if not keep_float_weight: quantized_weights = np.asarray( - quantized_weights, dtype=onnx.mapping.TENSOR_TYPE_TO_NP_TYPE[weight_qType]).reshape(initializer.dims) + quantized_weights, + dtype=onnx.mapping.TENSOR_TYPE_TO_NP_TYPE[weight_qType], + ).reshape(initializer.dims) q_weight_initializer = onnx.numpy_helper.from_array(quantized_weights, q_weight_name) self.model.initializer().extend([q_weight_initializer]) return (q_weight_name, zp_name, scale_name) def _dequantize_value(self, value_name): - ''' + """ Given a value (input/output) which is quantized, add a DequantizeLinear node to dequantize it back to float32 parameter value_name: value to dequantize parameter new_nodes_list: List of new nodes created before processing current node return: None if there is already a DequantizeLinear node that dequantizes it A DequantizeLinear node otherwise - ''' + """ if (value_name in self.quantized_value_map) and (value_name not in self.generated_value_names): quantized_value = self.quantized_value_map[value_name] # Add DequantizeLinear Node for this input dqlinear_name = value_name + "_DequantizeLinear" dqlinear_node = self.model.find_node_by_name(dqlinear_name, self.new_nodes, self.model.graph()) if dqlinear_node is None: - dqlinear_inputs = [quantized_value.q_name, quantized_value.scale_name, quantized_value.zp_name] - dequantize_node = onnx.helper.make_node("DequantizeLinear", dqlinear_inputs, [value_name], - dqlinear_name) + dqlinear_inputs = [ + quantized_value.q_name, + quantized_value.scale_name, + quantized_value.zp_name, + ] + dequantize_node = onnx.helper.make_node( + "DequantizeLinear", dqlinear_inputs, [value_name], dqlinear_name + ) return dequantize_node else: # DQ op is already present, assert it's output matches the input of current node - assert (value_name == dqlinear_node.output[0]) + assert value_name == dqlinear_node.output[0] return None def _dequantize_outputs(self): - ''' + """ Dequantize output if it is quantized parameter new_nodes_list: List of new nodes created before processing current node return: List of new nodes created - ''' + """ for output in self.model.graph().output: dequantize_node = self._dequantize_value(output.name) @@ -860,7 +1089,7 @@ def calculate_quantization_params(self): # adjust tensor_ranges for input of Clip and Relu node for node in self.model.nodes(): - if node.op_type not in ['Clip', 'Relu']: + if node.op_type not in ["Clip", "Relu"]: continue if not self.should_quantize(node): continue @@ -875,40 +1104,49 @@ def calculate_quantization_params(self): rmin, rmax = self.tensors_range[tensor_name] qmin, qmax = get_qmin_qmax_for_qType(self.input_qType, symmetric=self.is_activation_symmetric) - quantization_params[tensor_name] = compute_scale_zp(rmin, rmax, - qmin, qmax, - self.is_activation_symmetric) + quantization_params[tensor_name] = compute_scale_zp(rmin, rmax, qmin, qmax, self.is_activation_symmetric) return quantization_params - # static method def CleanGraphInitializers(graph, model): - ''' + """ Clean unused initializers including which is caused by quantizing the model. return cleaned graph, and list of tensor names from this graph and all its subgraphes that can not be found in this graph and its subgraphes - ''' + """ requesting_tensor_names = {} - requesting_tensor_names.update({input_name: 1 for node in graph.node for input_name in node.input if input_name}) + requesting_tensor_names.update( + {input_name: 1 for node in graph.node for input_name in node.input if input_name} + ) requesting_tensor_names.update({g_out.name: 1 for g_out in graph.output if g_out.name}) new_nodes = [] for node in graph.node: node_2_add = node - graph_attrs = [attr for attr in node.attribute if attr.type == onnx.AttributeProto.GRAPH or attr.type == onnx.AttributeProto.GRAPHS] + graph_attrs = [ + attr + for attr in node.attribute + if attr.type == onnx.AttributeProto.GRAPH or attr.type == onnx.AttributeProto.GRAPHS + ] if len(graph_attrs) > 0: kwargs = {} for attr in node.attribute: kv = {} if attr.type == onnx.AttributeProto.GRAPH: - cleaned_sub_graph, sub_requesting_tensor_names = ONNXQuantizer.CleanGraphInitializers(attr.g, model) + ( + cleaned_sub_graph, + sub_requesting_tensor_names, + ) = ONNXQuantizer.CleanGraphInitializers(attr.g, model) kv = {attr.name: cleaned_sub_graph} requesting_tensor_names.update({gn: 1 for gn in sub_requesting_tensor_names}) elif attr.type == onnx.AttributeProto.GRAPHS: cleaned_graphes = [] for subgraph in attr.graphs: - cleaned_sub_graph, sub_requesting_tensor_names = ONNXQuantizer.CleanGraphInitializers(subgraph, model) + ( + cleaned_sub_graph, + sub_requesting_tensor_names, + ) = ONNXQuantizer.CleanGraphInitializers(subgraph, model) cleaned_graphes.extend([cleaned_sub_graph]) requesting_tensor_names.update({gn: 1 for gn in sub_requesting_tensor_names}) kv = {attr.name: cleaned_graphes} @@ -918,7 +1156,7 @@ def CleanGraphInitializers(graph, model): node_2_add = onnx.helper.make_node(node.op_type, node.input, node.output, name=node.name, **kwargs) new_nodes.extend([node_2_add]) - graph.ClearField('node') + graph.ClearField("node") graph.node.extend(new_nodes) generated_names = {} @@ -945,7 +1183,11 @@ def CleanGraphInitializers(graph, model): graph.input.remove(name_to_input[ini_tensor.name]) except StopIteration: if model.ir_version < 4: - print("Warning: invalid weight name {} found in the graph (not a graph input)".format(ini_tensor.name)) + print( + "Warning: invalid weight name {} found in the graph (not a graph input)".format( + ini_tensor.name + ) + ) for input in graph.input: if input.name in requesting_tensor_names: diff --git a/onnxruntime/python/tools/quantization/operators/__init__.py b/onnxruntime/python/tools/quantization/operators/__init__.py index 0a0ac70418765..b2830b3dfa1da 100644 --- a/onnxruntime/python/tools/quantization/operators/__init__.py +++ b/onnxruntime/python/tools/quantization/operators/__init__.py @@ -1,2 +1,2 @@ -#from .base_operator import QuantOperatorBase -#from .matmul import MatMulInteger +# from .base_operator import QuantOperatorBase +# from .matmul import MatMulInteger diff --git a/onnxruntime/python/tools/quantization/operators/activation.py b/onnxruntime/python/tools/quantization/operators/activation.py index 6751c03d46835..3df1d1c631522 100644 --- a/onnxruntime/python/tools/quantization/operators/activation.py +++ b/onnxruntime/python/tools/quantization/operators/activation.py @@ -1,8 +1,9 @@ import onnx +from onnx import onnx_pb as onnx_proto + +from ..quant_utils import QuantizedValue, QuantizedValueType, attribute_to_kwarg, ms_domain from .base_operator import QuantOperatorBase from .qdq_base_operator import QDQOperatorBase -from ..quant_utils import QuantizedValue, QuantizedValueType, attribute_to_kwarg, ms_domain -from onnx import onnx_pb as onnx_proto class QLinearActivation(QuantOperatorBase): @@ -11,7 +12,7 @@ def __init__(self, onnx_quantizer, onnx_node): def QuantizeClipRelu(self): node = self.node - assert (node.op_type == "Relu" or node.op_type == 'Clip') + assert node.op_type == "Relu" or node.op_type == "Clip" # When mode is QLinearOps, the output quantization params are calculated based on outputs from # activation nodes, therefore these nodes can be removed from the graph if they follow a quantized op. @@ -25,22 +26,34 @@ def QuantizeClipRelu(self): def quantize(self): node = self.node - if node.op_type == "Relu" or node.op_type == 'Clip': + if node.op_type == "Relu" or node.op_type == "Clip": self.QuantizeClipRelu() return - nnapi_sigmoid_option = 'extra.Sigmoid.nnapi' - sigmoid_nnapi_mode = (node.op_type == 'Sigmoid' and - nnapi_sigmoid_option in self.quantizer.extra_options and - self.quantizer.extra_options[nnapi_sigmoid_option]) + nnapi_sigmoid_option = "extra.Sigmoid.nnapi" + sigmoid_nnapi_mode = ( + node.op_type == "Sigmoid" + and nnapi_sigmoid_option in self.quantizer.extra_options + and self.quantizer.extra_options[nnapi_sigmoid_option] + ) use_scale = 1 / 256.0 if sigmoid_nnapi_mode else None use_zeropoint = 0 if sigmoid_nnapi_mode else None # No assert on op_type as it is controlled by registry # only try to quantize when given quantization parameters for it - data_found, output_scale_name, output_zp_name, _, _ = \ - self.quantizer._get_quantization_params(node.output[0], use_scale, use_zeropoint) - quantized_input_names, zero_point_names, scale_names, nodes = self.quantizer.quantize_inputs(node, [0]) + ( + data_found, + output_scale_name, + output_zp_name, + _, + _, + ) = self.quantizer._get_quantization_params(node.output[0], use_scale, use_zeropoint) + ( + quantized_input_names, + zero_point_names, + scale_names, + nodes, + ) = self.quantizer.quantize_inputs(node, [0]) if not data_found or quantized_input_names is None: return super().quantize() @@ -54,15 +67,29 @@ def quantize(self): kwargs["domain"] = ms_domain qlinear_activation_inputs = [ - quantized_input_names[0], scale_names[0], zero_point_names[0], output_scale_name, output_zp_name + quantized_input_names[0], + scale_names[0], + zero_point_names[0], + output_scale_name, + output_zp_name, ] - qlinear_activation_node = onnx.helper.make_node("QLinear" + node.op_type, qlinear_activation_inputs, - [qlinear_activation_output], qlinear_activation_name, **kwargs) + qlinear_activation_node = onnx.helper.make_node( + "QLinear" + node.op_type, + qlinear_activation_inputs, + [qlinear_activation_output], + qlinear_activation_name, + **kwargs + ) # Create an entry for this quantized value - q_output = QuantizedValue(node.output[0], qlinear_activation_output, output_scale_name, output_zp_name, - QuantizedValueType.Input) + q_output = QuantizedValue( + node.output[0], + qlinear_activation_output, + output_scale_name, + output_zp_name, + QuantizedValueType.Input, + ) self.quantizer.quantized_value_map[node.output[0]] = q_output nodes.append(qlinear_activation_node) diff --git a/onnxruntime/python/tools/quantization/operators/argmax.py b/onnxruntime/python/tools/quantization/operators/argmax.py index d24c5cee7082f..66cd68214920b 100644 --- a/onnxruntime/python/tools/quantization/operators/argmax.py +++ b/onnxruntime/python/tools/quantization/operators/argmax.py @@ -1,5 +1,6 @@ from .base_operator import QuantOperatorBase + # Use the quantized tensor as input without DQ. class QArgMax(QuantOperatorBase): def __init__(self, onnx_quantizer, onnx_node): @@ -14,4 +15,4 @@ def quantize(self): return node.input[0] = quantized_input_value.q_name - self.quantizer.new_nodes += [node] \ No newline at end of file + self.quantizer.new_nodes += [node] diff --git a/onnxruntime/python/tools/quantization/operators/attention.py b/onnxruntime/python/tools/quantization/operators/attention.py index 078605f378437..0b29f5a2e1e13 100644 --- a/onnxruntime/python/tools/quantization/operators/attention.py +++ b/onnxruntime/python/tools/quantization/operators/attention.py @@ -1,10 +1,12 @@ import onnx -from .base_operator import QuantOperatorBase -from ..quant_utils import attribute_to_kwarg, ms_domain from onnx import onnx_pb as onnx_proto -''' + +from ..quant_utils import attribute_to_kwarg, ms_domain +from .base_operator import QuantOperatorBase + +""" Quantize Attention -''' +""" class AttentionQuant(QuantOperatorBase): @@ -12,23 +14,27 @@ def __init__(self, onnx_quantizer, onnx_node): super().__init__(onnx_quantizer, onnx_node) def quantize(self): - ''' - parameter node: Attention node. - parameter new_nodes_list: List of new nodes created before processing this node. - return: a list of nodes in topological order that represents quantized Attention node. - ''' + """ + parameter node: Attention node. + parameter new_nodes_list: List of new nodes created before processing this node. + return: a list of nodes in topological order that represents quantized Attention node. + """ node = self.node - assert (node.op_type == "Attention") + assert node.op_type == "Attention" # TODO This is a temporary fix to stop exporting QAttention with qkv_hidden_sizes # attribute. This needs to be removed once the QAttention for varied q,k,v sizes # is implemented for attr in node.attribute: - if 'qkv_hidden_sizes' == attr.name: + if "qkv_hidden_sizes" == attr.name: return super().quantize() - (quantized_input_names, zero_point_names, scale_names, nodes) = \ - self.quantizer.quantize_inputs(node, [0, 1], reduce_range=True, op_level_per_channel=True) + ( + quantized_input_names, + zero_point_names, + scale_names, + nodes, + ) = self.quantizer.quantize_inputs(node, [0, 1], reduce_range=True, op_level_per_channel=True) if quantized_input_names is None: return super().quantize() diff --git a/onnxruntime/python/tools/quantization/operators/base_operator.py b/onnxruntime/python/tools/quantization/operators/base_operator.py index f6b510ff675ff..dd6a7a1a7000a 100644 --- a/onnxruntime/python/tools/quantization/operators/base_operator.py +++ b/onnxruntime/python/tools/quantization/operators/base_operator.py @@ -4,13 +4,13 @@ def __init__(self, onnx_quantizer, onnx_node): self.node = onnx_node def quantize(self): - ''' + """ Given a node which does not support quantization, this method checks whether the input to this node is quantized and adds a DequantizeLinear node to dequantize this input back to FP32 parameter node: Current node parameter new_nodes_list: List of new nodes created before processing current node return: List of new nodes created - ''' + """ nodes = [] for index, node_input in enumerate(self.node.input): dequantize_node = self.quantizer._dequantize_value(node_input) @@ -18,4 +18,4 @@ def quantize(self): self.quantizer.new_nodes.append(dequantize_node) # Append the original node - self.quantizer.new_nodes.append(self.node) \ No newline at end of file + self.quantizer.new_nodes.append(self.node) diff --git a/onnxruntime/python/tools/quantization/operators/binary_op.py b/onnxruntime/python/tools/quantization/operators/binary_op.py index d2097eb0234a7..7cfd69c204037 100644 --- a/onnxruntime/python/tools/quantization/operators/binary_op.py +++ b/onnxruntime/python/tools/quantization/operators/binary_op.py @@ -1,8 +1,9 @@ import onnx -from .base_operator import QuantOperatorBase -from ..quant_utils import attribute_to_kwarg, ms_domain, QuantizedValue, QuantizedValueType from onnx import onnx_pb as onnx_proto +from ..quant_utils import QuantizedValue, QuantizedValueType, attribute_to_kwarg, ms_domain +from .base_operator import QuantOperatorBase + class QLinearBinaryOp(QuantOperatorBase): def __init__(self, onnx_quantizer, onnx_node): @@ -11,10 +12,19 @@ def __init__(self, onnx_quantizer, onnx_node): def quantize(self): node = self.node - data_found, output_scale_name, output_zp_name, _, _ = \ - self.quantizer._get_quantization_params(node.output[0]) - (quantized_input_names, zero_point_names, scale_names, nodes) = \ - self.quantizer.quantize_inputs(node, [0, 1], initializer_use_weight_qType=False) + ( + data_found, + output_scale_name, + output_zp_name, + _, + _, + ) = self.quantizer._get_quantization_params(node.output[0]) + ( + quantized_input_names, + zero_point_names, + scale_names, + nodes, + ) = self.quantizer.quantize_inputs(node, [0, 1], initializer_use_weight_qType=False) if not data_found or quantized_input_names is None: return super().quantize() @@ -40,14 +50,23 @@ def quantize(self): qlinear_binary_math_inputs.append(output_scale_name) qlinear_binary_math_inputs.append(output_zp_name) - qlinear_binary_math_node = onnx.helper.make_node("QLinear" + node.op_type, qlinear_binary_math_inputs, - [qlinear_binary_math_output], qlinear_binary_math_name, - **kwargs) + qlinear_binary_math_node = onnx.helper.make_node( + "QLinear" + node.op_type, + qlinear_binary_math_inputs, + [qlinear_binary_math_output], + qlinear_binary_math_name, + **kwargs + ) nodes.append(qlinear_binary_math_node) # Create an entry for this quantized value - q_output = QuantizedValue(node.output[0], qlinear_binary_math_output, output_scale_name, output_zp_name, - QuantizedValueType.Input) + q_output = QuantizedValue( + node.output[0], + qlinear_binary_math_output, + output_scale_name, + output_zp_name, + QuantizedValueType.Input, + ) self.quantizer.quantized_value_map[node.output[0]] = q_output self.quantizer.new_nodes += nodes diff --git a/onnxruntime/python/tools/quantization/operators/concat.py b/onnxruntime/python/tools/quantization/operators/concat.py index 76c05828e5f40..2781e2600218e 100644 --- a/onnxruntime/python/tools/quantization/operators/concat.py +++ b/onnxruntime/python/tools/quantization/operators/concat.py @@ -1,7 +1,9 @@ import onnx + +from ..quant_utils import QuantizedValue, QuantizedValueType, attribute_to_kwarg, ms_domain from .base_operator import QuantOperatorBase from .qdq_base_operator import QDQOperatorBase -from ..quant_utils import QuantizedValue, attribute_to_kwarg, ms_domain, QuantizedValueType + class QLinearConcat(QuantOperatorBase): def __init__(self, onnx_quantizer, onnx_node): @@ -10,18 +12,31 @@ def __init__(self, onnx_quantizer, onnx_node): def quantize(self): node = self.node - data_found, output_scale_name, output_zp_name, _, _ = \ - self.quantizer._get_quantization_params(node.output[0]) - (q_input_names, zero_point_names, scale_names, nodes) = \ - self.quantizer.quantize_inputs(node, [*range(0, len(node.input))], initializer_use_weight_qType=False) + ( + data_found, + output_scale_name, + output_zp_name, + _, + _, + ) = self.quantizer._get_quantization_params(node.output[0]) + ( + q_input_names, + zero_point_names, + scale_names, + nodes, + ) = self.quantizer.quantize_inputs(node, [*range(0, len(node.input))], initializer_use_weight_qType=False) if not data_found or q_input_names is None: return super().quantize() # Create an entry for output quantized value quantized_input_value = self.quantizer.quantized_value_map[node.input[0]] - quantized_output_value = QuantizedValue(node.output[0], node.output[0] + "_quantized", - output_scale_name, output_zp_name, - quantized_input_value.value_type) + quantized_output_value = QuantizedValue( + node.output[0], + node.output[0] + "_quantized", + output_scale_name, + output_zp_name, + quantized_input_value.value_type, + ) self.quantizer.quantized_value_map[node.output[0]] = quantized_output_value kwargs = {} @@ -33,11 +48,14 @@ def quantize(self): qlconcat_inputs = [output_scale_name, output_zp_name] for i in range(0, len(q_input_names)): qlconcat_inputs.extend([q_input_names[i], scale_names[i], zero_point_names[i]]) - qlconcat_node = onnx.helper.make_node("QLinearConcat", qlconcat_inputs, [quantized_output_value.q_name], qnode_name, **kwargs) + qlconcat_node = onnx.helper.make_node( + "QLinearConcat", qlconcat_inputs, [quantized_output_value.q_name], qnode_name, **kwargs + ) self.quantizer.new_nodes += nodes self.quantizer.new_nodes += [qlconcat_node] + class QDQConcat(QDQOperatorBase): def __init__(self, onnx_quantizer, onnx_node): super().__init__(onnx_quantizer, onnx_node) diff --git a/onnxruntime/python/tools/quantization/operators/conv.py b/onnxruntime/python/tools/quantization/operators/conv.py index f5cb66885dd8b..f0eeababed129 100644 --- a/onnxruntime/python/tools/quantization/operators/conv.py +++ b/onnxruntime/python/tools/quantization/operators/conv.py @@ -1,9 +1,17 @@ -import onnx import numpy as np +import onnx +from onnx import onnx_pb as onnx_proto + +from ..quant_utils import ( + BiasToQuantize, + QuantizedValue, + QuantizedValueType, + attribute_to_kwarg, + find_by_name, + get_mul_node, +) from .base_operator import QuantOperatorBase from .qdq_base_operator import QDQOperatorBase -from ..quant_utils import find_by_name, get_mul_node, QuantizedValue, QuantizedValueType, attribute_to_kwarg, BiasToQuantize -from onnx import onnx_pb as onnx_proto class ConvInteger(QuantOperatorBase): @@ -11,7 +19,7 @@ def __init__(self, onnx_quantizer, onnx_node): super().__init__(onnx_quantizer, onnx_node) def add_bias(self, nodes, scaled_output): - ''' + """ Given a node, this function handles bias add by adding a "reshape" node on bias and an "add" node parameter nodes: new nodes would be appended into nodes parameter node: current node (Conv) @@ -19,7 +27,7 @@ def add_bias(self, nodes, scaled_output): parameter output: output of Conv parameter bias_name: bias of Conv return: the name of output - ''' + """ node = self.node model = self.quantizer.model # Add tensors for the shape to be reshaped to @@ -29,14 +37,15 @@ def add_bias(self, nodes, scaled_output): # Add reshape for correct broadcase output = node.output[0] - reshape_input_data = node.input[2] # bias of Conv + reshape_input_data = node.input[2] # bias of Conv reshape_input_shape = output + "_bias_reshape_shape" reshape_output = output + "_bias_reshape_output" shape = np.ones((len(weight.dims)), dtype=np.int64) shape[1] = -1 - init_shape = onnx.helper.make_tensor(reshape_input_shape, onnx_proto.TensorProto.INT64, [len(weight.dims)], - shape) + init_shape = onnx.helper.make_tensor( + reshape_input_shape, onnx_proto.TensorProto.INT64, [len(weight.dims)], shape + ) model.add_initializer(init_shape) reshape_node = onnx.helper.make_node("Reshape", [reshape_input_data, reshape_input_shape], [reshape_output]) @@ -48,10 +57,14 @@ def add_bias(self, nodes, scaled_output): def quantize(self): node = self.node - assert (node.op_type == "Conv") + assert node.op_type == "Conv" - (quantized_input_names, zero_point_names, scale_names, nodes) = \ - self.quantizer.quantize_inputs(node, [0, 1], reduce_range=self.quantizer.reduce_range) + ( + quantized_input_names, + zero_point_names, + scale_names, + nodes, + ) = self.quantizer.quantize_inputs(node, [0, 1], reduce_range=self.quantizer.reduce_range) conv_integer_output = node.output[0] + "_output_quantized" conv_integer_name = node.name + "_quant" if node.name != "" else "" @@ -59,19 +72,24 @@ def quantize(self): kwargs = {} for attribute in node.attribute: kwargs.update(attribute_to_kwarg(attribute)) - conv_integer_node = onnx.helper.make_node("ConvInteger", quantized_input_names + zero_point_names, - [conv_integer_output], conv_integer_name, **kwargs) + conv_integer_node = onnx.helper.make_node( + "ConvInteger", quantized_input_names + zero_point_names, [conv_integer_output], conv_integer_name, **kwargs + ) nodes.append(conv_integer_node) # Add cast operation to cast convInteger output to float. cast_op_output = conv_integer_output + "_cast_output" - cast_node = onnx.helper.make_node("Cast", [conv_integer_output], [cast_op_output], - conv_integer_output + "_cast", - to=onnx_proto.TensorProto.FLOAT) + cast_node = onnx.helper.make_node( + "Cast", + [conv_integer_output], + [cast_op_output], + conv_integer_output + "_cast", + to=onnx_proto.TensorProto.FLOAT, + ) nodes.append(cast_node) # Add mul operation to multiply scales of two inputs. - assert (len(scale_names) == 2) + assert len(scale_names) == 2 if conv_integer_name != "": scales_mul_op = conv_integer_name + "_scales_mul" else: @@ -90,7 +108,13 @@ def quantize(self): # Add mul operation to multiply mul_scales_op result with output of ConvInteger # and make the output of this node the same as output of original conv node. output_scale_mul_op = conv_integer_name + "_output_scale_mul" if conv_integer_name != "" else "" - nodes.append(get_mul_node([cast_op_output, scales_mul_op_output], scaled_output_name, output_scale_mul_op)) + nodes.append( + get_mul_node( + [cast_op_output, scales_mul_op_output], + scaled_output_name, + output_scale_mul_op, + ) + ) if has_bias: self.add_bias(nodes, scaled_output_name) @@ -104,22 +128,36 @@ def __init__(self, onnx_quantizer, onnx_node): def quantize(self): node = self.node - assert (node.op_type == "Conv") + assert node.op_type == "Conv" - data_found, output_scale_name, output_zp_name, _, _ = \ - self.quantizer._get_quantization_params(node.output[0]) + ( + data_found, + output_scale_name, + output_zp_name, + _, + _, + ) = self.quantizer._get_quantization_params(node.output[0]) if self.quantizer.is_input_a_weight(node.input[1]) and self.quantizer.is_per_channel(): - (quantized_input_names, zero_point_names, scale_names, nodes) = \ - self.quantizer.quantize_inputs(node, [0], reduce_range=self.quantizer.reduce_range) - quant_weight_tuple = self.quantizer.quantize_weight_per_channel(node.input[1], onnx_proto.TensorProto.INT8, - 0) + ( + quantized_input_names, + zero_point_names, + scale_names, + nodes, + ) = self.quantizer.quantize_inputs(node, [0], reduce_range=self.quantizer.reduce_range) + quant_weight_tuple = self.quantizer.quantize_weight_per_channel( + node.input[1], onnx_proto.TensorProto.INT8, 0 + ) quantized_input_names.append(quant_weight_tuple[0]) zero_point_names.append(quant_weight_tuple[1]) scale_names.append(quant_weight_tuple[2]) else: - (quantized_input_names, zero_point_names, scale_names, nodes) = \ - self.quantizer.quantize_inputs(node, [0, 1], reduce_range=self.quantizer.reduce_range) + ( + quantized_input_names, + zero_point_names, + scale_names, + nodes, + ) = self.quantizer.quantize_inputs(node, [0, 1], reduce_range=self.quantizer.reduce_range) if not data_found or quantized_input_names is None: return super().quantize() @@ -153,13 +191,19 @@ def quantize(self): if bias_present: qlinear_conv_inputs.append(quantized_bias_name) - qlinear_conv_node = onnx.helper.make_node("QLinearConv", qlinear_conv_inputs, [qlinear_conv_output], - qlinear_conv_name, **kwargs) + qlinear_conv_node = onnx.helper.make_node( + "QLinearConv", qlinear_conv_inputs, [qlinear_conv_output], qlinear_conv_name, **kwargs + ) nodes.append(qlinear_conv_node) # Create an entry for this quantized value - q_output = QuantizedValue(node.output[0], qlinear_conv_output, output_scale_name, output_zp_name, - QuantizedValueType.Input) + q_output = QuantizedValue( + node.output[0], + qlinear_conv_output, + output_scale_name, + output_zp_name, + QuantizedValueType.Input, + ) self.quantizer.quantized_value_map[node.output[0]] = q_output self.quantizer.new_nodes += nodes @@ -171,7 +215,7 @@ def __init__(self, onnx_quantizer, onnx_node): def quantize(self): node = self.node - assert (node.op_type == "Conv") + assert node.op_type == "Conv" self.quantizer.quantize_tensor(node.input[0]) if not self.disable_qdq_for_node_output: diff --git a/onnxruntime/python/tools/quantization/operators/direct_q8.py b/onnxruntime/python/tools/quantization/operators/direct_q8.py index cabe709b20512..33dc04a5dddaf 100644 --- a/onnxruntime/python/tools/quantization/operators/direct_q8.py +++ b/onnxruntime/python/tools/quantization/operators/direct_q8.py @@ -1,6 +1,7 @@ +from ..quant_utils import QuantizedValue, QuantizedValueType from .base_operator import QuantOperatorBase from .qdq_base_operator import QDQOperatorBase -from ..quant_utils import QuantizedValue, QuantizedValueType + # For operators that support 8bits operations directly, and output could # reuse input[0]'s type, zeropoint, scale; For example,Transpose, Reshape, etc. @@ -19,9 +20,13 @@ def quantize(self): self.quantizer.new_nodes += [node] return - quantized_output_value = QuantizedValue(node.output[0], node.output[0] + "_quantized", - quantized_input_value.scale_name, quantized_input_value.zp_name, - quantized_input_value.value_type) + quantized_output_value = QuantizedValue( + node.output[0], + node.output[0] + "_quantized", + quantized_input_value.scale_name, + quantized_input_value.zp_name, + quantized_input_value.value_type, + ) self.quantizer.quantized_value_map[node.output[0]] = quantized_output_value node.input[0] = quantized_input_value.q_name @@ -30,19 +35,27 @@ def quantize(self): else: # Force quantize those ops if possible, use exclude node list if this is not you want - if (not self.quantizer.is_valid_quantize_weight(node.input[0])): + if not self.quantizer.is_valid_quantize_weight(node.input[0]): super().quantize() return - (quantized_input_names, zero_point_names, scale_names, nodes) = \ - self.quantizer.quantize_inputs(node, [0]) + ( + quantized_input_names, + zero_point_names, + scale_names, + nodes, + ) = self.quantizer.quantize_inputs(node, [0]) if quantized_input_names is None: return super().quantize() # Create an entry for output quantized value - quantized_output_value = QuantizedValue(node.output[0], node.output[0] + "_quantized", - scale_names[0], zero_point_names[0], - QuantizedValueType.Input) + quantized_output_value = QuantizedValue( + node.output[0], + node.output[0] + "_quantized", + scale_names[0], + zero_point_names[0], + QuantizedValueType.Input, + ) self.quantizer.quantized_value_map[node.output[0]] = quantized_output_value node.input[0] = quantized_input_names[0] @@ -52,7 +65,6 @@ def quantize(self): self.quantizer.new_nodes += nodes - class QDQDirect8BitOp(QDQOperatorBase): def __init__(self, onnx_quantizer, onnx_node): super().__init__(onnx_quantizer, onnx_node) diff --git a/onnxruntime/python/tools/quantization/operators/embed_layernorm.py b/onnxruntime/python/tools/quantization/operators/embed_layernorm.py index 1e96b7bd7508b..28e0fe3eb86e5 100644 --- a/onnxruntime/python/tools/quantization/operators/embed_layernorm.py +++ b/onnxruntime/python/tools/quantization/operators/embed_layernorm.py @@ -1,28 +1,32 @@ -import onnx import logging -from .base_operator import QuantOperatorBase -from ..quant_utils import attribute_to_kwarg, ms_domain + +import onnx from onnx import onnx_pb as onnx_proto -''' +from ..quant_utils import attribute_to_kwarg, ms_domain +from .base_operator import QuantOperatorBase + +""" Quantizes the EmbedLayerNorm fused ONNXRuntime Op. This Quant operator keeps the input and segment IDs at int32 but will quantize all initializer and weight inputs associated with the node to uint8. -''' +""" + + class EmbedLayerNormalizationQuant(QuantOperatorBase): def __init__(self, onnx_quantizer, onnx_node): super().__init__(onnx_quantizer, onnx_node) def quantize(self): node = self.node - assert (node.op_type == "EmbedLayerNormalization") + assert node.op_type == "EmbedLayerNormalization" if len(node.output) > 2: logging.info(f"Quantization is not applied to {node.name} since it has 3 outputs") return super().quantize() - ''' + """ Pre-quantization EmbedLayerNorm inputs: [0] input_ids (int32) [1] segment_ids (int32) @@ -32,15 +36,19 @@ def quantize(self): [5] gamma (float32) [6] beta (float32) [7] mask (int32) (optional) - ''' - (quantized_input_names, zero_point_names, scale_names, nodes) = \ - self.quantizer.quantize_inputs(node, [2, 3, 4, 5, 6]) + """ + ( + quantized_input_names, + zero_point_names, + scale_names, + nodes, + ) = self.quantizer.quantize_inputs(node, [2, 3, 4, 5, 6]) if quantized_input_names is None: return super().quantize() qembed_layer_norm_name = "" if node.name == "" else node.name + "_quant" - ''' + """ Quantized Input Tensor List [0] input_ids (int32) [1] segment_ids (int32) @@ -60,7 +68,7 @@ def quantize(self): [15] segment_embedding_zero_point (uint8) [16] gamma_zero_point (uint8) [17] beta_zero_point (uint8) - ''' + """ inputs = [] # 'input_ids' inputs.extend([node.input[0]]) @@ -98,8 +106,13 @@ def quantize(self): kwargs.update(attribute_to_kwarg(attribute)) kwargs["domain"] = ms_domain - qembed_layer_norm_node = onnx.helper.make_node("QEmbedLayerNormalization", inputs, node.output, - qembed_layer_norm_name, **kwargs) + qembed_layer_norm_node = onnx.helper.make_node( + "QEmbedLayerNormalization", + inputs, + node.output, + qembed_layer_norm_name, + **kwargs, + ) nodes.append(qembed_layer_norm_node) self.quantizer.new_nodes += nodes diff --git a/onnxruntime/python/tools/quantization/operators/gather.py b/onnxruntime/python/tools/quantization/operators/gather.py index 495f53b7f9868..c48804b99d283 100644 --- a/onnxruntime/python/tools/quantization/operators/gather.py +++ b/onnxruntime/python/tools/quantization/operators/gather.py @@ -1,10 +1,12 @@ import onnx -from .base_operator import QuantOperatorBase -from ..quant_utils import QuantizedValue, QuantizedValueType from onnx import onnx_pb as onnx_proto -''' + +from ..quant_utils import QuantizedValue, QuantizedValueType +from .base_operator import QuantOperatorBase + +""" Quantize Gather -''' +""" class GatherQuant(QuantOperatorBase): @@ -13,21 +15,30 @@ def __init__(self, onnx_quantizer, onnx_node): def quantize(self): node = self.node - assert (node.op_type == "Gather") - if (not self.quantizer.is_valid_quantize_weight(node.input[0])): + assert node.op_type == "Gather" + if not self.quantizer.is_valid_quantize_weight(node.input[0]): super().quantize() return - (quantized_input_names, zero_point_names, scale_names, nodes) = \ - self.quantizer.quantize_inputs(node, [0]) + ( + quantized_input_names, + zero_point_names, + scale_names, + nodes, + ) = self.quantizer.quantize_inputs(node, [0]) if quantized_input_names is None: return super().quantize() gather_new_output = node.output[0] + "_quantized" # Create an entry for this quantized value - q_output = QuantizedValue(node.output[0], gather_new_output, scale_names[0], zero_point_names[0], - QuantizedValueType.Input) + q_output = QuantizedValue( + node.output[0], + gather_new_output, + scale_names[0], + zero_point_names[0], + QuantizedValueType.Input, + ) self.quantizer.quantized_value_map[node.output[0]] = q_output gather_original_output = node.output[0] diff --git a/onnxruntime/python/tools/quantization/operators/gavgpool.py b/onnxruntime/python/tools/quantization/operators/gavgpool.py index c34fd5b37635b..7527685072342 100644 --- a/onnxruntime/python/tools/quantization/operators/gavgpool.py +++ b/onnxruntime/python/tools/quantization/operators/gavgpool.py @@ -1,6 +1,7 @@ import onnx + +from ..quant_utils import QuantizedValue, QuantizedValueType, attribute_to_kwarg, ms_domain from .base_operator import QuantOperatorBase -from ..quant_utils import attribute_to_kwarg, ms_domain, QuantizedValue, QuantizedValueType class QGlobalAveragePool(QuantOperatorBase): @@ -9,7 +10,7 @@ def __init__(self, onnx_quantizer, onnx_node): def quantize(self): node = self.node - assert (node.op_type == "GlobalAveragePool") + assert node.op_type == "GlobalAveragePool" # If input to this node is not quantized then keep this node. if node.input[0] not in self.quantizer.quantized_value_map: @@ -19,13 +20,23 @@ def quantize(self): # Create an entry for output quantized value. quantized_input_value = self.quantizer.quantized_value_map[node.input[0]] - data_found, output_scale_name_from_parameter, output_zp_name_from_parameter, _, _ = \ - self.quantizer._get_quantization_params(node.output[0]) + ( + data_found, + output_scale_name_from_parameter, + output_zp_name_from_parameter, + _, + _, + ) = self.quantizer._get_quantization_params(node.output[0]) # Just use input scale and zp if parameters for output is not specified. output_scale_name = output_scale_name_from_parameter if data_found else quantized_input_value.scale_name output_zp_name = output_zp_name_from_parameter if data_found else quantized_input_value.zp_name - quantized_output_value = QuantizedValue(node.output[0], node.output[0] + "_quantized", output_scale_name, - output_zp_name, QuantizedValueType.Input) + quantized_output_value = QuantizedValue( + node.output[0], + node.output[0] + "_quantized", + output_scale_name, + output_zp_name, + QuantizedValueType.Input, + ) self.quantizer.quantized_value_map[node.output[0]] = quantized_output_value kwargs = {} @@ -35,8 +46,17 @@ def quantize(self): kwargs["channels_last"] = 0 qnode_name = node.name + "_quant" if node.name != "" else "" - qnode = onnx.helper.make_node("QLinear" + node.op_type, [ - quantized_input_value.q_name, quantized_input_value.scale_name, quantized_input_value.zp_name, - output_scale_name, output_zp_name - ], [quantized_output_value.q_name], qnode_name, **kwargs) + qnode = onnx.helper.make_node( + "QLinear" + node.op_type, + [ + quantized_input_value.q_name, + quantized_input_value.scale_name, + quantized_input_value.zp_name, + output_scale_name, + output_zp_name, + ], + [quantized_output_value.q_name], + qnode_name, + **kwargs + ) self.quantizer.new_nodes += [qnode] diff --git a/onnxruntime/python/tools/quantization/operators/gemm.py b/onnxruntime/python/tools/quantization/operators/gemm.py index f297bfb428a19..ee804c742c038 100644 --- a/onnxruntime/python/tools/quantization/operators/gemm.py +++ b/onnxruntime/python/tools/quantization/operators/gemm.py @@ -1,55 +1,76 @@ -import onnx -import numpy as np import logging + +import numpy as np +import onnx +from onnx import onnx_pb as onnx_proto + +from ..quant_utils import QuantizedValue, QuantizedValueType, attribute_to_kwarg, find_by_name, get_mul_node, ms_domain from .base_operator import QuantOperatorBase from .qdq_base_operator import QDQOperatorBase -from ..quant_utils import find_by_name, get_mul_node, QuantizedValue, QuantizedValueType, attribute_to_kwarg, ms_domain -from onnx import onnx_pb as onnx_proto def is_B_transposed(gemm_node): - transB_attribute = [attr for attr in gemm_node.attribute if attr.name == 'transB'] + transB_attribute = [attr for attr in gemm_node.attribute if attr.name == "transB"] if len(transB_attribute): return 0 < onnx.helper.get_attribute_value(transB_attribute[0]) return False + def get_beta(gemm_node): - beta_attribute = [attr for attr in gemm_node.attribute if attr.name == 'beta'] + beta_attribute = [attr for attr in gemm_node.attribute if attr.name == "beta"] if len(beta_attribute): return onnx.helper.get_attribute_value(beta_attribute[0]) return 1.0 + def set_default_beta(gemm_node): - beta_attribute = [attr for attr in gemm_node.attribute if attr.name == 'beta'] + beta_attribute = [attr for attr in gemm_node.attribute if attr.name == "beta"] if len(beta_attribute): beta_attribute[0].f = 1.0 return 1.0 + class QLinearGemm(QuantOperatorBase): def __init__(self, onnx_quantizer, onnx_node): super().__init__(onnx_quantizer, onnx_node) def quantize(self): node = self.node - assert (node.op_type == "Gemm") + assert node.op_type == "Gemm" - data_found, output_scale_name, output_zp_name, _, _ = \ - self.quantizer._get_quantization_params(node.output[0]) + ( + data_found, + output_scale_name, + output_zp_name, + _, + _, + ) = self.quantizer._get_quantization_params(node.output[0]) if self.quantizer.is_input_a_weight(node.input[1]) and self.quantizer.is_per_channel(): - (quantized_input_names, zero_point_names, scale_names, nodes) = \ - self.quantizer.quantize_inputs(node, [0], reduce_range=self.quantizer.reduce_range) - quant_weight_tuple = self.quantizer.quantize_weight_per_channel(node.input[1], onnx_proto.TensorProto.INT8, - 0 if is_B_transposed(node) else 1) + ( + quantized_input_names, + zero_point_names, + scale_names, + nodes, + ) = self.quantizer.quantize_inputs(node, [0], reduce_range=self.quantizer.reduce_range) + quant_weight_tuple = self.quantizer.quantize_weight_per_channel( + node.input[1], + onnx_proto.TensorProto.INT8, + 0 if is_B_transposed(node) else 1, + ) quantized_input_names.append(quant_weight_tuple[0]) zero_point_names.append(quant_weight_tuple[1]) scale_names.append(quant_weight_tuple[2]) else: - (quantized_input_names, zero_point_names, scale_names, nodes) = \ - self.quantizer.quantize_inputs(node, [0, 1], reduce_range=self.quantizer.reduce_range) + ( + quantized_input_names, + zero_point_names, + scale_names, + nodes, + ) = self.quantizer.quantize_inputs(node, [0, 1], reduce_range=self.quantizer.reduce_range) if not data_found or quantized_input_names is None: return super().quantize() @@ -59,7 +80,9 @@ def quantize(self): if not self.quantizer.is_input_a_weight(node.input[2]): return super().quantize() - quantized_bias_name = self.quantizer.quantize_bias_static(node.input[2], node.input[0], node.input[1], get_beta(self.node)) + quantized_bias_name = self.quantizer.quantize_bias_static( + node.input[2], node.input[0], node.input[1], get_beta(self.node) + ) qgemm_output = node.output[0] + "_quantized" qgemm_name = qgemm_name = node.name + "_quant" if node.name != "" else "" @@ -77,13 +100,17 @@ def quantize(self): qgemm_inputs.extend([quantized_bias_name, output_scale_name, output_zp_name]) - qgemm_node = onnx.helper.make_node("QGemm", qgemm_inputs, [qgemm_output], - qgemm_name, **kwargs) + qgemm_node = onnx.helper.make_node("QGemm", qgemm_inputs, [qgemm_output], qgemm_name, **kwargs) nodes.append(qgemm_node) # Create an entry for this quantized value - q_output = QuantizedValue(node.output[0], qgemm_output, output_scale_name, output_zp_name, - QuantizedValueType.Input) + q_output = QuantizedValue( + node.output[0], + qgemm_output, + output_scale_name, + output_zp_name, + QuantizedValueType.Input, + ) self.quantizer.quantized_value_map[node.output[0]] = q_output self.quantizer.new_nodes += nodes @@ -95,7 +122,7 @@ def __init__(self, onnx_quantizer, onnx_node): def quantize(self): node = self.node - assert (node.op_type == "Gemm") + assert node.op_type == "Gemm" self.quantizer.quantize_tensor(node.input[0]) if not self.disable_qdq_for_node_output: @@ -112,6 +139,7 @@ def quantize(self): set_default_beta(self.node) else: logging.warning( - "Bias of Gemm node '{}' is not constant. Please exclude this node for better performance." - .format(self.node.name)) - + "Bias of Gemm node '{}' is not constant. Please exclude this node for better performance.".format( + self.node.name + ) + ) diff --git a/onnxruntime/python/tools/quantization/operators/lstm.py b/onnxruntime/python/tools/quantization/operators/lstm.py index 3389a9bb638b0..87552a18a037e 100644 --- a/onnxruntime/python/tools/quantization/operators/lstm.py +++ b/onnxruntime/python/tools/quantization/operators/lstm.py @@ -1,11 +1,13 @@ -import onnx import numpy -from .base_operator import QuantOperatorBase -from ..quant_utils import attribute_to_kwarg, ms_domain, QuantType +import onnx from onnx import onnx_pb as onnx_proto -''' + +from ..quant_utils import QuantType, attribute_to_kwarg, ms_domain +from .base_operator import QuantOperatorBase + +""" Quantize LSTM -''' +""" class LSTMQuant(QuantOperatorBase): @@ -13,16 +15,17 @@ def __init__(self, onnx_quantizer, onnx_node): super().__init__(onnx_quantizer, onnx_node) def quantize(self): - ''' - parameter node: LSTM node. - parameter new_nodes_list: List of new nodes created before processing this node. - return: a list of nodes in topological order that represents quantized Attention node. - ''' + """ + parameter node: LSTM node. + parameter new_nodes_list: List of new nodes created before processing this node. + return: a list of nodes in topological order that represents quantized Attention node. + """ node = self.node - assert (node.op_type == "LSTM") + assert node.op_type == "LSTM" - if (not self.quantizer.is_valid_quantize_weight(node.input[1]) - or not self.quantizer.is_valid_quantize_weight(node.input[2])): + if not self.quantizer.is_valid_quantize_weight(node.input[1]) or not self.quantizer.is_valid_quantize_weight( + node.input[2] + ): super().quantize() return @@ -30,7 +33,7 @@ def quantize(self): W = model.get_initializer(node.input[1]) R = model.get_initializer(node.input[2]) - if (len(W.dims) != 3 or len(R.dims) != 3): + if len(W.dims) != 3 or len(R.dims) != 3: super().quantize() return @@ -43,10 +46,12 @@ def quantize(self): W.dims[0] = W_num_dir * W_4_hidden_size R.dims[0] = R_num_dir * R_4_hidden_size - quant_input_weight_tuple = self.quantizer.quantize_weight_per_channel(node.input[1], - onnx_proto.TensorProto.INT8, 0) - quant_recurrent_weight_tuple = self.quantizer.quantize_weight_per_channel(node.input[2], - onnx_proto.TensorProto.INT8, 0) + quant_input_weight_tuple = self.quantizer.quantize_weight_per_channel( + node.input[1], onnx_proto.TensorProto.INT8, 0 + ) + quant_recurrent_weight_tuple = self.quantizer.quantize_weight_per_channel( + node.input[2], onnx_proto.TensorProto.INT8, 0 + ) W_quant_weight = model.get_initializer(quant_input_weight_tuple[0]) R_quant_weight = model.get_initializer(quant_recurrent_weight_tuple[0]) @@ -87,10 +92,14 @@ def quantize(self): inputs.extend([node.input[5] if input_len > 5 else ""]) inputs.extend([node.input[6] if input_len > 6 else ""]) inputs.extend([node.input[7] if input_len > 7 else ""]) - inputs.extend([ - quant_input_weight_tuple[2], quant_input_weight_tuple[1], quant_recurrent_weight_tuple[2], - quant_recurrent_weight_tuple[1] - ]) + inputs.extend( + [ + quant_input_weight_tuple[2], + quant_input_weight_tuple[1], + quant_recurrent_weight_tuple[2], + quant_recurrent_weight_tuple[1], + ] + ) kwargs = {} for attribute in node.attribute: diff --git a/onnxruntime/python/tools/quantization/operators/matmul.py b/onnxruntime/python/tools/quantization/operators/matmul.py index 2d37eeb46e9ad..86a0202cdbce9 100644 --- a/onnxruntime/python/tools/quantization/operators/matmul.py +++ b/onnxruntime/python/tools/quantization/operators/matmul.py @@ -1,12 +1,15 @@ -import onnx import itertools + +import onnx +from onnx import onnx_pb as onnx_proto + +from ..quant_utils import QuantizedValue, QuantizedValueType, find_by_name, get_mul_node from .base_operator import QuantOperatorBase from .qdq_base_operator import QDQOperatorBase -from ..quant_utils import find_by_name, get_mul_node, QuantizedValue, QuantizedValueType -from onnx import onnx_pb as onnx_proto -''' + +""" Used when quantize mode is QuantizationMode.IntegerOps. -''' +""" class MatMulInteger(QuantOperatorBase): @@ -15,28 +18,43 @@ def __init__(self, onnx_quantizer, onnx_node): def quantize(self): node = self.node - assert (node.op_type == "MatMul") + assert node.op_type == "MatMul" - (quantized_input_names, zero_point_names, scale_names, nodes) = \ - self.quantizer.quantize_inputs(node, [0, 1], reduce_range=True, op_level_per_channel=True) + ( + quantized_input_names, + zero_point_names, + scale_names, + nodes, + ) = self.quantizer.quantize_inputs(node, [0, 1], reduce_range=True, op_level_per_channel=True) matmul_integer_output = node.output[0] + "_output_quantized" matmul_integer_name = node.name + "_quant" if node.name != "" else "" - matmul_integer_node = onnx.helper.make_node("MatMulInteger", quantized_input_names + zero_point_names, - [matmul_integer_output], matmul_integer_name) + matmul_integer_node = onnx.helper.make_node( + "MatMulInteger", + quantized_input_names + zero_point_names, + [matmul_integer_output], + matmul_integer_name, + ) nodes.append(matmul_integer_node) # Add cast operation to cast matmulInteger output to float. cast_op_output = matmul_integer_output + "_cast_output" - cast_node = onnx.helper.make_node("Cast", [matmul_integer_output], [cast_op_output], - matmul_integer_output + "_cast", - to=onnx_proto.TensorProto.FLOAT) + cast_node = onnx.helper.make_node( + "Cast", + [matmul_integer_output], + [cast_op_output], + matmul_integer_output + "_cast", + to=onnx_proto.TensorProto.FLOAT, + ) nodes.append(cast_node) # Add mul operation to multiply scales of two inputs. - assert (len(scale_names) == 2) - scales_mul_op = matmul_integer_name + "_scales_mul" if matmul_integer_name != "" else scale_names[ - 0] + "_" + scale_names[1] + "_mul" + assert len(scale_names) == 2 + scales_mul_op = ( + matmul_integer_name + "_scales_mul" + if matmul_integer_name != "" + else scale_names[0] + "_" + scale_names[1] + "_mul" + ) scales_mul_node = find_by_name(scales_mul_op, self.quantizer.new_nodes) if scales_mul_node is None: @@ -50,13 +68,19 @@ def quantize(self): output_scale_mul_op = "" if matmul_integer_name != "": output_scale_mul_op = matmul_integer_name + "_output_scale_mul" - nodes.append(get_mul_node([cast_op_output, scales_mul_op_output], node.output[0], output_scale_mul_op)) + nodes.append( + get_mul_node( + [cast_op_output, scales_mul_op_output], + node.output[0], + output_scale_mul_op, + ) + ) self.quantizer.new_nodes += nodes -''' +""" Used when quantize mode is QuantizationMode.QLinearOps -''' +""" class QLinearMatMul(QuantOperatorBase): @@ -65,12 +89,21 @@ def __init__(self, onnx_quantizer, onnx_node): def quantize(self): node = self.node - assert (node.op_type == "MatMul") - - (quantized_input_names, zero_point_names, scale_names, nodes) = \ - self.quantizer.quantize_inputs(node, [0, 1], reduce_range=True, op_level_per_channel=True) - data_found, output_scale_name, output_zp_name, _, _ = \ - self.quantizer._get_quantization_params(node.output[0]) + assert node.op_type == "MatMul" + + ( + quantized_input_names, + zero_point_names, + scale_names, + nodes, + ) = self.quantizer.quantize_inputs(node, [0, 1], reduce_range=True, op_level_per_channel=True) + ( + data_found, + output_scale_name, + output_zp_name, + _, + _, + ) = self.quantizer._get_quantization_params(node.output[0]) if not data_found or quantized_input_names is None: return super().quantize() @@ -90,24 +123,34 @@ def quantize(self): qlinear_matmul_inputs.append(output_scale_name) qlinear_matmul_inputs.append(output_zp_name) - qlinear_matmul_node = onnx.helper.make_node("QLinearMatMul", qlinear_matmul_inputs, [qlinear_matmul_output], - qlinear_matmul_name) + qlinear_matmul_node = onnx.helper.make_node( + "QLinearMatMul", + qlinear_matmul_inputs, + [qlinear_matmul_output], + qlinear_matmul_name, + ) nodes.append(qlinear_matmul_node) # Create an entry for this quantized value - q_output = QuantizedValue(node.output[0], qlinear_matmul_output, output_scale_name, output_zp_name, - QuantizedValueType.Input) + q_output = QuantizedValue( + node.output[0], + qlinear_matmul_output, + output_scale_name, + output_zp_name, + QuantizedValueType.Input, + ) self.quantizer.quantized_value_map[node.output[0]] = q_output self.quantizer.new_nodes += nodes + class QDQMatMul(QDQOperatorBase): def __init__(self, onnx_quantizer, onnx_node): super().__init__(onnx_quantizer, onnx_node) def quantize(self): node = self.node - assert (node.op_type == "MatMul") + assert node.op_type == "MatMul" if self.disable_qdq_for_node_output: nodes_to_iterate = node.input @@ -116,7 +159,7 @@ def quantize(self): for tensor_name in nodes_to_iterate: # only support per-channel quantization on weight - if self.quantizer.is_per_channel() and find_by_name(tensor_name, self.quantizer.model.initializer()) : + if self.quantizer.is_per_channel() and find_by_name(tensor_name, self.quantizer.model.initializer()): channel_axis = self.quantizer.qdq_op_type_per_channel_support_to_axis.get(node.op_type, 1) self.quantizer.quantize_tensor_per_channel(tensor_name, channel_axis) else: diff --git a/onnxruntime/python/tools/quantization/operators/maxpool.py b/onnxruntime/python/tools/quantization/operators/maxpool.py index 1eb2ce5565007..4fec7ee81d01f 100644 --- a/onnxruntime/python/tools/quantization/operators/maxpool.py +++ b/onnxruntime/python/tools/quantization/operators/maxpool.py @@ -7,7 +7,7 @@ def __init__(self, onnx_quantizer, onnx_node): def quantize(self): node = self.node - assert (node.op_type == "MaxPool") + assert node.op_type == "MaxPool" # if version is less than 12, go to normal quantize. if self.quantizer.opset_version < 12: @@ -24,7 +24,7 @@ def __init__(self, onnx_quantizer, onnx_node): def quantize(self): node = self.node - assert (node.op_type == "MaxPool") + assert node.op_type == "MaxPool" # if version is less than 12, just no change if self.quantizer.opset_version < 12: diff --git a/onnxruntime/python/tools/quantization/operators/pad.py b/onnxruntime/python/tools/quantization/operators/pad.py index 75a0cad10fa93..5ee93b75f1c27 100644 --- a/onnxruntime/python/tools/quantization/operators/pad.py +++ b/onnxruntime/python/tools/quantization/operators/pad.py @@ -1,7 +1,8 @@ -import onnx import numpy as np -from .base_operator import QuantOperatorBase +import onnx + from ..quant_utils import QuantizedValue, QuantizedValueType, attribute_to_kwarg, quantize_nparray +from .base_operator import QuantOperatorBase class QPad(QuantOperatorBase): @@ -10,7 +11,7 @@ def __init__(self, onnx_quantizer, onnx_node): def quantize(self): node = self.node - assert (node.op_type == "Pad") + assert node.op_type == "Pad" # Only after version 11, it has the optional constant_value # If input[0] is not quantized, do not quanitize this node @@ -24,7 +25,7 @@ def quantize(self): kv = attribute_to_kwarg(attribute) kwargs.update(kv) - if 'mode' not in kwargs or kwargs['mode'] == b'constant': + if "mode" not in kwargs or kwargs["mode"] == b"constant": if len(node.input) > 2: # There is 3rd input 'constant_value' zp_tensor = self.quantizer.model.get_initializer(quantized_input_value.zp_name) scale_tensor = self.quantizer.model.get_initializer(quantized_input_value.scale_name) @@ -39,29 +40,43 @@ def quantize(self): scale_array = onnx.numpy_helper.to_array(scale_tensor) scale_value = scale_array.item() if scale_array.ndim == 0 else scale_array[0] padding_constant_array = onnx.numpy_helper.to_array(padding_constant_initializer) - quantized_padding_constant_array = quantize_nparray(self.quantizer.input_qType, - padding_constant_array, scale_value, zp_value) + quantized_padding_constant_array = quantize_nparray( + self.quantizer.input_qType, + padding_constant_array, + scale_value, + zp_value, + ) quantized_padding_constant_name = node.input[2] + "_quantized" quantized_padding_constant_initializer = onnx.numpy_helper.from_array( - quantized_padding_constant_array, quantized_padding_constant_name) + quantized_padding_constant_array, + quantized_padding_constant_name, + ) # Suppose this padding constant initializer only used by the node self.quantizer.model.remove_initializer(padding_constant_initializer) self.quantizer.model.add_initializer(quantized_padding_constant_initializer) node.input[2] = quantized_padding_constant_name else: # TODO: check quantize_inputs after sub graph is supported - pad_value_qnodes = self.quantizer._get_quantize_input_nodes(node, 2, self.quantizer.input_qType, - quantized_input_value.scale_name, - quantized_input_value.zp_name) + pad_value_qnodes = self.quantizer._get_quantize_input_nodes( + node, + 2, + self.quantizer.input_qType, + quantized_input_value.scale_name, + quantized_input_value.zp_name, + ) self.quantizer.new_nodes += [pad_value_qnodes] node.input[2] = pad_value_qnodes.output[0] else: node.input.extend([quantized_input_value.zp_name]) # pad zero_point for original zero # Create an entry for output quantized value - quantized_output_value = QuantizedValue(node.output[0], node.output[0] + "_quantized", - quantized_input_value.scale_name, quantized_input_value.zp_name, - QuantizedValueType.Input) + quantized_output_value = QuantizedValue( + node.output[0], + node.output[0] + "_quantized", + quantized_input_value.scale_name, + quantized_input_value.zp_name, + QuantizedValueType.Input, + ) self.quantizer.quantized_value_map[node.output[0]] = quantized_output_value node.input[0] = quantized_input_value.q_name diff --git a/onnxruntime/python/tools/quantization/operators/pooling.py b/onnxruntime/python/tools/quantization/operators/pooling.py index c1586740c4f9b..6feb5faad9a49 100644 --- a/onnxruntime/python/tools/quantization/operators/pooling.py +++ b/onnxruntime/python/tools/quantization/operators/pooling.py @@ -1,6 +1,8 @@ import onnx + +from ..quant_utils import QuantizedValue, QuantizedValueType, attribute_to_kwarg, ms_domain from .base_operator import QuantOperatorBase -from ..quant_utils import attribute_to_kwarg, ms_domain, QuantizedValue, QuantizedValueType + class QLinearPool(QuantOperatorBase): def __init__(self, onnx_quantizer, onnx_node): @@ -10,11 +12,21 @@ def quantize(self): node = self.node # only try to quantize when given quantization parameters for it - data_found, output_scale_name, output_zp_name, _, _ = \ - self.quantizer._get_quantization_params(node.output[0]) + ( + data_found, + output_scale_name, + output_zp_name, + _, + _, + ) = self.quantizer._get_quantization_params(node.output[0]) # get quantized input tensor names, quantize input if needed - quantized_input_names, input_zero_point_names, input_scale_names, nodes = self.quantizer.quantize_inputs(node, [0]) + ( + quantized_input_names, + input_zero_point_names, + input_scale_names, + nodes, + ) = self.quantizer.quantize_inputs(node, [0]) if not data_found or quantized_input_names is None: return super().quantize() @@ -22,7 +34,12 @@ def quantize(self): # Create an entry for output quantized value. qlinear_output_name = node.output[0] + "_quantized" quantized_output_value = QuantizedValue( - node.output[0], qlinear_output_name, output_scale_name, output_zp_name, QuantizedValueType.Input) + node.output[0], + qlinear_output_name, + output_scale_name, + output_zp_name, + QuantizedValueType.Input, + ) self.quantizer.quantized_value_map[node.output[0]] = quantized_output_value # Create qlinear pool node for given type (AveragePool, etc) @@ -33,10 +50,17 @@ def quantize(self): qlinear_node_name = node.name + "_quant" if node.name != "" else "" qnode = onnx.helper.make_node( "QLinear" + node.op_type, - [quantized_input_names[0], input_scale_names[0], input_zero_point_names[0], output_scale_name, output_zp_name], + [ + quantized_input_names[0], + input_scale_names[0], + input_zero_point_names[0], + output_scale_name, + output_zp_name, + ], [qlinear_output_name], qlinear_node_name, - **kwargs) + **kwargs + ) # add all newly created nodes nodes.append(qnode) diff --git a/onnxruntime/python/tools/quantization/operators/qdq_base_operator.py b/onnxruntime/python/tools/quantization/operators/qdq_base_operator.py index f96e57525dd32..58a653052f6df 100644 --- a/onnxruntime/python/tools/quantization/operators/qdq_base_operator.py +++ b/onnxruntime/python/tools/quantization/operators/qdq_base_operator.py @@ -1,14 +1,16 @@ import itertools -from .base_operator import QuantOperatorBase + from ..quant_utils import QuantizedValue, QuantizedValueType, attribute_to_kwarg, quantize_nparray +from .base_operator import QuantOperatorBase class QDQOperatorBase: def __init__(self, onnx_quantizer, onnx_node): self.quantizer = onnx_quantizer self.node = onnx_node - self.disable_qdq_for_node_output = True if onnx_node.op_type in onnx_quantizer.op_types_to_exclude_output_quantization \ - else False + self.disable_qdq_for_node_output = ( + True if onnx_node.op_type in onnx_quantizer.op_types_to_exclude_output_quantization else False + ) def quantize(self): node = self.node diff --git a/onnxruntime/python/tools/quantization/operators/resize.py b/onnxruntime/python/tools/quantization/operators/resize.py index c07cd99068f22..936736425c955 100644 --- a/onnxruntime/python/tools/quantization/operators/resize.py +++ b/onnxruntime/python/tools/quantization/operators/resize.py @@ -7,7 +7,7 @@ def __init__(self, onnx_quantizer, onnx_node): def quantize(self): node = self.node - assert (node.op_type == "Resize") + assert node.op_type == "Resize" # if version is less than 11, go to normal quantize. if self.quantizer.opset_version < 11: @@ -24,7 +24,7 @@ def __init__(self, onnx_quantizer, onnx_node): def quantize(self): node = self.node - assert (node.op_type == "Resize") + assert node.op_type == "Resize" # if version is less than 11, just keep this node if self.quantizer.opset_version < 11: diff --git a/onnxruntime/python/tools/quantization/operators/split.py b/onnxruntime/python/tools/quantization/operators/split.py index dd72e9d848032..d5a656da567d0 100644 --- a/onnxruntime/python/tools/quantization/operators/split.py +++ b/onnxruntime/python/tools/quantization/operators/split.py @@ -1,8 +1,9 @@ import onnx -from .base_operator import QuantOperatorBase -from ..quant_utils import QuantizedValue, QuantizedValueType, attribute_to_kwarg from onnx import onnx_pb as onnx_proto +from ..quant_utils import QuantizedValue, QuantizedValueType, attribute_to_kwarg +from .base_operator import QuantOperatorBase + class QSplit(QuantOperatorBase): def __init__(self, onnx_quantizer, onnx_node): @@ -10,7 +11,12 @@ def __init__(self, onnx_quantizer, onnx_node): def quantize(self): node = self.node - quantized_input_names, zero_point_names, scale_names, nodes = self.quantizer.quantize_inputs(node, [0]) + ( + quantized_input_names, + zero_point_names, + scale_names, + nodes, + ) = self.quantizer.quantize_inputs(node, [0]) if quantized_input_names is None: return super().quantize() @@ -26,14 +32,20 @@ def quantize(self): for output_name in node.output: quantized_output_name = output_name + "quantized" quantized_output_names.append(quantized_output_name) - q_output = QuantizedValue(output_name, quantized_output_name, scale_names[0], zero_point_names[0], - QuantizedValueType.Input) + q_output = QuantizedValue( + output_name, + quantized_output_name, + scale_names[0], + zero_point_names[0], + QuantizedValueType.Input, + ) self.quantizer.quantized_value_map[output_name] = q_output if len(node.input) > 1: quantized_input_names = quantized_input_names.extend(node.input[1:]) - quantized_node = onnx.helper.make_node(node.op_type, quantized_input_names, quantized_output_names, - quantized_node_name, **kwargs) + quantized_node = onnx.helper.make_node( + node.op_type, quantized_input_names, quantized_output_names, quantized_node_name, **kwargs + ) nodes.append(quantized_node) self.quantizer.new_nodes += nodes diff --git a/onnxruntime/python/tools/quantization/qdq_quantizer.py b/onnxruntime/python/tools/quantization/qdq_quantizer.py index 1539207be9781..4a8a538053571 100644 --- a/onnxruntime/python/tools/quantization/qdq_quantizer.py +++ b/onnxruntime/python/tools/quantization/qdq_quantizer.py @@ -3,33 +3,72 @@ # Licensed under the MIT License. See License.txt in the project root for # license information. # -------------------------------------------------------------------------- +import logging import os import struct from pathlib import Path -import numpy as np -import logging +import numpy as np import onnx import onnx.numpy_helper -from onnx import onnx_pb as onnx_proto from onnx import TensorProto -from onnxruntime import SessionOptions, InferenceSession, GraphOptimizationLevel - -from .quant_utils import QuantizationMode, QuantizedValueType, QuantizedInitializer, QuantizedValue -from .quant_utils import find_by_name, get_elem_index, get_mul_node, generate_identified_filename, attribute_to_kwarg, type_to_name, quantize_nparray -from .quant_utils import QuantType, onnx_domain, __producer__, __version__ +from onnx import onnx_pb as onnx_proto -from .registry import CreateQDQQuantizer +from onnxruntime import GraphOptimizationLevel, InferenceSession, SessionOptions from .onnx_model import ONNXModel from .onnx_quantizer import ONNXQuantizer +from .quant_utils import ( + QuantizationMode, + QuantizedInitializer, + QuantizedValue, + QuantizedValueType, + QuantType, + __producer__, + __version__, + attribute_to_kwarg, + find_by_name, + generate_identified_filename, + get_elem_index, + get_mul_node, + onnx_domain, + quantize_nparray, + type_to_name, +) +from .registry import CreateQDQQuantizer class QDQQuantizer(ONNXQuantizer): - def __init__(self, model, per_channel, reduce_range, mode, static, weight_qType, input_qType, tensors_range, - nodes_to_quantize, nodes_to_exclude, op_types_to_quantize, extra_options={}): - ONNXQuantizer.__init__(self, model, per_channel, reduce_range, mode, static, weight_qType, input_qType, - tensors_range, nodes_to_quantize, nodes_to_exclude, op_types_to_quantize, extra_options) + def __init__( + self, + model, + per_channel, + reduce_range, + mode, + static, + weight_qType, + input_qType, + tensors_range, + nodes_to_quantize, + nodes_to_exclude, + op_types_to_quantize, + extra_options={}, + ): + ONNXQuantizer.__init__( + self, + model, + per_channel, + reduce_range, + mode, + static, + weight_qType, + input_qType, + tensors_range, + nodes_to_quantize, + nodes_to_exclude, + op_types_to_quantize, + extra_options, + ) self.tensors_to_quantize = [] self.tensors_to_quantize_per_channel = [] self.bias_to_quantize = [] @@ -40,23 +79,33 @@ def __init__(self, model, per_channel, reduce_range, mode, static, weight_qType, # because those ops may be followed by nodes that require high resolution inputs. # Adding QDQ for those ops' output may end up with worse accuracy. # So, we don't recommend to add QDQ to node's output under such condition. - self.op_types_to_exclude_output_quantization = [] if 'OpTypesToExcludeOutputQuantizatioin' not in extra_options \ - else extra_options['OpTypesToExcludeOutputQuantizatioin'] + self.op_types_to_exclude_output_quantization = ( + [] + if "OpTypesToExcludeOutputQuantizatioin" not in extra_options + else extra_options["OpTypesToExcludeOutputQuantizatioin"] + ) # We do quantization on Dequantizelinear's input to remove Quantizelinear for weight as an optimization. # In some cases, for example QDQ BERT model for TensorRT, QDQ should always appear as a pair. # Therefore, we need to disable this optimization and add qdq pair to weight. - self.add_qdq_pair_to_weight = False if 'AddQDQPairToWeight' not in extra_options \ - else extra_options['AddQDQPairToWeight'] - - # The default behavior is that multiple nodes can share a QDQ pair as their inputs. - # In TRT, QDQ pair can’t be shared between nodes, so it will create dedicated QDQ pairs for each node. - self.dedicated_qdq_pair = False if 'DedicatedQDQPair' not in extra_options else extra_options['DedicatedQDQPair'] + self.add_qdq_pair_to_weight = ( + False if "AddQDQPairToWeight" not in extra_options else extra_options["AddQDQPairToWeight"] + ) + + # The default behavior is that multiple nodes can share a QDQ pair as their inputs. + # In TRT, QDQ pair can’t be shared between nodes, so it will create dedicated QDQ pairs for each node. + self.dedicated_qdq_pair = ( + False if "DedicatedQDQPair" not in extra_options else extra_options["DedicatedQDQPair"] + ) if self.dedicated_qdq_pair: self.tensor_to_its_receiving_nodes = {} # Let user set channel axis for specific op type and it's effective only when per channel quantization is supported and per_channel is True. - self.qdq_op_type_per_channel_support_to_axis = {} if 'QDQOpTypePerChannelSupportToAxis' not in extra_options else extra_options['QDQOpTypePerChannelSupportToAxis'] + self.qdq_op_type_per_channel_support_to_axis = ( + {} + if "QDQOpTypePerChannelSupportToAxis" not in extra_options + else extra_options["QDQOpTypePerChannelSupportToAxis"] + ) def quantize_tensor(self, tensor_name): weight = find_by_name(tensor_name, self.model.initializer()) @@ -65,12 +114,14 @@ def quantize_tensor(self, tensor_name): self.tensors_to_quantize.append(tensor_name) elif tensor_name in self.value_infos.keys(): vi = self.value_infos[tensor_name] - if vi.type.HasField('tensor_type') and vi.type.tensor_type.elem_type == TensorProto.FLOAT: + if vi.type.HasField("tensor_type") and vi.type.tensor_type.elem_type == TensorProto.FLOAT: self.tensors_to_quantize.append(tensor_name) else: logging.warning( "failed to infer the type of tensor: {}. Skip to quantize it. Please check if it is expected.".format( - tensor_name)) + tensor_name + ) + ) def quantize_tensor_per_channel(self, tensor_name, axis): weight = find_by_name(tensor_name, self.model.initializer()) @@ -80,10 +131,12 @@ def quantize_tensor_per_channel(self, tensor_name, axis): else: logging.warning( "only support per-channel quantization on weight. Quantize tensor: {} with per-tensor instead.".format( - tensor_name)) + tensor_name + ) + ) self.quantize_tensor(tensor_name) - def quantize_bias_tensor(self, bias_name, input_name, weight_name, beta = 1.0): + def quantize_bias_tensor(self, bias_name, input_name, weight_name, beta=1.0): weight = find_by_name(bias_name, self.model.initializer()) if weight is not None: if weight.data_type == onnx_proto.TensorProto.FLOAT: @@ -124,9 +177,11 @@ def quantize_model(self): return self.model.model def try_replacing_upstream_output(self, upstream_output_name, output_name): - if output_name in self.quantization_params.keys() and \ - len(self.model.input_name_to_nodes()[upstream_output_name]) == 1 and \ - not self.model.is_graph_output(upstream_output_name): + if ( + output_name in self.quantization_params.keys() + and len(self.model.input_name_to_nodes()[upstream_output_name]) == 1 + and not self.model.is_graph_output(upstream_output_name) + ): self.model.replace_output_of_all_nodes(upstream_output_name, output_name) self.tensors_to_quantize.remove(upstream_output_name) return True @@ -141,25 +196,34 @@ def quantize_tensors(self): if initializer is not None: if self.add_qdq_pair_to_weight: - q_weight_name, zp_name, scale_name = self.quantize_weight(initializer, - self.weight_qType, - keep_float_weight=True) - qlinear_node = onnx.helper.make_node("QuantizeLinear", [tensor_name, scale_name, zp_name], - [tensor_name + "_QuantizeLinear"], - tensor_name + "_QuantizeLinear") - dequant_node = onnx.helper.make_node("DequantizeLinear", - [tensor_name + "_QuantizeLinear", scale_name, zp_name], - [tensor_name + "_DequantizeLinear"], - tensor_name + "_DequantizeLinear") + q_weight_name, zp_name, scale_name = self.quantize_weight( + initializer, self.weight_qType, keep_float_weight=True + ) + qlinear_node = onnx.helper.make_node( + "QuantizeLinear", + [tensor_name, scale_name, zp_name], + [tensor_name + "_QuantizeLinear"], + tensor_name + "_QuantizeLinear", + ) + dequant_node = onnx.helper.make_node( + "DequantizeLinear", + [tensor_name + "_QuantizeLinear", scale_name, zp_name], + [tensor_name + "_DequantizeLinear"], + tensor_name + "_DequantizeLinear", + ) self.model.replace_input_of_all_nodes(tensor_name, tensor_name + "_DequantizeLinear") self.model.add_nodes([qlinear_node, dequant_node]) else: q_weight_name, zp_name, scale_name = self.quantize_weight(initializer, self.weight_qType) inputs = [q_weight_name, scale_name, zp_name] - output_name = tensor_name + '_DequantizeLinear' - node = onnx.helper.make_node("DequantizeLinear", inputs, [output_name], - tensor_name + '_DequantizeLinear') + output_name = tensor_name + "_DequantizeLinear" + node = onnx.helper.make_node( + "DequantizeLinear", + inputs, + [output_name], + tensor_name + "_DequantizeLinear", + ) self.model.add_node(node) self.model.replace_input_of_all_nodes(tensor_name, tensor_name + "_DequantizeLinear") else: @@ -168,32 +232,49 @@ def quantize_tensors(self): if data_found == False: raise ValueError( "Quantization parameters are not specified for param {}." - "In static mode quantization params for inputs and outputs of nodes to be quantized are required." - .format(tensor_name)) - - if self.dedicated_qdq_pair and tensor_name in self.tensor_to_its_receiving_nodes and len(self.tensor_to_its_receiving_nodes[tensor_name]) > 1: + "In static mode quantization params for inputs and outputs of nodes to be quantized are required.".format( + tensor_name + ) + ) + + if ( + self.dedicated_qdq_pair + and tensor_name in self.tensor_to_its_receiving_nodes + and len(self.tensor_to_its_receiving_nodes[tensor_name]) > 1 + ): num_dedicated_qdq_pair = len(self.tensor_to_its_receiving_nodes[tensor_name]) for i in range(num_dedicated_qdq_pair): - postfix = str(i+1) + postfix = str(i + 1) q_input = tensor_name - q_output = tensor_name + "_QuantizeLinear_" + postfix + q_output = tensor_name + "_QuantizeLinear_" + postfix dq_input = q_output dq_output = tensor_name + "_DequantizeLinear_" + postfix quant_node_name = tensor_name + "_QuantizeLinear_" + postfix dequant_node_name = tensor_name + "_DequantizeLinear_" + postfix - qlinear_node = onnx.helper.make_node("QuantizeLinear", [q_input, scale_name, zp_name], - [q_output], quant_node_name) - dequant_node = onnx.helper.make_node("DequantizeLinear", - [dq_input, scale_name, zp_name], - [dq_output], - dequant_node_name) + qlinear_node = onnx.helper.make_node( + "QuantizeLinear", + [q_input, scale_name, zp_name], + [q_output], + quant_node_name, + ) + dequant_node = onnx.helper.make_node( + "DequantizeLinear", + [dq_input, scale_name, zp_name], + [dq_output], + dequant_node_name, + ) self.model.add_nodes([qlinear_node, dequant_node]) node = self.tensor_to_its_receiving_nodes[tensor_name][i] self.model.replace_node_input(node, tensor_name, dq_output) - quantized_value = QuantizedValue(tensor_name, dq_output, scale_name, zp_name, - QuantizedValueType.Input) + quantized_value = QuantizedValue( + tensor_name, + dq_output, + scale_name, + zp_name, + QuantizedValueType.Input, + ) self.quantized_value_map[tensor_name] = quantized_value else: q_input = tensor_name @@ -209,16 +290,27 @@ def quantize_tensors(self): quant_node_name = tensor_name + "_QuantizeLinear" dequant_node_name = tensor_name + "_DequantizeLinear" - qlinear_node = onnx.helper.make_node("QuantizeLinear", [q_input, scale_name, zp_name], - [q_output], quant_node_name) - dequant_node = onnx.helper.make_node("DequantizeLinear", - [dq_input, scale_name, zp_name], - [dq_output], - dequant_node_name) + qlinear_node = onnx.helper.make_node( + "QuantizeLinear", + [q_input, scale_name, zp_name], + [q_output], + quant_node_name, + ) + dequant_node = onnx.helper.make_node( + "DequantizeLinear", + [dq_input, scale_name, zp_name], + [dq_output], + dequant_node_name, + ) self.model.add_nodes([qlinear_node, dequant_node]) - quantized_value = QuantizedValue(tensor_name, dq_output, scale_name, zp_name, - QuantizedValueType.Input) + quantized_value = QuantizedValue( + tensor_name, + dq_output, + scale_name, + zp_name, + QuantizedValueType.Input, + ) self.quantized_value_map[tensor_name] = quantized_value def quantize_bias_tensors(self): @@ -231,13 +323,20 @@ def quantize_bias_tensors(self): quant_value = self.quantized_value_map[bias_name] inputs = [quant_value.q_name, quant_value.scale_name, quant_value.zp_name] if quant_value.axis is not None: - dequant_node = onnx.helper.make_node("DequantizeLinear", - inputs, [bias_name], - bias_name + '_DequantizeLinear', - axis=quant_value.axis) + dequant_node = onnx.helper.make_node( + "DequantizeLinear", + inputs, + [bias_name], + bias_name + "_DequantizeLinear", + axis=quant_value.axis, + ) else: - dequant_node = onnx.helper.make_node("DequantizeLinear", inputs, [bias_name], - bias_name + '_DequantizeLinear') + dequant_node = onnx.helper.make_node( + "DequantizeLinear", + inputs, + [bias_name], + bias_name + "_DequantizeLinear", + ) self.model.add_node(dequant_node) def quantize_weights_per_channel(self): @@ -245,31 +344,44 @@ def quantize_weights_per_channel(self): raise ValueError("Per-Channel support with QDQ format requires onnx opset version 13 or above.") for weight_name, axis in self.tensors_to_quantize_per_channel: if self.add_qdq_pair_to_weight: - q_name, zp_name, scale_name = self.quantize_weight_per_channel(weight_name, onnx_proto.TensorProto.INT8, - axis, keep_float_weight=True) - qlinear_node = onnx.helper.make_node("QuantizeLinear", [weight_name, scale_name, zp_name], - [weight_name + "_QuantizeLinear"], - weight_name + "_QuantizeLinear", - axis=axis) - dequant_node = onnx.helper.make_node("DequantizeLinear", - [weight_name + "_QuantizeLinear", scale_name, zp_name], - [weight_name + "_DequantizeLinear"], - weight_name + "_DequantizeLinear", - axis=axis) + q_name, zp_name, scale_name = self.quantize_weight_per_channel( + weight_name, + onnx_proto.TensorProto.INT8, + axis, + keep_float_weight=True, + ) + qlinear_node = onnx.helper.make_node( + "QuantizeLinear", + [weight_name, scale_name, zp_name], + [weight_name + "_QuantizeLinear"], + weight_name + "_QuantizeLinear", + axis=axis, + ) + dequant_node = onnx.helper.make_node( + "DequantizeLinear", + [weight_name + "_QuantizeLinear", scale_name, zp_name], + [weight_name + "_DequantizeLinear"], + weight_name + "_DequantizeLinear", + axis=axis, + ) self.model.replace_input_of_all_nodes(weight_name, weight_name + "_DequantizeLinear") self.model.add_nodes([qlinear_node, dequant_node]) else: - #q_name, zp_name, scale_name = self.quantize_weight_per_channel(weight_name, self.weight_qType, axis) - q_name, zp_name, scale_name = self.quantize_weight_per_channel(weight_name, onnx_proto.TensorProto.INT8, - axis) + # q_name, zp_name, scale_name = self.quantize_weight_per_channel(weight_name, self.weight_qType, axis) + q_name, zp_name, scale_name = self.quantize_weight_per_channel( + weight_name, onnx_proto.TensorProto.INT8, axis + ) inputs = [q_name, scale_name, zp_name] output_name = weight_name + "_DequantizeLinear" - node = onnx.helper.make_node("DequantizeLinear", - inputs, [output_name], - weight_name + '_DequantizeLinear', - axis=axis) + node = onnx.helper.make_node( + "DequantizeLinear", + inputs, + [output_name], + weight_name + "_DequantizeLinear", + axis=axis, + ) self.model.add_node(node) # Replace weight_name with output of DequantizeLinear diff --git a/onnxruntime/python/tools/quantization/quant_utils.py b/onnxruntime/python/tools/quantization/quant_utils.py index 01a3b7197bd25..3b0aed2734b32 100644 --- a/onnxruntime/python/tools/quantization/quant_utils.py +++ b/onnxruntime/python/tools/quantization/quant_utils.py @@ -1,14 +1,14 @@ import logging -import numpy -import onnx import tempfile - from enum import Enum -from onnx import onnx_pb as onnx_proto -from onnx import external_data_helper from pathlib import Path -from onnxruntime import SessionOptions, InferenceSession, GraphOptimizationLevel +import numpy +import onnx +from onnx import external_data_helper +from onnx import onnx_pb as onnx_proto + +from onnxruntime import GraphOptimizationLevel, InferenceSession, SessionOptions __producer__ = "onnx.quantize" __version__ = "0.1.0" @@ -97,14 +97,17 @@ def from_string(format): except KeyError: raise ValueError() + ONNX_TYPE_TO_NP_TYPE = { - onnx_proto.TensorProto.INT8: numpy.dtype('int8'), - onnx_proto.TensorProto.UINT8: numpy.dtype('uint8') + onnx_proto.TensorProto.INT8: numpy.dtype("int8"), + onnx_proto.TensorProto.UINT8: numpy.dtype("uint8"), } + def quantize_nparray(qType, arr, scale, zero_point, low=None, high=None): - assert qType in ONNX_TYPE_TO_NP_TYPE, \ - "Unexpected data type {} requested. Only INT8 and UINT8 are supported.".format(qType) + assert ( + qType in ONNX_TYPE_TO_NP_TYPE + ), "Unexpected data type {} requested. Only INT8 and UINT8 are supported.".format(qType) dtype = ONNX_TYPE_TO_NP_TYPE[qType] cliplow = max(0 if dtype == numpy.uint8 else -127, -127 if low is None else low) cliphigh = min(255 if dtype == numpy.uint8 else 127, 255 if high is None else high) @@ -114,10 +117,10 @@ def quantize_nparray(qType, arr, scale, zero_point, low=None, high=None): def compute_scale_zp(rmin, rmax, qmin, qmax, symmetric=False): - ''' - Calculate the scale s and zero point z for the quantization relation + """ + Calculate the scale s and zero point z for the quantization relation r = s(q-z), where r are the original values and q are the corresponding - quantized values. + quantized values. r and z are calculated such that every value within [rmin,rmax] has an approximate representation within [qmin,qmax]. In addition, qmin <= z <= @@ -131,8 +134,8 @@ def compute_scale_zp(rmin, rmax, qmin, qmax, symmetric=False): :parameter qmax: maximum value representable by the target quantization data type :return: zero and scale [z, s] - ''' - + """ + # Adjust rmin and rmax such that 0 is included in the range. This is # required to make sure zero can be represented by the quantization data # type (i.e. to make sure qmin <= zero_point <= qmax) @@ -144,21 +147,21 @@ def compute_scale_zp(rmin, rmax, qmin, qmax, symmetric=False): rmin = -absmax rmax = +absmax - scale = (rmax - rmin) / float(qmax-qmin) if rmax!=rmin else 1.0 - zero_point = round(qmin - rmin/scale) + scale = (rmax - rmin) / float(qmax - qmin) if rmax != rmin else 1.0 + zero_point = round(qmin - rmin / scale) return [zero_point, scale] def quantize_data(data, qType, symmetric, reduce_range=False): - ''' + """ :param data: data to quantize :param qType: data type to quantize to. Supported types UINT8 and INT8 :param symmetric: whether symmetric quantization is used or not. This is applied to INT8. :return: minimum, maximum, zero point, scale, and quantized weights To pack weights, we compute a linear transformation - + - when data `type == uint8` mode, from `[rmin, rmax]` -> :math:`[0, 2^{b-1}]` and - when data `type == int8`, from `[-m , m]` -> :math:`[-(2^{b-1}-1), 2^{b-1}-1]` where `m = max(abs(rmin), abs(rmax))` @@ -166,12 +169,12 @@ def quantize_data(data, qType, symmetric, reduce_range=False): and add necessary intermediate nodes to trasnform quantized weight to full weight using the equation :math:`r = S(q-z)`, where - + - *r*: real original value - *q*: quantized value - *S*: scale - *z*: zero point - ''' + """ rmin = 0 rmax = 0 @@ -188,46 +191,52 @@ def quantize_data(data, qType, symmetric, reduce_range=False): return rmin, rmax, zero_point, scale, quantized_data + def get_qmin_qmax_for_qType(qType, reduce_range=False, symmetric=False): - ''' + """ Return qmin and qmax, the minimum and maximum value representable by the given qType :parameter qType: onnx.onnx_pb.TensorProto.UINT8 or onnx.onnx_pb.TensorProto.UINT8 :return: qmin, qmax - ''' + """ if qType == onnx_proto.TensorProto.UINT8: - (qmin, qmax) = (0,127) if reduce_range else (0,255) + (qmin, qmax) = (0, 127) if reduce_range else (0, 255) elif qType == onnx_proto.TensorProto.INT8: if symmetric: - (qmin, qmax) = (-64,64) if reduce_range else (-127,127) + (qmin, qmax) = (-64, 64) if reduce_range else (-127, 127) else: - (qmin, qmax) = (-64,64) if reduce_range else (-128,127) + (qmin, qmax) = (-64, 64) if reduce_range else (-128, 127) else: raise ValueError("Unexpected data type {} requested. Only INT8 and UINT8 are supported.".format(qType)) return qmin, qmax + def get_qrange_for_qType(qType, reduce_range=False, symmetric=False): - ''' + """ Helper function to get the quantization range for a type. parameter qType: quantization type. return: quantization range. - ''' + """ qmin, qmax = get_qmin_qmax_for_qType(qType, reduce_range, symmetric=symmetric) - return qmax - qmin + return qmax - qmin + class QuantizedInitializer: - ''' - Represents a linearly quantized weight input from ONNX operators - ''' - def __init__(self, - name, - initializer, - rmins, - rmaxs, - zero_points, - scales, - data=[], - quantized_data=[], - axis=None): + """ + Represents a linearly quantized weight input from ONNX operators + """ + + def __init__( + self, + name, + initializer, + rmins, + rmaxs, + zero_points, + scales, + data=[], + quantized_data=[], + axis=None, + ): self.name = name self.initializer = initializer # TensorProto initializer in ONNX graph self.rmins = rmins # List of minimum range for each axis @@ -243,16 +252,19 @@ def __init__(self, class QuantizedValue: - ''' + """ Represents a linearly quantized value (input\output\intializer) - ''' - def __init__(self, - name, - new_quantized_name, - scale_name, - zero_point_name, - quantized_value_type, - axis=None): + """ + + def __init__( + self, + name, + new_quantized_name, + scale_name, + zero_point_name, + quantized_value_type, + axis=None, + ): self.original_name = name self.q_name = new_quantized_name self.scale_name = scale_name @@ -262,9 +274,10 @@ def __init__(self, class BiasToQuantize: - ''' + """ Represents a bias to be quantized - ''' + """ + def __init__(self, bias_name, input_name, weight_name): self.bias_name = bias_name self.input_name = input_name @@ -272,57 +285,57 @@ def __init__(self, bias_name, input_name, weight_name): def attribute_to_kwarg(attribute): - ''' + """ Convert attribute to kwarg format for use with onnx.helper.make_node. :parameter attribute: attribute in AttributeProto format. :return: attribute in {key: value} format. - ''' - if (attribute.type == 0): - raise ValueError('attribute {} does not have type specified.'.format(attribute.name)) + """ + if attribute.type == 0: + raise ValueError("attribute {} does not have type specified.".format(attribute.name)) # Based on attribute type definitions from AttributeProto # definition in https://github.com/onnx/onnx/blob/master/onnx/onnx.proto - if (attribute.type == 1): + if attribute.type == 1: value = attribute.f - elif (attribute.type == 2): + elif attribute.type == 2: value = attribute.i - elif (attribute.type == 3): + elif attribute.type == 3: value = attribute.s - elif (attribute.type == 4): + elif attribute.type == 4: value = attribute.t - elif (attribute.type == 5): + elif attribute.type == 5: value = attribute.g - elif (attribute.type == 6): + elif attribute.type == 6: value = attribute.floats - elif (attribute.type == 7): + elif attribute.type == 7: value = attribute.ints - elif (attribute.type == 8): + elif attribute.type == 8: value = attribute.strings - elif (attribute.type == 9): + elif attribute.type == 9: value = attribute.tensors - elif (attribute.type == 10): + elif attribute.type == 10: value = attribute.graphs else: - raise ValueError('attribute {} has unsupported type {}.'.format(attribute.name, attribute.type)) + raise ValueError("attribute {} has unsupported type {}.".format(attribute.name, attribute.type)) return {attribute.name: value} def find_by_name(item_name, item_list): - ''' + """ Helper function to find item by name in a list. parameter item_name: name of the item. parameter item_list: list of items. return: item if found. None otherwise. - ''' + """ items = [item for item in item_list if item.name == item_name] return items[0] if len(items) > 0 else None def get_elem_index(elem_name, elem_list): - ''' + """ Helper function to return index of an item in a node list - ''' + """ elem_idx = -1 for i in range(0, len(elem_list)): if elem_list[i] == elem_name: @@ -331,50 +344,56 @@ def get_elem_index(elem_name, elem_list): def get_mul_node(inputs, output, name): - ''' + """ Helper function to create a Mul node. parameter inputs: list of input names. parameter output: output name. parameter name: name of the node. return: Mul node in NodeProto format. - ''' + """ return onnx.helper.make_node("Mul", inputs, [output], name) def generate_identified_filename(filename: Path, identifier: str) -> Path: - ''' - Helper function to generate a identifiable filepath by concatenating the given identifier as a suffix. - ''' + """ + Helper function to generate a identifiable filepath by concatenating the given identifier as a suffix. + """ return filename.parent.joinpath(filename.stem + identifier).with_suffix(filename.suffix) + def apply_plot(hist, hist_edges): import sys - import numpy + import matplotlib.pyplot as plt + import numpy + numpy.set_printoptions(threshold=sys.maxsize) print("Histogram:") print(hist) print("Histogram Edges:") print(hist_edges) plt.stairs(hist, hist_edges, fill=True) - plt.xlabel('Tensor value') - plt.ylabel('Counts') - plt.title('Tensor value V.S. Counts') + plt.xlabel("Tensor value") + plt.ylabel("Counts") + plt.title("Tensor value V.S. Counts") plt.show() + def write_calibration_table(calibration_cache): - ''' - Helper function to write calibration table to files. - ''' + """ + Helper function to write calibration table to files. + """ import json + import flatbuffers - import onnxruntime.quantization.CalTableFlatBuffers.TrtTable as TrtTable + import onnxruntime.quantization.CalTableFlatBuffers.KeyValue as KeyValue + import onnxruntime.quantization.CalTableFlatBuffers.TrtTable as TrtTable logging.info("calibration cache: {}".format(calibration_cache)) - with open("calibration.json", 'w') as file: + with open("calibration.json", "w") as file: file.write(json.dumps(calibration_cache)) # use `json.loads` to do the reverse # Serialize data using FlatBuffers @@ -406,7 +425,7 @@ def write_calibration_table(calibration_cache): builder.Finish(cal_table) buf = builder.Output() - with open("calibration.flatbuffers", 'wb') as file: + with open("calibration.flatbuffers", "wb") as file: file.write(buf) # Deserialize data (for validation) @@ -419,12 +438,13 @@ def write_calibration_table(calibration_cache): logging.info(key_value.Value()) # write plain text - with open("calibration.cache", 'w') as file: + with open("calibration.cache", "w") as file: for key in sorted(calibration_cache.keys()): value = calibration_cache[key] - s = key + ' ' + str(max(abs(value[0]), abs(value[1]))) + s = key + " " + str(max(abs(value[0]), abs(value[1]))) file.write(s) - file.write('\n') + file.write("\n") + def smooth_distribution(p, eps=0.0001): """Given a discrete distribution (may have not been normalized to 1), @@ -444,7 +464,11 @@ def smooth_distribution(p, eps=0.0001): # raise ValueError('The discrete probability distribution is malformed. All entries are 0.') return -1 eps1 = eps * float(n_zeros) / float(n_nonzeros) - assert eps1 < 1.0, 'n_zeros=%d, n_nonzeros=%d, eps1=%f' % (n_zeros, n_nonzeros, eps1) + assert eps1 < 1.0, "n_zeros=%d, n_nonzeros=%d, eps1=%f" % ( + n_zeros, + n_nonzeros, + eps1, + ) hist = p.astype(np.float32) hist += eps * is_zeros + (-eps1) * is_nonzeros @@ -452,32 +476,36 @@ def smooth_distribution(p, eps=0.0001): return hist -def model_has_external_data(model_path : Path): + +def model_has_external_data(model_path: Path): model = onnx.load(model_path.as_posix(), load_external_data=False) for intializer in model.graph.initializer: if external_data_helper.uses_external_data(intializer): return True return False -def optimize_model(model_path : Path, opt_model_path : Path): - ''' + +def optimize_model(model_path: Path, opt_model_path: Path): + """ Generate model that applies graph optimization (constant folding, etc.) parameter model_path: path to the original onnx model parameter opt_model_path: path to the optimized onnx model :return: optimized onnx model - ''' + """ sess_option = SessionOptions() sess_option.optimized_model_filepath = opt_model_path.as_posix() sess_option.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_BASIC - _ = InferenceSession(model_path.as_posix(), sess_option, providers=['CPUExecutionProvider']) + _ = InferenceSession(model_path.as_posix(), sess_option, providers=["CPUExecutionProvider"]) + def add_infer_metadata(model): metadata_props = {"onnx.infer": "onnxruntime.quant"} if model.metadata_props: for p in model.metadata_props: - metadata_props.update({p.key : p.value}) + metadata_props.update({p.key: p.value}) onnx.helper.set_model_props(model, metadata_props) + def model_has_infer_metadata(model): if model.metadata_props: for p in model.metadata_props: @@ -485,7 +513,8 @@ def model_has_infer_metadata(model): return True return False -def load_model_with_shape_infer(model_path : Path): + +def load_model_with_shape_infer(model_path: Path): inferred_model_path = generate_identified_filename(model_path, "-inferred") onnx.shape_inference.infer_shapes_path(str(model_path), str(inferred_model_path)) model = onnx.load(inferred_model_path.as_posix()) @@ -493,8 +522,8 @@ def load_model_with_shape_infer(model_path : Path): return model -def load_model(model_path : Path, need_optimize : bool): - with tempfile.TemporaryDirectory(prefix='ort.quant.') as quant_tmp_dir: +def load_model(model_path: Path, need_optimize: bool): + with tempfile.TemporaryDirectory(prefix="ort.quant.") as quant_tmp_dir: if need_optimize and not model_has_external_data(model_path): opt_model_path = Path(quant_tmp_dir).joinpath("model.onnx") optimize_model(model_path, opt_model_path) @@ -504,18 +533,19 @@ def load_model(model_path : Path, need_optimize : bool): add_infer_metadata(model) return model + def save_and_reload_model(model): - with tempfile.TemporaryDirectory(prefix='ort.quant.') as quant_tmp_dir: + with tempfile.TemporaryDirectory(prefix="ort.quant.") as quant_tmp_dir: model_path = Path(quant_tmp_dir).joinpath("model.onnx") - onnx.external_data_helper.convert_model_to_external_data(model, - all_tensors_to_one_file=True) + onnx.external_data_helper.convert_model_to_external_data(model, all_tensors_to_one_file=True) onnx.save_model(model, model_path.as_posix()) return load_model(model_path, False) + def clone_model_with_shape_infer(model): if model_has_infer_metadata(model): cloned_model = onnx_proto.ModelProto() cloned_model.CopyFrom(model) else: cloned_model = save_and_reload_model(model) - return cloned_model \ No newline at end of file + return cloned_model diff --git a/onnxruntime/python/tools/quantization/quantize.py b/onnxruntime/python/tools/quantization/quantize.py index 82be85f8d2662..55f1003724568 100644 --- a/onnxruntime/python/tools/quantization/quantize.py +++ b/onnxruntime/python/tools/quantization/quantize.py @@ -5,52 +5,63 @@ # -------------------------------------------------------------------------- import logging from pathlib import Path -from onnx import onnx_pb as onnx_proto - -from .quant_utils import QuantizationMode, QuantizedValueType, QuantizedInitializer, QuantizedValue -from .quant_utils import find_by_name, get_elem_index, get_mul_node, generate_identified_filename, attribute_to_kwarg -from .quant_utils import QuantType, QuantFormat -from .quant_utils import load_model -from .registry import QLinearOpsRegistry, IntegerOpsRegistry +from onnx import onnx_pb as onnx_proto +from .calibrate import CalibrationDataReader, CalibrationMethod, create_calibrator from .onnx_model import ONNXModel from .onnx_quantizer import ONNXQuantizer from .qdq_quantizer import QDQQuantizer -from .calibrate import CalibrationDataReader, create_calibrator, CalibrationMethod +from .quant_utils import ( + QuantFormat, + QuantizationMode, + QuantizedInitializer, + QuantizedValue, + QuantizedValueType, + QuantType, + attribute_to_kwarg, + find_by_name, + generate_identified_filename, + get_elem_index, + get_mul_node, + load_model, +) +from .registry import IntegerOpsRegistry, QLinearOpsRegistry -def check_static_quant_arguments(quant_format : QuantFormat, - activation_type : QuantType, - weight_type : QuantType): +def check_static_quant_arguments(quant_format: QuantFormat, activation_type: QuantType, weight_type: QuantType): if activation_type == QuantType.QInt8 and weight_type == QuantType.QUInt8: - raise ValueError("ONNXRuntime quantization doesn't support data format:" - "activation_type=QuantType.QInt8, weight_type = QuantType.QUInt8") - - if activation_type == QuantType.QInt8 and \ - weight_type == QuantType.QInt8 and \ - quant_format != QuantFormat.QDQ: \ - logging.warning("Please use QuantFormat.QDQ for activation type QInt8 and weight type QInt8. " - "Or it will lead to bad performance on x64.") - - -def quantize_static(model_input, - model_output, - calibration_data_reader: CalibrationDataReader, - quant_format=QuantFormat.QDQ, - op_types_to_quantize=[], - per_channel=False, - reduce_range=False, - activation_type=QuantType.QInt8, - weight_type=QuantType.QInt8, - nodes_to_quantize=[], - nodes_to_exclude=[], - optimize_model=True, - use_external_data_format=False, - calibrate_method=CalibrationMethod.MinMax, - extra_options = {}): - - ''' + raise ValueError( + "ONNXRuntime quantization doesn't support data format:" + "activation_type=QuantType.QInt8, weight_type = QuantType.QUInt8" + ) + + if activation_type == QuantType.QInt8 and weight_type == QuantType.QInt8 and quant_format != QuantFormat.QDQ: + logging.warning( + "Please use QuantFormat.QDQ for activation type QInt8 and weight type QInt8. " + "Or it will lead to bad performance on x64." + ) + + +def quantize_static( + model_input, + model_output, + calibration_data_reader: CalibrationDataReader, + quant_format=QuantFormat.QDQ, + op_types_to_quantize=[], + per_channel=False, + reduce_range=False, + activation_type=QuantType.QInt8, + weight_type=QuantType.QInt8, + nodes_to_quantize=[], + nodes_to_exclude=[], + optimize_model=True, + use_external_data_format=False, + calibrate_method=CalibrationMethod.MinMax, + extra_options={}, +): + + """ Given an onnx model and calibration data reader, create a quantized onnx model and save it into a file It is recommended to use QuantFormat.QDQ format from 1.11 with activation_type = QuantType.QInt8 and @@ -81,9 +92,9 @@ def quantize_static(model_input, List of nodes names to exclude. The nodes in this list will be excluded from quantization when it is not None. :param optimize_model: optimize model before quantization. - :param use_external_data_format: option used for large size (>2GB) model. Set to False by default. - :param calibrate_method: - Current calibration methods supported are MinMax and Entropy. + :param use_external_data_format: option used for large size (>2GB) model. Set to False by default. + :param calibrate_method: + Current calibration methods supported are MinMax and Entropy. Please use CalibrationMethod.MinMax or CalibrationMethod.Entropy as options. :param extra_options: key value pair dictionary for various options in different case. Current used: @@ -97,13 +108,13 @@ def quantize_static(model_input, always quantize input and so generate quantized output. Also the True behavior could be disabled per node using the nodes_to_exclude. MatMulConstBOnly = True/False: Default is False for static mode. If enabled, only MatMul with const B will be quantized. - AddQDQPairToWeight = True/False : Default is False which quantizes floating-point weight and feeds it to - soley inserted DeQuantizeLinear node. If True, it remains floating-point weight and + AddQDQPairToWeight = True/False : Default is False which quantizes floating-point weight and feeds it to + soley inserted DeQuantizeLinear node. If True, it remains floating-point weight and inserts both QuantizeLinear/DeQuantizeLinear nodes to weight. - OpTypesToExcludeOutputQuantizatioin = list of op type : Default is []. If any op type is specified, it won't quantize + OpTypesToExcludeOutputQuantizatioin = list of op type : Default is []. If any op type is specified, it won't quantize the output of ops with this specific op types. DedicatedQDQPair = True/False : Default is False. When inserting QDQ pair, multiple nodes can share a single QDQ pair as their inputs. - If True, it will create identical and dedicated QDQ pair for each node. + If True, it will create identical and dedicated QDQ pair for each node. QDQOpTypePerChannelSupportToAxis = dictionary : Default is {}. Set channel axis for specific op type, for example: {'MatMul': 1}, and it's effective only when per channel quantization is supported and per_channel is True. If specific op type supports per channel quantization but not explicitly specified with channel axis, @@ -114,7 +125,7 @@ def quantize_static(model_input, CalibMovingAverageConstant = float : Default is 0.01. Constant smoothing factor to use when computing the moving average of the minimum and maximum values. Effective only when the calibration method selected is MinMax and when CalibMovingAverage is set to True. - ''' + """ mode = QuantizationMode.QLinearOps @@ -124,17 +135,19 @@ def quantize_static(model_input, model = load_model(Path(model_input), optimize_model) calib_extra_options_keys = [ - ('CalibTensorRangeSymmetric', 'symmetric'), - ('CalibMovingAverage', 'moving_average'), - ('CalibMovingAverageConstant', 'averaging_constant') + ("CalibTensorRangeSymmetric", "symmetric"), + ("CalibMovingAverage", "moving_average"), + ("CalibMovingAverageConstant", "averaging_constant"), ] - calib_extra_options = {key: extra_options.get(name) for (name, key) in calib_extra_options_keys if name in extra_options} + calib_extra_options = { + key: extra_options.get(name) for (name, key) in calib_extra_options_keys if name in extra_options + } calibrator = create_calibrator( model, op_types_to_quantize, calibrate_method=calibrate_method, use_external_data_format=use_external_data_format, - extra_options=calib_extra_options + extra_options=calib_extra_options, ) calibrator.collect_data(calibration_data_reader) tensors_range = calibrator.compute_range() @@ -154,7 +167,8 @@ def quantize_static(model_input, nodes_to_quantize, nodes_to_exclude, op_types_to_quantize, - extra_options) + extra_options, + ) else: quantizer = QDQQuantizer( model, @@ -168,24 +182,27 @@ def quantize_static(model_input, nodes_to_quantize, nodes_to_exclude, op_types_to_quantize, - extra_options) + extra_options, + ) quantizer.quantize_model() quantizer.model.save_model_to_file(model_output, use_external_data_format) -def quantize_dynamic(model_input: Path, - model_output: Path, - op_types_to_quantize=[], - per_channel=False, - reduce_range=False, - weight_type=QuantType.QInt8, - nodes_to_quantize=[], - nodes_to_exclude=[], - optimize_model=True, - use_external_data_format=False, - extra_options = { }): - ''' +def quantize_dynamic( + model_input: Path, + model_output: Path, + op_types_to_quantize=[], + per_channel=False, + reduce_range=False, + weight_type=QuantType.QInt8, + nodes_to_quantize=[], + nodes_to_exclude=[], + optimize_model=True, + use_external_data_format=False, + extra_options={}, +): + """ Given an onnx model, create a quantized onnx model and save it into a file :param model_input: file path of model to quantize :param model_output: file path of quantized model @@ -218,7 +235,7 @@ def quantize_dynamic(model_input: Path, always quantize input and so generate quantized output. Also the True behavior could be disabled per node using the nodes_to_exclude. MatMulConstBOnly = True/False: Default is True for dynamic mode. If enabled, only MatMul with const B will be quantized. - ''' + """ mode = QuantizationMode.IntegerOps @@ -227,22 +244,23 @@ def quantize_dynamic(model_input: Path, model = load_model(Path(model_input), optimize_model) - if 'MatMulConstBOnly' not in extra_options: - extra_options['MatMulConstBOnly'] = True + if "MatMulConstBOnly" not in extra_options: + extra_options["MatMulConstBOnly"] = True quantizer = ONNXQuantizer( model, per_channel, reduce_range, mode, - False, #static + False, # static weight_type, - QuantType.QUInt8, #dynamic activation only supports uint8 + QuantType.QUInt8, # dynamic activation only supports uint8 None, nodes_to_quantize, nodes_to_exclude, op_types_to_quantize, - extra_options) + extra_options, + ) quantizer.quantize_model() quantizer.model.save_model_to_file(model_output, use_external_data_format) diff --git a/onnxruntime/python/tools/quantization/registry.py b/onnxruntime/python/tools/quantization/registry.py index d046cbc6dfa86..139445748dad7 100644 --- a/onnxruntime/python/tools/quantization/registry.py +++ b/onnxruntime/python/tools/quantization/registry.py @@ -1,28 +1,28 @@ -from .quant_utils import QuantizationMode +from .operators.activation import QDQRemovableActivation, QLinearActivation from .operators.argmax import QArgMax -from .operators.base_operator import QuantOperatorBase -from .operators.qdq_base_operator import QDQOperatorBase -from .operators.matmul import MatMulInteger, QLinearMatMul, QDQMatMul from .operators.attention import AttentionQuant +from .operators.base_operator import QuantOperatorBase +from .operators.binary_op import QLinearBinaryOp +from .operators.concat import QDQConcat, QLinearConcat +from .operators.conv import ConvInteger, QDQConv, QLinearConv +from .operators.direct_q8 import Direct8BitOp, QDQDirect8BitOp from .operators.embed_layernorm import EmbedLayerNormalizationQuant from .operators.gather import GatherQuant -from .operators.conv import QLinearConv, ConvInteger, QDQConv -from .operators.activation import QLinearActivation, QDQRemovableActivation -from .operators.binary_op import QLinearBinaryOp -from .operators.maxpool import QDQMaxPool, QMaxPool from .operators.gavgpool import QGlobalAveragePool +from .operators.gemm import QDQGemm, QLinearGemm from .operators.lstm import LSTMQuant -from .operators.split import QSplit +from .operators.matmul import MatMulInteger, QDQMatMul, QLinearMatMul +from .operators.maxpool import QDQMaxPool, QMaxPool from .operators.pad import QPad -from .operators.direct_q8 import Direct8BitOp, QDQDirect8BitOp -from .operators.resize import QResize, QDQResize from .operators.pooling import QLinearPool -from .operators.concat import QLinearConcat, QDQConcat -from .operators.gemm import QLinearGemm, QDQGemm +from .operators.qdq_base_operator import QDQOperatorBase +from .operators.resize import QDQResize, QResize +from .operators.split import QSplit +from .quant_utils import QuantizationMode CommonOpsRegistry = { "Gather": GatherQuant, - "Transpose" : Direct8BitOp, + "Transpose": Direct8BitOp, "EmbedLayerNormalization": EmbedLayerNormalizationQuant, } @@ -50,10 +50,10 @@ "Split": QSplit, "Pad": QPad, "Reshape": Direct8BitOp, - "Squeeze" : Direct8BitOp, - "Unsqueeze" : Direct8BitOp, + "Squeeze": Direct8BitOp, + "Unsqueeze": Direct8BitOp, "Resize": QResize, - "AveragePool" : QLinearPool, + "AveragePool": QLinearPool, "Concat": QLinearConcat, } QLinearOpsRegistry.update(CommonOpsRegistry) @@ -64,12 +64,12 @@ "Clip": QDQRemovableActivation, "Relu": QDQRemovableActivation, "Reshape": QDQDirect8BitOp, - "Transpose" : QDQDirect8BitOp, - "Squeeze" : QDQDirect8BitOp, - "Unsqueeze" : QDQDirect8BitOp, + "Transpose": QDQDirect8BitOp, + "Squeeze": QDQDirect8BitOp, + "Unsqueeze": QDQDirect8BitOp, "Resize": QDQResize, "MaxPool": QDQMaxPool, - "AveragePool" : QDQDirect8BitOp, + "AveragePool": QDQDirect8BitOp, "Concat": QDQConcat, "MatMul": QDQMatMul, } diff --git a/onnxruntime/python/tools/symbolic_shape_infer.py b/onnxruntime/python/tools/symbolic_shape_infer.py index 889adf0c4531d..daece6e8aea4d 100755 --- a/onnxruntime/python/tools/symbolic_shape_infer.py +++ b/onnxruntime/python/tools/symbolic_shape_infer.py @@ -4,15 +4,16 @@ # -*- coding: UTF-8 -*- import argparse import logging + import numpy as np import onnx -from onnx import helper, numpy_helper, shape_inference import sympy - +from onnx import helper, numpy_helper, shape_inference from packaging import version + assert version.parse(onnx.__version__) >= version.parse("1.8.0") -logger = logging.getLogger(__name__) +logger = logging.getLogger(__name__) def get_attribute(node, attr_name, default_value=None): @@ -23,29 +24,29 @@ def get_attribute(node, attr_name, default_value=None): def get_dim_from_proto(dim): - return getattr(dim, dim.WhichOneof('value')) if type(dim.WhichOneof('value')) == str else None + return getattr(dim, dim.WhichOneof("value")) if type(dim.WhichOneof("value")) == str else None def is_sequence(type_proto): - cls_type = type_proto.WhichOneof('value') - assert cls_type in ['tensor_type', 'sequence_type'] - return cls_type == 'sequence_type' + cls_type = type_proto.WhichOneof("value") + assert cls_type in ["tensor_type", "sequence_type"] + return cls_type == "sequence_type" def get_shape_from_type_proto(type_proto): assert not is_sequence(type_proto) - if type_proto.tensor_type.HasField('shape'): + if type_proto.tensor_type.HasField("shape"): return [get_dim_from_proto(d) for d in type_proto.tensor_type.shape.dim] else: return None # note no shape is different from shape without dim (scalar) def get_shape_from_value_info(vi): - cls_type = vi.type.WhichOneof('value') + cls_type = vi.type.WhichOneof("value") if cls_type is None: return None if is_sequence(vi.type): - if 'tensor_type' == vi.type.sequence_type.elem_type.WhichOneof('value'): + if "tensor_type" == vi.type.sequence_type.elem_type.WhichOneof("value"): return get_shape_from_type_proto(vi.type.sequence_type.elem_type) else: return None @@ -64,7 +65,7 @@ def get_shape_from_sympy_shape(sympy_shape): def is_literal(dim): - return type(dim) in [int, np.int64, np.int32, sympy.Integer] or (hasattr(dim, 'is_number') and dim.is_number) + return type(dim) in [int, np.int64, np.int32, sympy.Integer] or (hasattr(dim, "is_number") and dim.is_number) def handle_negative_axis(axis, rank): @@ -73,7 +74,7 @@ def handle_negative_axis(axis, rank): def get_opset(mp, domain=None): - domain = domain or ['', 'onnx', 'ai.onnx'] + domain = domain or ["", "onnx", "ai.onnx"] if type(domain) != list: domain = [domain] for opset in mp.opset_import: @@ -115,93 +116,93 @@ def sympy_reduce_product(x): class SymbolicShapeInference: - def __init__(self, int_max, auto_merge, guess_output_rank, verbose, prefix=''): + def __init__(self, int_max, auto_merge, guess_output_rank, verbose, prefix=""): self.dispatcher_ = { - 'Add': self._infer_symbolic_compute_ops, - 'ArrayFeatureExtractor': self._infer_ArrayFeatureExtractor, - 'AveragePool': self._infer_Pool, - 'BatchNormalization': self._infer_BatchNormalization, - 'Cast': self._infer_Cast, - 'CategoryMapper': self._infer_CategoryMapper, - 'Compress': self._infer_Compress, - 'Concat': self._infer_Concat, - 'ConcatFromSequence': self._infer_ConcatFromSequence, - 'Constant': self._infer_Constant, - 'ConstantOfShape': self._infer_ConstantOfShape, - 'Conv': self._infer_Conv, - 'CumSum': self._pass_on_shape_and_type, - 'Div': self._infer_symbolic_compute_ops, - 'Einsum': self._infer_Einsum, - 'Expand': self._infer_Expand, - 'Equal': self._infer_symbolic_compute_ops, - 'Floor': self._infer_symbolic_compute_ops, - 'Gather': self._infer_Gather, - 'GatherElements': self._infer_GatherElements, - 'GatherND': self._infer_GatherND, - 'Gelu': self._pass_on_shape_and_type, - 'If': self._infer_If, - 'Loop': self._infer_Loop, - 'MatMul': self._infer_MatMul, - 'MatMulInteger16': self._infer_MatMulInteger, - 'MaxPool': self._infer_Pool, - 'Max': self._infer_symbolic_compute_ops, - 'Min': self._infer_symbolic_compute_ops, - 'Mul': self._infer_symbolic_compute_ops, - 'NonMaxSuppression': self._infer_NonMaxSuppression, - 'NonZero': self._infer_NonZero, - 'OneHot': self._infer_OneHot, - 'Pad': self._infer_Pad, - 'Range': self._infer_Range, - 'Reciprocal': self._pass_on_shape_and_type, - 'ReduceSum': self._infer_ReduceSum, - 'ReduceProd': self._infer_ReduceProd, - 'Reshape': self._infer_Reshape, - 'Resize': self._infer_Resize, - 'Round': self._pass_on_shape_and_type, - 'Scan': self._infer_Scan, - 'ScatterElements': self._infer_ScatterElements, - 'SequenceAt': self._infer_SequenceAt, - 'SequenceInsert': self._infer_SequenceInsert, - 'Shape': self._infer_Shape, - 'Size': self._infer_Size, - 'Slice': self._infer_Slice, - 'SoftmaxCrossEntropyLoss': self._infer_SoftmaxCrossEntropyLoss, - 'SoftmaxCrossEntropyLossInternal': self._infer_SoftmaxCrossEntropyLoss, - 'NegativeLogLikelihoodLossInternal': self._infer_SoftmaxCrossEntropyLoss, - 'Split': self._infer_Split, - 'SplitToSequence': self._infer_SplitToSequence, - 'Squeeze': self._infer_Squeeze, - 'Sub': self._infer_symbolic_compute_ops, - 'Tile': self._infer_Tile, - 'TopK': self._infer_TopK, - 'Transpose': self._infer_Transpose, - 'Unsqueeze': self._infer_Unsqueeze, - 'Where': self._infer_symbolic_compute_ops, - 'ZipMap': self._infer_ZipMap, - 'Neg': self._infer_symbolic_compute_ops, + "Add": self._infer_symbolic_compute_ops, + "ArrayFeatureExtractor": self._infer_ArrayFeatureExtractor, + "AveragePool": self._infer_Pool, + "BatchNormalization": self._infer_BatchNormalization, + "Cast": self._infer_Cast, + "CategoryMapper": self._infer_CategoryMapper, + "Compress": self._infer_Compress, + "Concat": self._infer_Concat, + "ConcatFromSequence": self._infer_ConcatFromSequence, + "Constant": self._infer_Constant, + "ConstantOfShape": self._infer_ConstantOfShape, + "Conv": self._infer_Conv, + "CumSum": self._pass_on_shape_and_type, + "Div": self._infer_symbolic_compute_ops, + "Einsum": self._infer_Einsum, + "Expand": self._infer_Expand, + "Equal": self._infer_symbolic_compute_ops, + "Floor": self._infer_symbolic_compute_ops, + "Gather": self._infer_Gather, + "GatherElements": self._infer_GatherElements, + "GatherND": self._infer_GatherND, + "Gelu": self._pass_on_shape_and_type, + "If": self._infer_If, + "Loop": self._infer_Loop, + "MatMul": self._infer_MatMul, + "MatMulInteger16": self._infer_MatMulInteger, + "MaxPool": self._infer_Pool, + "Max": self._infer_symbolic_compute_ops, + "Min": self._infer_symbolic_compute_ops, + "Mul": self._infer_symbolic_compute_ops, + "NonMaxSuppression": self._infer_NonMaxSuppression, + "NonZero": self._infer_NonZero, + "OneHot": self._infer_OneHot, + "Pad": self._infer_Pad, + "Range": self._infer_Range, + "Reciprocal": self._pass_on_shape_and_type, + "ReduceSum": self._infer_ReduceSum, + "ReduceProd": self._infer_ReduceProd, + "Reshape": self._infer_Reshape, + "Resize": self._infer_Resize, + "Round": self._pass_on_shape_and_type, + "Scan": self._infer_Scan, + "ScatterElements": self._infer_ScatterElements, + "SequenceAt": self._infer_SequenceAt, + "SequenceInsert": self._infer_SequenceInsert, + "Shape": self._infer_Shape, + "Size": self._infer_Size, + "Slice": self._infer_Slice, + "SoftmaxCrossEntropyLoss": self._infer_SoftmaxCrossEntropyLoss, + "SoftmaxCrossEntropyLossInternal": self._infer_SoftmaxCrossEntropyLoss, + "NegativeLogLikelihoodLossInternal": self._infer_SoftmaxCrossEntropyLoss, + "Split": self._infer_Split, + "SplitToSequence": self._infer_SplitToSequence, + "Squeeze": self._infer_Squeeze, + "Sub": self._infer_symbolic_compute_ops, + "Tile": self._infer_Tile, + "TopK": self._infer_TopK, + "Transpose": self._infer_Transpose, + "Unsqueeze": self._infer_Unsqueeze, + "Where": self._infer_symbolic_compute_ops, + "ZipMap": self._infer_ZipMap, + "Neg": self._infer_symbolic_compute_ops, # contrib ops: - 'Attention': self._infer_Attention, - 'BiasGelu': self._infer_BiasGelu, - 'EmbedLayerNormalization': self._infer_EmbedLayerNormalization, - 'FastGelu': self._infer_FastGelu, - 'Gelu': self._infer_Gelu, - 'LayerNormalization': self._infer_LayerNormalization, - 'LongformerAttention': self._infer_LongformerAttention, - 'PythonOp': self._infer_PythonOp, - 'SkipLayerNormalization': self._infer_SkipLayerNormalization + "Attention": self._infer_Attention, + "BiasGelu": self._infer_BiasGelu, + "EmbedLayerNormalization": self._infer_EmbedLayerNormalization, + "FastGelu": self._infer_FastGelu, + "Gelu": self._infer_Gelu, + "LayerNormalization": self._infer_LayerNormalization, + "LongformerAttention": self._infer_LongformerAttention, + "PythonOp": self._infer_PythonOp, + "SkipLayerNormalization": self._infer_SkipLayerNormalization, } self.aten_op_dispatcher_ = { - 'aten::embedding': self._infer_Gather, - 'aten::bitwise_or': self._infer_aten_bitwise_or, - 'aten::diagonal': self._infer_aten_diagonal, - 'aten::max_pool2d_with_indices': self._infer_aten_pool2d, - 'aten::multinomial': self._infer_aten_multinomial, - 'aten::unfold': self._infer_aten_unfold, - 'aten::argmax': self._infer_aten_argmax, - 'aten::avg_pool2d': self._infer_aten_pool2d, - 'aten::_adaptive_avg_pool2d': self._infer_aten_pool2d, - 'aten::binary_cross_entropy_with_logits': self._infer_aten_bce, - 'aten::numpy_T': self._infer_Transpose, + "aten::embedding": self._infer_Gather, + "aten::bitwise_or": self._infer_aten_bitwise_or, + "aten::diagonal": self._infer_aten_diagonal, + "aten::max_pool2d_with_indices": self._infer_aten_pool2d, + "aten::multinomial": self._infer_aten_multinomial, + "aten::unfold": self._infer_aten_unfold, + "aten::argmax": self._infer_aten_argmax, + "aten::avg_pool2d": self._infer_aten_pool2d, + "aten::_adaptive_avg_pool2d": self._infer_aten_pool2d, + "aten::binary_cross_entropy_with_logits": self._infer_aten_bce, + "aten::numpy_T": self._infer_Transpose, } self.run_ = True self.suggested_merge_ = {} @@ -241,7 +242,7 @@ def _add_suggested_merge(self, symbols, apply=False): # when nothing to map to, use the shorter one if map_to is None: if self.verbose_ > 0: - logger.warning('Potential unsafe merge between symbolic expressions: ({})'.format(','.join(symbols))) + logger.warning("Potential unsafe merge between symbolic expressions: ({})".format(",".join(symbols))) symbols_list = list(symbols) lens = [len(s) for s in symbols_list] map_to = symbols_list[lens.index(min(lens))] @@ -278,8 +279,16 @@ def _preprocess(self, in_mp): self.initializers_ = dict([(i.name, i) for i in self.out_mp_.graph.initializer]) self.known_vi_ = dict([(i.name, i) for i in list(self.out_mp_.graph.input)]) self.known_vi_.update( - dict([(i.name, helper.make_tensor_value_info(i.name, i.data_type, list(i.dims))) - for i in self.out_mp_.graph.initializer])) + dict( + [ + ( + i.name, + helper.make_tensor_value_info(i.name, i.data_type, list(i.dims)), + ) + for i in self.out_mp_.graph.initializer + ] + ) + ) def _merge_symbols(self, dims): if not all([type(d) == str for d in dims]): @@ -290,13 +299,17 @@ def _merge_symbols(self, dims): if sum(is_int) == 1: int_dim = is_int.index(1) if self.verbose_ > 0: - logger.debug('dim {} has been merged with value {}'.format( - unique_dims[:int_dim] + unique_dims[int_dim + 1:], unique_dims[int_dim])) + logger.debug( + "dim {} has been merged with value {}".format( + unique_dims[:int_dim] + unique_dims[int_dim + 1 :], + unique_dims[int_dim], + ) + ) self._check_merged_dims(unique_dims, allow_broadcast=False) return unique_dims[int_dim] else: if self.verbose_ > 0: - logger.debug('dim {} has been mergd with dim {}'.format(unique_dims[1:], unique_dims[0])) + logger.debug("dim {} has been mergd with dim {}".format(unique_dims[1:], unique_dims[0])) return dims[0] else: return None @@ -331,7 +344,7 @@ def _broadcast_shapes(self, shape1, shape2): if self.auto_merge_: self._add_suggested_merge([dim1, dim2], apply=True) else: - logger.warning('unsupported broadcast between ' + str(dim1) + ' ' + str(dim2)) + logger.warning("unsupported broadcast between " + str(dim1) + " " + str(dim2)) new_shape = [new_dim] + new_shape return new_shape @@ -351,8 +364,11 @@ def _get_sympy_shape(self, node, idx): sympy_shape = [] for d in self._get_shape(node, idx): if type(d) == str: - sympy_shape.append(self.symbolic_dims_[d] if d in - self.symbolic_dims_ else sympy.Symbol(d, integer=True, nonnegative=True)) + sympy_shape.append( + self.symbolic_dims_[d] + if d in self.symbolic_dims_ + else sympy.Symbol(d, integer=True, nonnegative=True) + ) else: assert None != d sympy_shape.append(d) @@ -387,16 +403,20 @@ def _update_computed_dims(self, new_sympy_shape): def _onnx_infer_single_node(self, node): # skip onnx shape inference for some ops, as they are handled in _infer_* skip_infer = node.op_type in [ - 'If', 'Loop', 'Scan', 'SplitToSequence', 'ZipMap', \ - # contrib ops - - - 'Attention', 'BiasGelu', \ - 'EmbedLayerNormalization', \ - 'FastGelu', 'Gelu', 'LayerNormalization', \ - 'LongformerAttention', \ - 'SkipLayerNormalization', \ - 'PythonOp' + "If", + "Loop", + "Scan", + "SplitToSequence", + "ZipMap", # contrib ops + "Attention", + "BiasGelu", + "EmbedLayerNormalization", + "FastGelu", + "Gelu", + "LayerNormalization", + "LongformerAttention", + "SkipLayerNormalization", + "PythonOp", ] if not skip_infer: @@ -406,15 +426,21 @@ def _onnx_infer_single_node(self, node): # (2) opset version >= 9. In older version, initializer is required in graph input by onnx spec. # (3) The initializer is not in graph input. The means the node input is "constant" in inference. initializers = [] - if (get_opset(self.out_mp_) >= 9) and node.op_type in ['Unsqueeze']: + if (get_opset(self.out_mp_) >= 9) and node.op_type in ["Unsqueeze"]: initializers = [ - self.initializers_[name] for name in node.input + self.initializers_[name] + for name in node.input if (name in self.initializers_ and name not in self.graph_inputs_) ] # run single node inference with self.known_vi_ shapes - tmp_graph = helper.make_graph([node], 'tmp', [self.known_vi_[i] for i in node.input if i], - [make_named_value_info(i) for i in node.output], initializers) + tmp_graph = helper.make_graph( + [node], + "tmp", + [self.known_vi_[i] for i in node.input if i], + [make_named_value_info(i) for i in node.output], + initializers, + ) self.tmp_mp_.graph.CopyFrom(tmp_graph) @@ -431,26 +457,32 @@ def _onnx_infer_single_node(self, node): def _onnx_infer_subgraph(self, node, subgraph, use_node_input=True, inc_subgraph_id=True): if self.verbose_ > 2: - logger.debug('Inferencing subgraph of node {} with output({}...): {}'.format(node.name, node.output[0], - node.op_type)) + logger.debug( + "Inferencing subgraph of node {} with output({}...): {}".format(node.name, node.output[0], node.op_type) + ) # node inputs are not passed directly to the subgraph # it's up to the node dispatcher to prepare subgraph input # for example, with Scan/Loop, subgraph input shape would be trimmed from node input shape # besides, inputs in subgraph could shadow implicit inputs subgraph_inputs = set([i.name for i in list(subgraph.initializer) + list(subgraph.input)]) subgraph_implicit_input = set([name for name in self.known_vi_.keys() if not name in subgraph_inputs]) - tmp_graph = helper.make_graph(list(subgraph.node), 'tmp', - list(subgraph.input) + [self.known_vi_[i] for i in subgraph_implicit_input], - [make_named_value_info(i.name) for i in subgraph.output]) + tmp_graph = helper.make_graph( + list(subgraph.node), + "tmp", + list(subgraph.input) + [self.known_vi_[i] for i in subgraph_implicit_input], + [make_named_value_info(i.name) for i in subgraph.output], + ) tmp_graph.initializer.extend([i for i in self.out_mp_.graph.initializer if i.name in subgraph_implicit_input]) tmp_graph.initializer.extend(subgraph.initializer) self.tmp_mp_.graph.CopyFrom(tmp_graph) - symbolic_shape_inference = SymbolicShapeInference(self.int_max_, - self.auto_merge_, - self.guess_output_rank_, - self.verbose_, - prefix=self.prefix_ + '_' + str(self.subgraph_id_)) + symbolic_shape_inference = SymbolicShapeInference( + self.int_max_, + self.auto_merge_, + self.guess_output_rank_, + self.verbose_, + prefix=self.prefix_ + "_" + str(self.subgraph_id_), + ) if inc_subgraph_id: self.subgraph_id_ += 1 @@ -462,18 +494,19 @@ def _onnx_infer_subgraph(self, node, subgraph, use_node_input=True, inc_subgraph symbolic_shape_inference._update_output_from_vi() if use_node_input: # if subgraph uses node input, it needs to update to merged dims - subgraph.ClearField('input') - subgraph.input.extend(symbolic_shape_inference.out_mp_.graph.input[:len(node.input)]) - subgraph.ClearField('output') + subgraph.ClearField("input") + subgraph.input.extend(symbolic_shape_inference.out_mp_.graph.input[: len(node.input)]) + subgraph.ClearField("output") subgraph.output.extend(symbolic_shape_inference.out_mp_.graph.output) - subgraph.ClearField('value_info') + subgraph.ClearField("value_info") subgraph.value_info.extend(symbolic_shape_inference.out_mp_.graph.value_info) - subgraph.ClearField('node') + subgraph.ClearField("node") subgraph.node.extend(symbolic_shape_inference.out_mp_.graph.node) # for new symbolic dims from subgraph output, add to main graph symbolic dims subgraph_shapes = [get_shape_from_value_info(o) for o in symbolic_shape_inference.out_mp_.graph.output] subgraph_new_symbolic_dims = set( - [d for s in subgraph_shapes if s for d in s if type(d) == str and not d in self.symbolic_dims_]) + [d for s in subgraph_shapes if s for d in s if type(d) == str and not d in self.symbolic_dims_] + ) new_dims = {} for d in subgraph_new_symbolic_dims: assert d in symbolic_shape_inference.symbolic_dims_ @@ -524,17 +557,25 @@ def _compute_on_sympy_data(self, node, op_func): self.sympy_data_[node.output[0]] = op_func(values) def _pass_on_sympy_data(self, node): - assert len(node.input) == 1 or node.op_type in ['Reshape', 'Unsqueeze', 'Squeeze'] + assert len(node.input) == 1 or node.op_type in [ + "Reshape", + "Unsqueeze", + "Squeeze", + ] self._compute_on_sympy_data(node, lambda x: x[0]) def _pass_on_shape_and_type(self, node): vi = self.known_vi_[node.output[0]] vi.CopyFrom( - helper.make_tensor_value_info(node.output[0], self.known_vi_[node.input[0]].type.tensor_type.elem_type, - self._get_shape(node, 0))) + helper.make_tensor_value_info( + node.output[0], + self.known_vi_[node.input[0]].type.tensor_type.elem_type, + self._get_shape(node, 0), + ) + ) def _new_symbolic_dim(self, prefix, dim): - new_dim = '{}_d{}'.format(prefix, dim) + new_dim = "{}_d{}".format(prefix, dim) if new_dim in self.suggested_merge_: v = self.suggested_merge_[new_dim] new_symbolic_dim = sympy.Integer(int(v)) if is_literal(v) else v @@ -545,8 +586,14 @@ def _new_symbolic_dim(self, prefix, dim): def _new_symbolic_dim_from_output(self, node, out_idx=0, dim=0): return self._new_symbolic_dim( - '{}{}_{}_o{}_'.format(node.op_type, self.prefix_, - list(self.out_mp_.graph.node).index(node), out_idx), dim) + "{}{}_{}_o{}_".format( + node.op_type, + self.prefix_, + list(self.out_mp_.graph.node).index(node), + out_idx, + ), + dim, + ) def _new_symbolic_shape(self, rank, node, out_idx=0): return [self._new_symbolic_dim_from_output(node, out_idx, i) for i in range(rank)] @@ -560,7 +607,7 @@ def _compute_conv_pool_shape(self, node): sympy_shape[1] = W_shape[0] else: W_shape = None - kernel_shape = get_attribute(node, 'kernel_shape') + kernel_shape = get_attribute(node, "kernel_shape") rank = len(kernel_shape) assert len(sympy_shape) == rank + 2 @@ -575,14 +622,14 @@ def _compute_conv_pool_shape(self, node): sympy_shape[-rank:] = [sympy.Integer(d) for d in shape[-rank:]] return sympy_shape - dilations = get_attribute(node, 'dilations', [1] * rank) - strides = get_attribute(node, 'strides', [1] * rank) + dilations = get_attribute(node, "dilations", [1] * rank) + strides = get_attribute(node, "strides", [1] * rank) effective_kernel_shape = [(k - 1) * d + 1 for k, d in zip(kernel_shape, dilations)] - pads = get_attribute(node, 'pads') + pads = get_attribute(node, "pads") if pads is None: pads = [0] * (2 * rank) - auto_pad = get_attribute(node, 'auto_pad', b'NOTSET').decode('utf-8') - if auto_pad != 'VALID' and auto_pad != 'NOTSET': + auto_pad = get_attribute(node, "auto_pad", b"NOTSET").decode("utf-8") + if auto_pad != "VALID" and auto_pad != "NOTSET": try: residual = [sympy.Mod(d, s) for d, s in zip(sympy_shape[-rank:], strides)] total_pads = [ @@ -590,9 +637,10 @@ def _compute_conv_pool_shape(self, node): for k, s, r in zip(effective_kernel_shape, strides, residual) ] except TypeError: # sympy may throw TypeError: cannot determine truth value of Relational - total_pads = [max(0, (k - s)) for k, s in zip(effective_kernel_shape, strides) - ] # assuming no residual if sympy throws error - elif auto_pad == 'VALID': + total_pads = [ + max(0, (k - s)) for k, s in zip(effective_kernel_shape, strides) + ] # assuming no residual if sympy throws error + elif auto_pad == "VALID": total_pads = [] else: total_pads = [0] * rank @@ -600,14 +648,15 @@ def _compute_conv_pool_shape(self, node): assert len(pads) == 2 * rank total_pads = [p1 + p2 for p1, p2 in zip(pads[:rank], pads[rank:])] - ceil_mode = get_attribute(node, 'ceil_mode', 0) + ceil_mode = get_attribute(node, "ceil_mode", 0) for i in range(rank): effective_input_size = sympy_shape[-rank + i] if len(total_pads) > 0: effective_input_size = effective_input_size + total_pads[i] if ceil_mode: strided_kernel_positions = sympy.ceiling( - (effective_input_size - effective_kernel_shape[i]) / strides[i]) + (effective_input_size - effective_kernel_shape[i]) / strides[i] + ) else: strided_kernel_positions = (effective_input_size - effective_kernel_shape[i]) // strides[i] sympy_shape[-rank + i] = strided_kernel_positions + 1 @@ -640,7 +689,10 @@ def _compute_matmul_shape(self, node, output_dtype=None): rhs_reduce_dim = -2 new_shape = self._broadcast_shapes(lhs_shape[:-2], rhs_shape[:-2]) + [lhs_shape[-2]] + [rhs_shape[-1]] # merge reduce dim - self._check_merged_dims([lhs_shape[lhs_reduce_dim], rhs_shape[rhs_reduce_dim]], allow_broadcast=False) + self._check_merged_dims( + [lhs_shape[lhs_reduce_dim], rhs_shape[rhs_reduce_dim]], + allow_broadcast=False, + ) if output_dtype is None: # infer output_dtype from input type when not specified output_dtype = self.known_vi_[node.input[0]].type.tensor_type.elem_type @@ -648,19 +700,23 @@ def _compute_matmul_shape(self, node, output_dtype=None): vi.CopyFrom(helper.make_tensor_value_info(node.output[0], output_dtype, new_shape)) def _fuse_tensor_type(self, node, out_idx, dst_type, src_type): - ''' + """ update dst_tensor_type to be compatible with src_tensor_type when dimension mismatches - ''' - dst_tensor_type = dst_type.sequence_type.elem_type.tensor_type if is_sequence( - dst_type) else dst_type.tensor_type - src_tensor_type = src_type.sequence_type.elem_type.tensor_type if is_sequence( - src_type) else src_type.tensor_type + """ + dst_tensor_type = ( + dst_type.sequence_type.elem_type.tensor_type if is_sequence(dst_type) else dst_type.tensor_type + ) + src_tensor_type = ( + src_type.sequence_type.elem_type.tensor_type if is_sequence(src_type) else src_type.tensor_type + ) if dst_tensor_type.elem_type != src_tensor_type.elem_type: node_id = node.name if node.name else node.op_type - raise ValueError(f"For node {node_id}, dst_tensor_type.elem_type != src_tensor_type.elem_type: " - f"{onnx.onnx_pb.TensorProto.DataType.Name(dst_tensor_type.elem_type)} vs " - f"{onnx.onnx_pb.TensorProto.DataType.Name(src_tensor_type.elem_type)}") - if dst_tensor_type.HasField('shape'): + raise ValueError( + f"For node {node_id}, dst_tensor_type.elem_type != src_tensor_type.elem_type: " + f"{onnx.onnx_pb.TensorProto.DataType.Name(dst_tensor_type.elem_type)} vs " + f"{onnx.onnx_pb.TensorProto.DataType.Name(src_tensor_type.elem_type)}" + ) + if dst_tensor_type.HasField("shape"): for di, ds in enumerate(zip(dst_tensor_type.shape.dim, src_tensor_type.shape.dim)): if ds[0] != ds[1]: # create a new symbolic dimension for node/out_idx/mismatch dim id in dst_tensor_type for tensor_type @@ -677,33 +733,29 @@ def _infer_ArrayFeatureExtractor(self, node): indices_shape = self._get_shape(node, 1) vi = self.known_vi_[node.output[0]] vi.CopyFrom( - helper.make_tensor_value_info(node.output[0], self.known_vi_[node.input[0]].type.tensor_type.elem_type, - data_shape[:-1] + indices_shape)) + helper.make_tensor_value_info( + node.output[0], + self.known_vi_[node.input[0]].type.tensor_type.elem_type, + data_shape[:-1] + indices_shape, + ) + ) def _infer_symbolic_compute_ops(self, node): funcs = { - 'Add': - lambda l: l[0] + l[1], - 'Div': - lambda l: l[0] // l[1], # integer div in sympy - 'Equal': - lambda l: l[0] == l[1], - 'Floor': - lambda l: sympy.floor(l[0]), - 'Max': - lambda l: l[1] if is_literal(l[0]) and int(l[0]) < -self.int_max_ else - (l[0] if is_literal(l[1]) and int(l[1]) < -self.int_max_ else sympy.Max(l[0], l[1])), - 'Min': - lambda l: l[1] if is_literal(l[0]) and int(l[0]) > self.int_max_ else - (l[0] if is_literal(l[1]) and int(l[1]) > self.int_max_ else sympy.Min(l[0], l[1])), - 'Mul': - lambda l: l[0] * l[1], - 'Sub': - lambda l: l[0] - l[1], - 'Where': - lambda l: l[1] if l[0] else l[2], - 'Neg': - lambda l: -l[0] + "Add": lambda l: l[0] + l[1], + "Div": lambda l: l[0] // l[1], # integer div in sympy + "Equal": lambda l: l[0] == l[1], + "Floor": lambda l: sympy.floor(l[0]), + "Max": lambda l: l[1] + if is_literal(l[0]) and int(l[0]) < -self.int_max_ + else (l[0] if is_literal(l[1]) and int(l[1]) < -self.int_max_ else sympy.Max(l[0], l[1])), + "Min": lambda l: l[1] + if is_literal(l[0]) and int(l[0]) > self.int_max_ + else (l[0] if is_literal(l[1]) and int(l[1]) > self.int_max_ else sympy.Min(l[0], l[1])), + "Mul": lambda l: l[0] * l[1], + "Sub": lambda l: l[0] - l[1], + "Where": lambda l: l[1] if l[0] else l[2], + "Neg": lambda l: -l[0], } assert node.op_type in funcs self._compute_on_sympy_data(node, funcs[node.op_type]) @@ -724,7 +776,7 @@ def _infer_Compress(self, node): input_shape = self._get_shape(node, 0) # create a new symbolic dimension for Compress output compress_len = str(self._new_symbolic_dim_from_output(node)) - axis = get_attribute(node, 'axis') + axis = get_attribute(node, "axis") if axis == None: # when axis is not specified, input is flattened before compress so output is 1D output_shape = [compress_len] @@ -733,14 +785,18 @@ def _infer_Compress(self, node): output_shape[handle_negative_axis(axis, len(input_shape))] = compress_len vi = self.known_vi_[node.output[0]] vi.CopyFrom( - helper.make_tensor_value_info(node.output[0], self.known_vi_[node.input[0]].type.tensor_type.elem_type, - output_shape)) + helper.make_tensor_value_info( + node.output[0], + self.known_vi_[node.input[0]].type.tensor_type.elem_type, + output_shape, + ) + ) def _infer_Concat(self, node): if any([i in self.sympy_data_ or i in self.initializers_ for i in node.input]): values = self._get_int_values(node) if all([v is not None for v in values]): - assert 0 == get_attribute(node, 'axis') + assert 0 == get_attribute(node, "axis") self.sympy_data_[node.output[0]] = [] for i in range(len(node.input)): value = values[i] @@ -750,7 +806,7 @@ def _infer_Concat(self, node): self.sympy_data_[node.output[0]].append(value) sympy_shape = self._get_sympy_shape(node, 0) - axis = handle_negative_axis(get_attribute(node, 'axis'), len(sympy_shape)) + axis = handle_negative_axis(get_attribute(node, "axis"), len(sympy_shape)) for i_idx in range(1, len(node.input)): input_shape = self._get_sympy_shape(node, i_idx) if input_shape: @@ -770,13 +826,17 @@ def _infer_Concat(self, node): sympy_shape[d] = merged vi = self.known_vi_[node.output[0]] vi.CopyFrom( - helper.make_tensor_value_info(node.output[0], self.known_vi_[node.input[0]].type.tensor_type.elem_type, - get_shape_from_sympy_shape(sympy_shape))) + helper.make_tensor_value_info( + node.output[0], + self.known_vi_[node.input[0]].type.tensor_type.elem_type, + get_shape_from_sympy_shape(sympy_shape), + ) + ) def _infer_ConcatFromSequence(self, node): seq_shape = self._get_shape(node, 0) - new_axis = 1 if get_attribute(node, 'new_axis') else 0 - axis = handle_negative_axis(get_attribute(node, 'axis'), len(seq_shape) + new_axis) + new_axis = 1 if get_attribute(node, "new_axis") else 0 + axis = handle_negative_axis(get_attribute(node, "axis"), len(seq_shape) + new_axis) concat_dim = str(self._new_symbolic_dim_from_output(node, 0, axis)) new_shape = seq_shape if new_axis: @@ -786,11 +846,14 @@ def _infer_ConcatFromSequence(self, node): vi = self.known_vi_[node.output[0]] vi.CopyFrom( helper.make_tensor_value_info( - node.output[0], self.known_vi_[node.input[0]].type.sequence_type.elem_type.tensor_type.elem_type, - new_shape)) + node.output[0], + self.known_vi_[node.input[0]].type.sequence_type.elem_type.tensor_type.elem_type, + new_shape, + ) + ) def _infer_Constant(self, node): - t = get_attribute(node, 'value') + t = get_attribute(node, "value") self.sympy_data_[node.output[0]] = numpy_helper.to_array(t) def _infer_ConstantOfShape(self, node): @@ -803,30 +866,38 @@ def _infer_ConstantOfShape(self, node): # update sympy data if output type is int, and shape is known if vi.type.tensor_type.elem_type == onnx.TensorProto.INT64 and all([is_literal(x) for x in sympy_shape]): self.sympy_data_[node.output[0]] = np.ones( - [int(x) - for x in sympy_shape], dtype=np.int64) * numpy_helper.to_array(get_attribute(node, 'value', 0)) + [int(x) for x in sympy_shape], dtype=np.int64 + ) * numpy_helper.to_array(get_attribute(node, "value", 0)) else: # create new dynamic shape # note input0 is a 1D vector of shape, the new symbolic shape has the rank of the shape vector length sympy_shape = self._new_symbolic_shape(self._get_shape(node, 0)[0], node) vi.CopyFrom( - helper.make_tensor_value_info(node.output[0], vi.type.tensor_type.elem_type, - get_shape_from_sympy_shape(sympy_shape))) + helper.make_tensor_value_info( + node.output[0], + vi.type.tensor_type.elem_type, + get_shape_from_sympy_shape(sympy_shape), + ) + ) def _infer_Conv(self, node): sympy_shape = self._compute_conv_pool_shape(node) self._update_computed_dims(sympy_shape) vi = self.known_vi_[node.output[0]] vi.CopyFrom( - helper.make_tensor_value_info(node.output[0], vi.type.tensor_type.elem_type, - get_shape_from_sympy_shape(sympy_shape))) + helper.make_tensor_value_info( + node.output[0], + vi.type.tensor_type.elem_type, + get_shape_from_sympy_shape(sympy_shape), + ) + ) def _infer_Einsum(self, node): # ref:https://github.com/onnx/onnx/blob/623dfaa0151b2e4ce49779c3ec31cbd78c592b80/onnx/defs/math/defs.cc#L3275 - equation = get_attribute(node, 'equation') - equation = equation.replace(b' ', b'') - mid_index = equation.find(b'->') + equation = get_attribute(node, "equation") + equation = equation.replace(b" ", b"") + mid_index = equation.find(b"->") left_equation = equation[:mid_index] if mid_index != -1 else equation num_operands = 0 @@ -835,9 +906,9 @@ def _infer_Einsum(self, node): letter_to_dim = {} - terms = left_equation.split(b',') + terms = left_equation.split(b",") for term in terms: - ellipsis_index = term.find(b'...') + ellipsis_index = term.find(b"...") shape = self._get_shape(node, num_operands) rank = len(shape) if ellipsis_index != -1: @@ -846,7 +917,7 @@ def _infer_Einsum(self, node): num_ellipsis = num_ellipsis + 1 for i in range(1, rank + 1): letter = term[-i] - if letter != 46: # letter != b'.' + if letter != 46: # letter != b'.' dim = shape[-i] if letter not in letter_to_dim.keys(): letter_to_dim[letter] = dim @@ -856,21 +927,22 @@ def _infer_Einsum(self, node): new_sympy_shape = [] from collections import OrderedDict + num_letter_occurrences = OrderedDict() if mid_index != -1: - right_equation = equation[mid_index + 2:] - right_ellipsis_index = right_equation.find(b'...') + right_equation = equation[mid_index + 2 :] + right_ellipsis_index = right_equation.find(b"...") if right_ellipsis_index != -1: for i in range(num_ellipsis_indices): new_sympy_shape.append(shape[i]) for c in right_equation: - if c != 46: # c != b'.' + if c != 46: # c != b'.' new_sympy_shape.append(letter_to_dim[c]) else: for i in range(num_ellipsis_indices): new_sympy_shape.append(shape[i]) for c in left_equation: - if c != 44 and c != 46: # c != b',' and c != b'.': + if c != 44 and c != 46: # c != b',' and c != b'.': if c in num_letter_occurrences: num_letter_occurrences[c] = num_letter_occurrences[c] + 1 else: @@ -892,19 +964,27 @@ def _infer_Expand(self, node): new_shape = self._broadcast_shapes(shape, get_shape_from_sympy_shape(expand_to_shape)) vi = self.known_vi_[node.output[0]] vi.CopyFrom( - helper.make_tensor_value_info(node.output[0], self.known_vi_[node.input[0]].type.tensor_type.elem_type, - new_shape)) + helper.make_tensor_value_info( + node.output[0], + self.known_vi_[node.input[0]].type.tensor_type.elem_type, + new_shape, + ) + ) def _infer_Gather(self, node): data_shape = self._get_shape(node, 0) - axis = handle_negative_axis(get_attribute(node, 'axis', 0), len(data_shape)) + axis = handle_negative_axis(get_attribute(node, "axis", 0), len(data_shape)) indices_shape = self._get_shape(node, 1) vi = self.known_vi_[node.output[0]] vi.CopyFrom( - helper.make_tensor_value_info(node.output[0], self.known_vi_[node.input[0]].type.tensor_type.elem_type, - data_shape[:axis] + indices_shape + data_shape[axis + 1:])) + helper.make_tensor_value_info( + node.output[0], + self.known_vi_[node.input[0]].type.tensor_type.elem_type, + data_shape[:axis] + indices_shape + data_shape[axis + 1 :], + ) + ) # for 1D input, do some sympy compute - if node.input[0] in self.sympy_data_ and len(data_shape) == 1 and 0 == get_attribute(node, 'axis', 0): + if node.input[0] in self.sympy_data_ and len(data_shape) == 1 and 0 == get_attribute(node, "axis", 0): idx = self._try_get_value(node, 1) if idx is not None: data = self.sympy_data_[node.input[0]] @@ -921,8 +1001,12 @@ def _infer_GatherElements(self, node): indices_shape = self._get_shape(node, 1) vi = self.known_vi_[node.output[0]] vi.CopyFrom( - helper.make_tensor_value_info(node.output[0], self.known_vi_[node.input[0]].type.tensor_type.elem_type, - indices_shape)) + helper.make_tensor_value_info( + node.output[0], + self.known_vi_[node.input[0]].type.tensor_type.elem_type, + indices_shape, + ) + ) def _infer_GatherND(self, node): data_shape = self._get_shape(node, 0) @@ -934,12 +1018,19 @@ def _infer_GatherND(self, node): new_shape = indices_shape[:-1] + data_shape[last_index_dimension:] vi = self.known_vi_[node.output[0]] vi.CopyFrom( - helper.make_tensor_value_info(node.output[0], self.known_vi_[node.input[0]].type.tensor_type.elem_type, - new_shape)) + helper.make_tensor_value_info( + node.output[0], + self.known_vi_[node.input[0]].type.tensor_type.elem_type, + new_shape, + ) + ) def _infer_If(self, node): # special case for constant condition, in case there are mismatching shape from the non-executed branch - subgraphs = [get_attribute(node, 'then_branch'), get_attribute(node, 'else_branch')] + subgraphs = [ + get_attribute(node, "then_branch"), + get_attribute(node, "else_branch"), + ] cond = self._try_get_value(node, 0) if cond is not None: if as_scalar(cond) > 0: @@ -963,7 +1054,7 @@ def _infer_If(self, node): self.sympy_data_[vi.name] = subgraph_infer.sympy_data_[subgraph.output[i_out].name] def _infer_Loop(self, node): - subgraph = get_attribute(node, 'body') + subgraph = get_attribute(node, "body") assert len(subgraph.input) == len(node.input) num_loop_carried = len(node.input) - 2 # minus the length and initial loop condition # when sequence_type is used as loop carried input @@ -1002,8 +1093,11 @@ def _infer_Loop(self, node): if need_second_infer: if self.verbose_ > 2: - logger.debug("Rerun Loop: {}({}...), because of sequence in loop carried variables".format( - node.name, node.output[0])) + logger.debug( + "Rerun Loop: {}({}...), because of sequence in loop carried variables".format( + node.name, node.output[0] + ) + ) self._onnx_infer_subgraph(node, subgraph, inc_subgraph_id=False) # create a new symbolic dimension for iteration dependent dimension @@ -1014,7 +1108,7 @@ def _infer_Loop(self, node): if i >= num_loop_carried: assert not is_sequence(vi.type) # TODO: handle loop accumulation in sequence_type subgraph_vi_dim = subgraph.output[i + 1].type.tensor_type.shape.dim - vi.type.tensor_type.shape.ClearField('dim') + vi.type.tensor_type.shape.ClearField("dim") vi_dim = vi.type.tensor_type.shape.dim vi_dim.add().dim_param = loop_iter_dim vi_dim.extend(list(subgraph_vi_dim)) @@ -1041,19 +1135,25 @@ def _infer_NonZero(self, node): def _infer_OneHot(self, node): sympy_shape = self._get_sympy_shape(node, 0) depth = self._try_get_value(node, 1) - axis = get_attribute(node, 'axis', -1) + axis = get_attribute(node, "axis", -1) axis = handle_negative_axis(axis, len(sympy_shape) + 1) new_shape = get_shape_from_sympy_shape( - sympy_shape[:axis] + [self._new_symbolic_dim_from_output(node) if not is_literal(depth) else depth] + - sympy_shape[axis:]) + sympy_shape[:axis] + + [self._new_symbolic_dim_from_output(node) if not is_literal(depth) else depth] + + sympy_shape[axis:] + ) vi = self.known_vi_[node.output[0]] vi.CopyFrom( - helper.make_tensor_value_info(node.output[0], self.known_vi_[node.input[2]].type.tensor_type.elem_type, - new_shape)) + helper.make_tensor_value_info( + node.output[0], + self.known_vi_[node.input[2]].type.tensor_type.elem_type, + new_shape, + ) + ) def _infer_Pad(self, node): if get_opset(self.out_mp_) <= 10: - pads = get_attribute(node, 'pads') + pads = get_attribute(node, "pads") else: pads = self._try_get_value(node, 1) @@ -1073,7 +1173,8 @@ def _infer_Pad(self, node): vi = self.known_vi_[node.output[0]] vi.CopyFrom( - helper.make_tensor_value_info(node.output[0], output_tp, get_shape_from_sympy_shape(new_sympy_shape))) + helper.make_tensor_value_info(node.output[0], output_tp, get_shape_from_sympy_shape(new_sympy_shape)) + ) def _infer_Pool(self, node): sympy_shape = self._compute_conv_pool_shape(node) @@ -1083,8 +1184,12 @@ def _infer_Pool(self, node): continue vi = self.known_vi_[o] vi.CopyFrom( - helper.make_tensor_value_info(o, vi.type.tensor_type.elem_type, - get_shape_from_sympy_shape(sympy_shape))) + helper.make_tensor_value_info( + o, + vi.type.tensor_type.elem_type, + get_shape_from_sympy_shape(sympy_shape), + ) + ) def _infer_aten_bitwise_or(self, node): shape0 = self._get_shape(node, 0) @@ -1092,9 +1197,7 @@ def _infer_aten_bitwise_or(self, node): new_shape = self._broadcast_shapes(shape0, shape1) t0 = self.known_vi_[node.input[0]] vi = self.known_vi_[node.output[0]] - vi.CopyFrom( - helper.make_tensor_value_info(node.output[0], t0.type.tensor_type.elem_type, - new_shape)) + vi.CopyFrom(helper.make_tensor_value_info(node.output[0], t0.type.tensor_type.elem_type, new_shape)) def _infer_aten_diagonal(self, node): sympy_shape = self._get_sympy_shape(node, 0) @@ -1123,21 +1226,29 @@ def _infer_aten_diagonal(self, node): if node.output[0]: vi = self.known_vi_[node.output[0]] vi.CopyFrom( - helper.make_tensor_value_info(node.output[0], self.known_vi_[node.input[0]].type.tensor_type.elem_type, - get_shape_from_sympy_shape(new_shape))) + helper.make_tensor_value_info( + node.output[0], + self.known_vi_[node.input[0]].type.tensor_type.elem_type, + get_shape_from_sympy_shape(new_shape), + ) + ) def _infer_aten_multinomial(self, node): sympy_shape = self._get_sympy_shape(node, 0) rank = len(sympy_shape) - assert rank in [1,2] + assert rank in [1, 2] num_samples = self._try_get_value(node, 1) di = rank - 1 last_dim = num_samples if num_samples else str(self._new_symbolic_dim_from_output(node, 0, di)) output_shape = sympy_shape[:-1] + [last_dim] vi = self.known_vi_[node.output[0]] vi.CopyFrom( - helper.make_tensor_value_info(node.output[0], onnx.TensorProto.INT64, - get_shape_from_sympy_shape(output_shape))) + helper.make_tensor_value_info( + node.output[0], + onnx.TensorProto.INT64, + get_shape_from_sympy_shape(output_shape), + ) + ) def _infer_aten_pool2d(self, node): sympy_shape = self._get_sympy_shape(node, 0) @@ -1167,12 +1278,16 @@ def _infer_aten_unfold(self, node): if node.output[0]: vi = self.known_vi_[node.output[0]] vi.CopyFrom( - helper.make_tensor_value_info(node.output[0], self.known_vi_[node.input[0]].type.tensor_type.elem_type, - get_shape_from_sympy_shape(sympy_shape))) + helper.make_tensor_value_info( + node.output[0], + self.known_vi_[node.input[0]].type.tensor_type.elem_type, + get_shape_from_sympy_shape(sympy_shape), + ) + ) def _infer_aten_argmax(self, node): new_shape = None - if node.input[1] == '': + if node.input[1] == "": # The argmax of the flattened input is returned. new_shape = [] else: @@ -1228,11 +1343,15 @@ def _infer_Range(self, node): new_sympy_shape = [self._new_symbolic_dim_from_output(node)] self._update_computed_dims(new_sympy_shape) vi.CopyFrom( - helper.make_tensor_value_info(node.output[0], self.known_vi_[node.input[0]].type.tensor_type.elem_type, - get_shape_from_sympy_shape(new_sympy_shape))) + helper.make_tensor_value_info( + node.output[0], + self.known_vi_[node.input[0]].type.tensor_type.elem_type, + get_shape_from_sympy_shape(new_sympy_shape), + ) + ) def _infer_ReduceSum(self, node): - keep_dims = get_attribute(node, 'keepdims', 1) + keep_dims = get_attribute(node, "keepdims", 1) if get_opset(self.out_mp_) >= 13 and len(node.input) > 1: # ReduceSum changes axes to input[1] in opset 13 axes = self._try_get_value(node, 1) @@ -1241,8 +1360,11 @@ def _infer_ReduceSum(self, node): assert keep_dims # can only handle keep_dims==True when axes is unknown, by generating new ranks vi.CopyFrom( helper.make_tensor_value_info( - node.output[0], self.known_vi_[node.input[0]].type.tensor_type.elem_type, - get_shape_from_sympy_shape(self._new_symbolic_shape(self._get_shape_rank(node, 0), node)))) + node.output[0], + self.known_vi_[node.input[0]].type.tensor_type.elem_type, + get_shape_from_sympy_shape(self._new_symbolic_shape(self._get_shape_rank(node, 0), node)), + ) + ) else: shape = self._get_shape(node, 0) output_shape = [] @@ -1254,13 +1376,16 @@ def _infer_ReduceSum(self, node): else: output_shape.append(d) vi.CopyFrom( - helper.make_tensor_value_info(node.output[0], - self.known_vi_[node.input[0]].type.tensor_type.elem_type, - output_shape)) + helper.make_tensor_value_info( + node.output[0], + self.known_vi_[node.input[0]].type.tensor_type.elem_type, + output_shape, + ) + ) def _infer_ReduceProd(self, node): - axes = get_attribute(node, 'axes') - keep_dims = get_attribute(node, 'keepdims', 1) + axes = get_attribute(node, "axes") + keep_dims = get_attribute(node, "keepdims", 1) if keep_dims == 0 and axes == [0]: data = self._get_int_values(node)[0] if data is not None: @@ -1275,8 +1400,12 @@ def _infer_Reshape(self, node): shape_rank = shape_shape[0] assert is_literal(shape_rank) vi.CopyFrom( - helper.make_tensor_value_info(node.output[0], vi.type.tensor_type.elem_type, - get_shape_from_sympy_shape(self._new_symbolic_shape(shape_rank, node)))) + helper.make_tensor_value_info( + node.output[0], + vi.type.tensor_type.elem_type, + get_shape_from_sympy_shape(self._new_symbolic_shape(shape_rank, node)), + ) + ) else: input_sympy_shape = self._get_sympy_shape(node, 0) total = int(1) @@ -1305,8 +1434,12 @@ def _infer_Reshape(self, node): self._update_computed_dims(new_sympy_shape) vi.CopyFrom( - helper.make_tensor_value_info(node.output[0], vi.type.tensor_type.elem_type, - get_shape_from_sympy_shape(new_sympy_shape))) + helper.make_tensor_value_info( + node.output[0], + vi.type.tensor_type.elem_type, + get_shape_from_sympy_shape(new_sympy_shape), + ) + ) self._pass_on_sympy_data(node) @@ -1319,9 +1452,12 @@ def _infer_Resize(self, node): new_sympy_shape = [sympy.simplify(sympy.floor(d * s)) for d, s in zip(input_sympy_shape, scales)] self._update_computed_dims(new_sympy_shape) vi.CopyFrom( - helper.make_tensor_value_info(node.output[0], - self.known_vi_[node.input[0]].type.tensor_type.elem_type, - get_shape_from_sympy_shape(new_sympy_shape))) + helper.make_tensor_value_info( + node.output[0], + self.known_vi_[node.input[0]].type.tensor_type.elem_type, + get_shape_from_sympy_shape(new_sympy_shape), + ) + ) else: roi = self._try_get_value(node, 1) scales = self._try_get_value(node, 2) @@ -1331,7 +1467,7 @@ def _infer_Resize(self, node): self._update_computed_dims(new_sympy_shape) elif scales is not None: rank = len(scales) - if get_attribute(node, 'coordinate_transformation_mode') == 'tf_crop_and_resize': + if get_attribute(node, "coordinate_transformation_mode") == "tf_crop_and_resize": assert len(roi) == 2 * rank roi_start = list(roi)[:rank] roi_end = list(roi)[rank:] @@ -1348,13 +1484,17 @@ def _infer_Resize(self, node): new_sympy_shape = self._new_symbolic_shape(self._get_shape_rank(node, 0), node) vi.CopyFrom( - helper.make_tensor_value_info(node.output[0], self.known_vi_[node.input[0]].type.tensor_type.elem_type, - get_shape_from_sympy_shape(new_sympy_shape))) + helper.make_tensor_value_info( + node.output[0], + self.known_vi_[node.input[0]].type.tensor_type.elem_type, + get_shape_from_sympy_shape(new_sympy_shape), + ) + ) def _infer_Scan(self, node): - subgraph = get_attribute(node, 'body') - num_scan_inputs = get_attribute(node, 'num_scan_inputs') - scan_input_axes = get_attribute(node, 'scan_input_axes', [0] * num_scan_inputs) + subgraph = get_attribute(node, "body") + num_scan_inputs = get_attribute(node, "num_scan_inputs") + scan_input_axes = get_attribute(node, "scan_input_axes", [0] * num_scan_inputs) num_scan_states = len(node.input) - num_scan_inputs scan_input_axes = [ handle_negative_axis(ax, self._get_shape_rank(node, i + num_scan_states)) @@ -1363,7 +1503,7 @@ def _infer_Scan(self, node): # We may have cases where the subgraph has optionial inputs that appear in both subgraph's input and initializer, # but not in the node's input. In such cases, the input model might be invalid, but let's skip those optional inputs. assert len(subgraph.input) >= len(node.input) - subgraph_inputs = subgraph.input[:len(node.input)] + subgraph_inputs = subgraph.input[: len(node.input)] for i, si in enumerate(subgraph_inputs): subgraph_name = si.name si.CopyFrom(self.known_vi_[node.input[i]]) @@ -1373,7 +1513,7 @@ def _infer_Scan(self, node): si.name = subgraph_name self._onnx_infer_subgraph(node, subgraph) num_scan_outputs = len(node.output) - num_scan_states - scan_output_axes = get_attribute(node, 'scan_output_axes', [0] * num_scan_outputs) + scan_output_axes = get_attribute(node, "scan_output_axes", [0] * num_scan_outputs) scan_input_dim = get_shape_from_type_proto(self.known_vi_[node.input[-1]].type)[scan_input_axes[-1]] for i, o in enumerate(node.output): vi = self.known_vi_[o] @@ -1390,8 +1530,12 @@ def _infer_ScatterElements(self, node): data_shape = self._get_shape(node, 0) vi = self.known_vi_[node.output[0]] vi.CopyFrom( - helper.make_tensor_value_info(node.output[0], self.known_vi_[node.input[0]].type.tensor_type.elem_type, - data_shape)) + helper.make_tensor_value_info( + node.output[0], + self.known_vi_[node.input[0]].type.tensor_type.elem_type, + data_shape, + ) + ) def _infer_SequenceAt(self, node): # need to create new symbolic dimension if sequence shape has None: @@ -1421,7 +1565,8 @@ def _infer_Size(self, node): sympy_shape = self._get_sympy_shape(node, 0) self.sympy_data_[node.output[0]] = sympy_reduce_product(sympy_shape) self.known_vi_[node.output[0]].CopyFrom( - helper.make_tensor_value_info(node.output[0], onnx.TensorProto.INT64, [])) + helper.make_tensor_value_info(node.output[0], onnx.TensorProto.INT64, []) + ) def _infer_Slice(self, node): def less_equal(x, y): @@ -1444,7 +1589,7 @@ def less_equal(x, y): return bool(y - x >= 0) def handle_negative_index(index, bound): - """ normalizes a negative index to be in [0, bound) """ + """normalizes a negative index to be in [0, bound)""" try: if not less_equal(0, index): if is_literal(index) and index <= -self.int_max_: @@ -1456,9 +1601,9 @@ def handle_negative_index(index, bound): return index if get_opset(self.out_mp_) <= 9: - axes = get_attribute(node, 'axes') - starts = get_attribute(node, 'starts') - ends = get_attribute(node, 'ends') + axes = get_attribute(node, "axes") + starts = get_attribute(node, "starts") + ends = get_attribute(node, "ends") if not axes: axes = list(range(len(starts))) steps = [1] * len(axes) @@ -1497,8 +1642,9 @@ def handle_negative_index(index, bound): e = min(e, new_sympy_shape[i]) else: if e > 0: - e = sympy.Min(e, new_sympy_shape[i] - ) if e > 1 else e #special case for slicing first to make computation easier + e = ( + sympy.Min(e, new_sympy_shape[i]) if e > 1 else e + ) # special case for slicing first to make computation easier else: if is_literal(new_sympy_shape[i]): e = sympy.Min(e, new_sympy_shape[i]) @@ -1507,7 +1653,9 @@ def handle_negative_index(index, bound): if not less_equal(e, new_sympy_shape[i]): e = new_sympy_shape[i] except Exception: - logger.warning('Unable to determine if {} <= {}, treat as equal'.format(e, new_sympy_shape[i])) + logger.warning( + "Unable to determine if {} <= {}, treat as equal".format(e, new_sympy_shape[i]) + ) e = new_sympy_shape[i] s = handle_negative_index(s, new_sympy_shape[i]) @@ -1520,16 +1668,26 @@ def handle_negative_index(index, bound): vi = self.known_vi_[node.output[0]] vi.CopyFrom( - helper.make_tensor_value_info(node.output[0], vi.type.tensor_type.elem_type, - get_shape_from_sympy_shape(new_sympy_shape))) + helper.make_tensor_value_info( + node.output[0], + vi.type.tensor_type.elem_type, + get_shape_from_sympy_shape(new_sympy_shape), + ) + ) # handle sympy_data if needed, for slice in shape computation - if (node.input[0] in self.sympy_data_ and [0] == axes and len(starts) == 1 and len(ends) == 1 - and len(steps) == 1): + if ( + node.input[0] in self.sympy_data_ + and [0] == axes + and len(starts) == 1 + and len(ends) == 1 + and len(steps) == 1 + ): input_sympy_data = self.sympy_data_[node.input[0]] - if type(input_sympy_data) == list or (type(input_sympy_data) == np.array - and len(input_sympy_data.shape) == 1): - self.sympy_data_[node.output[0]] = input_sympy_data[starts[0]:ends[0]:steps[0]] + if type(input_sympy_data) == list or ( + type(input_sympy_data) == np.array and len(input_sympy_data.shape) == 1 + ): + self.sympy_data_[node.output[0]] = input_sympy_data[starts[0] : ends[0] : steps[0]] def _infer_SoftmaxCrossEntropyLoss(self, node): vi = self.known_vi_[node.output[0]] @@ -1544,8 +1702,8 @@ def _infer_SoftmaxCrossEntropyLoss(self, node): def _infer_Split_Common(self, node, make_value_info_func): input_sympy_shape = self._get_sympy_shape(node, 0) - axis = handle_negative_axis(get_attribute(node, 'axis', 0), len(input_sympy_shape)) - split = get_attribute(node, 'split') + axis = handle_negative_axis(get_attribute(node, "axis", 0), len(input_sympy_shape)) + split = get_attribute(node, "split") if not split: num_outputs = len(node.output) split = [input_sympy_shape[axis] / sympy.Integer(num_outputs)] * num_outputs @@ -1557,8 +1715,11 @@ def _infer_Split_Common(self, node, make_value_info_func): vi = self.known_vi_[node.output[i_o]] vi.CopyFrom( make_value_info_func( - node.output[i_o], self.known_vi_[node.input[0]].type.tensor_type.elem_type, - get_shape_from_sympy_shape(input_sympy_shape[:axis] + [split[i_o]] + input_sympy_shape[axis + 1:]))) + node.output[i_o], + self.known_vi_[node.input[0]].type.tensor_type.elem_type, + get_shape_from_sympy_shape(input_sympy_shape[:axis] + [split[i_o]] + input_sympy_shape[axis + 1 :]), + ) + ) self.known_vi_[vi.name] = vi def _infer_Split(self, node): @@ -1573,11 +1734,11 @@ def _infer_Squeeze(self, node): # Depending on op-version 'axes' are provided as attribute or via 2nd input if op_set < 13: - axes = get_attribute(node, 'axes') + axes = get_attribute(node, "axes") assert self._try_get_value(node, 1) is None else: axes = self._try_get_value(node, 1) - assert get_attribute(node, 'axes') is None + assert get_attribute(node, "axes") is None if axes is None: # No axes have been provided (neither via attribute nor via input). @@ -1587,8 +1748,10 @@ def _infer_Squeeze(self, node): if self.verbose_ > 0: symbolic_dimensions = [s for s in input_shape if type(s) != int] if len(symbolic_dimensions) > 0: - logger.debug(f"Symbolic dimensions in input shape of op: '{node.op_type}' node: '{node.name}'. " + - f"Assuming the following dimensions are never equal to 1: {symbolic_dimensions}") + logger.debug( + f"Symbolic dimensions in input shape of op: '{node.op_type}' node: '{node.name}'. " + + f"Assuming the following dimensions are never equal to 1: {symbolic_dimensions}" + ) else: axes = [handle_negative_axis(a, len(input_shape)) for a in axes] output_shape = [] @@ -1598,13 +1761,19 @@ def _infer_Squeeze(self, node): else: assert input_shape[i] == 1 or type(input_shape[i]) != int if self.verbose_ > 0 and type(input_shape[i]) != int: - logger.debug(f"Symbolic dimensions in input shape of op: '{node.op_type}' node: '{node.name}'. " + - f"Assuming the dimension '{input_shape[i]}' at index {i} of the input to be equal to 1.") + logger.debug( + f"Symbolic dimensions in input shape of op: '{node.op_type}' node: '{node.name}'. " + + f"Assuming the dimension '{input_shape[i]}' at index {i} of the input to be equal to 1." + ) vi = self.known_vi_[node.output[0]] vi.CopyFrom( - helper.make_tensor_value_info(node.output[0], self.known_vi_[node.input[0]].type.tensor_type.elem_type, - output_shape)) + helper.make_tensor_value_info( + node.output[0], + self.known_vi_[node.input[0]].type.tensor_type.elem_type, + output_shape, + ) + ) self._pass_on_sympy_data(node) def _infer_Tile(self, node): @@ -1620,16 +1789,20 @@ def _infer_Tile(self, node): new_sympy_shape = self._new_symbolic_shape(self._get_shape_rank(node, 0), node) vi = self.known_vi_[node.output[0]] vi.CopyFrom( - helper.make_tensor_value_info(node.output[0], vi.type.tensor_type.elem_type, - get_shape_from_sympy_shape(new_sympy_shape))) + helper.make_tensor_value_info( + node.output[0], + vi.type.tensor_type.elem_type, + get_shape_from_sympy_shape(new_sympy_shape), + ) + ) def _infer_TopK(self, node): rank = self._get_shape_rank(node, 0) - axis = handle_negative_axis(get_attribute(node, 'axis', -1), rank) + axis = handle_negative_axis(get_attribute(node, "axis", -1), rank) new_shape = self._get_shape(node, 0) if get_opset(self.out_mp_) <= 9: - k = get_attribute(node, 'k') + k = get_attribute(node, "k") else: k = self._get_int_values(node)[1] @@ -1655,10 +1828,11 @@ def _infer_TopK(self, node): def _infer_Transpose(self, node): if node.input[0] in self.sympy_data_: data_shape = self._get_shape(node, 0) - perm = get_attribute(node, 'perm', reversed(list(range(len(data_shape))))) + perm = get_attribute(node, "perm", reversed(list(range(len(data_shape))))) input_data = self.sympy_data_[node.input[0]] - self.sympy_data_[node.output[0]] = np.transpose(np.array(input_data).reshape(*data_shape), - axes=tuple(perm)).flatten().tolist() + self.sympy_data_[node.output[0]] = ( + np.transpose(np.array(input_data).reshape(*data_shape), axes=tuple(perm)).flatten().tolist() + ) def _infer_Unsqueeze(self, node): input_shape = self._get_shape(node, 0) @@ -1666,11 +1840,11 @@ def _infer_Unsqueeze(self, node): # Depending on op-version 'axes' are provided as attribute or via 2nd input if op_set < 13: - axes = get_attribute(node, 'axes') + axes = get_attribute(node, "axes") assert self._try_get_value(node, 1) is None else: axes = self._try_get_value(node, 1) - assert get_attribute(node, 'axes') is None + assert get_attribute(node, "axes") is None output_rank = len(input_shape) + len(axes) axes = [handle_negative_axis(a, output_rank) for a in axes] @@ -1686,16 +1860,20 @@ def _infer_Unsqueeze(self, node): vi = self.known_vi_[node.output[0]] vi.CopyFrom( - helper.make_tensor_value_info(node.output[0], self.known_vi_[node.input[0]].type.tensor_type.elem_type, - output_shape)) + helper.make_tensor_value_info( + node.output[0], + self.known_vi_[node.input[0]].type.tensor_type.elem_type, + output_shape, + ) + ) self._pass_on_sympy_data(node) def _infer_ZipMap(self, node): map_key_type = None - if get_attribute(node, 'classlabels_int64s') is not None: + if get_attribute(node, "classlabels_int64s") is not None: map_key_type = onnx.TensorProto.INT64 - elif get_attribute(node, 'classlabels_strings') is not None: + elif get_attribute(node, "classlabels_strings") is not None: map_key_type = onnx.TensorProto.STRING assert map_key_type is not None @@ -1710,7 +1888,7 @@ def _infer_Attention(self, node): shape = self._get_shape(node, 0) shape_bias = self._get_shape(node, 2) assert len(shape) == 3 and len(shape_bias) == 1 - qkv_hidden_sizes_attr = get_attribute(node, 'qkv_hidden_sizes') + qkv_hidden_sizes_attr = get_attribute(node, "qkv_hidden_sizes") if qkv_hidden_sizes_attr is not None: assert len(qkv_hidden_sizes_attr) == 3 shape[2] = int(qkv_hidden_sizes_attr[2]) @@ -1777,9 +1955,9 @@ def _infer_SkipLayerNormalization(self, node): self._propagate_shape_and_type(node) def _infer_PythonOp(self, node): - output_tensor_types = get_attribute(node, 'output_tensor_types') + output_tensor_types = get_attribute(node, "output_tensor_types") assert output_tensor_types - output_tensor_ranks = get_attribute(node, 'output_tensor_ranks') + output_tensor_ranks = get_attribute(node, "output_tensor_ranks") assert output_tensor_ranks # set the context output seperately. @@ -1811,16 +1989,16 @@ def _is_none_dim(self, dim_value): if dim_value in self.symbolic_dims_.keys(): return False return True - + def _is_shape_contains_none_dim(self, out_shape): for out in out_shape: if self._is_none_dim(out): return out return None - + def _infer_impl(self, start_sympy_data=None): self.sympy_data_ = start_sympy_data or {} - self.out_mp_.graph.ClearField('value_info') + self.out_mp_.graph.ClearField("value_info") self._apply_suggested_merge(graph_input_only=True) self.input_symbols_ = set() for i in self.out_mp_.graph.input: @@ -1853,7 +2031,7 @@ def _infer_impl(self, start_sympy_data=None): # for tensor ops like Reshape/Tile/Expand that read initializer, we need to do sympy computation based inference anyways self.tmp_mp_ = onnx.ModelProto() self.tmp_mp_.CopyFrom(self.out_mp_) - self.tmp_mp_.graph.ClearField('initializer') + self.tmp_mp_.graph.ClearField("initializer") # compute prerequesite for node for topological sort # node with subgraphs may have dependency on implicit inputs, which will affect topological sort @@ -1862,10 +2040,13 @@ def _infer_impl(self, start_sympy_data=None): def get_prereq(node): names = set(i for i in node.input if i) subgraphs = [] - if 'If' == node.op_type: - subgraphs = [get_attribute(node, 'then_branch'), get_attribute(node, 'else_branch')] - elif node.op_type in ['Loop', 'Scan']: - subgraphs = [get_attribute(node, 'body')] + if "If" == node.op_type: + subgraphs = [ + get_attribute(node, "then_branch"), + get_attribute(node, "else_branch"), + ] + elif node.op_type in ["Loop", "Scan"]: + subgraphs = [get_attribute(node, "body")] for g in subgraphs: g_outputs_and_initializers = {i.name for i in g.initializer} g_prereq = set() @@ -1894,12 +2075,14 @@ def get_prereq(node): old_sorted_nodes_len = len(sorted_nodes) for node in self.out_mp_.graph.node: if (node.output[0] not in sorted_known_vi) and all( - [i in sorted_known_vi for i in prereq_for_node[node.output[0]] if i]): + [i in sorted_known_vi for i in prereq_for_node[node.output[0]] if i] + ): sorted_known_vi.update(node.output) sorted_nodes.append(node) if old_sorted_nodes_len == len(sorted_nodes) and not all( - [o.name in sorted_known_vi for o in self.out_mp_.graph.output]): - raise Exception('Invalid model with cyclic graph') + [o.name in sorted_known_vi for o in self.out_mp_.graph.output] + ): + raise Exception("Invalid model with cyclic graph") for node in sorted_nodes: assert all([i in self.known_vi_ for i in node.input if i]) @@ -1907,37 +2090,47 @@ def get_prereq(node): known_aten_op = False if node.op_type in self.dispatcher_: self.dispatcher_[node.op_type](node) - elif node.op_type in ['ConvTranspose']: + elif node.op_type in ["ConvTranspose"]: # onnx shape inference ops like ConvTranspose may have empty shape for symbolic input # before adding symbolic compute for them # mark the output type as UNDEFINED to allow guessing of rank vi = self.known_vi_[node.output[0]] if len(vi.type.tensor_type.shape.dim) == 0: vi.type.tensor_type.elem_type = onnx.TensorProto.UNDEFINED - elif node.op_type == 'ATen' and node.domain == 'org.pytorch.aten': + elif node.op_type == "ATen" and node.domain == "org.pytorch.aten": for attr in node.attribute: # TODO: Is overload_name needed? - if attr.name == 'operator': - aten_op_name = attr.s.decode('utf-8') if isinstance(attr.s, bytes) else attr.s + if attr.name == "operator": + aten_op_name = attr.s.decode("utf-8") if isinstance(attr.s, bytes) else attr.s if aten_op_name in self.aten_op_dispatcher_: known_aten_op = True self.aten_op_dispatcher_[aten_op_name](node) break if self.verbose_ > 2: - logger.debug(node.op_type + ': ' + node.name) + logger.debug(node.op_type + ": " + node.name) for i, name in enumerate(node.input): - logger.debug(' Input {}: {} {}'.format(i, name, 'initializer' if name in self.initializers_ else '')) + logger.debug( + " Input {}: {} {}".format(i, name, "initializer" if name in self.initializers_ else "") + ) # onnx automatically merge dims with value, i.e. Mul(['aaa', 'bbb'], [1000, 1]) -> [1000, 'bbb'] # symbolic shape inference needs to apply merge of 'aaa' -> 1000 in this case if node.op_type in [ - 'Add', 'Sub', 'Mul', 'Div', 'MatMul', 'MatMulInteger', 'MatMulInteger16', 'Where', 'Sum' + "Add", + "Sub", + "Mul", + "Div", + "MatMul", + "MatMulInteger", + "MatMulInteger16", + "Where", + "Sum", ]: vi = self.known_vi_[node.output[0]] out_rank = len(get_shape_from_type_proto(vi.type)) in_shapes = [self._get_shape(node, i) for i in range(len(node.input))] - for d in range(out_rank - (2 if node.op_type in ['MatMul', 'MatMulInteger', 'MatMulInteger16'] else 0)): + for d in range(out_rank - (2 if node.op_type in ["MatMul", "MatMulInteger", "MatMulInteger16"] else 0)): in_dims = [s[len(s) - out_rank + d] for s in in_shapes if len(s) + d >= out_rank] if len(in_dims) > 1: self._check_merged_dims(in_dims, allow_broadcast=True) @@ -1945,41 +2138,70 @@ def get_prereq(node): for i_o in range(len(node.output)): vi = self.known_vi_[node.output[i_o]] out_type = vi.type - out_type_kind = out_type.WhichOneof('value') + out_type_kind = out_type.WhichOneof("value") # do not process shape for non-tensors - if out_type_kind not in ['tensor_type', 'sparse_tensor_type', None]: + if out_type_kind not in ["tensor_type", "sparse_tensor_type", None]: if self.verbose_ > 2: - if out_type_kind == 'sequence_type': - seq_cls_type = out_type.sequence_type.elem_type.WhichOneof('value') - if 'tensor_type' == seq_cls_type: - logger.debug(' {}: sequence of {} {}'.format( - node.output[i_o], str(get_shape_from_value_info(vi)), - onnx.TensorProto.DataType.Name( - vi.type.sequence_type.elem_type.tensor_type.elem_type))) + if out_type_kind == "sequence_type": + seq_cls_type = out_type.sequence_type.elem_type.WhichOneof("value") + if "tensor_type" == seq_cls_type: + logger.debug( + " {}: sequence of {} {}".format( + node.output[i_o], + str(get_shape_from_value_info(vi)), + onnx.TensorProto.DataType.Name( + vi.type.sequence_type.elem_type.tensor_type.elem_type + ), + ) + ) else: - logger.debug(' {}: sequence of {}'.format(node.output[i_o], seq_cls_type)) + logger.debug(" {}: sequence of {}".format(node.output[i_o], seq_cls_type)) else: - logger.debug(' {}: {}'.format(node.output[i_o], out_type_kind)) + logger.debug(" {}: {}".format(node.output[i_o], out_type_kind)) continue out_shape = get_shape_from_value_info(vi) out_type_undefined = out_type.tensor_type.elem_type == onnx.TensorProto.UNDEFINED if self.verbose_ > 2: - logger.debug(' {}: {} {}'.format(node.output[i_o], str(out_shape), - onnx.TensorProto.DataType.Name(vi.type.tensor_type.elem_type))) + logger.debug( + " {}: {} {}".format( + node.output[i_o], + str(out_shape), + onnx.TensorProto.DataType.Name(vi.type.tensor_type.elem_type), + ) + ) if node.output[i_o] in self.sympy_data_: - logger.debug(' Sympy Data: ' + str(self.sympy_data_[node.output[i_o]])) + logger.debug(" Sympy Data: " + str(self.sympy_data_[node.output[i_o]])) # onnx >= 1.11.0, use unk__#index instead of None when the shape dim is uncertain - if (out_shape is not None and (None in out_shape or self._is_shape_contains_none_dim(out_shape))) or out_type_undefined: + if ( + out_shape is not None and (None in out_shape or self._is_shape_contains_none_dim(out_shape)) + ) or out_type_undefined: if self.auto_merge_: if node.op_type in [ - 'Add', 'Sub', 'Mul', 'Div', 'MatMul', 'MatMulInteger', 'MatMulInteger16', 'Concat', - 'Where', 'Sum', 'Equal', 'Less', 'Greater', 'LessOrEqual', 'GreaterOrEqual' + "Add", + "Sub", + "Mul", + "Div", + "MatMul", + "MatMulInteger", + "MatMulInteger16", + "Concat", + "Where", + "Sum", + "Equal", + "Less", + "Greater", + "LessOrEqual", + "GreaterOrEqual", ]: shapes = [self._get_shape(node, i) for i in range(len(node.input))] - if node.op_type in ['MatMul', 'MatMulInteger', 'MatMulInteger16']: + if node.op_type in [ + "MatMul", + "MatMulInteger", + "MatMulInteger16", + ]: if None in out_shape or self._is_shape_contains_none_dim(out_shape): if None in out_shape: idx = out_shape.index(None) @@ -1989,9 +2211,12 @@ def get_prereq(node): # only support auto merge for MatMul for dim < rank-2 when rank > 2 assert len(shapes[0]) > 2 and dim_idx[0] < len(shapes[0]) - 2 assert len(shapes[1]) > 2 and dim_idx[1] < len(shapes[1]) - 2 - elif node.op_type == 'Expand': + elif node.op_type == "Expand": # auto merge for cases like Expand([min(batch, 1), min(seq, 512)], [batch, seq]) - shapes = [self._get_shape(node, 0), self._get_value(node, 1)] + shapes = [ + self._get_shape(node, 0), + self._get_value(node, 1), + ] else: shapes = [] @@ -2003,10 +2228,13 @@ def get_prereq(node): # if a tensor has a lower rank (dim_idx[idx] < 0), it would automatically broadcast and need no merge dim_idx = [len(s) - len(out_shape) + idx for s in shapes] if len(dim_idx) > 0: - self._add_suggested_merge([ - s[i] if is_literal(s[i]) else str(s[i]) for s, i in zip(shapes, dim_idx) - if i >= 0 - ]) + self._add_suggested_merge( + [ + s[i] if is_literal(s[i]) else str(s[i]) + for s, i in zip(shapes, dim_idx) + if i >= 0 + ] + ) self.run_ = True else: self.run_ = False @@ -2033,30 +2261,42 @@ def get_prereq(node): # otherwise, use original data type out_dtype = vi.type.tensor_type.elem_type vi.CopyFrom( - helper.make_tensor_value_info(vi.name, out_dtype, - get_shape_from_sympy_shape(new_shape))) + helper.make_tensor_value_info( + vi.name, + out_dtype, + get_shape_from_sympy_shape(new_shape), + ) + ) if self.verbose_ > 0: if is_unknown_op: - logger.debug("Possible unknown op: {} node: {}, guessing {} shape".format( - node.op_type, node.name, vi.name)) + logger.debug( + "Possible unknown op: {} node: {}, guessing {} shape".format( + node.op_type, node.name, vi.name + ) + ) if self.verbose_ > 2: - logger.debug(' {}: {} {}'.format(node.output[i_o], str(new_shape), - vi.type.tensor_type.elem_type)) + logger.debug( + " {}: {} {}".format( + node.output[i_o], + str(new_shape), + vi.type.tensor_type.elem_type, + ) + ) self.run_ = True continue # continue the inference after guess, no need to stop as no merge is needed if self.verbose_ > 0 or not self.auto_merge_ or out_type_undefined: - logger.debug('Stopping at incomplete shape inference at ' + node.op_type + ': ' + node.name) - logger.debug('node inputs:') + logger.debug("Stopping at incomplete shape inference at " + node.op_type + ": " + node.name) + logger.debug("node inputs:") for i in node.input: logger.debug(self.known_vi_[i]) - logger.debug('node outputs:') + logger.debug("node outputs:") for o in node.output: logger.debug(self.known_vi_[o]) if self.auto_merge_ and not out_type_undefined: - logger.debug('Merging: ' + str(self.suggested_merge_)) + logger.debug("Merging: " + str(self.suggested_merge_)) return False self.run_ = False @@ -2071,7 +2311,7 @@ def _update_output_from_vi(self): def infer_shapes(in_mp, int_max=2**31 - 1, auto_merge=False, guess_output_rank=False, verbose=0): onnx_opset = get_opset(in_mp) if (not onnx_opset) or onnx_opset < 7: - logger.warning('Only support models of onnx opset 7 and above.') + logger.warning("Only support models of onnx opset 7 and above.") return None symbolic_shape_inference = SymbolicShapeInference(int_max, auto_merge, guess_output_rank, verbose) all_shapes_inferred = False @@ -2086,35 +2326,48 @@ def infer_shapes(in_mp, int_max=2**31 - 1, auto_merge=False, guess_output_rank=F def parse_arguments(): parser = argparse.ArgumentParser() - parser.add_argument('--input', required=True, help='The input model file') - parser.add_argument('--output', help='The output model file') - parser.add_argument('--auto_merge', - help='Automatically merge symbolic dims when confliction happens', - action='store_true', - default=False) - parser.add_argument('--int_max', - help='maximum value for integer to be treated as boundless for ops like slice', - type=int, - default=2**31 - 1) - parser.add_argument('--guess_output_rank', - help='guess output rank to be the same as input 0 for unknown ops', - action='store_true', - default=False) - parser.add_argument('--verbose', - help='Prints detailed logs of inference, 0: turn off, 1: warnings, 3: detailed', - type=int, - default=0) + parser.add_argument("--input", required=True, help="The input model file") + parser.add_argument("--output", help="The output model file") + parser.add_argument( + "--auto_merge", + help="Automatically merge symbolic dims when confliction happens", + action="store_true", + default=False, + ) + parser.add_argument( + "--int_max", + help="maximum value for integer to be treated as boundless for ops like slice", + type=int, + default=2**31 - 1, + ) + parser.add_argument( + "--guess_output_rank", + help="guess output rank to be the same as input 0 for unknown ops", + action="store_true", + default=False, + ) + parser.add_argument( + "--verbose", + help="Prints detailed logs of inference, 0: turn off, 1: warnings, 3: detailed", + type=int, + default=0, + ) return parser.parse_args() -if __name__ == '__main__': +if __name__ == "__main__": args = parse_arguments() - logger.info('input model: ' + args.input) + logger.info("input model: " + args.input) if args.output: - logger.info('output model ' + args.output) - logger.info('Doing symbolic shape inference...') - out_mp = SymbolicShapeInference.infer_shapes(onnx.load(args.input), args.int_max, args.auto_merge, - args.guess_output_rank, args.verbose) + logger.info("output model " + args.output) + logger.info("Doing symbolic shape inference...") + out_mp = SymbolicShapeInference.infer_shapes( + onnx.load(args.input), + args.int_max, + args.auto_merge, + args.guess_output_rank, + args.verbose, + ) if args.output and out_mp: onnx.save(out_mp, args.output) - logger.info('Done!') + logger.info("Done!") diff --git a/onnxruntime/python/tools/tensorrt/perf/benchmark.py b/onnxruntime/python/tools/tensorrt/perf/benchmark.py index 5086f3d0ce16d..9a2572cdfe4e1 100644 --- a/onnxruntime/python/tools/tensorrt/perf/benchmark.py +++ b/onnxruntime/python/tools/tensorrt/perf/benchmark.py @@ -1,72 +1,78 @@ -import os -import csv -import timeit -from datetime import datetime -import numpy -import logging -import coloredlogs -import numpy as np import argparse import copy +import csv import json +import logging +import os +import pprint import re import sys -import onnxruntime -from onnx import numpy_helper -from perf_utils import * -import pprint import time +import timeit +from datetime import datetime + +import coloredlogs +import numpy +import numpy as np import pandas as pd from float16 import * +from onnx import numpy_helper +from perf_utils import * + +import onnxruntime debug = False -sys.path.append('.') -logger = logging.getLogger('') +sys.path.append(".") +logger = logging.getLogger("") ep_to_provider_list = { cpu: [cpu_ep], - acl: [acl_ep], + acl: [acl_ep], cuda: [cuda_ep], cuda_fp16: [cuda_ep], trt: [trt_ep, cuda_ep], - trt_fp16: [trt_ep, cuda_ep] + trt_fp16: [trt_ep, cuda_ep], } -# latency gain headers -trt_cuda_gain = 'TRT_CUDA_gain(%)' -trt_cuda_fp16_gain = 'TRT_CUDA_fp16_gain(%)' -trt_native_gain = 'TRT_Standalone_gain(%)' -trt_native_fp16_gain = 'TRT_Standalone_fp16_gain(%)' +# latency gain headers +trt_cuda_gain = "TRT_CUDA_gain(%)" +trt_cuda_fp16_gain = "TRT_CUDA_fp16_gain(%)" +trt_native_gain = "TRT_Standalone_gain(%)" +trt_native_fp16_gain = "TRT_Standalone_fp16_gain(%)" # metadata FAIL_MODEL_FILE = ".fail_model_map" LATENCY_FILE = ".latency_map" METRICS_FILE = ".metrics_map" SESSION_FILE = ".session_map" -MEMORY_FILE = './temp_memory.csv' +MEMORY_FILE = "./temp_memory.csv" + def split_and_sort_output(string_list): string_list = string_list.split("\n") string_list.sort() return string_list -def is_dynamic(model): + +def is_dynamic(model): inp = model.graph.input[0] - for dim in inp.type.tensor_type.shape.dim: - if not dim.HasField('dim_value'): + for dim in inp.type.tensor_type.shape.dim: + if not dim.HasField("dim_value"): return True - return False + return False + -def get_model_inputs(model): +def get_model_inputs(model): all_inputs = [node.name for node in model.graph.input] input_initializers = [node.name for node in model.graph.initializer] inputs = list(set(all_inputs) - set(input_initializers)) - return inputs + return inputs + def run_trt_standalone(trtexec, model_name, model_path, all_inputs_shape, fp16, track_memory): logger.info("running standalone trt") onnx_model_path = "--onnx=" + model_path - + # load inputs input_shape = [] loaded_inputs = [] @@ -79,28 +85,34 @@ def run_trt_standalone(trtexec, model_name, model_path, all_inputs_shape, fp16, for i in range(len(ort_inputs)): name = ort_inputs[i] - loaded_input = name + ':' + test_data_dir + '/' + str(i) + '.bin' + loaded_input = name + ":" + test_data_dir + "/" + str(i) + ".bin" logger.info(loaded_input) shape = [] for j in all_inputs_shape[i]: shape.append(str(j)) shape = "x".join(shape) - shape = name + ':' + shape + shape = name + ":" + shape input_shape.append(shape) loaded_inputs.append(loaded_input) - shapes_arg = '--optShapes=' + ','.join(input_shape) - inputs_arg = '--loadInputs=' + ','.join(loaded_inputs) + shapes_arg = "--optShapes=" + ",".join(input_shape) + inputs_arg = "--loadInputs=" + ",".join(loaded_inputs) result = {} - command = [trtexec, onnx_model_path, "--duration=50", "--percentile=90", "--workspace=4096"] + command = [ + trtexec, + onnx_model_path, + "--duration=50", + "--percentile=90", + "--workspace=4096", + ] command.extend([inputs_arg]) - + # add benchmarking flags if is_dynamic(model): command.extend([shapes_arg]) - if fp16: + if fp16: command.extend(["--fp16"]) - + # save engine engine_name = model_name + ".engine" save_command = command + ["--saveEngine=" + engine_name] @@ -114,41 +126,42 @@ def run_trt_standalone(trtexec, model_name, model_path, all_inputs_shape, fp16, mem_usage = None p = None success = False - if track_memory: - p = start_memory_tracking() - try: + if track_memory: + p = start_memory_tracking() + try: out = get_output(load_command) success = True mem_usage = end_memory_tracking(p, success) - except Exception as e: + except Exception as e: end_memory_tracking(p, success) - raise(e) - else: + raise (e) + else: out = get_output(load_command) - + # parse trtexec output tmp = out.split("\n") target_list = [] for t in tmp: - if 'mean' in t: + if "mean" in t: target_list.append(t) - if 'percentile' in t: + if "percentile" in t: target_list.append(t) target = target_list[2] - avg_latency_match = re.search('mean = (.*?) ms', target) + avg_latency_match = re.search("mean = (.*?) ms", target) if avg_latency_match: - result["average_latency_ms"] = avg_latency_match.group(1) # extract number - percentile_match = re.search('percentile\(90%\) = (.*?) ms', target) + result["average_latency_ms"] = avg_latency_match.group(1) # extract number + percentile_match = re.search("percentile\(90%\) = (.*?) ms", target) if percentile_match: - result["latency_90_percentile"] = percentile_match.group(1) # extract number - if mem_usage: + result["latency_90_percentile"] = percentile_match.group(1) # extract number + if mem_usage: result["memory"] = mem_usage - + logger.info(result) return result + def get_latency_result(runtimes, batch_size): latency_ms = sum(runtimes) / float(len(runtimes)) * 1000.0 latency_variance = numpy.var(runtimes, dtype=numpy.float64) * 1000.0 @@ -165,35 +178,38 @@ def get_latency_result(runtimes, batch_size): } return result + def get_ort_session_inputs_and_outputs(name, session, ort_input): sess_inputs = {} sess_outputs = None - if 'bert_squad' in name.lower() or 'bert-squad' in name.lower(): + if "bert_squad" in name.lower() or "bert-squad" in name.lower(): unique_ids_raw_output = ort_input[0] input_ids = ort_input[1] input_mask = ort_input[2] segment_ids = ort_input[3] sess_inputs = { - "unique_ids_raw_output___9:0": unique_ids_raw_output, - "input_ids:0": input_ids[0:1], - "input_mask:0": input_mask[0:1], - "segment_ids:0": segment_ids[0:1]} + "unique_ids_raw_output___9:0": unique_ids_raw_output, + "input_ids:0": input_ids[0:1], + "input_mask:0": input_mask[0:1], + "segment_ids:0": segment_ids[0:1], + } sess_outputs = ["unique_ids:0", "unstack:0", "unstack:1"] - elif 'bidaf' in name.lower(): + elif "bidaf" in name.lower(): sess_inputs = { - "context_word": ort_input[0], - "context_char": ort_input[2], - "query_word": ort_input[1], - "query_char": ort_input[3]} - sess_outputs = ["start_pos","end_pos"] - - elif 'yolov4' in name.lower(): + "context_word": ort_input[0], + "context_char": ort_input[2], + "query_word": ort_input[1], + "query_char": ort_input[3], + } + sess_outputs = ["start_pos", "end_pos"] + + elif "yolov4" in name.lower(): sess_inputs[session.get_inputs()[0].name] = ort_input[0] - sess_outputs = ['Identity:0'] + sess_outputs = ["Identity:0"] else: sess_inputs = {} @@ -204,54 +220,89 @@ def get_ort_session_inputs_and_outputs(name, session, ort_input): sess_outputs.append(session.get_outputs()[i].name) return (sess_inputs, sess_outputs) -def track_ep_memory(ep): - return cpu != ep -def get_trtexec_pid(df, python_pid): - for pid in df['pid'].tolist(): - if pid != python_pid: +def track_ep_memory(ep): + return cpu != ep + + +def get_trtexec_pid(df, python_pid): + for pid in df["pid"].tolist(): + if pid != python_pid: return pid -def get_max_memory(): + +def get_max_memory(): df = pd.read_csv(MEMORY_FILE) - pid = df['pid'].iloc[0] - mem_series = df.loc[df['pid'] == pid, ' used_gpu_memory [MiB]'] - max_mem = max(mem_series.str.replace(' MiB','').astype(int)) + pid = df["pid"].iloc[0] + mem_series = df.loc[df["pid"] == pid, " used_gpu_memory [MiB]"] + max_mem = max(mem_series.str.replace(" MiB", "").astype(int)) return max_mem -def start_memory_tracking(): + +def start_memory_tracking(): logger.info("starting memory tracking process") - p = subprocess.Popen(["nvidia-smi", "--query-compute-apps=pid,used_memory", "--format=csv", "-l", "1", "-f", MEMORY_FILE]) + p = subprocess.Popen( + [ + "nvidia-smi", + "--query-compute-apps=pid,used_memory", + "--format=csv", + "-l", + "1", + "-f", + MEMORY_FILE, + ] + ) return p -def end_memory_tracking(p, success): + +def end_memory_tracking(p, success): logger.info("terminating memory tracking process") p.terminate() p.wait() p.kill() mem_usage = None if success: - mem_usage = get_max_memory() + mem_usage = get_max_memory() if os.path.exists(MEMORY_FILE): os.remove(MEMORY_FILE) return mem_usage + def inference_ort_with_ep(ep, session, repeat_times, sess_outputs, sess_inputs, with_binding, io_binding): - if with_binding and cpu not in ep: # other eps utilize python binding - runtime = timeit.repeat(lambda: session.run_with_iobinding(io_binding), number=1, repeat=repeat_times) + if with_binding and cpu not in ep: # other eps utilize python binding + runtime = timeit.repeat( + lambda: session.run_with_iobinding(io_binding), + number=1, + repeat=repeat_times, + ) else: - runtime = timeit.repeat(lambda: session.run(sess_outputs, sess_inputs), number=1, repeat=repeat_times) + runtime = timeit.repeat( + lambda: session.run(sess_outputs, sess_inputs), + number=1, + repeat=repeat_times, + ) success = True return runtime, success -def inference_ort(args, name, session, ep, ort_inputs, result_template, repeat_times, batch_size, track_memory): + +def inference_ort( + args, + name, + session, + ep, + ort_inputs, + result_template, + repeat_times, + batch_size, + track_memory, +): runtimes = [] if args.input_data == "random": - repeat_times = 1 # warn-up run is included in ort_inputs + repeat_times = 1 # warn-up run is included in ort_inputs else: - repeat_times += 1 # add warn-up run - - mem_usages = [] + repeat_times += 1 # add warn-up run + + mem_usages = [] p = None mem_usage = None success = False @@ -268,34 +319,51 @@ def inference_ort(args, name, session, ep, ort_inputs, result_template, repeat_t if args.io_binding: for name, inp in sess_inputs.items(): io_binding.bind_cpu_input(name, inp) - for out in sess_outputs: + for out in sess_outputs: io_binding.bind_output(out) - + try: - if track_memory: - p = start_memory_tracking() - runtime, success = inference_ort_with_ep(ep, session, repeat_times, sess_outputs, sess_inputs, args.io_binding, io_binding) + if track_memory: + p = start_memory_tracking() + runtime, success = inference_ort_with_ep( + ep, + session, + repeat_times, + sess_outputs, + sess_inputs, + args.io_binding, + io_binding, + ) mem_usage = end_memory_tracking(p, success) - mem_usages.append(mem_usage) - else: - runtime, success = inference_ort_with_ep(ep, session, repeat_times, sess_outputs, sess_inputs, args.io_binding, io_binding) - runtimes += runtime[1:] # remove warmup - + mem_usages.append(mem_usage) + else: + runtime, success = inference_ort_with_ep( + ep, + session, + repeat_times, + sess_outputs, + sess_inputs, + args.io_binding, + io_binding, + ) + runtimes += runtime[1:] # remove warmup + except Exception as e: logger.error(e) if track_memory: end_memory_tracking(p, success) - raise(e) + raise (e) - if len(mem_usages) > 0: + if len(mem_usages) > 0: mem_usage = max(mem_usages) - + result = {} result.update(result_template) result.update({"io_binding": True}) latency_result = get_latency_result(runtimes, batch_size) result.update(latency_result) - return result, mem_usage + return result, mem_usage + def inference_ort_and_get_prediction(name, session, ort_inputs): @@ -309,35 +377,38 @@ def inference_ort_and_get_prediction(name, session, ort_inputs): logger.info(sess_outputs) result = session.run(sess_outputs, sess_inputs) - + if debug: logger.info("ORT session output results:") logger.info(result) # handle shape of output differently - if 'bert_squad' in name.lower(): + if "bert_squad" in name.lower(): ort_outputs.append([result]) - elif 'shufflenet-v2' in name.lower() or 'shufflenet_v2' in name.lower(): + elif "shufflenet-v2" in name.lower() or "shufflenet_v2" in name.lower(): ort_outputs.append(result[0]) else: ort_outputs.append(result) return ort_outputs + def get_acl_version(): from pathlib import Path + home = str(Path.home()) p = subprocess.run(["find", home, "-name", "libarm_compute.so"], check=True, stdout=subprocess.PIPE) libarm_compute_path = p.stdout.decode("ascii").strip() - if libarm_compute_path == '': + if libarm_compute_path == "": return "No Compute Library Found" else: - p = subprocess.run(["strings", libarm_compute_path], check=True, stdout=subprocess.PIPE) + p = subprocess.run(["strings", libarm_compute_path], check=True, stdout=subprocess.PIPE) libarm_so_strings = p.stdout.decode("ascii").strip() - version_match = re.search(r'arm_compute_version.*\n', libarm_so_strings) - version = version_match.group(0).split(' ')[0] + version_match = re.search(r"arm_compute_version.*\n", libarm_so_strings) + version = version_match.group(0).split(" ")[0] return version + ####################################################################################################################################### # The following two lists will be generated. # @@ -372,7 +443,7 @@ def load_onnx_model_zoo_test_data(path, all_inputs_shape, fp16): i = 0 for data in input_data: tensor = onnx.TensorProto() - with open(data, 'rb') as f: + with open(data, "rb") as f: tensor.ParseFromString(f.read()) tensor_to_array = numpy_helper.to_array(tensor) if fp16 and tensor_to_array.dtype == np.dtype(np.float32): @@ -383,18 +454,18 @@ def load_onnx_model_zoo_test_data(path, all_inputs_shape, fp16): all_inputs_shape.append(input_data_pb[-1].shape) logger.info(all_inputs_shape[-1]) inputs.append(input_data_pb) - logger.info('Loaded {} inputs successfully.'.format(len(inputs))) + logger.info("Loaded {} inputs successfully.".format(len(inputs))) # load outputs output = get_output(["find", ".", "-name", "output*"]) output_data = split_and_sort_output(output) - if len(output_data) > 0 and output_data[0] != '': + if len(output_data) > 0 and output_data[0] != "": logger.info(output_data) output_data_pb = [] for data in output_data: tensor = onnx.TensorProto() - with open(data, 'rb') as f: + with open(data, "rb") as f: tensor.ParseFromString(f.read()) tensor_to_array = numpy_helper.to_array(tensor) @@ -405,11 +476,12 @@ def load_onnx_model_zoo_test_data(path, all_inputs_shape, fp16): logger.info(np.array(output_data_pb[-1]).shape) outputs.append(output_data_pb) - logger.info('Loaded {} outputs successfully.'.format(len(outputs))) + logger.info("Loaded {} outputs successfully.".format(len(outputs))) os.chdir(pwd) return inputs, outputs + def generate_onnx_model_random_input(test_times, ref_input): inputs = [] @@ -419,15 +491,17 @@ def generate_onnx_model_random_input(test_times, ref_input): for tensor in ref_input: shape = tensor.shape dtype = tensor.dtype - if dtype == np.int8 or \ - dtype == np.uint8 or \ - dtype == np.int16 or \ - dtype == np.uint16 or \ - dtype == np.int32 or \ - dtype == np.uint32 or \ - dtype == np.int64 or \ - dtype == np.uint64: - new_tensor = np.random.randint(0, np.max(tensor)+1, shape, dtype) + if ( + dtype == np.int8 + or dtype == np.uint8 + or dtype == np.int16 + or dtype == np.uint16 + or dtype == np.int32 + or dtype == np.uint32 + or dtype == np.int64 + or dtype == np.uint64 + ): + new_tensor = np.random.randint(0, np.max(tensor) + 1, shape, dtype) else: new_tensor = np.random.random_sample(shape).astype(dtype) @@ -443,22 +517,24 @@ def generate_onnx_model_random_input(test_times, ref_input): return inputs + def percentage_in_allowed_threshold(e, percent_mismatch): - percent_string = re.search(r'\(([^)]+)', str(e)).group(1) + percent_string = re.search(r"\(([^)]+)", str(e)).group(1) if "%" in percent_string: - percentage_wrong = float(percent_string.replace("%","")) + percentage_wrong = float(percent_string.replace("%", "")) return percentage_wrong < percent_mismatch - else: - return False # error in output + else: + return False # error in output + def validate(all_ref_outputs, all_outputs, rtol, atol, percent_mismatch): if len(all_ref_outputs) == 0: logger.info("No reference output provided.") return True, None - logger.info('Reference {} results.'.format(len(all_ref_outputs))) - logger.info('Predicted {} results.'.format(len(all_outputs))) - logger.info('rtol: {}, atol: {}'.format(rtol, atol)) + logger.info("Reference {} results.".format(len(all_ref_outputs))) + logger.info("Predicted {} results.".format(len(all_outputs))) + logger.info("rtol: {}, atol: {}".format(rtol, atol)) for i in range(len(all_outputs)): ref_outputs = all_ref_outputs[i] @@ -474,14 +550,15 @@ def validate(all_ref_outputs, all_outputs, rtol, atol, percent_mismatch): try: np.testing.assert_allclose(ref_o, o, rtol, atol) except Exception as e: - if percentage_in_allowed_threshold(e, percent_mismatch): + if percentage_in_allowed_threshold(e, percent_mismatch): continue logger.error(e) return False, e - logger.info('ONNX Runtime outputs are similar to reference outputs!') + logger.info("ONNX Runtime outputs are similar to reference outputs!") return True, None + # not use for this script def cleanup_files(): files = [] @@ -504,27 +581,30 @@ def cleanup_files(): if "custom_test_data" in f: logger.info(f) continue - subprocess.Popen(["rm","-rf", f], stdout=subprocess.PIPE) + subprocess.Popen(["rm", "-rf", f], stdout=subprocess.PIPE) + def remove_files(running_mode, path): files = [] - out = "" - if running_mode == "validate": + out = "" + if running_mode == "validate": out = get_output(["find", path, "-name", "onnxruntime_profile*"]) - if running_mode == "benchmark": + if running_mode == "benchmark": logger.info(running_mode) out = get_output(["find", path, "-name", "*.engine"]) + def update_fail_report(fail_results, model, ep, e_type, e): result = {} result["model"] = model result["ep"] = ep result["error type"] = e_type - result["error message"] = re.sub('^\n', '', str(e)) + result["error message"] = re.sub("^\n", "", str(e)) fail_results.append(result) + def update_metrics_map(model_to_metrics, model_name, ep_to_operator): if len(ep_to_operator) <= 0: return @@ -537,14 +617,20 @@ def update_metrics_map(model_to_metrics, model_name, ep_to_operator): model_to_metrics[model_name][ep] = {} if ep == cuda or ep == cuda_fp16: - model_to_metrics[model_name][ep]['ratio_of_ops_in_cuda_not_fallback_cpu'] = calculate_cuda_op_percentage(op_map) - model_to_metrics[model_name][ep]['total_ops'] = get_total_ops(op_map) + model_to_metrics[model_name][ep]["ratio_of_ops_in_cuda_not_fallback_cpu"] = calculate_cuda_op_percentage( + op_map + ) + model_to_metrics[model_name][ep]["total_ops"] = get_total_ops(op_map) else: - total_trt_execution_time, total_execution_time, ratio_of_execution_time_in_trt = calculate_trt_latency_percentage(op_map) - model_to_metrics[model_name][ep]['total_ops'] = get_total_ops(op_map) - model_to_metrics[model_name][ep]['total_trt_execution_time'] = total_trt_execution_time - model_to_metrics[model_name][ep]['total_execution_time'] = total_execution_time - model_to_metrics[model_name][ep]['ratio_of_execution_time_in_trt'] = ratio_of_execution_time_in_trt + ( + total_trt_execution_time, + total_execution_time, + ratio_of_execution_time_in_trt, + ) = calculate_trt_latency_percentage(op_map) + model_to_metrics[model_name][ep]["total_ops"] = get_total_ops(op_map) + model_to_metrics[model_name][ep]["total_trt_execution_time"] = total_trt_execution_time + model_to_metrics[model_name][ep]["total_execution_time"] = total_execution_time + model_to_metrics[model_name][ep]["ratio_of_execution_time_in_trt"] = ratio_of_execution_time_in_trt def update_metrics_map_ori(model_to_metrics, name, ep_to_operator): @@ -566,48 +652,64 @@ def update_metrics_map_ori(model_to_metrics, name, ep_to_operator): elif ep == trt_fp16: trt_fp16_op_map = op_map - if name not in model_to_metrics: model_to_metrics[name] = {} if cuda_op_map: - model_to_metrics[name]['ratio_of_ops_in_cuda_not_fallback_cpu'] = calculate_cuda_op_percentage(cuda_op_map) + model_to_metrics[name]["ratio_of_ops_in_cuda_not_fallback_cpu"] = calculate_cuda_op_percentage(cuda_op_map) if trt_op_map: - total_trt_execution_time, total_execution_time, ratio_of_execution_time_in_trt = calculate_trt_latency_percentage(trt_op_map) - model_to_metrics[name]['total_trt_execution_time'] = total_trt_execution_time - model_to_metrics[name]['total_execution_time'] = total_execution_time - model_to_metrics[name]['ratio_of_execution_time_in_trt'] = ratio_of_execution_time_in_trt + ( + total_trt_execution_time, + total_execution_time, + ratio_of_execution_time_in_trt, + ) = calculate_trt_latency_percentage(trt_op_map) + model_to_metrics[name]["total_trt_execution_time"] = total_trt_execution_time + model_to_metrics[name]["total_execution_time"] = total_execution_time + model_to_metrics[name]["ratio_of_execution_time_in_trt"] = ratio_of_execution_time_in_trt if cuda_op_map: - total_ops_in_trt, total_ops, ratio_of_ops_in_trt = calculate_trt_op_percentage(trt_op_map, cuda_op_map) - model_to_metrics[name]['total_ops_in_trt'] = total_ops_in_trt - model_to_metrics[name]['total_ops'] = total_ops - model_to_metrics[name]['ratio_of_ops_in_trt'] = ratio_of_ops_in_trt + ( + total_ops_in_trt, + total_ops, + ratio_of_ops_in_trt, + ) = calculate_trt_op_percentage(trt_op_map, cuda_op_map) + model_to_metrics[name]["total_ops_in_trt"] = total_ops_in_trt + model_to_metrics[name]["total_ops"] = total_ops + model_to_metrics[name]["ratio_of_ops_in_trt"] = ratio_of_ops_in_trt if trt_fp16_op_map: - total_trt_execution_time, total_execution_time, ratio_of_execution_time_in_trt = calculate_trt_latency_percentage(trt_fp16_op_map) + ( + total_trt_execution_time, + total_execution_time, + ratio_of_execution_time_in_trt, + ) = calculate_trt_latency_percentage(trt_fp16_op_map) name_ = name + " (FP16)" model_to_metrics[name_] = {} - model_to_metrics[name_]['total_trt_execution_time'] = total_trt_execution_time - model_to_metrics[name_]['total_execution_time'] = total_execution_time - model_to_metrics[name_]['ratio_of_execution_time_in_trt'] = ratio_of_execution_time_in_trt + model_to_metrics[name_]["total_trt_execution_time"] = total_trt_execution_time + model_to_metrics[name_]["total_execution_time"] = total_execution_time + model_to_metrics[name_]["ratio_of_execution_time_in_trt"] = ratio_of_execution_time_in_trt if cuda_fp16_op_map: - total_ops_in_trt, total_ops, ratio_of_ops_in_trt = calculate_trt_op_percentage(trt_fp16_op_map, cuda_op_map) - model_to_metrics[name_]['total_ops_in_trt'] = total_ops_in_trt - model_to_metrics[name_]['total_ops'] = total_ops - model_to_metrics[name_]['ratio_of_ops_in_trt'] = ratio_of_ops_in_trt + ( + total_ops_in_trt, + total_ops, + ratio_of_ops_in_trt, + ) = calculate_trt_op_percentage(trt_fp16_op_map, cuda_op_map) + model_to_metrics[name_]["total_ops_in_trt"] = total_ops_in_trt + model_to_metrics[name_]["total_ops"] = total_ops + model_to_metrics[name_]["ratio_of_ops_in_trt"] = ratio_of_ops_in_trt if debug: pp = pprint.PrettyPrinter(indent=4) - logger.info('CUDA operator map:') + logger.info("CUDA operator map:") pp.pprint(cuda_op_map) - logger.info('TRT operator map:') + logger.info("TRT operator map:") pp.pprint(trt_op_map) - logger.info('CUDA FP16 operator map:') + logger.info("CUDA FP16 operator map:") pp.pprint(cuda_fp16_op_map) - logger.info('TRT FP16 operator map:') + logger.info("TRT FP16 operator map:") pp.pprint(trt_fp16_op_map) + ################################################################################################### # # model: {ep1: {error_type: xxx, error_message: xxx}, ep2: {error_type: xx, error_message: xx}} @@ -619,11 +721,11 @@ def update_fail_model_map(model_to_fail_ep, model_name, ep, e_type, e): return if model_name not in model_to_fail_ep: - model_to_fail_ep[model_name] = {} + model_to_fail_ep[model_name] = {} new_map = {} new_map["error_type"] = e_type - new_map["error_message"] = re.sub('^\n', '', str(e)) + new_map["error_message"] = re.sub("^\n", "", str(e)) model_to_fail_ep[model_name][ep] = new_map # If TRT fails, TRT FP16 should fail as well @@ -633,7 +735,8 @@ def update_fail_model_map(model_to_fail_ep, model_name, ep, e_type, e): new_map_1 = {} new_map_1["error_type"] = e_type new_map_1["error_message"] = e_ - model_to_fail_ep[model_name][ep_] = new_map_1 + model_to_fail_ep[model_name][ep_] = new_map_1 + def update_fail_model_map_ori(model_to_fail_ep, fail_results, model_name, ep, e_type, e): @@ -641,8 +744,8 @@ def update_fail_model_map_ori(model_to_fail_ep, fail_results, model_name, ep, e_ return if model_name not in model_to_fail_ep: - model_to_fail_ep[model_name] = {} - + model_to_fail_ep[model_name] = {} + model_to_fail_ep[model_name][ep] = e_type update_fail_report(fail_results, model_name, ep, e_type, e) @@ -653,6 +756,7 @@ def update_fail_model_map_ori(model_to_fail_ep, fail_results, model_name, ep, e_ update_fail_report(fail_results, model_name, ep_, e_type, error_message_) model_to_fail_ep[model_name][ep_] = e_type + def skip_ep(model_name, ep, model_to_fail_ep): if model_name not in model_to_fail_ep: @@ -667,6 +771,7 @@ def skip_ep(model_name, ep, model_to_fail_ep): return False + def read_map_from_file(map_file): with open(map_file) as f: try: @@ -676,42 +781,46 @@ def read_map_from_file(map_file): return data + def write_map_to_file(result, file_name): existed_result = {} if os.path.exists(file_name): existed_result = read_map_from_file(file_name) - + for model, ep_list in result.items(): if model in existed_result: - existed_result[model] = {** existed_result[model], ** result[model]} + existed_result[model] = {**existed_result[model], **result[model]} else: existed_result[model] = result[model] - with open(file_name, 'w') as file: - file.write(json.dumps(existed_result)) # use `json.loads` to do the reverse + with open(file_name, "w") as file: + file.write(json.dumps(existed_result)) # use `json.loads` to do the reverse def get_cuda_version(): - nvidia_strings = get_output(["nvidia-smi"]) - version = re.search(r'CUDA Version: \d\d\.\d', nvidia_strings).group(0) + nvidia_strings = get_output(["nvidia-smi"]) + version = re.search(r"CUDA Version: \d\d\.\d", nvidia_strings).group(0) return version - + + def get_trt_version(workspace): libnvinfer = get_output(["find", workspace, "-name", "libnvinfer.so.*"]) - nvinfer = re.search(r'.*libnvinfer.so.*', libnvinfer).group(0) + nvinfer = re.search(r".*libnvinfer.so.*", libnvinfer).group(0) trt_strings = get_output(["nm", "-D", nvinfer]) - version = re.search(r'tensorrt_version.*', trt_strings).group(0) + version = re.search(r"tensorrt_version.*", trt_strings).group(0) return version - -def get_linux_distro(): + + +def get_linux_distro(): linux_strings = get_output(["cat", "/etc/os-release"]) stdout = linux_strings.split("\n")[:2] infos = [] for row in stdout: - row = re.sub('=', ': ', row) - row = re.sub('"', '', row) + row = re.sub("=", ": ", row) + row = re.sub('"', "", row) infos.append(row) - return infos + return infos + def get_memory_info(): mem_strings = get_output(["cat", "/proc/meminfo"]) @@ -719,35 +828,39 @@ def get_memory_info(): infos = [] for row in stdout: if "Mem" in row: - row = re.sub(': +', ': ', row) + row = re.sub(": +", ": ", row) infos.append(row) return infos -def get_cpu_info(): + +def get_cpu_info(): cpu_strings = get_output(["lscpu"]) stdout = cpu_strings.split("\n") infos = [] for row in stdout: if "mode" in row or "Arch" in row or "name" in row: - row = re.sub(': +', ': ', row) + row = re.sub(": +", ": ", row) infos.append(row) return infos + def get_gpu_info(): info = get_output(["lspci", "-v"]) - infos = re.findall('NVIDIA.*', info) + infos = re.findall("NVIDIA.*", info) return infos + def get_cudnn_version(workspace): cudnn_path = get_output(["whereis", "cudnn_version.h"]) - cudnn_path = re.search(': (.*)', cudnn_path).group(1) + cudnn_path = re.search(": (.*)", cudnn_path).group(1) cudnn_outputs = get_output(["cat", cudnn_path]) - major = re.search('CUDNN_MAJOR (.*)', cudnn_outputs).group(1) - minor = re.search('CUDNN_MINOR (.*)', cudnn_outputs).group(1) - patch = re.search('CUDNN_PATCHLEVEL (.*)', cudnn_outputs).group(1) - cudnn_version = major + '.' + minor + '.' + patch + major = re.search("CUDNN_MAJOR (.*)", cudnn_outputs).group(1) + minor = re.search("CUDNN_MINOR (.*)", cudnn_outputs).group(1) + patch = re.search("CUDNN_PATCHLEVEL (.*)", cudnn_outputs).group(1) + cudnn_version = major + "." + minor + "." + patch return cudnn_version + def get_system_info(args): info = {} info["cuda"] = get_cuda_version() @@ -757,21 +870,25 @@ def get_system_info(args): info["cpu_info"] = get_cpu_info() info["gpu_info"] = get_gpu_info() info["memory"] = get_memory_info() - info["ep_option_overrides"] = {trt_ep: args.trt_ep_options, cuda_ep: args.cuda_ep_options} + info["ep_option_overrides"] = { + trt_ep: args.trt_ep_options, + cuda_ep: args.cuda_ep_options, + } return info + def find_model_path(path): output = get_output(["find", "-L", path, "-name", "*.onnx"]) model_path = split_and_sort_output(output) logger.info(model_path) - if model_path == ['']: + if model_path == [""]: return None target_model_path = [] for m in model_path: - if "by_trt_perf" in m or m.startswith('.'): + if "by_trt_perf" in m or m.startswith("."): continue target_model_path.append(m) @@ -782,121 +899,141 @@ def find_model_path(path): return target_model_path[0] + def find_model_directory(path): - output = get_output(["find", "-L", path, "-maxdepth", "1", "-mindepth", "1", "-name", "*", "-type", "d"]) + output = get_output( + [ + "find", + "-L", + path, + "-maxdepth", + "1", + "-mindepth", + "1", + "-name", + "*", + "-type", + "d", + ] + ) model_dir = split_and_sort_output(output) - if model_dir == ['']: + if model_dir == [""]: return None return model_dir + def find_test_data_directory(path): output = get_output(["find", "-L", path, "-maxdepth", "1", "-name", "test_data*", "-type", "d"]) test_data_dir = split_and_sort_output(output) logger.info(test_data_dir) - if test_data_dir == ['']: + if test_data_dir == [""]: return None return test_data_dir + def parse_models_info_from_directory(path, models): - test_data_dir = find_test_data_directory(path) + test_data_dir = find_test_data_directory(path) if test_data_dir: model_name = os.path.split(path)[-1] - model_name = model_name + '_' + os.path.split(os.path.split(path)[0])[-1] # get opset version as model_name + model_name = model_name + "_" + os.path.split(os.path.split(path)[0])[-1] # get opset version as model_name model_path = find_model_path(path) model = {} model["model_name"] = model_name - model["model_path"] = model_path - model["working_directory"] = path - model["test_data_path"] = path + model["model_path"] = model_path + model["working_directory"] = path + model["test_data_path"] = path - models[model_name] = model + models[model_name] = model logger.info(model) return - + model_dir = find_model_directory(path) - + if model_dir: for dir in model_dir: parse_models_info_from_directory(os.path.join(path, dir), models) - + def parse_models_info_from_file(root_dir, path, models): # default working directory - root_working_directory = root_dir + 'perf/' + root_working_directory = root_dir + "perf/" with open(path) as f: data = json.load(f) for row in data: - if 'root_working_directory' in row: - root_working_directory = row['root_working_directory'] + if "root_working_directory" in row: + root_working_directory = row["root_working_directory"] continue - if 'model_name' in row: - models[row['model_name']] = {} + if "model_name" in row: + models[row["model_name"]] = {} else: - logger.error('Model name must be provided in models_info.json') + logger.error("Model name must be provided in models_info.json") raise - model = models[row['model_name']] + model = models[row["model_name"]] - if 'working_directory' in row: - if os.path.isabs(row['working_directory']): - model['working_directory'] = row['working_directory'] + if "working_directory" in row: + if os.path.isabs(row["working_directory"]): + model["working_directory"] = row["working_directory"] else: - model['working_directory'] = os.path.join(root_working_directory, row['working_directory']) + model["working_directory"] = os.path.join(root_working_directory, row["working_directory"]) else: - logger.error('Model path must be provided in models_info.json') + logger.error("Model path must be provided in models_info.json") raise - if 'model_path' in row: - model['model_path'] = row['model_path'] + if "model_path" in row: + model["model_path"] = row["model_path"] else: - logger.error('Model path must be provided in models_info.json') + logger.error("Model path must be provided in models_info.json") raise - if 'test_data_path' in row: - model['test_data_path'] = row['test_data_path'] + if "test_data_path" in row: + model["test_data_path"] = row["test_data_path"] else: - logger.error('Test data path must be provided in models_info.json') + logger.error("Test data path must be provided in models_info.json") raise - if 'model_path_fp16' in row: - model['model_path_fp16'] = row['model_path_fp16'] + if "model_path_fp16" in row: + model["model_path_fp16"] = row["model_path_fp16"] - if 'test_data_path_fp16' in row: - model['test_data_path_fp16'] = row['test_data_path_fp16'] + if "test_data_path_fp16" in row: + model["test_data_path_fp16"] = row["test_data_path_fp16"] def convert_model_from_float_to_float16(model_path): - from onnxmltools.utils import load_model, save_model from float16 import convert_float_to_float16 + from onnxmltools.utils import load_model, save_model new_model_path = os.path.join(os.getcwd(), "new_fp16_model_by_trt_perf.onnx") if not os.path.exists(new_model_path): onnx_model = load_model(model_path) new_onnx_model = convert_float_to_float16(onnx_model) - save_model(new_onnx_model, 'new_fp16_model_by_trt_perf.onnx') + save_model(new_onnx_model, "new_fp16_model_by_trt_perf.onnx") return new_model_path + def get_test_data(fp16, test_data_dir, all_inputs_shape): inputs = [] ref_outputs = [] inputs, ref_outputs = load_onnx_model_zoo_test_data(test_data_dir, all_inputs_shape, fp16) return inputs, ref_outputs -def run_symbolic_shape_inference(model_path, new_model_path): + +def run_symbolic_shape_inference(model_path, new_model_path): import onnxruntime.tools.symbolic_shape_infer as symbolic_shape_infer + logger.info("run symbolic shape inference") try: out = symbolic_shape_infer.SymbolicShapeInference.infer_shapes(onnx.load(model_path), auto_merge=True) @@ -906,6 +1043,7 @@ def run_symbolic_shape_inference(model_path, new_model_path): logger.error(e) return False, "Symbolic shape inference error" + def get_provider_options(providers, trt_ep_options, cuda_ep_options): provider_options = [] @@ -919,14 +1057,20 @@ def get_provider_options(providers, trt_ep_options, cuda_ep_options): return provider_options + def time_and_create_session(model_path, providers, provider_options, session_options): start = datetime.now() - session = onnxruntime.InferenceSession(model_path, providers=providers, provider_options=provider_options, - sess_options=session_options) + session = onnxruntime.InferenceSession( + model_path, + providers=providers, + provider_options=provider_options, + sess_options=session_options, + ) end = datetime.now() creation_time = (end - start).total_seconds() return session, creation_time + def create_session(model_path, providers, provider_options, session_options): logger.info(model_path) @@ -939,24 +1083,25 @@ def create_session(model_path, providers, provider_options, session_options): new_model_path = model_path[:].replace(".onnx", "_new_by_trt_perf.onnx") if not os.path.exists(new_model_path): status = run_symbolic_shape_inference(model_path, new_model_path) - if not status[0]: # symbolic shape inference error + if not status[0]: # symbolic shape inference error e = status[1] - raise Exception(e) + raise Exception(e) return time_and_create_session(new_model_path, providers, provider_options, session_options) else: - raise Exception(e) + raise Exception(e) + def run_onnxruntime(args, models): success_results = [] - model_to_latency = {} # model -> cuda and tensorrt latency - model_to_metrics = {} # model -> metrics from profiling file - model_to_fail_ep = {} # model -> failing ep - model_to_session = {} # models -> session creation time - - if args.running_mode == "benchmark": + model_to_latency = {} # model -> cuda and tensorrt latency + model_to_metrics = {} # model -> metrics from profiling file + model_to_fail_ep = {} # model -> failing ep + model_to_session = {} # models -> session creation time + + if args.running_mode == "benchmark": model_to_session = read_map_from_file(SESSION_FILE) - + ep_list = [] if args.ep: ep_list.append(args.ep) @@ -968,7 +1113,6 @@ def run_onnxruntime(args, models): validation_exemption = [trt_fp16] - if os.path.exists(FAIL_MODEL_FILE): model_to_fail_ep = read_map_from_file(FAIL_MODEL_FILE) @@ -984,24 +1128,23 @@ def run_onnxruntime(args, models): os.mkdir(path) os.chdir(path) path = os.getcwd() - + inputs = [] ref_outputs = [] - all_inputs_shape = [] # use for standalone trt - ep_to_operator = {} # ep -> { operator -> count } + all_inputs_shape = [] # use for standalone trt + ep_to_operator = {} # ep -> { operator -> count } profile_already_parsed = set() - ####################### # iterate ep ####################### for ep in ep_list: if skip_ep(name, ep, model_to_fail_ep): continue - + if not is_standalone(ep): ep_ = ep_to_provider_list[ep][0] - if (ep_ not in onnxruntime.get_available_providers()): + if ep_ not in onnxruntime.get_available_providers(): logger.error("No {} support".format(ep_)) continue @@ -1010,16 +1153,16 @@ def run_onnxruntime(args, models): logger.info("[Initialize] model = {}, ep = {} ...".format(name, ep)) - # Set environment variables for ort-trt benchmarking + # Set environment variables for ort-trt benchmarking trt_ep_options = copy.deepcopy(args.trt_ep_options) if "ORT-TRT" in ep: trt_ep_options["trt_fp16_enable"] = "True" if "Fp16" in ep else "False" fp16 = False - - # use float16.py for cuda fp16 only - if cuda_fp16 == ep: - + + # use float16.py for cuda fp16 only + if cuda_fp16 == ep: + # handle model if "model_path_fp16" in model_info: model_path = model_info["model_path_fp16"] @@ -1030,71 +1173,78 @@ def run_onnxruntime(args, models): fp16 = True except Exception as e: logger.error(e) - update_fail_model_map(model_to_fail_ep, name, ep, 'script error', e) + update_fail_model_map(model_to_fail_ep, name, ep, "script error", e) continue # handle test data if "test_data_path_fp16" in model_info: test_data_dir = model_info["test_data_path_fp16"] - fp16 = False - - if standalone_trt_fp16 == ep: + fp16 = False + + if standalone_trt_fp16 == ep: fp16 = True - + inputs, ref_outputs = get_test_data(fp16, test_data_dir, all_inputs_shape) # generate random input data if args.input_data == "random": - inputs = generate_onnx_model_random_input(args.test_times+1, inputs[0]) + inputs = generate_onnx_model_random_input(args.test_times + 1, inputs[0]) ####################################### # benchmark or validation ####################################### - if args.running_mode == 'benchmark': + if args.running_mode == "benchmark": logger.info("\n----------------------------- benchmark -------------------------------------") - - # memory tracking variables - p = None + + # memory tracking variables + p = None mem_usage = None result = None # get standalone TensorRT perf - if is_standalone(ep) and args.trtexec: - try: - result = run_trt_standalone(args.trtexec, name, model_path, all_inputs_shape, fp16, args.track_memory) - except Exception as e: + if is_standalone(ep) and args.trtexec: + try: + result = run_trt_standalone( + args.trtexec, + name, + model_path, + all_inputs_shape, + fp16, + args.track_memory, + ) + except Exception as e: logger.error(e) - update_fail_model_map(model_to_fail_ep, name, ep, 'runtime error', e) + update_fail_model_map(model_to_fail_ep, name, ep, "runtime error", e) continue # inference with onnxruntime ep - else: + else: # resolve providers to create session - providers = ep_to_provider_list[ep] + providers = ep_to_provider_list[ep] provider_options = get_provider_options(providers, trt_ep_options, args.cuda_ep_options) options = onnxruntime.SessionOptions() - + enablement = args.graph_enablement if enablement == enable_all: options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL - elif enablement == extended: + elif enablement == extended: options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_EXTENDED - elif enablement == basic: + elif enablement == basic: options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_BASIC - else: # disable + else: # disable options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_DISABLE_ALL - + # create onnxruntime inference session try: sess, second_creation_time = create_session(model_path, providers, provider_options, options) except Exception as e: logger.error(e) - update_fail_model_map(model_to_fail_ep, name, ep, 'runtime error', e) + update_fail_model_map(model_to_fail_ep, name, ep, "runtime error", e) continue - if second_creation_time: + if second_creation_time: model_to_session[name] = copy.deepcopy({ep + second: second_creation_time}) - + logger.info("start to inference {} with {} ...".format(name, ep)) logger.info(sess.get_providers()) logger.info(sess.get_provider_options()) @@ -1120,48 +1270,58 @@ def run_onnxruntime(args, models): "inputs": len(sess.get_inputs()), "batch_size": batch_size, "sequence_length": 1, - "datetime": str(datetime.now()),} - + "datetime": str(datetime.now()), + } + # run cpu fewer times - repeat_times = 100 if ep == cpu else args.test_times + repeat_times = 100 if ep == cpu else args.test_times track_memory = False if ep == cpu else args.track_memory - + # inference with ort - try: - result, mem_usage = inference_ort(args, name, sess, ep, inputs, result_template, repeat_times, batch_size, track_memory) + try: + result, mem_usage = inference_ort( + args, + name, + sess, + ep, + inputs, + result_template, + repeat_times, + batch_size, + track_memory, + ) except Exception as e: logger.error(e) - update_fail_model_map(model_to_fail_ep, name, ep, 'runtime error', e) + update_fail_model_map(model_to_fail_ep, name, ep, "runtime error", e) continue - + if result: - + latency_result[ep] = {} latency_result[ep]["average_latency_ms"] = result["average_latency_ms"] latency_result[ep]["latency_90_percentile"] = result["latency_90_percentile"] - if "memory" in result: + if "memory" in result: mem_usage = result["memory"] - if mem_usage: + if mem_usage: latency_result[ep]["memory"] = mem_usage - if not args.trtexec: # skip standalone + if not args.trtexec: # skip standalone success_results.append(result) model_to_latency[name] = copy.deepcopy(latency_result) - - if ep == trt_fp16: # delete engine + + if ep == trt_fp16: # delete engine remove_files(args.running_mode, model_info["working_directory"]) logger.info("---------------------------- benchmark [end] ----------------------------------\n") - - elif args.running_mode == 'validate': + elif args.running_mode == "validate": logger.info("\n----------------------------- validate -------------------------------------") # enable profiling to generate profiling file for analysis options = onnxruntime.SessionOptions() options.enable_profiling = True options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL - time.sleep(1) # avoid to generate same profile file name + time.sleep(1) # avoid to generate same profile file name providers = ep_to_provider_list[ep] provider_options = get_provider_options(providers, trt_ep_options, args.cuda_ep_options) @@ -1172,10 +1332,10 @@ def run_onnxruntime(args, models): except Exception as e: logger.error(e) - update_fail_model_map(model_to_fail_ep, name, ep, 'runtime error', e) + update_fail_model_map(model_to_fail_ep, name, ep, "runtime error", e) continue - if creation_time: + if creation_time: model_to_session[name] = copy.deepcopy({ep: creation_time}) sess.disable_fallback() @@ -1199,14 +1359,26 @@ def run_onnxruntime(args, models): try: ort_outputs = inference_ort_and_get_prediction(name, sess, inputs) - status = validate(ref_outputs, ort_outputs, args.rtol, args.atol, args.percent_mismatch) + status = validate( + ref_outputs, + ort_outputs, + args.rtol, + args.atol, + args.percent_mismatch, + ) if not status[0]: remove_files(args.running_mode, model_info["working_directory"]) - update_fail_model_map(model_to_fail_ep, name, ep, 'result accuracy issue', status[1]) + update_fail_model_map( + model_to_fail_ep, + name, + ep, + "result accuracy issue", + status[1], + ) continue except Exception as e: logger.error(e) - update_fail_model_map(model_to_fail_ep, name, ep, 'runtime error', e) + update_fail_model_map(model_to_fail_ep, name, ep, "runtime error", e) continue # Run inference again. the reason is that some ep like tensorrt @@ -1232,7 +1404,6 @@ def run_onnxruntime(args, models): # end of iterate ep #################### - # get percentage of execution time and operators in TRT update_metrics_map(model_to_metrics, name, ep_to_operator) @@ -1241,14 +1412,22 @@ def run_onnxruntime(args, models): # end of model - return success_results, model_to_latency, model_to_fail_ep, model_to_metrics, model_to_session + return ( + success_results, + model_to_latency, + model_to_fail_ep, + model_to_metrics, + model_to_session, + ) + -def calculate_gain(value, ep1, ep2): - ep1_latency = float(value[ep1]['average_latency_ms']) - ep2_latency = float(value[ep2]['average_latency_ms']) - gain = (ep2_latency - ep1_latency)*100/ep2_latency +def calculate_gain(value, ep1, ep2): + ep1_latency = float(value[ep1]["average_latency_ms"]) + ep2_latency = float(value[ep2]["average_latency_ms"]) + gain = (ep2_latency - ep1_latency) * 100 / ep2_latency return gain + def add_improvement_information(model_to_latency): for key, value in model_to_latency.items(): if "ORT-TRT" in value and "ORT-CUDA" in value: @@ -1264,14 +1443,33 @@ def add_improvement_information(model_to_latency): gain = calculate_gain(value, trt_fp16, standalone_trt_fp16) value[trt_native_fp16_gain] = "{:.2f} %".format(gain) + def output_details(results, csv_filename): - need_write_header = True + need_write_header = True if os.path.exists(csv_filename): - need_write_header = False + need_write_header = False - with open(csv_filename, mode="a", newline='') as csv_file: + with open(csv_filename, mode="a", newline="") as csv_file: column_names = [ - "engine", "version", "device", "fp16", "io_binding", "graph_optimizations", "enable_cache", "model_name", "inputs", "batch_size", "sequence_length", "datetime", "test_times", "QPS", "average_latency_ms", "latency_variance", "latency_90_percentile", "latency_95_percentile", "latency_99_percentile" + "engine", + "version", + "device", + "fp16", + "io_binding", + "graph_optimizations", + "enable_cache", + "model_name", + "inputs", + "batch_size", + "sequence_length", + "datetime", + "test_times", + "QPS", + "average_latency_ms", + "latency_variance", + "latency_90_percentile", + "latency_95_percentile", + "latency_99_percentile", ] csv_writer = csv.DictWriter(csv_file, fieldnames=column_names) @@ -1280,9 +1478,10 @@ def output_details(results, csv_filename): for result in results: csv_writer.writerow(result) + def output_fail(model_to_fail_ep, csv_filename): - with open(csv_filename, mode="w", newline='') as csv_file: + with open(csv_filename, mode="w", newline="") as csv_file: column_names = ["model", "ep", "error type", "error message"] csv_writer = csv.DictWriter(csv_file, fieldnames=column_names) @@ -1296,55 +1495,59 @@ def output_fail(model_to_fail_ep, csv_filename): result["error type"] = ep_info["error_type"] result["error message"] = ep_info["error_message"] csv_writer.writerow(result) - + + def read_success_from_file(success_file): success_results = [] with open(success_file) as success: - csv_reader = csv.DictReader(success) - for row in csv_reader: - success_results.append(row) + csv_reader = csv.DictReader(success) + for row in csv_reader: + success_results.append(row) success_json = json.loads(json.dumps(success_results, indent=4)) return success_json -def add_status_dict(status_dict, model_name, ep, status): + +def add_status_dict(status_dict, model_name, ep, status): if model_name not in status_dict: status_dict[model_name] = {} status_dict[model_name][ep] = status + def build_status(status_dict, results, is_fail): - - if is_fail: - for model, model_info in results.items(): - for ep, ep_info in model_info.items(): - model_name = model - ep = ep - status = 'Fail' - add_status_dict(status_dict, model_name, ep, status) - else: - for model, value in results.items(): - for ep, ep_info in value.items(): - model_name = model - ep = ep - status = 'Pass' - add_status_dict(status_dict, model_name, ep, status) - - return status_dict + + if is_fail: + for model, model_info in results.items(): + for ep, ep_info in model_info.items(): + model_name = model + ep = ep + status = "Fail" + add_status_dict(status_dict, model_name, ep, status) + else: + for model, value in results.items(): + for ep, ep_info in value.items(): + model_name = model + ep = ep + status = "Pass" + add_status_dict(status_dict, model_name, ep, status) + + return status_dict + def output_status(results, csv_filename): - - need_write_header = True + + need_write_header = True if os.path.exists(csv_filename): - need_write_header = False + need_write_header = False - with open(csv_filename, mode="a", newline='') as csv_file: + with open(csv_filename, mode="a", newline="") as csv_file: column_names = table_headers csv_writer = csv.writer(csv_file) if need_write_header: csv_writer.writerow(column_names) - + cpu_status = "" cuda_fp32_status = "" trt_fp32_status = "" @@ -1352,69 +1555,80 @@ def output_status(results, csv_filename): cuda_fp16_status = "" trt_fp16_status = "" standalone_fp16_status = "" - for model_name, ep_dict in results.items(): for ep, status in ep_dict.items(): - if ep == cpu: - cpu_status = status - elif ep == cuda: + if ep == cpu: + cpu_status = status + elif ep == cuda: cuda_fp32_status = status - elif ep == trt: + elif ep == trt: trt_fp32_status = status elif ep == standalone_trt: standalone_fp32_status = status - elif ep == cuda_fp16: + elif ep == cuda_fp16: cuda_fp16_status = status elif ep == trt_fp16: trt_fp16_status = status - elif ep == standalone_trt_fp16: + elif ep == standalone_trt_fp16: standalone_fp16_status = status - else: + else: continue - - row = [model_name, - cpu_status, - cuda_fp32_status, - trt_fp32_status, - standalone_fp32_status, - cuda_fp16_status, - trt_fp16_status, - standalone_fp16_status] - csv_writer.writerow(row) + row = [ + model_name, + cpu_status, + cuda_fp32_status, + trt_fp32_status, + standalone_fp32_status, + cuda_fp16_status, + trt_fp16_status, + standalone_fp16_status, + ] + csv_writer.writerow(row) -def output_specs(info, csv_filename): - cpu_version = info['cpu_info'][2] - gpu_version = info['gpu_info'][0] - tensorrt_version = info['trt'] + ' , *All ORT-TRT and TRT are run in Mixed Precision mode (Fp16 and Fp32).' - cuda_version = info['cuda'] - cudnn_version = info['cudnn'] - ep_option_overrides = json.dumps(info['ep_option_overrides']) - table = pd.DataFrame({'.': [1, 2, 3, 4, 5, 6], - 'Spec': ['CPU', 'GPU', 'TensorRT', 'CUDA', 'CuDNN', 'EPOptionOverrides'], - 'Version': [cpu_version, gpu_version, tensorrt_version, cuda_version, cudnn_version, - ep_option_overrides]}) +def output_specs(info, csv_filename): + cpu_version = info["cpu_info"][2] + gpu_version = info["gpu_info"][0] + tensorrt_version = info["trt"] + " , *All ORT-TRT and TRT are run in Mixed Precision mode (Fp16 and Fp32)." + cuda_version = info["cuda"] + cudnn_version = info["cudnn"] + ep_option_overrides = json.dumps(info["ep_option_overrides"]) + + table = pd.DataFrame( + { + ".": [1, 2, 3, 4, 5, 6], + "Spec": ["CPU", "GPU", "TensorRT", "CUDA", "CuDNN", "EPOptionOverrides"], + "Version": [ + cpu_version, + gpu_version, + tensorrt_version, + cuda_version, + cudnn_version, + ep_option_overrides, + ], + } + ) table.to_csv(csv_filename, index=False) + def output_session_creation(results, csv_filename): - need_write_header = True + need_write_header = True if os.path.exists(csv_filename): - need_write_header = False + need_write_header = False - with open(csv_filename, mode="a", newline='') as csv_file: + with open(csv_filename, mode="a", newline="") as csv_file: session_1 = [p + session_ending for p in ort_provider_list] session_2 = [p + second_session_ending for p in ort_provider_list] - column_names = [model_title] + session_1 + session_2 + column_names = [model_title] + session_1 + session_2 csv_writer = csv.writer(csv_file) - csv_writer = csv.writer(csv_file) if need_write_header: csv_writer.writerow(column_names) - + cpu_time = "" cuda_fp32_time = "" trt_fp32_time = "" @@ -1428,51 +1642,53 @@ def output_session_creation(results, csv_filename): for model_name, ep_dict in results.items(): for ep, time in ep_dict.items(): - if ep == cpu: - cpu_time = time - elif ep == cuda: + if ep == cpu: + cpu_time = time + elif ep == cuda: cuda_fp32_time = time - elif ep == trt: + elif ep == trt: trt_fp32_time = time - elif ep == cuda_fp16: + elif ep == cuda_fp16: cuda_fp16_time = time elif ep == trt_fp16: trt_fp16_time = time - if ep == cpu + second: - cpu_time_2 = time - elif ep == cuda + second: + if ep == cpu + second: + cpu_time_2 = time + elif ep == cuda + second: cuda_fp32_time_2 = time - elif ep == trt + second: + elif ep == trt + second: trt_fp32_time_2 = time - elif ep == cuda_fp16 + second: + elif ep == cuda_fp16 + second: cuda_fp16_time_2 = time elif ep == trt_fp16 + second: trt_fp16_time_2 = time - else: + else: continue - - row = [model_name, - cpu_time, - cuda_fp32_time, - trt_fp32_time, - cuda_fp16_time, - trt_fp16_time, - cpu_time_2, - cuda_fp32_time_2, - trt_fp32_time_2, - cuda_fp16_time_2, - trt_fp16_time_2] + + row = [ + model_name, + cpu_time, + cuda_fp32_time, + trt_fp32_time, + cuda_fp16_time, + trt_fp16_time, + cpu_time_2, + cuda_fp32_time_2, + trt_fp32_time_2, + cuda_fp16_time_2, + trt_fp16_time_2, + ] csv_writer.writerow(row) def output_latency(results, csv_filename): - need_write_header = True + need_write_header = True if os.path.exists(csv_filename): - need_write_header = False + need_write_header = False - with open(csv_filename, mode="a", newline='') as csv_file: + with open(csv_filename, mode="a", newline="") as csv_file: column_names = [model_title] - for provider in provider_list: + for provider in provider_list: column_names.append(provider + avg_ending) column_names.append(provider + percentile_ending) if cpu not in provider: @@ -1484,7 +1700,7 @@ def output_latency(results, csv_filename): csv_writer.writerow(column_names) for key, value in results.items(): - cpu_average = "" + cpu_average = "" if cpu in value and "average_latency_ms" in value[cpu]: cpu_average = value[cpu]["average_latency_ms"] @@ -1493,113 +1709,117 @@ def output_latency(results, csv_filename): cpu_90_percentile = value[cpu]["latency_90_percentile"] cuda_average = "" - if cuda in value and 'average_latency_ms' in value[cuda]: - cuda_average = value[cuda]['average_latency_ms'] + if cuda in value and "average_latency_ms" in value[cuda]: + cuda_average = value[cuda]["average_latency_ms"] cuda_90_percentile = "" - if cuda in value and 'latency_90_percentile' in value[cuda]: - cuda_90_percentile = value[cuda]['latency_90_percentile'] + if cuda in value and "latency_90_percentile" in value[cuda]: + cuda_90_percentile = value[cuda]["latency_90_percentile"] cuda_memory = "" - if cuda in value and 'memory' in value[cuda]: - cuda_memory = value[cuda]['memory'] - + if cuda in value and "memory" in value[cuda]: + cuda_memory = value[cuda]["memory"] + trt_average = "" - if trt in value and 'average_latency_ms' in value[trt]: - trt_average = value[trt]['average_latency_ms'] + if trt in value and "average_latency_ms" in value[trt]: + trt_average = value[trt]["average_latency_ms"] trt_90_percentile = "" - if trt in value and 'latency_90_percentile' in value[trt]: - trt_90_percentile = value[trt]['latency_90_percentile'] - + if trt in value and "latency_90_percentile" in value[trt]: + trt_90_percentile = value[trt]["latency_90_percentile"] + trt_memory = "" - if trt in value and 'memory' in value[trt]: - trt_memory = value[trt]['memory'] + if trt in value and "memory" in value[trt]: + trt_memory = value[trt]["memory"] standalone_trt_average = "" - if standalone_trt in value and 'average_latency_ms' in value[standalone_trt]: - standalone_trt_average = value[standalone_trt]['average_latency_ms'] + if standalone_trt in value and "average_latency_ms" in value[standalone_trt]: + standalone_trt_average = value[standalone_trt]["average_latency_ms"] standalone_trt_90_percentile = "" - if standalone_trt in value and 'latency_90_percentile' in value[standalone_trt]: - standalone_trt_90_percentile = value[standalone_trt]['latency_90_percentile'] - + if standalone_trt in value and "latency_90_percentile" in value[standalone_trt]: + standalone_trt_90_percentile = value[standalone_trt]["latency_90_percentile"] + standalone_trt_memory = "" - if standalone_trt in value and 'memory' in value[standalone_trt]: - standalone_trt_memory = value[standalone_trt]['memory'] + if standalone_trt in value and "memory" in value[standalone_trt]: + standalone_trt_memory = value[standalone_trt]["memory"] cuda_fp16_average = "" - if cuda_fp16 in value and 'average_latency_ms' in value[cuda_fp16]: - cuda_fp16_average = value[cuda_fp16]['average_latency_ms'] + if cuda_fp16 in value and "average_latency_ms" in value[cuda_fp16]: + cuda_fp16_average = value[cuda_fp16]["average_latency_ms"] cuda_fp16_memory = "" - if cuda_fp16 in value and 'memory' in value[cuda_fp16]: - cuda_fp16_memory = value[cuda_fp16]['memory'] - + if cuda_fp16 in value and "memory" in value[cuda_fp16]: + cuda_fp16_memory = value[cuda_fp16]["memory"] + cuda_fp16_90_percentile = "" - if cuda_fp16 in value and 'latency_90_percentile' in value[cuda_fp16]: - cuda_fp16_90_percentile = value[cuda_fp16]['latency_90_percentile'] + if cuda_fp16 in value and "latency_90_percentile" in value[cuda_fp16]: + cuda_fp16_90_percentile = value[cuda_fp16]["latency_90_percentile"] trt_fp16_average = "" - if trt_fp16 in value and 'average_latency_ms' in value[trt_fp16]: - trt_fp16_average = value[trt_fp16]['average_latency_ms'] + if trt_fp16 in value and "average_latency_ms" in value[trt_fp16]: + trt_fp16_average = value[trt_fp16]["average_latency_ms"] trt_fp16_90_percentile = "" - if trt_fp16 in value and 'latency_90_percentile' in value[trt_fp16]: - trt_fp16_90_percentile = value[trt_fp16]['latency_90_percentile'] + if trt_fp16 in value and "latency_90_percentile" in value[trt_fp16]: + trt_fp16_90_percentile = value[trt_fp16]["latency_90_percentile"] trt_fp16_memory = "" - if trt_fp16 in value and 'memory' in value[trt_fp16]: - trt_fp16_memory = value[trt_fp16]['memory'] - + if trt_fp16 in value and "memory" in value[trt_fp16]: + trt_fp16_memory = value[trt_fp16]["memory"] + standalone_trt_fp16_average = "" - if standalone_trt_fp16 in value and 'average_latency_ms' in value[standalone_trt_fp16]: - standalone_trt_fp16_average = value[standalone_trt_fp16]['average_latency_ms'] + if standalone_trt_fp16 in value and "average_latency_ms" in value[standalone_trt_fp16]: + standalone_trt_fp16_average = value[standalone_trt_fp16]["average_latency_ms"] standalone_trt_fp16_90_percentile = "" - if standalone_trt_fp16 in value and 'latency_90_percentile' in value[standalone_trt_fp16]: - standalone_trt_fp16_90_percentile = value[standalone_trt_fp16]['latency_90_percentile'] - + if standalone_trt_fp16 in value and "latency_90_percentile" in value[standalone_trt_fp16]: + standalone_trt_fp16_90_percentile = value[standalone_trt_fp16]["latency_90_percentile"] + standalone_trt_fp16_memory = "" - if standalone_trt_fp16 in value and 'memory' in value[standalone_trt_fp16]: - standalone_trt_fp16_memory = value[standalone_trt_fp16]['memory'] - - row = [key, - cpu_average, - cpu_90_percentile, - cuda_average, - cuda_90_percentile, - cuda_memory, - trt_average, - trt_90_percentile, - trt_memory, - standalone_trt_average, - standalone_trt_90_percentile, - standalone_trt_memory, - cuda_fp16_average, - cuda_fp16_90_percentile, - cuda_fp16_memory, - trt_fp16_average, - trt_fp16_90_percentile, - trt_fp16_memory, - standalone_trt_fp16_average, - standalone_trt_fp16_90_percentile, - standalone_trt_fp16_memory, - ] + if standalone_trt_fp16 in value and "memory" in value[standalone_trt_fp16]: + standalone_trt_fp16_memory = value[standalone_trt_fp16]["memory"] + + row = [ + key, + cpu_average, + cpu_90_percentile, + cuda_average, + cuda_90_percentile, + cuda_memory, + trt_average, + trt_90_percentile, + trt_memory, + standalone_trt_average, + standalone_trt_90_percentile, + standalone_trt_memory, + cuda_fp16_average, + cuda_fp16_90_percentile, + cuda_fp16_memory, + trt_fp16_average, + trt_fp16_90_percentile, + trt_fp16_memory, + standalone_trt_fp16_average, + standalone_trt_fp16_90_percentile, + standalone_trt_fp16_memory, + ] csv_writer.writerow(row) logger.info(f"CUDA/TRT latency comparison are saved to csv file: {csv_filename}") + def output_metrics(model_to_metrics, csv_filename): - with open(csv_filename, mode="w", newline='') as csv_file: - column_names = ["Model", - "% CUDA operators (not fall back to CPU)", - "Total TRT operators", - "Total operators", - "% TRT operator", - "Total TRT execution time", - "Total execution time", - "% TRT execution time"] + with open(csv_filename, mode="w", newline="") as csv_file: + column_names = [ + "Model", + "% CUDA operators (not fall back to CPU)", + "Total TRT operators", + "Total operators", + "% TRT operator", + "Total TRT execution time", + "Total execution time", + "% TRT execution time", + ] csv_writer = csv.writer(csv_file) csv_writer.writerow(column_names) @@ -1612,70 +1832,71 @@ def output_metrics(model_to_metrics, csv_filename): result_fp16["model_name"] = model + " (FP16)" if cuda in ep_info: - result['ratio_of_ops_in_cuda_not_fallback_cpu'] = ep_info[cuda]['ratio_of_ops_in_cuda_not_fallback_cpu'] + result["ratio_of_ops_in_cuda_not_fallback_cpu"] = ep_info[cuda]["ratio_of_ops_in_cuda_not_fallback_cpu"] if trt in ep_info: - result['total_trt_execution_time'] = ep_info[trt]['total_trt_execution_time'] - result['total_execution_time'] = ep_info[trt]['total_execution_time'] - result['ratio_of_execution_time_in_trt'] = ep_info[trt]['ratio_of_execution_time_in_trt'] + result["total_trt_execution_time"] = ep_info[trt]["total_trt_execution_time"] + result["total_execution_time"] = ep_info[trt]["total_execution_time"] + result["ratio_of_execution_time_in_trt"] = ep_info[trt]["ratio_of_execution_time_in_trt"] - if cuda in ep_info and trt in ep_info: + if cuda in ep_info and trt in ep_info: ######################################################################################## # equation of % TRT ops: # (total ops in cuda json - cuda and cpu ops in trt json)/ total ops in cuda json ######################################################################################## - total_ops_in_cuda = ep_info[cuda]["total_ops"] + total_ops_in_cuda = ep_info[cuda]["total_ops"] cuda_cpu_ops_in_trt = ep_info[trt]["total_ops"] - result['total_ops_in_trt'] = total_ops_in_cuda - cuda_cpu_ops_in_trt - result['total_ops'] = total_ops_in_cuda - result['ratio_of_ops_in_trt'] = (total_ops_in_cuda - cuda_cpu_ops_in_trt) / total_ops_in_cuda + result["total_ops_in_trt"] = total_ops_in_cuda - cuda_cpu_ops_in_trt + result["total_ops"] = total_ops_in_cuda + result["ratio_of_ops_in_trt"] = (total_ops_in_cuda - cuda_cpu_ops_in_trt) / total_ops_in_cuda if cuda_fp16 in ep_info: - result_fp16['ratio_of_ops_in_cuda_not_fallback_cpu'] = ep_info[cuda_fp16]['ratio_of_ops_in_cuda_not_fallback_cpu'] + result_fp16["ratio_of_ops_in_cuda_not_fallback_cpu"] = ep_info[cuda_fp16][ + "ratio_of_ops_in_cuda_not_fallback_cpu" + ] if trt_fp16 in ep_info: - result_fp16['total_trt_execution_time'] = ep_info[trt_fp16]['total_trt_execution_time'] - result_fp16['total_execution_time'] = ep_info[trt_fp16]['total_execution_time'] - result_fp16['ratio_of_execution_time_in_trt'] = ep_info[trt_fp16]['ratio_of_execution_time_in_trt'] + result_fp16["total_trt_execution_time"] = ep_info[trt_fp16]["total_trt_execution_time"] + result_fp16["total_execution_time"] = ep_info[trt_fp16]["total_execution_time"] + result_fp16["ratio_of_execution_time_in_trt"] = ep_info[trt_fp16]["ratio_of_execution_time_in_trt"] - if cuda_fp16 in ep_info and trt_fp16 in ep_info: + if cuda_fp16 in ep_info and trt_fp16 in ep_info: ######################################################################################## # equation of % TRT ops: # (total ops in cuda json - cuda and cpu ops in trt json)/ total ops in cuda json ######################################################################################## - total_ops_in_cuda = ep_info[cuda_fp16]["total_ops"] + total_ops_in_cuda = ep_info[cuda_fp16]["total_ops"] cuda_cpu_ops_in_trt = ep_info[trt_fp16]["total_ops"] - result_fp16['total_ops_in_trt'] = total_ops_in_cuda - cuda_cpu_ops_in_trt - result_fp16['total_ops'] = total_ops_in_cuda - result_fp16['ratio_of_ops_in_trt'] = (total_ops_in_cuda - cuda_cpu_ops_in_trt) / total_ops_in_cuda + result_fp16["total_ops_in_trt"] = total_ops_in_cuda - cuda_cpu_ops_in_trt + result_fp16["total_ops"] = total_ops_in_cuda + result_fp16["ratio_of_ops_in_trt"] = (total_ops_in_cuda - cuda_cpu_ops_in_trt) / total_ops_in_cuda - results.append(result) results.append(result_fp16) - - for value in results: - row = [value['model_name'], - value['ratio_of_ops_in_cuda_not_fallback_cpu'] if 'ratio_of_ops_in_cuda_not_fallback_cpu' in value else " ", - value['total_ops_in_trt'] if 'total_ops_in_trt' in value else " ", - value['total_ops'] if 'total_ops' in value else " ", - value['ratio_of_ops_in_trt'] if 'ratio_of_ops_in_trt' in value else " ", - value['total_trt_execution_time'] if 'total_trt_execution_time' in value else " ", - value['total_execution_time'] if 'total_execution_time' in value else " ", - value['ratio_of_execution_time_in_trt'] if 'ratio_of_execution_time_in_trt' in value else " ", - ] + row = [ + value["model_name"], + value["ratio_of_ops_in_cuda_not_fallback_cpu"] + if "ratio_of_ops_in_cuda_not_fallback_cpu" in value + else " ", + value["total_ops_in_trt"] if "total_ops_in_trt" in value else " ", + value["total_ops"] if "total_ops" in value else " ", + value["ratio_of_ops_in_trt"] if "ratio_of_ops_in_trt" in value else " ", + value["total_trt_execution_time"] if "total_trt_execution_time" in value else " ", + value["total_execution_time"] if "total_execution_time" in value else " ", + value["ratio_of_execution_time_in_trt"] if "ratio_of_execution_time_in_trt" in value else " ", + ] csv_writer.writerow(row) logger.info(f"Tensorrt ratio metrics are saved to csv file: {csv_filename}") + def output_system_info(result, csv_filename): - with open(csv_filename, mode="a", newline='') as csv_file: - column_names = [ - "cpu_info", "cuda", "gpu_info", "linux_distro", "memory", "trt" - ] + with open(csv_filename, mode="a", newline="") as csv_file: + column_names = ["cpu_info", "cuda", "gpu_info", "linux_distro", "memory", "trt"] csv_writer = csv.DictWriter(csv_file, fieldnames=column_names) csv_writer.writeheader() @@ -1683,15 +1904,17 @@ def output_system_info(result, csv_filename): logger.info(f"System information are saved to csv file: {csv_filename}") + def str2bool(v): if isinstance(v, bool): - return v - if v.lower() in ('yes', 'true', 't', 'y', '1'): + return v + if v.lower() in ("yes", "true", "t", "y", "1"): return True - elif v.lower() in ('no', 'false', 'f', 'n', '0'): + elif v.lower() in ("no", "false", "f", "n", "0"): return False else: - raise argparse.ArgumentTypeError('Boolean value expected.') + raise argparse.ArgumentTypeError("Boolean value expected.") + class ParseDictArgAction(argparse.Action): def __call__(self, parser, namespace, values, option_string): @@ -1704,61 +1927,161 @@ def __call__(self, parser, namespace, values, option_string): parser.error("argument {opt_str}: Expected '=' between key and value".format(opt_str=option_string)) if k in dict_arg: - parser.error("argument {opt_str}: Specified duplicate key '{dup_key}'".format(opt_str=option_string, dup_key=k)) + parser.error( + "argument {opt_str}: Specified duplicate key '{dup_key}'".format(opt_str=option_string, dup_key=k) + ) dict_arg[k] = v setattr(namespace, self.dest, dict_arg) + def parse_arguments(): # Used by argparse to display usage information for custom inputs. dict_arg_metavar = "Opt1=Val1,Opt2=Val2..." parser = argparse.ArgumentParser() - - parser.add_argument("-c", "--comparison", required=False, default="cuda_trt", choices=["cuda_trt", "acl"], help="EPs to compare: CPU vs. CUDA vs. TRT or CPU vs. ACL") - - parser.add_argument("-m", "--model_source", required=False, default="model_list.json", help="Model source: (1) model list file (2) model directory.") - - parser.add_argument("-r", "--running_mode", required=False, default="benchmark", choices=["validate", "benchmark"], help="Testing mode.") - - parser.add_argument("-i", "--input_data", required=False, default="fix", choices=["fix", "random"], help="Type of input data.") - - parser.add_argument("-o", "--perf_result_path", required=False, default="result", help="Directory for perf result.") - - parser.add_argument("-w", "--workspace", required=False, default="/", help="Workspace to find tensorrt and perf script (with models if parsing with model file)") - - parser.add_argument("-e", "--ep_list", nargs="+", required=False, default=None, help="Specify ORT Execution Providers list.") - parser.add_argument("--trt_ep_options", required=False, default={"trt_engine_cache_enable": "True", "trt_max_workspace_size": "4294967296"}, - action=ParseDictArgAction, metavar=dict_arg_metavar, help="Specify options for the ORT TensorRT Execution Provider") - - parser.add_argument("--cuda_ep_options", required=False, default={}, action=ParseDictArgAction, metavar=dict_arg_metavar, - help="Specify options for the ORT CUDA Execution Provider") - - parser.add_argument("-z", "--track_memory", required=False, default=True, help="Track CUDA and TRT Memory Usage") + parser.add_argument( + "-c", + "--comparison", + required=False, + default="cuda_trt", + choices=["cuda_trt", "acl"], + help="EPs to compare: CPU vs. CUDA vs. TRT or CPU vs. ACL", + ) + + parser.add_argument( + "-m", + "--model_source", + required=False, + default="model_list.json", + help="Model source: (1) model list file (2) model directory.", + ) + + parser.add_argument( + "-r", + "--running_mode", + required=False, + default="benchmark", + choices=["validate", "benchmark"], + help="Testing mode.", + ) + + parser.add_argument( + "-i", + "--input_data", + required=False, + default="fix", + choices=["fix", "random"], + help="Type of input data.", + ) + + parser.add_argument( + "-o", + "--perf_result_path", + required=False, + default="result", + help="Directory for perf result.", + ) + + parser.add_argument( + "-w", + "--workspace", + required=False, + default="/", + help="Workspace to find tensorrt and perf script (with models if parsing with model file)", + ) + + parser.add_argument( + "-e", + "--ep_list", + nargs="+", + required=False, + default=None, + help="Specify ORT Execution Providers list.", + ) + + parser.add_argument( + "--trt_ep_options", + required=False, + default={ + "trt_engine_cache_enable": "True", + "trt_max_workspace_size": "4294967296", + }, + action=ParseDictArgAction, + metavar=dict_arg_metavar, + help="Specify options for the ORT TensorRT Execution Provider", + ) + + parser.add_argument( + "--cuda_ep_options", + required=False, + default={}, + action=ParseDictArgAction, + metavar=dict_arg_metavar, + help="Specify options for the ORT CUDA Execution Provider", + ) + + parser.add_argument( + "-z", + "--track_memory", + required=False, + default=True, + help="Track CUDA and TRT Memory Usage", + ) parser.add_argument("-b", "--io_binding", required=False, default=False, help="Bind Inputs") - - parser.add_argument("-g", "--graph_enablement", required=False, default=enable_all, choices=[disable, basic, extended, enable_all], help="Choose graph optimization enablement.") - - parser.add_argument("--ep", required=False, default=None, help="Specify ORT Execution Provider.") - parser.add_argument("--fp16", required=False, default=True, action="store_true", help="Inlcude Float16 into benchmarking.") + parser.add_argument( + "-g", + "--graph_enablement", + required=False, + default=enable_all, + choices=[disable, basic, extended, enable_all], + help="Choose graph optimization enablement.", + ) + + parser.add_argument("--ep", required=False, default=None, help="Specify ORT Execution Provider.") + + parser.add_argument( + "--fp16", + required=False, + default=True, + action="store_true", + help="Inlcude Float16 into benchmarking.", + ) parser.add_argument("--trtexec", required=False, default=None, help="trtexec executable path.") # Validation options - parser.add_argument("--percent_mismatch", required=False, default=20.0, help="Allowed percentage of mismatched elements in validation.") - parser.add_argument("--rtol", required=False, default=0, help="Relative tolerance for validating outputs.") - parser.add_argument("--atol", required=False, default=20, help="Absolute tolerance for validating outputs.") - - parser.add_argument("-t", - "--test_times", - required=False, - default=1, - type=int, - help="Number of repeat times to get average inference latency.") + parser.add_argument( + "--percent_mismatch", + required=False, + default=20.0, + help="Allowed percentage of mismatched elements in validation.", + ) + parser.add_argument( + "--rtol", + required=False, + default=0, + help="Relative tolerance for validating outputs.", + ) + parser.add_argument( + "--atol", + required=False, + default=20, + help="Absolute tolerance for validating outputs.", + ) + + parser.add_argument( + "-t", + "--test_times", + required=False, + default=1, + type=int, + help="Number of repeat times to get average inference latency.", + ) parser.add_argument("--write_test_result", type=str2bool, required=False, default=True, help="") parser.add_argument("--benchmark_fail_csv", required=False, default=None, help="") @@ -1771,14 +2094,19 @@ def parse_arguments(): return args + def setup_logger(verbose): if verbose: - coloredlogs.install(level='DEBUG', fmt='[%(filename)s:%(lineno)s - %(funcName)20s()] %(message)s') + coloredlogs.install( + level="DEBUG", + fmt="[%(filename)s:%(lineno)s - %(funcName)20s()] %(message)s", + ) else: - coloredlogs.install(fmt='%(message)s') + coloredlogs.install(fmt="%(message)s") logging.getLogger("transformers").setLevel(logging.WARNING) -def parse_models_helper(args, models): + +def parse_models_helper(args, models): model_source = os.path.join(args.workspace, args.model_source) if ".json" in model_source: logger.info("Parsing model information from file ...") @@ -1787,18 +2115,25 @@ def parse_models_helper(args, models): logger.info("Parsing model information from directory ...") parse_models_info_from_directory(model_source, models) + def main(): args = parse_arguments() setup_logger(False) pp = pprint.PrettyPrinter(indent=4) - + logger.info("\n\nStart perf run ...\n") models = {} parse_models_helper(args, models) perf_start_time = datetime.now() - success_results, model_to_latency, model_to_fail_ep, model_to_metrics, model_to_session = run_onnxruntime(args, models) + ( + success_results, + model_to_latency, + model_to_fail_ep, + model_to_metrics, + model_to_session, + ) = run_onnxruntime(args, models) perf_end_time = datetime.now() logger.info("Done running the perf.") @@ -1806,20 +2141,22 @@ def main(): logger.info(list(models.keys())) logger.info("\nTotal models: {}".format(len(models))) - + fail_model_cnt = 0 for key, value in models.items(): - if key in model_to_fail_ep: fail_model_cnt += 1 + if key in model_to_fail_ep: + fail_model_cnt += 1 logger.info("Fail models: {}".format(fail_model_cnt)) - logger.info("Success models: {}".format(len(models) - fail_model_cnt )) + logger.info("Success models: {}".format(len(models) - fail_model_cnt)) path = os.path.join(os.getcwd(), args.perf_result_path) if not os.path.exists(path): from pathlib import Path + Path(path).mkdir(parents=True, exist_ok=True) time_stamp = datetime.now().strftime("%Y%m%d-%H%M%S") - + if len(model_to_fail_ep) > 0: logger.info("\n============================================") logger.info("========== Failing Models/EPs ==============") @@ -1840,12 +2177,16 @@ def main(): pretty_print(pp, model_to_latency) write_map_to_file(model_to_latency, LATENCY_FILE) if args.write_test_result: - csv_filename = args.benchmark_latency_csv if args.benchmark_latency_csv else f"benchmark_latency_{time_stamp}.csv" + csv_filename = ( + args.benchmark_latency_csv if args.benchmark_latency_csv else f"benchmark_latency_{time_stamp}.csv" + ) csv_filename = os.path.join(path, csv_filename) output_latency(model_to_latency, csv_filename) - + if success_results: - csv_filename = args.benchmark_success_csv if args.benchmark_success_csv else f"benchmark_success_{time_stamp}.csv" + csv_filename = ( + args.benchmark_success_csv if args.benchmark_success_csv else f"benchmark_success_{time_stamp}.csv" + ) csv_filename = os.path.join(path, csv_filename) output_details(success_results, csv_filename) @@ -1857,12 +2198,15 @@ def main(): write_map_to_file(model_to_metrics, METRICS_FILE) if args.write_test_result: - csv_filename = args.benchmark_metrics_csv if args.benchmark_metrics_csv else f"benchmark_metrics_{time_stamp}.csv" + csv_filename = ( + args.benchmark_metrics_csv if args.benchmark_metrics_csv else f"benchmark_metrics_{time_stamp}.csv" + ) csv_filename = os.path.join(path, csv_filename) output_metrics(model_to_metrics, csv_filename) if len(model_to_session) > 0: write_map_to_file(model_to_session, SESSION_FILE) + if __name__ == "__main__": main() diff --git a/onnxruntime/python/tools/tensorrt/perf/benchmark_wrapper.py b/onnxruntime/python/tools/tensorrt/perf/benchmark_wrapper.py index 11c858a8b4df0..31eeaab14a729 100644 --- a/onnxruntime/python/tools/tensorrt/perf/benchmark_wrapper.py +++ b/onnxruntime/python/tools/tensorrt/perf/benchmark_wrapper.py @@ -1,35 +1,49 @@ -import os -import csv -import logging -import coloredlogs import argparse import copy +import csv import json -import re +import logging +import os import pprint -from perf_utils import * +import re + +import coloredlogs from benchmark import * +from perf_utils import * + def write_model_info_to_file(model, path): - with open(path, 'w') as file: - file.write(json.dumps(model)) # use `json.loads` to do the reverse + with open(path, "w") as file: + file.write(json.dumps(model)) # use `json.loads` to do the reverse -def get_ep_list(comparison): - if comparison == 'acl': + +def get_ep_list(comparison): + if comparison == "acl": ep_list = [cpu, acl] - else: + else: # test with cuda and trt - ep_list = [cpu, cuda, trt, standalone_trt, cuda_fp16, trt_fp16, standalone_trt_fp16] + ep_list = [ + cpu, + cuda, + trt, + standalone_trt, + cuda_fp16, + trt_fp16, + standalone_trt_fp16, + ] return ep_list -def resolve_trtexec_path(workspace): + +def resolve_trtexec_path(workspace): trtexec_options = get_output(["find", workspace, "-name", "trtexec"]) - trtexec_path = re.search(r'.*/bin/trtexec', trtexec_options).group(0) + trtexec_path = re.search(r".*/bin/trtexec", trtexec_options).group(0) logger.info("using trtexec {}".format(trtexec_path)) return trtexec_path + def dict_to_args(dct): - return ','.join(["{}={}".format(k, v) for k, v in dct.items()]) + return ",".join(["{}={}".format(k, v) for k, v in dct.items()]) + def main(): args = parse_arguments() @@ -42,7 +56,7 @@ def main(): else: ep_list = get_ep_list(args.comparison) - if standalone_trt in ep_list or standalone_trt_fp16 in ep_list: + if standalone_trt in ep_list or standalone_trt_fp16 in ep_list: trtexec = resolve_trtexec_path(args.workspace) models = {} @@ -59,28 +73,35 @@ def main(): specs_csv = specs_name + csv_ending for model, model_info in models.items(): - logger.info("\n" + "="*40 + "="*len(model)) - logger.info("="*20 + model +"="*20) - logger.info("="*40 + "="*len(model)) + logger.info("\n" + "=" * 40 + "=" * len(model)) + logger.info("=" * 20 + model + "=" * 20) + logger.info("=" * 40 + "=" * len(model)) + + model_info["model_name"] = model - model_info["model_name"] = model - - model_list_file = os.path.join(os.getcwd(), model +'.json') + model_list_file = os.path.join(os.getcwd(), model + ".json") write_model_info_to_file([model_info], model_list_file) for ep in ep_list: - - command = ["python3", - "benchmark.py", - "-r", args.running_mode, - "-m", model_list_file, - "-o", args.perf_result_path, - "--ep", ep, - "--write_test_result", "false"] - - if ep == standalone_trt or ep == standalone_trt_fp16: - if args.running_mode == "validate": - continue + + command = [ + "python3", + "benchmark.py", + "-r", + args.running_mode, + "-m", + model_list_file, + "-o", + args.perf_result_path, + "--ep", + ep, + "--write_test_result", + "false", + ] + + if ep == standalone_trt or ep == standalone_trt_fp16: + if args.running_mode == "validate": + continue else: command.extend(["--trtexec", trtexec]) @@ -92,20 +113,30 @@ def main(): if args.running_mode == "validate": command.extend(["--benchmark_metrics_csv", benchmark_metrics_csv]) - + elif args.running_mode == "benchmark": - command.extend(["-t", str(args.test_times), - "-o", args.perf_result_path, - "--write_test_result", "false", - "--benchmark_fail_csv", benchmark_fail_csv, - "--benchmark_latency_csv", benchmark_latency_csv, - "--benchmark_success_csv", benchmark_success_csv]) - + command.extend( + [ + "-t", + str(args.test_times), + "-o", + args.perf_result_path, + "--write_test_result", + "false", + "--benchmark_fail_csv", + benchmark_fail_csv, + "--benchmark_latency_csv", + benchmark_latency_csv, + "--benchmark_success_csv", + benchmark_success_csv, + ] + ) + p = subprocess.run(command) logger.info(p) if p.returncode != 0: - error_type = "runtime error" + error_type = "runtime error" error_message = "Benchmark script exited with returncode = " + str(p.returncode) logger.error(error_message) update_fail_model_map(model_to_fail_ep, model, ep, error_type, error_message) @@ -117,6 +148,7 @@ def main(): path = os.path.join(os.getcwd(), args.perf_result_path) if not os.path.exists(path): from pathlib import Path + Path(path).mkdir(parents=True, exist_ok=True) if args.running_mode == "validate": @@ -127,8 +159,8 @@ def main(): if os.path.exists(METRICS_FILE): model_to_metrics = read_map_from_file(METRICS_FILE) output_metrics(model_to_metrics, os.path.join(path, benchmark_metrics_csv)) - logger.info("\nSaved model metrics results to {}".format(benchmark_metrics_csv)) - + logger.info("\nSaved model metrics results to {}".format(benchmark_metrics_csv)) + elif args.running_mode == "benchmark": logger.info("\n=========================================") logger.info("======= Models/EPs session creation =======") @@ -138,8 +170,8 @@ def main(): model_to_session = read_map_from_file(SESSION_FILE) pretty_print(pp, model_to_session) output_session_creation(model_to_session, os.path.join(path, benchmark_session_csv)) - logger.info("\nSaved session creation results to {}".format(benchmark_session_csv)) - + logger.info("\nSaved session creation results to {}".format(benchmark_session_csv)) + logger.info("\n=========================================================") logger.info("========== Failing Models/EPs (accumulated) ==============") logger.info("==========================================================") @@ -148,7 +180,7 @@ def main(): model_to_fail_ep = read_map_from_file(FAIL_MODEL_FILE) output_fail(model_to_fail_ep, os.path.join(path, benchmark_fail_csv)) logger.info(model_to_fail_ep) - logger.info("\nSaved model failing results to {}".format(benchmark_fail_csv)) + logger.info("\nSaved model failing results to {}".format(benchmark_fail_csv)) logger.info("\n=======================================================") logger.info("=========== Models/EPs Status (accumulated) ===========") @@ -163,11 +195,11 @@ def main(): model_fail = read_map_from_file(FAIL_MODEL_FILE) is_fail = True model_status = build_status(model_status, model_fail, is_fail) - + pretty_print(pp, model_status) - - output_status(model_status, os.path.join(path, benchmark_status_csv)) - logger.info("\nSaved model status results to {}".format(benchmark_status_csv)) + + output_status(model_status, os.path.join(path, benchmark_status_csv)) + logger.info("\nSaved model status results to {}".format(benchmark_status_csv)) logger.info("\n=========================================================") logger.info("=========== Models/EPs latency (accumulated) ===========") @@ -176,11 +208,11 @@ def main(): if os.path.exists(LATENCY_FILE): model_to_latency = read_map_from_file(LATENCY_FILE) add_improvement_information(model_to_latency) - + pretty_print(pp, model_to_latency) - + output_latency(model_to_latency, os.path.join(path, benchmark_latency_csv)) - logger.info("\nSaved model latency results to {}".format(benchmark_latency_csv)) + logger.info("\nSaved model latency results to {}".format(benchmark_latency_csv)) logger.info("\n===========================================") logger.info("=========== System information ===========") @@ -189,7 +221,8 @@ def main(): pretty_print(pp, info) logger.info("\n") output_specs(info, os.path.join(path, specs_csv)) - logger.info("\nSaved hardware specs to {}".format(specs_csv)) + logger.info("\nSaved hardware specs to {}".format(specs_csv)) + if __name__ == "__main__": main() diff --git a/onnxruntime/python/tools/tensorrt/perf/comparison_scripts/compare_latency.py b/onnxruntime/python/tools/tensorrt/perf/comparison_scripts/compare_latency.py index ae7b229590b17..93df53c9825db 100644 --- a/onnxruntime/python/tools/tensorrt/perf/comparison_scripts/compare_latency.py +++ b/onnxruntime/python/tools/tensorrt/perf/comparison_scripts/compare_latency.py @@ -1,46 +1,62 @@ -import pandas as pd -import numpy as np import argparse -ep_map = {"cpu": "CPU", "cuda":"CUDA","trt": "TRT EP","native": "Standalone TRT"} +import numpy as np +import pandas as pd -def parse_arguments(): +ep_map = {"cpu": "CPU", "cuda": "CUDA", "trt": "TRT EP", "native": "Standalone TRT"} + + +def parse_arguments(): # create parser parser = argparse.ArgumentParser() parser.add_argument("-p", "--prev", required=True, help="previous csv") parser.add_argument("-c", "--current", required=True, help="current csv") parser.add_argument("-o", "--output_csv", required=True, help="output different csv") - parser.add_argument("--ep", required=False, default="trt", choices=["cpu", "cuda", "trt", "native"], help="ep to capture regressions on") - parser.add_argument("--tolerance", required=False, default=0, help="allowed tolerance for latency comparison") + parser.add_argument( + "--ep", + required=False, + default="trt", + choices=["cpu", "cuda", "trt", "native"], + help="ep to capture regressions on", + ) + parser.add_argument( + "--tolerance", + required=False, + default=0, + help="allowed tolerance for latency comparison", + ) args = parser.parse_args() - return args + return args + -def get_table_condition(table, fp, ep, tol): +def get_table_condition(table, fp, ep, tol): ep = ep_map[ep] col1 = ep + " " + fp + " \nmean (ms)_x" col2 = ep + " " + fp + " \nmean (ms)_y" condition = table[col1] > (table[col2] + tol) return condition + def main(): args = parse_arguments() a = pd.read_csv(args.prev) b = pd.read_csv(args.current) - - common = a.merge(b, on=['Model']) - + + common = a.merge(b, on=["Model"]) + condition_fp32 = get_table_condition(common, "fp32", args.ep, args.tolerance) condition_fp16 = get_table_condition(common, "fp16", args.ep, args.tolerance) - - common['greater'] = np.where((condition_fp32 | condition_fp16), True, False) - greater = common[common['greater'] == True].drop(['greater'], axis=1) - + + common["greater"] = np.where((condition_fp32 | condition_fp16), True, False) + greater = common[common["greater"] == True].drop(["greater"], axis=1) + # arrange columns keys = list(greater.keys().sort_values()) - keys.insert(0, keys.pop(keys.index('Model'))) + keys.insert(0, keys.pop(keys.index("Model"))) greater = greater[keys] - + greater.to_csv(args.output_csv) -if __name__=='__main__': + +if __name__ == "__main__": main() diff --git a/onnxruntime/python/tools/tensorrt/perf/comparison_scripts/new_failures.py b/onnxruntime/python/tools/tensorrt/perf/comparison_scripts/new_failures.py index e5a81d2dbdadc..b702458066e90 100644 --- a/onnxruntime/python/tools/tensorrt/perf/comparison_scripts/new_failures.py +++ b/onnxruntime/python/tools/tensorrt/perf/comparison_scripts/new_failures.py @@ -1,22 +1,30 @@ -import pandas as pd import argparse -def parse_arguments(): +import pandas as pd + + +def parse_arguments(): # create parser parser = argparse.ArgumentParser() parser.add_argument("-p", "--prev", required=True, help="previous csv") parser.add_argument("-c", "--current", required=True, help="current csv") parser.add_argument("-o", "--output_csv", required=True, help="output different csv") args = parser.parse_args() - return args + return args + def main(): args = parse_arguments() a = pd.read_csv(args.prev) b = pd.read_csv(args.current) - common = b.merge(a, on=['model','ep','error type','error message']) - diff = b.append(common, ignore_index=True).drop_duplicates(['model', 'ep', 'error type', 'error message'], keep=False).loc[:b.index.max()] + common = b.merge(a, on=["model", "ep", "error type", "error message"]) + diff = ( + b.append(common, ignore_index=True) + .drop_duplicates(["model", "ep", "error type", "error message"], keep=False) + .loc[: b.index.max()] + ) diff.to_csv(args.output_csv) -if __name__=='__main__': + +if __name__ == "__main__": main() diff --git a/onnxruntime/python/tools/tensorrt/perf/perf_utils.py b/onnxruntime/python/tools/tensorrt/perf/perf_utils.py index 62e10220bad76..fedad5b8f3bf5 100644 --- a/onnxruntime/python/tools/tensorrt/perf/perf_utils.py +++ b/onnxruntime/python/tools/tensorrt/perf/perf_utils.py @@ -1,21 +1,22 @@ -import subprocess import json -import pprint import logging -import coloredlogs +import pprint import re +import subprocess import sys +import coloredlogs + debug = False -debug_verbose = False +debug_verbose = False -# ORT ep names +# ORT ep names cpu_ep = "CPUExecutionProvider" cuda_ep = "CUDAExecutionProvider" trt_ep = "TensorrtExecutionProvider" acl_ep = "ACLExecutionProvider" -# provider names +# provider names cpu = "ORT-CPUFp32" cuda = "ORT-CUDAFp32" cuda_fp16 = "ORT-CUDAFp16" @@ -26,56 +27,70 @@ acl = "ORT-ACLFp32" # table names -metrics_name = 'metrics' -success_name = 'success' -fail_name = 'fail' -memory_name = 'memory' -latency_name = 'latency' -status_name = 'status' -latency_over_time_name = 'latency_over_time' -specs_name = 'specs' -session_name = 'session' - -# column names -model_title = 'Model' -group_title = 'Group' - -# endings +metrics_name = "metrics" +success_name = "success" +fail_name = "fail" +memory_name = "memory" +latency_name = "latency" +status_name = "status" +latency_over_time_name = "latency_over_time" +specs_name = "specs" +session_name = "session" + +# column names +model_title = "Model" +group_title = "Group" + +# endings second = "_second" -csv_ending = '.csv' -avg_ending = ' \nmean (ms)' -percentile_ending = ' \n90th percentile (ms)' -memory_ending = ' \npeak memory usage (MiB)' -session_ending = ' \n session creation time (s)' -second_session_ending = ' \n second session creation time (s)' +csv_ending = ".csv" +avg_ending = " \nmean (ms)" +percentile_ending = " \n90th percentile (ms)" +memory_ending = " \npeak memory usage (MiB)" +session_ending = " \n session creation time (s)" +second_session_ending = " \n second session creation time (s)" ort_provider_list = [cpu, cuda, trt, cuda_fp16, trt_fp16] -provider_list = [cpu, cuda, trt, standalone_trt, cuda_fp16, trt_fp16, standalone_trt_fp16] +provider_list = [ + cpu, + cuda, + trt, + standalone_trt, + cuda_fp16, + trt_fp16, + standalone_trt_fp16, +] table_headers = [model_title] + provider_list -# graph options -disable = 'disable' -basic = 'basic' -extended = 'extended' -enable_all = 'all' +# graph options +disable = "disable" +basic = "basic" +extended = "extended" +enable_all = "all" + def is_standalone(ep): return ep == standalone_trt or ep == standalone_trt_fp16 + def get_output(command): p = subprocess.run(command, check=True, stdout=subprocess.PIPE) output = p.stdout.decode("ascii").strip() return output -def find(regex_string): + +def find(regex_string): import glob + results = glob.glob(regex_string) results.sort() return results + def pretty_print(pp, json_object): pp.pprint(json_object) sys.stdout.flush() + def parse_single_file(f): try: @@ -86,7 +101,7 @@ def parse_single_file(f): model_run_flag = False first_run_flag = True provider_op_map = {} # ep -> map of operator to duration - provider_op_map_first_run = {} # ep -> map of operator to duration + provider_op_map_first_run = {} # ep -> map of operator to duration for row in data: if not "cat" in row: @@ -134,20 +149,19 @@ def parse_single_file(f): op_map[row["name"]] = row["dur"] provider_op_map[provider] = op_map - if debug_verbose: - pprint._sorted = lambda x:x + pprint._sorted = lambda x: x pprint.sorted = lambda x, key=None: x pp = pprint.PrettyPrinter(indent=4) print("------First run ops map (START)------") for key, map in provider_op_map_first_run.items(): - print(key) + print(key) pp.pprint({k: v for k, v in sorted(map.items(), key=lambda item: item[1], reverse=True)}) print("------First run ops map (END) ------") print("------Second run ops map (START)------") for key, map in provider_op_map.items(): - print(key) + print(key) pp.pprint({k: v for k, v in sorted(map.items(), key=lambda item: item[1], reverse=True)}) print("------Second run ops map (END) ------") @@ -156,6 +170,7 @@ def parse_single_file(f): return None + def calculate_cuda_op_percentage(cuda_op_map): if not cuda_op_map or len(cuda_op_map) == 0: return 0 @@ -163,14 +178,15 @@ def calculate_cuda_op_percentage(cuda_op_map): cuda_ops = 0 cpu_ops = 0 for key, value in cuda_op_map.items(): - if key == 'CUDAExecutionProvider': + if key == "CUDAExecutionProvider": cuda_ops += len(value) - if key == 'CPUExecutionProvider': + if key == "CPUExecutionProvider": cpu_ops += len(value) return cuda_ops / (cuda_ops + cpu_ops) + ########################################## # Return: total ops executed in TRT, # total ops, @@ -208,6 +224,7 @@ def calculate_trt_op_percentage(trt_op_map, cuda_op_map): return ((total_ops - total_cuda_and_cpu_ops), total_ops, ratio_of_ops_in_trt) + def get_total_ops(op_map): total_ops = 0 @@ -227,7 +244,11 @@ def calculate_trt_latency_percentage(trt_op_map): # % of TRT execution time total_execution_time = 0 total_trt_execution_time = 0 - for ep in ["TensorrtExecutionProvider", "CUDAExecutionProvider", "CPUExecutionProvider"]: + for ep in [ + "TensorrtExecutionProvider", + "CUDAExecutionProvider", + "CPUExecutionProvider", + ]: if ep in trt_op_map: op_map = trt_op_map[ep] @@ -240,8 +261,6 @@ def calculate_trt_latency_percentage(trt_op_map): total_execution_time += total_time - - if total_execution_time == 0: ratio_of_trt_execution_time = 0 else: @@ -257,7 +276,10 @@ def calculate_trt_latency_percentage(trt_op_map): def get_profile_metrics(path, profile_already_parsed, logger=None): logger.info("Parsing/Analyzing profiling files in {} ...".format(path)) - p1 = subprocess.Popen(["find", path, "-name", "onnxruntime_profile*", "-printf", "%T+\t%p\n"], stdout=subprocess.PIPE) + p1 = subprocess.Popen( + ["find", path, "-name", "onnxruntime_profile*", "-printf", "%T+\t%p\n"], + stdout=subprocess.PIPE, + ) p2 = subprocess.Popen(["sort"], stdin=p1.stdout, stdout=subprocess.PIPE) stdout, sterr = p2.communicate() stdout = stdout.decode("ascii").strip() @@ -266,7 +288,7 @@ def get_profile_metrics(path, profile_already_parsed, logger=None): data = [] for profile in profiling_files: - profile = profile.split('\t')[1] + profile = profile.split("\t")[1] if profile in profile_already_parsed: continue profile_already_parsed.add(profile) diff --git a/onnxruntime/python/tools/tensorrt/perf/post.py b/onnxruntime/python/tools/tensorrt/perf/post.py index 661acd51bb10d..f127f491b3671 100644 --- a/onnxruntime/python/tools/tensorrt/perf/post.py +++ b/onnxruntime/python/tools/tensorrt/perf/post.py @@ -1,137 +1,171 @@ import argparse -import sys import os -import pandas as pd +import sys import time + +import pandas as pd from azure.kusto.data import KustoConnectionStringBuilder from azure.kusto.data.data_format import DataFormat -from azure.kusto.data.helpers import dataframe_from_result_table -from azure.kusto.ingest import ( - IngestionProperties, - ReportLevel, - QueuedIngestClient, -) +from azure.kusto.data.helpers import dataframe_from_result_table +from azure.kusto.ingest import IngestionProperties, QueuedIngestClient, ReportLevel from perf_utils import * -# database connection strings +# database connection strings cluster_ingest = "https://ingest-onnxruntimedashboarddb.southcentralus.kusto.windows.net" database = "ep_perf_dashboard" + def parse_arguments(): parser = argparse.ArgumentParser() - parser.add_argument( - "-r", "--report_folder", help="Path to the local file report", required=True) - parser.add_argument( - "-c", "--commit_hash", help="Commit id", required=True) - parser.add_argument( - "-u", "--report_url", help="Report Url", required=True) - parser.add_argument( - "-t", "--trt_version", help="Tensorrt Version", required=True) - parser.add_argument( - "-b", "--branch", help="Branch", required=True) - parser.add_argument( - "-d", "--datetime", help="Commit Datetime", required=True) + parser.add_argument("-r", "--report_folder", help="Path to the local file report", required=True) + parser.add_argument("-c", "--commit_hash", help="Commit id", required=True) + parser.add_argument("-u", "--report_url", help="Report Url", required=True) + parser.add_argument("-t", "--trt_version", help="Tensorrt Version", required=True) + parser.add_argument("-b", "--branch", help="Branch", required=True) + parser.add_argument("-d", "--datetime", help="Commit Datetime", required=True) return parser.parse_args() -def adjust_columns(table, columns, db_columns, model_group): + +def adjust_columns(table, columns, db_columns, model_group): table = table[columns] table = table.set_axis(db_columns, axis=1) table = table.assign(Group=model_group) - return table + return table + def get_latency_over_time(commit_hash, report_url, branch, latency_table): if not latency_table.empty: over_time = latency_table - over_time = over_time.melt(id_vars=[model_title, group_title], var_name='Ep', value_name='Latency') + over_time = over_time.melt(id_vars=[model_title, group_title], var_name="Ep", value_name="Latency") over_time = over_time.assign(CommitId=commit_hash) over_time = over_time.assign(ReportUrl=report_url) over_time = over_time.assign(Branch=branch) - over_time = over_time[['CommitId', model_title, 'Ep', 'Latency', 'ReportUrl', group_title, 'Branch']] - over_time.fillna('', inplace=True) + over_time = over_time[ + [ + "CommitId", + model_title, + "Ep", + "Latency", + "ReportUrl", + group_title, + "Branch", + ] + ] + over_time.fillna("", inplace=True) return over_time - + + def get_failures(fail, model_group): fail_columns = fail.keys() - fail_db_columns = [model_title, 'Ep', 'ErrorType', 'ErrorMessage'] + fail_db_columns = [model_title, "Ep", "ErrorType", "ErrorMessage"] fail = adjust_columns(fail, fail_columns, fail_db_columns, model_group) return fail -def get_memory(memory, model_group): + +def get_memory(memory, model_group): memory_columns = [model_title] - for provider in provider_list: + for provider in provider_list: if cpu not in provider: memory_columns.append(provider + memory_ending) - memory_db_columns = [model_title, cuda, trt, standalone_trt, cuda_fp16, trt_fp16, standalone_trt_fp16] + memory_db_columns = [ + model_title, + cuda, + trt, + standalone_trt, + cuda_fp16, + trt_fp16, + standalone_trt_fp16, + ] memory = adjust_columns(memory, memory_columns, memory_db_columns, model_group) return memory + def get_latency(latency, model_group): latency_columns = [model_title] - for provider in provider_list: + for provider in provider_list: latency_columns.append(provider + avg_ending) latency_db_columns = table_headers latency = adjust_columns(latency, latency_columns, latency_db_columns, model_group) return latency - + + def get_status(status, model_group): status_columns = status.keys() status_db_columns = table_headers status = adjust_columns(status, status_columns, status_db_columns, model_group) return status + def get_specs(specs, branch, commit_id, date_time): - init_id = int(specs.tail(1).get('.', 0)) + 1 - specs_additional = pd.DataFrame({'.': [init_id, init_id + 1, init_id + 2], - 'Spec': ['Branch', 'CommitId', 'CommitTime'], - 'Version': [branch, commit_id, date_time]}) + init_id = int(specs.tail(1).get(".", 0)) + 1 + specs_additional = pd.DataFrame( + { + ".": [init_id, init_id + 1, init_id + 2], + "Spec": ["Branch", "CommitId", "CommitTime"], + "Version": [branch, commit_id, date_time], + } + ) return pd.concat([specs, specs_additional], ignore_index=True) + def get_session(session, model_group): session_columns = session.keys() session_db_columns = [model_title] + ort_provider_list + [p + second for p in ort_provider_list] session = adjust_columns(session, session_columns, session_db_columns, model_group) return session + def write_table(ingest_client, table, table_name, commit_time, identifier): if table.empty: return - table = table.assign(UploadTime=commit_time) # add Commit DateTime - table = table.assign(Identifier=identifier) # add Identifier + table = table.assign(UploadTime=commit_time) # add Commit DateTime + table = table.assign(Identifier=identifier) # add Identifier ingestion_props = IngestionProperties( - database=database, - table=table_name, - data_format=DataFormat.CSV, - report_level=ReportLevel.FailuresAndSuccesses + database=database, + table=table_name, + data_format=DataFormat.CSV, + report_level=ReportLevel.FailuresAndSuccesses, ) # append rows ingest_client.ingest_from_dataframe(table, ingestion_properties=ingestion_props) -def get_time(): + +def get_time(): date_time = time.strftime(time_string_format) return date_time + def get_identifier(date_time, commit_id, trt_version, branch): - date = date_time.split('T')[0] # extract date only - return date + '_' + commit_id + '_' + trt_version + '_' + branch + date = date_time.split("T")[0] # extract date only + return date + "_" + commit_id + "_" + trt_version + "_" + branch + def main(): - + args = parse_arguments() - + # connect to database kcsb_ingest = KustoConnectionStringBuilder.with_az_cli_authentication(cluster_ingest) ingest_client = QueuedIngestClient(kcsb_ingest) date_time = args.datetime identifier = get_identifier(date_time, args.commit_hash, args.trt_version, args.branch) - + try: result_file = args.report_folder folders = os.listdir(result_file) os.chdir(result_file) - tables = [fail_name, memory_name, latency_name, status_name, latency_over_time_name, specs_name, session_name] + tables = [ + fail_name, + memory_name, + latency_name, + status_name, + latency_over_time_name, + specs_name, + session_name, + ] table_results = {} for table_name in tables: table_results[table_name] = pd.DataFrame() @@ -142,26 +176,54 @@ def main(): for csv in csv_filenames: table = pd.read_csv(csv) if session_name in csv: - table_results[session_name] = table_results[session_name].append(get_session(table, model_group), ignore_index=True) + table_results[session_name] = table_results[session_name].append( + get_session(table, model_group), ignore_index=True + ) elif specs_name in csv: - table_results[specs_name] = table_results[specs_name].append(get_specs(table, args.branch, args.commit_hash, date_time), ignore_index=True) + table_results[specs_name] = table_results[specs_name].append( + get_specs(table, args.branch, args.commit_hash, date_time), + ignore_index=True, + ) elif fail_name in csv: - table_results[fail_name] = table_results[fail_name].append(get_failures(table, model_group), ignore_index=True) + table_results[fail_name] = table_results[fail_name].append( + get_failures(table, model_group), ignore_index=True + ) elif latency_name in csv: - table_results[memory_name] = table_results[memory_name].append(get_memory(table, model_group), ignore_index=True) - table_results[latency_name] = table_results[latency_name].append(get_latency(table, model_group), ignore_index=True) - table_results[latency_over_time_name] = table_results[latency_over_time_name].append(get_latency_over_time(args.commit_hash, args.report_url, args.branch, table_results[latency_name]), ignore_index=True) + table_results[memory_name] = table_results[memory_name].append( + get_memory(table, model_group), ignore_index=True + ) + table_results[latency_name] = table_results[latency_name].append( + get_latency(table, model_group), ignore_index=True + ) + table_results[latency_over_time_name] = table_results[latency_over_time_name].append( + get_latency_over_time( + args.commit_hash, + args.report_url, + args.branch, + table_results[latency_name], + ), + ignore_index=True, + ) elif status_name in csv: - table_results[status_name] = table_results[status_name].append(get_status(table, model_group), ignore_index=True) + table_results[status_name] = table_results[status_name].append( + get_status(table, model_group), ignore_index=True + ) os.chdir(result_file) - for table in tables: - print('writing ' + table + ' to database') - db_table_name = 'ep_model_' + table - write_table(ingest_client, table_results[table], db_table_name, date_time, identifier) - - except BaseException as e: + for table in tables: + print("writing " + table + " to database") + db_table_name = "ep_model_" + table + write_table( + ingest_client, + table_results[table], + db_table_name, + date_time, + identifier, + ) + + except BaseException as e: print(str(e)) sys.exit(1) + if __name__ == "__main__": main() diff --git a/onnxruntime/python/tools/tensorrt/perf/setup_scripts/setup_onnx_zoo.py b/onnxruntime/python/tools/tensorrt/perf/setup_scripts/setup_onnx_zoo.py index e99fbd028ea04..b54f315c77a6d 100644 --- a/onnxruntime/python/tools/tensorrt/perf/setup_scripts/setup_onnx_zoo.py +++ b/onnxruntime/python/tools/tensorrt/perf/setup_scripts/setup_onnx_zoo.py @@ -1,17 +1,21 @@ +import json import os -import wget import tarfile -import json + +import wget + def get_tar_file(link): file_name = link.split("/")[-1] return file_name + def create_model_folder(model): os.mkdir(model) + def extract_and_get_files(file_name): - model_folder = file_name.replace(".tar.gz", "") + '/' + model_folder = file_name.replace(".tar.gz", "") + "/" create_model_folder(model_folder) model_tar = tarfile.open(file_name) model_tar.extractall(model_folder) @@ -20,21 +24,25 @@ def extract_and_get_files(file_name): model_tar.close() return model_folder, file_list + def download_model(link): file_name = get_tar_file(link) wget.download(link) model_folder, file_list = extract_and_get_files(file_name) return model_folder, file_list + def get_model_path(file_list): for file_name in file_list: if ".onnx" in file_name: return file_name -def get_test_path(model_path): - model_filename = os.path.basename(model_path) + +def get_test_path(model_path): + model_filename = os.path.basename(model_path) test_path = model_path.split(model_filename)[0] - return test_path + return test_path + def create_model_object(model, folder, model_file_path, test_path): model_dict = {} @@ -44,6 +52,7 @@ def create_model_object(model, folder, model_file_path, test_path): model_dict["test_data_path"] = "./" + test_path return model_dict + def get_model_info(link): model_folder, file_list = download_model(link) model = model_folder[:-1] @@ -52,20 +61,23 @@ def get_model_info(link): model_info = create_model_object(model, model_folder, model_file_path, test_path) return model_info -def write_json(models): - model_json = json.dumps(models, indent=4) - with open('model_list.json', 'w') as fp: + +def write_json(models): + model_json = json.dumps(models, indent=4) + with open("model_list.json", "w") as fp: fp.write(model_json) + def main(): links = [] - with open('links.txt', 'r') as fh: + with open("links.txt", "r") as fh: links = [link.rstrip() for link in fh.readlines()] - + model_list = [] for link in links: model_list.append(get_model_info(link)) write_json(model_list) + if __name__ == "__main__": main() diff --git a/onnxruntime/python/tools/transformers/__init__.py b/onnxruntime/python/tools/transformers/__init__.py index 8302e1b7c44d4..89ed09bcca66d 100644 --- a/onnxruntime/python/tools/transformers/__init__.py +++ b/onnxruntime/python/tools/transformers/__init__.py @@ -1,13 +1,14 @@ -#------------------------------------------------------------------------- +# ------------------------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. -#-------------------------------------------------------------------------- +# -------------------------------------------------------------------------- -import sys import os +import sys -sys.path.append(os.path.join(os.path.dirname(__file__), 'models', 'gpt2')) +sys.path.append(os.path.join(os.path.dirname(__file__), "models", "gpt2")) + +import convert_to_onnx # added for backward compatible import gpt2_helper -import convert_to_onnx diff --git a/onnxruntime/python/tools/transformers/affinity_helper.py b/onnxruntime/python/tools/transformers/affinity_helper.py index 8fb3e3b5713d9..28b9c4cd6fb75 100644 --- a/onnxruntime/python/tools/transformers/affinity_helper.py +++ b/onnxruntime/python/tools/transformers/affinity_helper.py @@ -1,7 +1,7 @@ -#------------------------------------------------------------------------- +# ------------------------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. -#-------------------------------------------------------------------------- +# -------------------------------------------------------------------------- # Get/Set cpu affinity. Currently only support part of Unix system import logging @@ -10,11 +10,11 @@ logger = logging.getLogger(__name__) -class AffinitySetting(): +class AffinitySetting: def __init__(self): self.pid = os.getpid() self.affinity = None - self.is_os_supported = hasattr(os, 'sched_getaffinity') and hasattr(os, 'sched_setaffinity') + self.is_os_supported = hasattr(os, "sched_getaffinity") and hasattr(os, "sched_setaffinity") if not self.is_os_supported: logger.warning("Current OS does not support os.get_affinity() and os.set_affinity()") @@ -25,12 +25,16 @@ def get_affinity(self): def set_affinity(self): if self.is_os_supported: current_affinity = os.sched_getaffinity(self.pid) - if (self.affinity != current_affinity): - logger.warning("Replacing affinity setting %s with %s", str(current_affinity), str(self.affinity)) + if self.affinity != current_affinity: + logger.warning( + "Replacing affinity setting %s with %s", + str(current_affinity), + str(self.affinity), + ) os.sched_setaffinity(self.pid, self.affinity) -if __name__ == '__main__': +if __name__ == "__main__": affi_helper = AffinitySetting() affi_helper.get_affinity() affi_helper.set_affinity() diff --git a/onnxruntime/python/tools/transformers/benchmark.py b/onnxruntime/python/tools/transformers/benchmark.py index ba8694da4d51e..c7c729a4f54ac 100644 --- a/onnxruntime/python/tools/transformers/benchmark.py +++ b/onnxruntime/python/tools/transformers/benchmark.py @@ -42,24 +42,40 @@ import argparse import logging +import os import timeit from datetime import datetime -import numpy +from enum import Enum -import os -import psutil +import numpy import onnx -from enum import Enum -from benchmark_helper import (OptimizerInfo, create_onnxruntime_session, Precision, setup_logger, get_latency_result, - output_details, output_summary, output_fusion_statistics, inference_ort, - inference_ort_with_io_binding, allocateOutputBuffers, ConfigModifier) +import psutil +from benchmark_helper import ( + ConfigModifier, + OptimizerInfo, + Precision, + allocateOutputBuffers, + create_onnxruntime_session, + get_latency_result, + inference_ort, + inference_ort_with_io_binding, + output_details, + output_fusion_statistics, + output_summary, + setup_logger, +) from fusion_options import FusionOptions +from onnx_exporter import ( + create_onnxruntime_input, + export_onnx_model_from_pt, + export_onnx_model_from_tf, + load_pretrained_model, +) from quantize_helper import QuantizeHelper -from onnx_exporter import create_onnxruntime_input, load_pretrained_model, export_onnx_model_from_pt, export_onnx_model_from_tf -logger = logging.getLogger('') +logger = logging.getLogger("") -from huggingface_models import MODELS, MODEL_CLASSES +from huggingface_models import MODEL_CLASSES, MODELS cpu_count = psutil.cpu_count(logical=False) @@ -68,35 +84,60 @@ os.environ["OMP_NUM_THREADS"] = str(cpu_count) import torch -from transformers import (AutoConfig, AutoTokenizer, AutoModel, GPT2Model, LxmertConfig) - - -def run_onnxruntime(use_gpu, provider, model_names, model_class, config_modifier, precision, num_threads, batch_sizes, - sequence_lengths, repeat_times, input_counts, optimizer_info, validate_onnx, cache_dir, onnx_dir, - verbose, overwrite, disable_ort_io_binding, use_raw_attention_mask, model_fusion_statistics, - model_source, args): +from transformers import AutoConfig, AutoModel, AutoTokenizer, GPT2Model, LxmertConfig + + +def run_onnxruntime( + use_gpu, + provider, + model_names, + model_class, + config_modifier, + precision, + num_threads, + batch_sizes, + sequence_lengths, + repeat_times, + input_counts, + optimizer_info, + validate_onnx, + cache_dir, + onnx_dir, + verbose, + overwrite, + disable_ort_io_binding, + use_raw_attention_mask, + model_fusion_statistics, + model_source, + args, +): import onnxruntime results = [] - if (use_gpu and ('CUDAExecutionProvider' not in onnxruntime.get_available_providers()) - and ('ROCMExecutionProvider' not in onnxruntime.get_available_providers())): + if ( + use_gpu + and ("CUDAExecutionProvider" not in onnxruntime.get_available_providers()) + and ("ROCMExecutionProvider" not in onnxruntime.get_available_providers()) + ): logger.error( "Please install onnxruntime-gpu package instead of onnxruntime, and use a machine with GPU for testing gpu performance." ) return results warm_up_repeat = 0 - if provider == 'tensorrt': + if provider == "tensorrt": optimizer_info = OptimizerInfo.NOOPT warm_up_repeat = 5 - if 'TensorrtExecutionProvider' not in onnxruntime.get_available_providers(): + if "TensorrtExecutionProvider" not in onnxruntime.get_available_providers(): logger.error( "Please install onnxruntime-gpu-tensorrt package, and use a machine with GPU for testing gpu performance." ) return results if optimizer_info == OptimizerInfo.NOOPT: - logger.warning(f"OptimizerInfo is set to {optimizer_info}, graph optimizations specified in FusionOptions are not applied.") + logger.warning( + f"OptimizerInfo is set to {optimizer_info}, graph optimizations specified in FusionOptions are not applied." + ) for model_name in model_names: all_input_names = MODELS[model_name][0] @@ -108,27 +149,64 @@ def run_onnxruntime(use_gpu, provider, model_names, model_class, config_modifier args.model_type = MODELS[model_name][3] fusion_options = FusionOptions.parse(args) - if 'pt' in model_source: + if "pt" in model_source: with torch.no_grad(): - onnx_model_file, is_valid_onnx_model, vocab_size, max_sequence_length = export_onnx_model_from_pt( - model_name, MODELS[model_name][1], MODELS[model_name][2], MODELS[model_name][3], model_class, - config_modifier, cache_dir, onnx_dir, input_names, use_gpu, precision, optimizer_info, - validate_onnx, use_raw_attention_mask, overwrite, model_fusion_statistics, fusion_options) - if 'tf' in model_source: - onnx_model_file, is_valid_onnx_model, vocab_size, max_sequence_length = export_onnx_model_from_tf( - model_name, MODELS[model_name][1], MODELS[model_name][2], MODELS[model_name][3], model_class, - config_modifier, cache_dir, onnx_dir, input_names, use_gpu, precision, optimizer_info, - validate_onnx, use_raw_attention_mask, overwrite, model_fusion_statistics, fusion_options) + ( + onnx_model_file, + is_valid_onnx_model, + vocab_size, + max_sequence_length, + ) = export_onnx_model_from_pt( + model_name, + MODELS[model_name][1], + MODELS[model_name][2], + MODELS[model_name][3], + model_class, + config_modifier, + cache_dir, + onnx_dir, + input_names, + use_gpu, + precision, + optimizer_info, + validate_onnx, + use_raw_attention_mask, + overwrite, + model_fusion_statistics, + fusion_options, + ) + if "tf" in model_source: + (onnx_model_file, is_valid_onnx_model, vocab_size, max_sequence_length,) = export_onnx_model_from_tf( + model_name, + MODELS[model_name][1], + MODELS[model_name][2], + MODELS[model_name][3], + model_class, + config_modifier, + cache_dir, + onnx_dir, + input_names, + use_gpu, + precision, + optimizer_info, + validate_onnx, + use_raw_attention_mask, + overwrite, + model_fusion_statistics, + fusion_options, + ) if not is_valid_onnx_model: continue - ort_session = create_onnxruntime_session(onnx_model_file, - use_gpu, - provider, - enable_all_optimization=True, - num_threads=num_threads, - verbose=verbose) + ort_session = create_onnxruntime_session( + onnx_model_file, + use_gpu, + provider, + enable_all_optimization=True, + num_threads=num_threads, + verbose=verbose, + ) if ort_session is None: continue @@ -137,8 +215,12 @@ def run_onnxruntime(use_gpu, provider, model_names, model_class, config_modifier device = "cuda" if use_gpu else "cpu" config = AutoConfig.from_pretrained(model_name, cache_dir=cache_dir) max_last_state_size = numpy.prod( - [max(batch_sizes), max(sequence_lengths), - max(vocab_size, config.hidden_size)]) + [ + max(batch_sizes), + max(sequence_lengths), + max(vocab_size, config.hidden_size), + ] + ) max_pooler_size = numpy.prod([max(batch_sizes), config.hidden_size]) for batch_size in batch_sizes: if batch_size <= 0: @@ -147,9 +229,15 @@ def run_onnxruntime(use_gpu, provider, model_names, model_class, config_modifier if max_sequence_length is not None and sequence_length > max_sequence_length: continue - input_value_type = numpy.int64 if 'pt' in model_source else numpy.int32 - ort_inputs = create_onnxruntime_input(vocab_size, batch_size, sequence_length, input_names, config, - input_value_type) + input_value_type = numpy.int64 if "pt" in model_source else numpy.int32 + ort_inputs = create_onnxruntime_input( + vocab_size, + batch_size, + sequence_length, + input_names, + config, + input_value_type, + ) result_template = { "engine": "onnxruntime", "version": onnxruntime.__version__, @@ -167,12 +255,19 @@ def run_onnxruntime(use_gpu, provider, model_names, model_class, config_modifier "datetime": str(datetime.now()), } - logger.info("Run onnxruntime on {} with input shape {}".format(model_name, - [batch_size, sequence_length])) + logger.info( + "Run onnxruntime on {} with input shape {}".format(model_name, [batch_size, sequence_length]) + ) if disable_ort_io_binding: - result = inference_ort(ort_session, ort_inputs, result_template, repeat_times, batch_size, - warm_up_repeat) + result = inference_ort( + ort_session, + ort_inputs, + result_template, + repeat_times, + batch_size, + warm_up_repeat, + ) else: # Get output sizes from a dummy ort run ort_outputs = ort_session.run(ort_output_names, ort_inputs) @@ -184,19 +279,41 @@ def run_onnxruntime(use_gpu, provider, model_names, model_class, config_modifier else: output_buffer_max_sizes.append(max_last_state_size) - data_type = numpy.longlong if 'pt' in model_source else numpy.intc - result = inference_ort_with_io_binding(ort_session, ort_inputs, result_template, repeat_times, - ort_output_names, ort_outputs, output_buffers, - output_buffer_max_sizes, batch_size, device, data_type, - warm_up_repeat) + data_type = numpy.longlong if "pt" in model_source else numpy.intc + result = inference_ort_with_io_binding( + ort_session, + ort_inputs, + result_template, + repeat_times, + ort_output_names, + ort_outputs, + output_buffers, + output_buffer_max_sizes, + batch_size, + device, + data_type, + warm_up_repeat, + ) logger.info(result) results.append(result) return results -def run_pytorch(use_gpu, model_names, model_class, config_modifier, precision, num_threads, batch_sizes, - sequence_lengths, repeat_times, torchscript, cache_dir, verbose): +def run_pytorch( + use_gpu, + model_names, + model_class, + config_modifier, + precision, + num_threads, + batch_sizes, + sequence_lengths, + repeat_times, + torchscript, + cache_dir, + verbose, +): results = [] if use_gpu and not torch.cuda.is_available(): logger.error("Please install PyTorch with Cuda, and use a machine with GPU for testing gpu performance.") @@ -207,11 +324,17 @@ def run_pytorch(use_gpu, model_names, model_class, config_modifier, precision, n for model_name in model_names: config = AutoConfig.from_pretrained(model_name, torchscript=torchscript, cache_dir=cache_dir) config_modifier.modify(config) - model = load_pretrained_model(model_name, config=config, cache_dir=cache_dir, custom_model_class=model_class) + model = load_pretrained_model( + model_name, + config=config, + cache_dir=cache_dir, + custom_model_class=model_class, + ) tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=cache_dir) - max_input_size = tokenizer.max_model_input_sizes[ - model_name] if model_name in tokenizer.max_model_input_sizes else 1024 + max_input_size = ( + tokenizer.max_model_input_sizes[model_name] if model_name in tokenizer.max_model_input_sizes else 1024 + ) logger.debug(f"Model {model}") logger.debug(f"Number of parameters {model.num_parameters()}") @@ -234,11 +357,13 @@ def run_pytorch(use_gpu, model_names, model_class, config_modifier, precision, n continue logger.info("Run PyTorch on {} with input shape {}".format(model_name, [batch_size, sequence_length])) - input_ids = torch.randint(low=0, - high=config.vocab_size - 1, - size=(batch_size, sequence_length), - dtype=torch.long, - device=device) + input_ids = torch.randint( + low=0, + high=config.vocab_size - 1, + size=(batch_size, sequence_length), + dtype=torch.long, + device=device, + ) try: inference = torch.jit.trace(model, input_ids) if torchscript else model inference(input_ids) @@ -272,9 +397,10 @@ def run_pytorch(use_gpu, model_names, model_class, config_modifier, precision, n def run_with_tf_optimizations(do_eager_mode: bool, use_xla: bool): - import tensorflow as tf from functools import wraps + import tensorflow as tf + def run_func(func): @wraps(func) def run_in_eager_mode(*args, **kwargs): @@ -296,26 +422,38 @@ def run_in_graph_mode(*args, **kwargs): return run_func -def run_tensorflow(use_gpu, model_names, model_class, config_modifier, precision, num_threads, batch_sizes, - sequence_lengths, repeat_times, cache_dir, verbose): +def run_tensorflow( + use_gpu, + model_names, + model_class, + config_modifier, + precision, + num_threads, + batch_sizes, + sequence_lengths, + repeat_times, + cache_dir, + verbose, +): results = [] import tensorflow as tf + tf.config.threading.set_intra_op_parallelism_threads(num_threads) if not use_gpu: - tf.config.set_visible_devices([], 'GPU') + tf.config.set_visible_devices([], "GPU") if use_gpu and not tf.test.is_built_with_cuda(): logger.error("Please install Tensorflow-gpu, and use a machine with GPU for testing gpu performance.") return results if use_gpu: # Restrict TensorFlow to only use the first GPU - physical_devices = tf.config.list_physical_devices('GPU') + physical_devices = tf.config.list_physical_devices("GPU") try: - tf.config.set_visible_devices(physical_devices[0], 'GPU') + tf.config.set_visible_devices(physical_devices[0], "GPU") tf.config.experimental.set_memory_growth(physical_devices[0], True) - tf.distribute.OneDeviceStrategy(device='/gpu:0') + tf.distribute.OneDeviceStrategy(device="/gpu:0") except RuntimeError as e: logger.exception(e) @@ -326,16 +464,19 @@ def run_tensorflow(use_gpu, model_names, model_class, config_modifier, precision config = AutoConfig.from_pretrained(model_name, cache_dir=cache_dir) config_modifier.modify(config) - model = load_pretrained_model(model_name, - config=config, - cache_dir=cache_dir, - custom_model_class=model_class, - is_tf_model=True) + model = load_pretrained_model( + model_name, + config=config, + cache_dir=cache_dir, + custom_model_class=model_class, + is_tf_model=True, + ) tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=cache_dir) - max_input_size = tokenizer.max_model_input_sizes[ - model_name] if model_name in tokenizer.max_model_input_sizes else 1024 + max_input_size = ( + tokenizer.max_model_input_sizes[model_name] if model_name in tokenizer.max_model_input_sizes else 1024 + ) for batch_size in batch_sizes: if batch_size <= 0: @@ -345,10 +486,12 @@ def run_tensorflow(use_gpu, model_names, model_class, config_modifier, precision if max_input_size is not None and sequence_length > max_input_size: continue - logger.info("Run Tensorflow on {} with input shape {}".format(model_name, - [batch_size, sequence_length])) + logger.info( + "Run Tensorflow on {} with input shape {}".format(model_name, [batch_size, sequence_length]) + ) import random + rng = random.Random() values = [rng.randint(0, config.vocab_size - 1) for i in range(batch_size * sequence_length)] input_ids = tf.constant(values, shape=(batch_size, sequence_length), dtype=tf.int32) @@ -367,7 +510,12 @@ def encoder_decoder_forward(): def lxmert_forward(): feats = tf.random.normal([1, 1, config.visual_feat_dim]) pos = tf.random.normal([1, 1, config.visual_pos_dim]) - return model(input_ids, visual_feats=feats, visual_pos=pos, training=False) + return model( + input_ids, + visual_feats=feats, + visual_pos=pos, + training=False, + ) inference = encoder_forward if config.is_encoder_decoder: @@ -401,6 +549,7 @@ def lxmert_forward(): except RuntimeError as e: logger.exception(e) from numba import cuda + device = cuda.get_current_device() device.reset() @@ -410,55 +559,73 @@ def lxmert_forward(): def parse_arguments(): parser = argparse.ArgumentParser() - parser.add_argument("-m", - "--models", - required=False, - nargs="+", - type=str, - default=["bert-base-cased", "roberta-base", "gpt2"], - choices=list(MODELS.keys()), - help="Pre-trained models in the list: " + ", ".join(MODELS.keys())) - - parser.add_argument("--model_source", - required=False, - nargs=1, - type=str, - default='pt', - choices=['pt', 'tf'], - help="Export onnx from pt or tf") - - parser.add_argument('--model_class', - required=False, - type=str, - default=None, - choices=list(MODEL_CLASSES), - help='Model type selected in the list: ' + ', '.join(MODEL_CLASSES)) - - parser.add_argument("-e", - "--engines", - required=False, - nargs="+", - type=str, - default=['onnxruntime'], - choices=['onnxruntime', 'torch', 'torchscript', 'tensorflow'], - help="Engines to benchmark") - - parser.add_argument("-c", - "--cache_dir", - required=False, - type=str, - default=os.path.join('.', 'cache_models'), - help="Directory to cache pre-trained models") - - parser.add_argument("--onnx_dir", - required=False, - type=str, - default=os.path.join('.', 'onnx_models'), - help="Directory to store onnx models") + parser.add_argument( + "-m", + "--models", + required=False, + nargs="+", + type=str, + default=["bert-base-cased", "roberta-base", "gpt2"], + choices=list(MODELS.keys()), + help="Pre-trained models in the list: " + ", ".join(MODELS.keys()), + ) + + parser.add_argument( + "--model_source", + required=False, + nargs=1, + type=str, + default="pt", + choices=["pt", "tf"], + help="Export onnx from pt or tf", + ) + + parser.add_argument( + "--model_class", + required=False, + type=str, + default=None, + choices=list(MODEL_CLASSES), + help="Model type selected in the list: " + ", ".join(MODEL_CLASSES), + ) + + parser.add_argument( + "-e", + "--engines", + required=False, + nargs="+", + type=str, + default=["onnxruntime"], + choices=["onnxruntime", "torch", "torchscript", "tensorflow"], + help="Engines to benchmark", + ) + + parser.add_argument( + "-c", + "--cache_dir", + required=False, + type=str, + default=os.path.join(".", "cache_models"), + help="Directory to cache pre-trained models", + ) + + parser.add_argument( + "--onnx_dir", + required=False, + type=str, + default=os.path.join(".", "onnx_models"), + help="Directory to store onnx models", + ) parser.add_argument("-g", "--use_gpu", required=False, action="store_true", help="Run on gpu device") - parser.add_argument("--provider", required=False, type=str, default=None, help="Execution provider to use") + parser.add_argument( + "--provider", + required=False, + type=str, + default=None, + help="Execution provider to use", + ) parser.add_argument( "-p", @@ -466,11 +633,17 @@ def parse_arguments(): type=Precision, default=Precision.FLOAT32, choices=list(Precision), - help="Precision of model to run. fp32 for full precision, fp16 for half precision, and int8 for quantization") + help="Precision of model to run. fp32 for full precision, fp16 for half precision, and int8 for quantization", + ) parser.add_argument("--verbose", required=False, action="store_true", help="Print more information") - parser.add_argument("--overwrite", required=False, action="store_true", help="Overwrite existing models") + parser.add_argument( + "--overwrite", + required=False, + action="store_true", + help="Overwrite existing models", + ) parser.add_argument( "-o", @@ -478,54 +651,96 @@ def parse_arguments(): type=OptimizerInfo, default=OptimizerInfo.BYSCRIPT, choices=list(OptimizerInfo), - help="Optimizer info: Use optimizer.py to optimize onnx model as default. Can also choose from by_ort and no_opt" + help="Optimizer info: Use optimizer.py to optimize onnx model as default. Can also choose from by_ort and no_opt", ) - parser.add_argument("-v", "--validate_onnx", required=False, action="store_true", help="Validate ONNX model") + parser.add_argument( + "-v", + "--validate_onnx", + required=False, + action="store_true", + help="Validate ONNX model", + ) - parser.add_argument("-f", - "--fusion_csv", - required=False, - default=None, - help="CSV file for saving summary results of graph optimization.") + parser.add_argument( + "-f", + "--fusion_csv", + required=False, + default=None, + help="CSV file for saving summary results of graph optimization.", + ) - parser.add_argument("-d", "--detail_csv", required=False, default=None, help="CSV file for saving detail results.") + parser.add_argument( + "-d", + "--detail_csv", + required=False, + default=None, + help="CSV file for saving detail results.", + ) - parser.add_argument("-r", "--result_csv", required=False, default=None, help="CSV file for saving summary results.") + parser.add_argument( + "-r", + "--result_csv", + required=False, + default=None, + help="CSV file for saving summary results.", + ) - parser.add_argument("-i", - "--input_counts", - required=False, - nargs="+", - default=[1], - type=int, - choices=[1, 2, 3], - help="Number of ONNX model inputs. Please use 1 for fair comparison with Torch or TorchScript.") + parser.add_argument( + "-i", + "--input_counts", + required=False, + nargs="+", + default=[1], + type=int, + choices=[1, 2, 3], + help="Number of ONNX model inputs. Please use 1 for fair comparison with Torch or TorchScript.", + ) - parser.add_argument("-t", - "--test_times", - required=False, - default=100, - type=int, - help="Number of repeat times to get average inference latency.") + parser.add_argument( + "-t", + "--test_times", + required=False, + default=100, + type=int, + help="Number of repeat times to get average inference latency.", + ) parser.add_argument("-b", "--batch_sizes", nargs="+", type=int, default=[1]) - parser.add_argument("-s", "--sequence_lengths", nargs="+", type=int, default=[4, 8, 16, 32, 64, 128, 256]) + parser.add_argument( + "-s", + "--sequence_lengths", + nargs="+", + type=int, + default=[4, 8, 16, 32, 64, 128, 256], + ) - parser.add_argument('--disable_ort_io_binding', - required=False, - action='store_true', - help='Disable running ONNX Runtime with binded inputs and outputs. ') + parser.add_argument( + "--disable_ort_io_binding", + required=False, + action="store_true", + help="Disable running ONNX Runtime with binded inputs and outputs. ", + ) parser.set_defaults(disable_ort_io_binding=False) - parser.add_argument("-n", "--num_threads", required=False, nargs="+", type=int, default=[0], help="Threads to use") + parser.add_argument( + "-n", + "--num_threads", + required=False, + nargs="+", + type=int, + default=[0], + help="Threads to use", + ) - parser.add_argument("--force_num_layers", - required=False, - type=int, - default=None, - help="Manually set the model's layer number") + parser.add_argument( + "--force_num_layers", + required=False, + type=int, + default=None, + help="Manually set the model's layer number", + ) FusionOptions.add_arguments(parser) @@ -573,30 +788,80 @@ def main(): logger.warning("--input_counts is not implemented for torch or torchscript engine.") if enable_torchscript: - results += run_pytorch(args.use_gpu, args.models, args.model_class, config_modifier, args.precision, - num_threads, args.batch_sizes, args.sequence_lengths, args.test_times, True, - args.cache_dir, args.verbose) + results += run_pytorch( + args.use_gpu, + args.models, + args.model_class, + config_modifier, + args.precision, + num_threads, + args.batch_sizes, + args.sequence_lengths, + args.test_times, + True, + args.cache_dir, + args.verbose, + ) if enable_torch: - results += run_pytorch(args.use_gpu, args.models, args.model_class, config_modifier, args.precision, - num_threads, args.batch_sizes, args.sequence_lengths, args.test_times, False, - args.cache_dir, args.verbose) + results += run_pytorch( + args.use_gpu, + args.models, + args.model_class, + config_modifier, + args.precision, + num_threads, + args.batch_sizes, + args.sequence_lengths, + args.test_times, + False, + args.cache_dir, + args.verbose, + ) if enable_tensorflow: - results += run_tensorflow(args.use_gpu, args.models, args.model_class, config_modifier, args.precision, - num_threads, args.batch_sizes, args.sequence_lengths, args.test_times, - args.cache_dir, args.verbose) + results += run_tensorflow( + args.use_gpu, + args.models, + args.model_class, + config_modifier, + args.precision, + num_threads, + args.batch_sizes, + args.sequence_lengths, + args.test_times, + args.cache_dir, + args.verbose, + ) model_fusion_statistics = {} if enable_onnxruntime: try: use_raw_attention_mask = True - results += run_onnxruntime(args.use_gpu, args.provider, args.models, args.model_class, config_modifier, - args.precision, num_threads, args.batch_sizes, args.sequence_lengths, - args.test_times, args.input_counts, args.optimizer_info, args.validate_onnx, - args.cache_dir, args.onnx_dir, args.verbose, args.overwrite, - args.disable_ort_io_binding, use_raw_attention_mask, model_fusion_statistics, - args.model_source, args) + results += run_onnxruntime( + args.use_gpu, + args.provider, + args.models, + args.model_class, + config_modifier, + args.precision, + num_threads, + args.batch_sizes, + args.sequence_lengths, + args.test_times, + args.input_counts, + args.optimizer_info, + args.validate_onnx, + args.cache_dir, + args.onnx_dir, + args.verbose, + args.overwrite, + args.disable_ort_io_binding, + use_raw_attention_mask, + model_fusion_statistics, + args.model_source, + args, + ) except: logger.error(f"Exception", exc_info=True) diff --git a/onnxruntime/python/tools/transformers/benchmark_helper.py b/onnxruntime/python/tools/transformers/benchmark_helper.py index 8fe6287ba52a1..f88f7c906c0d6 100644 --- a/onnxruntime/python/tools/transformers/benchmark_helper.py +++ b/onnxruntime/python/tools/transformers/benchmark_helper.py @@ -4,28 +4,29 @@ # license information. # -------------------------------------------------------------------------- +import argparse +import csv +import logging import os import sys -import csv -import numpy import time import timeit from datetime import datetime -import argparse -import logging +from enum import Enum + import coloredlogs -import torch +import numpy import onnx -from enum import Enum +import torch from packaging import version logger = logging.getLogger(__name__) class Precision(Enum): - FLOAT32 = 'fp32' - FLOAT16 = 'fp16' - INT8 = 'int8' + FLOAT32 = "fp32" + FLOAT16 = "fp16" + INT8 = "int8" def __str__(self): return self.value @@ -34,28 +35,28 @@ def __str__(self): class OptimizerInfo(Enum): # no_opt means using the raw ONNX model, but OnnxRuntime might still apply optimization as long as # graph optimization level is not 0 (disable all). - NOOPT = 'no_opt' - BYORT = 'by_ort' - BYSCRIPT = 'by_script' + NOOPT = "no_opt" + BYORT = "by_ort" + BYSCRIPT = "by_script" def __str__(self): return self.value -class ConfigModifier(): +class ConfigModifier: def __init__(self, num_layers): self.num_layers = num_layers def modify(self, config): if self.num_layers is None: return - if hasattr(config, 'num_hidden_layers'): + if hasattr(config, "num_hidden_layers"): config.num_hidden_layers = self.num_layers logger.info(f"Modifying pytorch model's number of hidden layers to: {self.num_layers}") - if hasattr(config, 'encoder_layers'): + if hasattr(config, "encoder_layers"): config.encoder_layers = self.num_layers logger.info(f"Modifying pytorch model's number of encoder layers to: {self.num_layers}") - if hasattr(config, 'decoder_layers '): + if hasattr(config, "decoder_layers "): config.decoder_layers = self.num_layers logger.info(f"Modifying pytorch model's number of decoder layers to: {self.num_layers}") @@ -69,16 +70,20 @@ def get_layer_num(self): } -def create_onnxruntime_session(onnx_model_path, - use_gpu, - provider=None, - enable_all_optimization=True, - num_threads=-1, - enable_profiling=False, - verbose=False): +def create_onnxruntime_session( + onnx_model_path, + use_gpu, + provider=None, + enable_all_optimization=True, + num_threads=-1, + enable_profiling=False, + verbose=False, +): session = None try: - from onnxruntime import SessionOptions, InferenceSession, GraphOptimizationLevel, __version__ as onnxruntime_version + from onnxruntime import GraphOptimizationLevel, InferenceSession, SessionOptions + from onnxruntime import __version__ as onnxruntime_version + sess_options = SessionOptions() if enable_all_optimization: @@ -100,20 +105,28 @@ def create_onnxruntime_session(onnx_model_path, logger.debug(f"Create session for onnx model: {onnx_model_path}") if use_gpu: - if provider == 'dml': - execution_providers = ['DmlExecutionProvider', 'CPUExecutionProvider'] - elif provider == 'rocm': - execution_providers = ['ROCMExecutionProvider', 'CPUExecutionProvider'] - elif provider == 'migraphx': - execution_providers = ['MIGraphXExecutionProvider', 'ROCMExecutionProvider', 'CPUExecutionProvider'] - elif provider == 'cuda': - execution_providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] - elif provider == 'tensorrt': - execution_providers = ['TensorrtExecutionProvider', 'CUDAExecutionProvider', 'CPUExecutionProvider'] + if provider == "dml": + execution_providers = ["DmlExecutionProvider", "CPUExecutionProvider"] + elif provider == "rocm": + execution_providers = ["ROCMExecutionProvider", "CPUExecutionProvider"] + elif provider == "migraphx": + execution_providers = [ + "MIGraphXExecutionProvider", + "ROCMExecutionProvider", + "CPUExecutionProvider", + ] + elif provider == "cuda": + execution_providers = ["CUDAExecutionProvider", "CPUExecutionProvider"] + elif provider == "tensorrt": + execution_providers = [ + "TensorrtExecutionProvider", + "CUDAExecutionProvider", + "CPUExecutionProvider", + ] else: - execution_providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] + execution_providers = ["CUDAExecutionProvider", "CPUExecutionProvider"] else: - execution_providers = ['CPUExecutionProvider'] + execution_providers = ["CPUExecutionProvider"] session = InferenceSession(onnx_model_path, sess_options, providers=execution_providers) except: logger.error(f"Exception", exc_info=True) @@ -123,9 +136,12 @@ def create_onnxruntime_session(onnx_model_path, def setup_logger(verbose=True): if verbose: - coloredlogs.install(level='DEBUG', fmt='[%(filename)s:%(lineno)s - %(funcName)20s()] %(message)s') + coloredlogs.install( + level="DEBUG", + fmt="[%(filename)s:%(lineno)s - %(funcName)20s()] %(message)s", + ) else: - coloredlogs.install(fmt='%(message)s') + coloredlogs.install(fmt="%(message)s") logging.getLogger("transformers").setLevel(logging.WARNING) @@ -137,25 +153,30 @@ def prepare_environment(cache_dir, output_dir, use_gpu, provider=None): os.makedirs(output_dir) import onnxruntime + if use_gpu: - if provider == 'dml': - assert 'DmlExecutionProvider' in onnxruntime.get_available_providers( + if provider == "dml": + assert ( + "DmlExecutionProvider" in onnxruntime.get_available_providers() ), "Please install onnxruntime-directml package to test GPU inference." else: - assert 'CUDAExecutionProvider' in onnxruntime.get_available_providers( + assert ( + "CUDAExecutionProvider" in onnxruntime.get_available_providers() ), "Please install onnxruntime-gpu package to test GPU inference." import transformers - logger.info(f'PyTorch Version:{torch.__version__}') - logger.info(f'Transformers Version:{transformers.__version__}') - logger.info(f'Onnxruntime Version:{onnxruntime.__version__}') + + logger.info(f"PyTorch Version:{torch.__version__}") + logger.info(f"Transformers Version:{transformers.__version__}") + logger.info(f"Onnxruntime Version:{onnxruntime.__version__}") # Support three major versions of PyTorch and OnnxRuntime, and up to 6 months of transformers. from packaging import version - assert version.parse(torch.__version__) >= version.parse('1.5.0') - assert version.parse(transformers.__version__) >= version.parse('3.0.0') - assert version.parse(onnxruntime.__version__) >= version.parse('1.4.0') + + assert version.parse(torch.__version__) >= version.parse("1.5.0") + assert version.parse(transformers.__version__) >= version.parse("3.0.0") + assert version.parse(onnxruntime.__version__) >= version.parse("1.4.0") def get_latency_result(runtimes, batch_size): @@ -175,12 +196,29 @@ def get_latency_result(runtimes, batch_size): def output_details(results, csv_filename): - with open(csv_filename, mode="a", newline='') as csv_file: + with open(csv_filename, mode="a", newline="") as csv_file: column_names = [ - "engine", "version", "providers", "device", "precision", "optimizer", "io_binding", "model_name", "inputs", - "threads", "batch_size", "sequence_length", "custom_layer_num", "datetime", "test_times", "QPS", - "average_latency_ms", "latency_variance", "latency_90_percentile", "latency_95_percentile", - "latency_99_percentile" + "engine", + "version", + "providers", + "device", + "precision", + "optimizer", + "io_binding", + "model_name", + "inputs", + "threads", + "batch_size", + "sequence_length", + "custom_layer_num", + "datetime", + "test_times", + "QPS", + "average_latency_ms", + "latency_variance", + "latency_90_percentile", + "latency_95_percentile", + "latency_99_percentile", ] csv_writer = csv.DictWriter(csv_file, fieldnames=column_names) @@ -192,10 +230,19 @@ def output_details(results, csv_filename): def output_summary(results, csv_filename, args): - with open(csv_filename, mode="a", newline='') as csv_file: + with open(csv_filename, mode="a", newline="") as csv_file: header_names = [ - "model_name", "inputs", "custom_layer_num", "engine", "version", "providers", "device", "precision", - "optimizer", "io_binding", "threads" + "model_name", + "inputs", + "custom_layer_num", + "engine", + "version", + "providers", + "device", + "precision", + "optimizer", + "io_binding", + "threads", ] data_names = [] for batch_size in args.batch_sizes: @@ -211,9 +258,13 @@ def output_summary(results, csv_filename, args): for threads in args.num_threads: row = {} for result in results: - if result["model_name"] == model_name and result["inputs"] == input_count and result[ - "engine"] == engine_name and result["io_binding"] == io_binding and result[ - "threads"] == threads: + if ( + result["model_name"] == model_name + and result["inputs"] == input_count + and result["engine"] == engine_name + and result["io_binding"] == io_binding + and result["threads"] == threads + ): headers = {k: v for k, v in result.items() if k in header_names} if not row: row.update(headers) @@ -232,9 +283,11 @@ def output_summary(results, csv_filename, args): def output_fusion_statistics(model_fusion_statistics, csv_filename): from transformers import __version__ as transformers_version - with open(csv_filename, mode="a", newline='') as csv_file: + + with open(csv_filename, mode="a", newline="") as csv_file: column_names = ["model_filename", "datetime", "transformers", "torch"] + list( - next(iter(model_fusion_statistics.values())).keys()) + next(iter(model_fusion_statistics.values())).keys() + ) csv_writer = csv.DictWriter(csv_file, fieldnames=column_names) csv_writer.writeheader() for key in model_fusion_statistics.keys(): @@ -256,18 +309,20 @@ def inference_ort(ort_session, ort_inputs, result_template, repeat_times, batch_ return result -def inference_ort_with_io_binding(ort_session, - ort_inputs, - result_template, - repeat_times, - ort_output_names, - ort_outputs, - output_buffers, - output_buffer_max_sizes, - batch_size, - device, - data_type=numpy.longlong, - warm_up_repeat=0): +def inference_ort_with_io_binding( + ort_session, + ort_inputs, + result_template, + repeat_times, + ort_output_names, + ort_outputs, + output_buffers, + output_buffer_max_sizes, + batch_size, + device, + data_type=numpy.longlong, + warm_up_repeat=0, +): result = {} # Bind inputs and outputs to onnxruntime session @@ -275,18 +330,42 @@ def inference_ort_with_io_binding(ort_session, # Bind inputs to device for name in ort_inputs.keys(): np_input = torch.from_numpy(ort_inputs[name]).to(device) - input_type = IO_BINDING_DATA_TYPE_MAP[str(ort_inputs[name].dtype)] if str( - ort_inputs[name].dtype) in IO_BINDING_DATA_TYPE_MAP else data_type - io_binding.bind_input(name, np_input.device.type, 0, input_type, np_input.shape, np_input.data_ptr()) + input_type = ( + IO_BINDING_DATA_TYPE_MAP[str(ort_inputs[name].dtype)] + if str(ort_inputs[name].dtype) in IO_BINDING_DATA_TYPE_MAP + else data_type + ) + io_binding.bind_input( + name, + np_input.device.type, + 0, + input_type, + np_input.shape, + np_input.data_ptr(), + ) # Bind outputs buffers with the sizes needed if not allocated already if len(output_buffers) == 0: allocateOutputBuffers(output_buffers, output_buffer_max_sizes, device) for i in range(len(ort_output_names)): - io_binding.bind_output(ort_output_names[i], output_buffers[i].device.type, 0, numpy.float32, - ort_outputs[i].shape, output_buffers[i].data_ptr()) - timeit.repeat(lambda: ort_session.run_with_iobinding(io_binding), number=1, repeat=warm_up_repeat) # Dry run - runtimes = timeit.repeat(lambda: ort_session.run_with_iobinding(io_binding), number=1, repeat=repeat_times) + io_binding.bind_output( + ort_output_names[i], + output_buffers[i].device.type, + 0, + numpy.float32, + ort_outputs[i].shape, + output_buffers[i].data_ptr(), + ) + timeit.repeat( + lambda: ort_session.run_with_iobinding(io_binding), + number=1, + repeat=warm_up_repeat, + ) # Dry run + runtimes = timeit.repeat( + lambda: ort_session.run_with_iobinding(io_binding), + number=1, + repeat=repeat_times, + ) result.update(result_template) result.update({"io_binding": True}) result.update(get_latency_result(runtimes, batch_size)) @@ -304,21 +383,23 @@ def allocateOutputBuffers(output_buffers, output_buffer_max_sizes, device): def set_random_seed(seed=123): """Set random seed manully to get deterministic results""" import random + random.seed(seed) numpy.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) - #torch.backends.cudnn.enabled = False - #torch.backends.cudnn.benchmark = False - #torch.backends.cudnn.deterministic = True + # torch.backends.cudnn.enabled = False + # torch.backends.cudnn.benchmark = False + # torch.backends.cudnn.deterministic = True def measure_memory(is_gpu, func): import os - import psutil from time import sleep + import psutil + class MemoryMonitor: def __init__(self, keep_measuring=True): self.keep_measuring = keep_measuring @@ -333,8 +414,16 @@ def measure_cpu_usage(self): return max_usage def measure_gpu_usage(self): - from py3nvml.py3nvml import nvmlInit, nvmlDeviceGetCount, nvmlDeviceGetHandleByIndex, \ - nvmlDeviceGetMemoryInfo, nvmlDeviceGetName, nvmlShutdown, NVMLError + from py3nvml.py3nvml import ( + NVMLError, + nvmlDeviceGetCount, + nvmlDeviceGetHandleByIndex, + nvmlDeviceGetMemoryInfo, + nvmlDeviceGetName, + nvmlInit, + nvmlShutdown, + ) + max_gpu_usage = [] gpu_name = [] try: @@ -350,11 +439,14 @@ def measure_gpu_usage(self): if not self.keep_measuring: break nvmlShutdown() - return [{ - "device_id": i, - "name": gpu_name[i], - "max_used_MB": max_gpu_usage[i] - } for i in range(deviceCount)] + return [ + { + "device_id": i, + "name": gpu_name[i], + "max_used_MB": max_gpu_usage[i], + } + for i in range(deviceCount) + ] except NVMLError as error: if not self.silent: self.logger.error("Error fetching GPU information using nvml: %s", error) @@ -365,6 +457,7 @@ def measure_gpu_usage(self): memory_before_test = monitor.measure_gpu_usage() if is_gpu else monitor.measure_cpu_usage() from concurrent.futures import ThreadPoolExecutor + with ThreadPoolExecutor() as executor: monitor = MemoryMonitor() mem_thread = executor.submit(monitor.measure_gpu_usage if is_gpu else monitor.measure_cpu_usage) diff --git a/onnxruntime/python/tools/transformers/bert_perf_test.py b/onnxruntime/python/tools/transformers/bert_perf_test.py index 7c13ca3c8d945..0d5e18e8fc58c 100644 --- a/onnxruntime/python/tools/transformers/bert_perf_test.py +++ b/onnxruntime/python/tools/transformers/bert_perf_test.py @@ -1,7 +1,7 @@ -#------------------------------------------------------------------------- +# ------------------------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. -#-------------------------------------------------------------------------- +# -------------------------------------------------------------------------- # This tool measures the inference performance of onnxruntime or onnxruntime-gpu python package on Bert model. @@ -12,22 +12,22 @@ # Example command to run test on batch_size 1 and 2 for a model on GPU: # python bert_perf_test.py --model bert.onnx --batch_size 1 2 --sequence_length 128 --use_gpu --samples 1000 --test_times 1 -import sys import argparse -import os -from pathlib import Path -import timeit -import statistics -import psutil import csv -import numpy as np -import torch +import multiprocessing +import os import random +import statistics +import sys +import timeit +from dataclasses import dataclass from datetime import datetime -import multiprocessing -from bert_test_data import get_bert_inputs, generate_test_data +from pathlib import Path -from dataclasses import dataclass +import numpy as np +import psutil +import torch +from bert_test_data import generate_test_data, get_bert_inputs @dataclass @@ -56,7 +56,7 @@ class ModelSetting: def create_session(model_path, use_gpu, provider, intra_op_num_threads, graph_optimization_level=None): import onnxruntime - if use_gpu and ('CUDAExecutionProvider' not in onnxruntime.get_available_providers()): + if use_gpu and ("CUDAExecutionProvider" not in onnxruntime.get_available_providers()): print( "Warning: Please install onnxruntime-gpu package instead of onnxruntime, and use a machine with GPU for testing gpu performance." ) @@ -65,20 +65,28 @@ def create_session(model_path, use_gpu, provider, intra_op_num_threads, graph_op session = onnxruntime.InferenceSession(model_path) else: if use_gpu: - if provider == 'dml': - execution_providers = ['DmlExecutionProvider', 'CPUExecutionProvider'] - elif provider == 'rocm': - execution_providers = ['ROCMExecutionProvider', 'CPUExecutionProvider'] - elif provider == 'migraphx': - execution_providers = ['MIGraphXExecutionProvider', 'ROCMExecutionProvider', 'CPUExecutionProvider'] - elif provider == 'cuda': - execution_providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] - elif provider == 'tensorrt': - execution_providers = ['TensorrtExecutionProvider', 'CUDAExecutionProvider', 'CPUExecutionProvider'] + if provider == "dml": + execution_providers = ["DmlExecutionProvider", "CPUExecutionProvider"] + elif provider == "rocm": + execution_providers = ["ROCMExecutionProvider", "CPUExecutionProvider"] + elif provider == "migraphx": + execution_providers = [ + "MIGraphXExecutionProvider", + "ROCMExecutionProvider", + "CPUExecutionProvider", + ] + elif provider == "cuda": + execution_providers = ["CUDAExecutionProvider", "CPUExecutionProvider"] + elif provider == "tensorrt": + execution_providers = [ + "TensorrtExecutionProvider", + "CUDAExecutionProvider", + "CPUExecutionProvider", + ] else: - execution_providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] + execution_providers = ["CUDAExecutionProvider", "CPUExecutionProvider"] else: - execution_providers = ['CPUExecutionProvider'] + execution_providers = ["CPUExecutionProvider"] sess_options = onnxruntime.SessionOptions() sess_options.execution_mode = onnxruntime.ExecutionMode.ORT_SEQUENTIAL @@ -102,55 +110,69 @@ def create_session(model_path, use_gpu, provider, intra_op_num_threads, graph_op session = onnxruntime.InferenceSession(model_path, sess_options, providers=execution_providers) if use_gpu: - if provider == 'dml': - assert 'DmlExecutionProvider' in session.get_providers() - elif provider == 'rocm': - assert 'ROCMExecutionProvider' in session.get_providers() - elif provider == 'migraphx': - assert 'MIGraphXExecutionProvider' in session.get_providers() - assert 'ROCMExecutionProvider' in session.get_providers() - elif provider == 'cuda': - assert 'CUDAExecutionProvider' in session.get_providers() - elif provider == 'tensorrt': - assert 'TensorrtExecutionProvider' in session.get_providers() - assert 'CUDAExecutionProvider' in session.get_providers() + if provider == "dml": + assert "DmlExecutionProvider" in session.get_providers() + elif provider == "rocm": + assert "ROCMExecutionProvider" in session.get_providers() + elif provider == "migraphx": + assert "MIGraphXExecutionProvider" in session.get_providers() + assert "ROCMExecutionProvider" in session.get_providers() + elif provider == "cuda": + assert "CUDAExecutionProvider" in session.get_providers() + elif provider == "tensorrt": + assert "TensorrtExecutionProvider" in session.get_providers() + assert "CUDAExecutionProvider" in session.get_providers() else: - assert 'CUDAExecutionProvider' in session.get_providers() + assert "CUDAExecutionProvider" in session.get_providers() else: - assert 'CPUExecutionProvider' in session.get_providers() + assert "CPUExecutionProvider" in session.get_providers() return session + def numpy_type(torch_type): - type_map = {torch.float32: np.float32, - torch.float16: np.float16, - torch.int32: np.int32, - torch.int64: np.longlong} + type_map = { + torch.float32: np.float32, + torch.float16: np.float16, + torch.int32: np.int32, + torch.int64: np.longlong, + } return type_map[torch_type] + def create_input_output_tensors(inputs, outputs, device): - input_tensors = {name: torch.from_numpy(array).to(device) - for name, array in inputs.items()} - output_tensors = {name: torch.from_numpy(array).to(device) - for name, array in outputs.items()} + input_tensors = {name: torch.from_numpy(array).to(device) for name, array in inputs.items()} + output_tensors = {name: torch.from_numpy(array).to(device) for name, array in outputs.items()} return input_tensors, output_tensors + def create_io_binding(sess, input_tensors, output_tensors): io_binding = sess.io_binding() for name, tensor in input_tensors.items(): - io_binding.bind_input(name, tensor.device.type, 0, - numpy_type(tensor.dtype), tensor.shape, - tensor.data_ptr()) + io_binding.bind_input( + name, + tensor.device.type, + 0, + numpy_type(tensor.dtype), + tensor.shape, + tensor.data_ptr(), + ) for name, tensor in output_tensors.items(): - io_binding.bind_output(name, tensor.device.type, 0, - numpy_type(tensor.dtype), tensor.shape, - tensor.data_ptr()) + io_binding.bind_output( + name, + tensor.device.type, + 0, + numpy_type(tensor.dtype), + tensor.shape, + tensor.data_ptr(), + ) return io_binding + def onnxruntime_inference_with_io_binding(session, all_inputs, output_names, test_setting): results = [] latency_list = [] - device = 'cuda' if test_setting.use_gpu else 'cpu' + device = "cuda" if test_setting.use_gpu else "cpu" for test_case_id, inputs in enumerate(all_inputs): result = session.run(output_names, inputs) results.append(result) @@ -171,6 +193,7 @@ def onnxruntime_inference_with_io_binding(session, all_inputs, output_names, tes return results, latency_list + def onnxruntime_inference(session, all_inputs, output_names): if len(all_inputs) > 0: # Use a random input as warm up. @@ -186,19 +209,25 @@ def onnxruntime_inference(session, all_inputs, output_names): latency_list.append(latency) return results, latency_list + def to_string(model_path, session, test_setting): sess_options = session.get_session_options() option = "model={},".format(os.path.basename(model_path)) - option += "graph_optimization_level={},intra_op_num_threads={},".format(sess_options.graph_optimization_level, - sess_options.intra_op_num_threads).replace( - 'GraphOptimizationLevel.ORT_', '') + option += "graph_optimization_level={},intra_op_num_threads={},".format( + sess_options.graph_optimization_level, sess_options.intra_op_num_threads + ).replace("GraphOptimizationLevel.ORT_", "") option += f"batch_size={test_setting.batch_size},sequence_length={test_setting.sequence_length},test_cases={test_setting.test_cases},test_times={test_setting.test_times},use_gpu={test_setting.use_gpu}" return option def run_one_test(model_setting, test_setting, perf_results, all_inputs, intra_op_num_threads): - session = create_session(model_setting.model_path, test_setting.use_gpu, test_setting.provider, intra_op_num_threads, - model_setting.opt_level) + session = create_session( + model_setting.model_path, + test_setting.use_gpu, + test_setting.provider, + intra_op_num_threads, + model_setting.opt_level, + ) output_names = [output.name for output in session.get_outputs()] key = to_string(model_setting.model_path, session, test_setting) @@ -211,7 +240,9 @@ def run_one_test(model_setting, test_setting, perf_results, all_inputs, intra_op all_latency_list = [] if test_setting.use_io_binding: for i in range(test_setting.test_times): - results, latency_list = onnxruntime_inference_with_io_binding(session, all_inputs, output_names, test_setting) + results, latency_list = onnxruntime_inference_with_io_binding( + session, all_inputs, output_names, test_setting + ) all_latency_list.extend(latency_list) else: for i in range(test_setting.test_times): @@ -229,23 +260,45 @@ def run_one_test(model_setting, test_setting, perf_results, all_inputs, intra_op latency_99 = np.percentile(latency_ms, 99) throughput = test_setting.batch_size * (1000.0 / average_latency) - perf_results[key] = (average_latency, latency_50, latency_75, latency_90, latency_95, latency_99, throughput) + perf_results[key] = ( + average_latency, + latency_50, + latency_75, + latency_90, + latency_95, + latency_99, + throughput, + ) - print("Average latency = {} ms, Throughput = {} QPS".format(format(average_latency, '.2f'), - format(throughput, '.2f'))) + print( + "Average latency = {} ms, Throughput = {} QPS".format(format(average_latency, ".2f"), format(throughput, ".2f")) + ) def launch_test(model_setting, test_setting, perf_results, all_inputs, intra_op_num_threads): - process = multiprocessing.Process(target=run_one_test, - args=(model_setting, test_setting, perf_results, all_inputs, - intra_op_num_threads)) + process = multiprocessing.Process( + target=run_one_test, + args=( + model_setting, + test_setting, + perf_results, + all_inputs, + intra_op_num_threads, + ), + ) process.start() process.join() def run_perf_tests(model_setting, test_setting, perf_results, all_inputs): - if (test_setting.intra_op_num_threads is not None): - launch_test(model_setting, test_setting, perf_results, all_inputs, test_setting.intra_op_num_threads) + if test_setting.intra_op_num_threads is not None: + launch_test( + model_setting, + test_setting, + perf_results, + all_inputs, + test_setting.intra_op_num_threads, + ) return cpu_count = psutil.cpu_count(logical=False) @@ -262,91 +315,139 @@ def run_perf_tests(model_setting, test_setting, perf_results, all_inputs): def run_performance(model_setting, test_setting, perf_results): - input_ids, segment_ids, input_mask = get_bert_inputs(model_setting.model_path, model_setting.input_ids_name, - model_setting.segment_ids_name, model_setting.input_mask_name) + input_ids, segment_ids, input_mask = get_bert_inputs( + model_setting.model_path, + model_setting.input_ids_name, + model_setting.segment_ids_name, + model_setting.input_mask_name, + ) # Do not generate random mask for performance test. print( f"Generating {test_setting.test_cases} samples for batch_size={test_setting.batch_size} sequence_length={test_setting.sequence_length}" ) - all_inputs = generate_test_data(test_setting.batch_size, - test_setting.sequence_length, - test_setting.test_cases, - test_setting.seed, - test_setting.verbose, - input_ids, - segment_ids, - input_mask, - random_mask_length=False) + all_inputs = generate_test_data( + test_setting.batch_size, + test_setting.sequence_length, + test_setting.test_cases, + test_setting.seed, + test_setting.verbose, + input_ids, + segment_ids, + input_mask, + random_mask_length=False, + ) run_perf_tests(model_setting, test_setting, perf_results, all_inputs) def parse_arguments(): parser = argparse.ArgumentParser() - parser.add_argument('--model', required=True, type=str, help="bert onnx model path") + parser.add_argument("--model", required=True, type=str, help="bert onnx model path") - parser.add_argument('-b', - '--batch_size', - required=True, - type=int, - nargs="+", - help="batch size of input. Allow one or multiple values in the range of [1, 128].") + parser.add_argument( + "-b", + "--batch_size", + required=True, + type=int, + nargs="+", + help="batch size of input. Allow one or multiple values in the range of [1, 128].", + ) - parser.add_argument('-s', '--sequence_length', required=True, type=int, help="maximum sequence length of input") + parser.add_argument( + "-s", + "--sequence_length", + required=True, + type=int, + help="maximum sequence length of input", + ) - parser.add_argument('--samples', required=False, type=int, default=10, help="number of samples to be generated") + parser.add_argument( + "--samples", + required=False, + type=int, + default=10, + help="number of samples to be generated", + ) - parser.add_argument('-t', - '--test_times', - required=False, - type=int, - default=0, - help="number of times to run per sample. By default, the value is 1000 / samples") + parser.add_argument( + "-t", + "--test_times", + required=False, + type=int, + default=0, + help="number of times to run per sample. By default, the value is 1000 / samples", + ) parser.add_argument( - '--opt_level', + "--opt_level", required=False, type=int, choices=[0, 1, 2, 99], default=99, - help="onnxruntime optimization level: 0 - disable all, 1 - basic, 2 - extended, 99 - enable all.") + help="onnxruntime optimization level: 0 - disable all, 1 - basic, 2 - extended, 99 - enable all.", + ) - parser.add_argument('--seed', - required=False, - type=int, - default=3, - help="random seed. Use the same seed to make sure test data is same in multiple tests.") + parser.add_argument( + "--seed", + required=False, + type=int, + default=3, + help="random seed. Use the same seed to make sure test data is same in multiple tests.", + ) - parser.add_argument('--verbose', required=False, action='store_true', help="print verbose information") + parser.add_argument( + "--verbose", + required=False, + action="store_true", + help="print verbose information", + ) parser.set_defaults(verbose=False) - parser.add_argument('--use_gpu', required=False, action='store_true', help="use GPU") + parser.add_argument("--use_gpu", required=False, action="store_true", help="use GPU") parser.set_defaults(use_gpu=False) - parser.add_argument('--use_io_binding', required=False, action='store_true', help="use io_binding") + parser.add_argument("--use_io_binding", required=False, action="store_true", help="use io_binding") parser.set_defaults(use_io_binding=False) - parser.add_argument("--provider", - required=False, - type=str, - default=None, - help="Execution provider to use") - - parser.add_argument('-n', - '--intra_op_num_threads', - required=False, - type=int, - default=None, - help=">=0, set intra_op_num_threads") - - parser.add_argument('--input_ids_name', required=False, type=str, default=None, help="input name for input ids") - parser.add_argument('--segment_ids_name', required=False, type=str, default=None, help="input name for segment ids") - parser.add_argument('--input_mask_name', - required=False, - type=str, - default=None, - help="input name for attention mask") + parser.add_argument( + "--provider", + required=False, + type=str, + default=None, + help="Execution provider to use", + ) + + parser.add_argument( + "-n", + "--intra_op_num_threads", + required=False, + type=int, + default=None, + help=">=0, set intra_op_num_threads", + ) + + parser.add_argument( + "--input_ids_name", + required=False, + type=str, + default=None, + help="input name for input ids", + ) + parser.add_argument( + "--segment_ids_name", + required=False, + type=str, + default=None, + help="input name for segment ids", + ) + parser.add_argument( + "--input_mask_name", + required=False, + type=str, + default=None, + help="input name for attention mask", + ) args = parser.parse_args() return args @@ -365,12 +466,27 @@ def main(): if not min(batch_size_set) >= 1 and max(batch_size_set) <= 128: raise Exception("batch_size not in range [1, 128]") - model_setting = ModelSetting(args.model, args.input_ids_name, args.segment_ids_name, args.input_mask_name, - args.opt_level) + model_setting = ModelSetting( + args.model, + args.input_ids_name, + args.segment_ids_name, + args.input_mask_name, + args.opt_level, + ) for batch_size in batch_size_set: - test_setting = TestSetting(batch_size, args.sequence_length, args.samples, args.test_times, args.use_gpu, args.use_io_binding, - args.provider, args.intra_op_num_threads, args.seed, args.verbose) + test_setting = TestSetting( + batch_size, + args.sequence_length, + args.samples, + args.test_times, + args.use_gpu, + args.use_io_binding, + args.provider, + args.intra_op_num_threads, + args.seed, + args.verbose, + ) print("test setting", test_setting) run_performance(model_setting, test_setting, perf_results) @@ -380,25 +496,33 @@ def main(): summary_file = os.path.join( Path(args.model).parent, - "perf_results_{}_B{}_S{}_{}.txt".format('GPU' if args.use_gpu else 'CPU', - "-".join([str(x) for x in sorted(list(batch_size_set))]), - args.sequence_length, - datetime.now().strftime("%Y%m%d-%H%M%S"))) - with open(summary_file, 'w+', newline='') as tsv_file: - tsv_writer = csv.writer(tsv_file, delimiter='\t', lineterminator='\n') + "perf_results_{}_B{}_S{}_{}.txt".format( + "GPU" if args.use_gpu else "CPU", + "-".join([str(x) for x in sorted(list(batch_size_set))]), + args.sequence_length, + datetime.now().strftime("%Y%m%d-%H%M%S"), + ), + ) + with open(summary_file, "w+", newline="") as tsv_file: + tsv_writer = csv.writer(tsv_file, delimiter="\t", lineterminator="\n") headers = None for (key, perf_result) in sorted_results: - params = key.split(',') + params = key.split(",") if headers is None: headers = [ - "Latency(ms)", "Latency_P50", "Latency_P75", "Latency_P90", "Latency_P95", "Latency_P99", - "Throughput(QPS)" + "Latency(ms)", + "Latency_P50", + "Latency_P75", + "Latency_P90", + "Latency_P95", + "Latency_P99", + "Throughput(QPS)", ] - headers.extend([x.split('=')[0] for x in params]) + headers.extend([x.split("=")[0] for x in params]) tsv_writer.writerow(headers) - values = [format(x, '.2f') for x in perf_result] - values.extend([x.split('=')[1] for x in params]) + values = [format(x, ".2f") for x in perf_result] + values.extend([x.split("=")[1] for x in params]) tsv_writer.writerow(values) print("Test summary is saved to", summary_file) diff --git a/onnxruntime/python/tools/transformers/bert_test_data.py b/onnxruntime/python/tools/transformers/bert_test_data.py index 940cb3306872c..dc818ae87c40c 100644 --- a/onnxruntime/python/tools/transformers/bert_test_data.py +++ b/onnxruntime/python/tools/transformers/bert_test_data.py @@ -1,24 +1,26 @@ -#------------------------------------------------------------------------- +# ------------------------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. -#-------------------------------------------------------------------------- +# -------------------------------------------------------------------------- # It is a tool to generate test data for a bert model. # The test data can be used by onnxruntime_perf_test tool to evaluate the inference latency. -import sys import argparse -import numpy as np import os import random +import sys from pathlib import Path -from typing import List, Dict, Tuple, Union +from typing import Dict, List, Tuple, Union + +import numpy as np from onnx import ModelProto, TensorProto, numpy_helper from onnx_model import OnnxModel -def fake_input_ids_data(input_ids: TensorProto, batch_size: int, sequence_length: int, - dictionary_size: int) -> np.ndarray: +def fake_input_ids_data( + input_ids: TensorProto, batch_size: int, sequence_length: int, dictionary_size: int +) -> np.ndarray: """Create input tensor based on the graph input of input_ids Args: @@ -30,7 +32,11 @@ def fake_input_ids_data(input_ids: TensorProto, batch_size: int, sequence_length Returns: np.ndarray: the input tensor created """ - assert input_ids.type.tensor_type.elem_type in [TensorProto.FLOAT, TensorProto.INT32, TensorProto.INT64] + assert input_ids.type.tensor_type.elem_type in [ + TensorProto.FLOAT, + TensorProto.INT32, + TensorProto.INT64, + ] data = np.random.randint(dictionary_size, size=(batch_size, sequence_length), dtype=np.int32) @@ -43,7 +49,7 @@ def fake_input_ids_data(input_ids: TensorProto, batch_size: int, sequence_length def fake_segment_ids_data(segment_ids: TensorProto, batch_size: int, sequence_length: int) -> np.ndarray: - """Create input tensor based on the graph input of segment_ids + """Create input tensor based on the graph input of segment_ids Args: segment_ids (TensorProto): graph input of the token_type_ids input tensor @@ -53,7 +59,11 @@ def fake_segment_ids_data(segment_ids: TensorProto, batch_size: int, sequence_le Returns: np.ndarray: the input tensor created """ - assert segment_ids.type.tensor_type.elem_type in [TensorProto.FLOAT, TensorProto.INT32, TensorProto.INT64] + assert segment_ids.type.tensor_type.elem_type in [ + TensorProto.FLOAT, + TensorProto.INT32, + TensorProto.INT64, + ] data = np.zeros((batch_size, sequence_length), dtype=np.int32) @@ -65,8 +75,12 @@ def fake_segment_ids_data(segment_ids: TensorProto, batch_size: int, sequence_le return data -def fake_input_mask_data(input_mask: TensorProto, batch_size: int, sequence_length: int, - random_mask_length: bool) -> np.ndarray: +def fake_input_mask_data( + input_mask: TensorProto, + batch_size: int, + sequence_length: int, + random_mask_length: bool, +) -> np.ndarray: """Create input tensor based on the graph input of segment_ids. Args: @@ -79,13 +93,17 @@ def fake_input_mask_data(input_mask: TensorProto, batch_size: int, sequence_leng np.ndarray: the input tensor created """ - assert input_mask.type.tensor_type.elem_type in [TensorProto.FLOAT, TensorProto.INT32, TensorProto.INT64] + assert input_mask.type.tensor_type.elem_type in [ + TensorProto.FLOAT, + TensorProto.INT32, + TensorProto.INT64, + ] if random_mask_length: actual_seq_len = random.randint(int(sequence_length * 2 / 3), sequence_length) data = np.zeros((batch_size, sequence_length), dtype=np.int32) temp = np.ones((batch_size, actual_seq_len), dtype=np.int32) - data[:temp.shape[0], :temp.shape[1]] = temp + data[: temp.shape[0], : temp.shape[1]] = temp else: data = np.ones((batch_size, sequence_length), dtype=np.int32) @@ -117,14 +135,23 @@ def output_test_data(dir: str, inputs: np.ndarray): index = 0 for name, data in inputs.items(): tensor = numpy_helper.from_array(data, name) - with open(os.path.join(dir, 'input_{}.pb'.format(index)), 'wb') as f: + with open(os.path.join(dir, "input_{}.pb".format(index)), "wb") as f: f.write(tensor.SerializeToString()) index += 1 -def fake_test_data(batch_size: int, sequence_length: int, test_cases: int, dictionary_size: int, verbose: bool, - random_seed: int, input_ids: TensorProto, segment_ids: TensorProto, input_mask: TensorProto, - random_mask_length: bool): +def fake_test_data( + batch_size: int, + sequence_length: int, + test_cases: int, + dictionary_size: int, + verbose: bool, + random_seed: int, + input_ids: TensorProto, + segment_ids: TensorProto, + input_mask: TensorProto, + random_mask_length: bool, +): """Create given number of input data for testing Args: @@ -164,9 +191,17 @@ def fake_test_data(batch_size: int, sequence_length: int, test_cases: int, dicti return all_inputs -def generate_test_data(batch_size: int, sequence_length: int, test_cases: int, seed: int, verbose: bool, - input_ids: TensorProto, segment_ids: TensorProto, input_mask: TensorProto, - random_mask_length: bool): +def generate_test_data( + batch_size: int, + sequence_length: int, + test_cases: int, + seed: int, + verbose: bool, + input_ids: TensorProto, + segment_ids: TensorProto, + input_mask: TensorProto, + random_mask_length: bool, +): """Create given number of minput data for testing Args: @@ -184,8 +219,18 @@ def generate_test_data(batch_size: int, sequence_length: int, test_cases: int, s List[Dict[str,numpy.ndarray]]: list of test cases, where each test case is a dictonary with input name as key and a tensor as value """ dictionary_size = 10000 - all_inputs = fake_test_data(batch_size, sequence_length, test_cases, dictionary_size, verbose, seed, input_ids, - segment_ids, input_mask, random_mask_length) + all_inputs = fake_test_data( + batch_size, + sequence_length, + test_cases, + dictionary_size, + verbose, + seed, + input_ids, + segment_ids, + input_mask, + random_mask_length, + ) if len(all_inputs) != test_cases: print("Failed to create test data for test.") return all_inputs @@ -199,16 +244,17 @@ def get_graph_input_from_embed_node(onnx_model, embed_node, input_index): graph_input = onnx_model.find_graph_input(input) if graph_input is None: parent_node = onnx_model.get_parent(embed_node, input_index) - if parent_node is not None and parent_node.op_type == 'Cast': + if parent_node is not None and parent_node.op_type == "Cast": graph_input = onnx_model.find_graph_input(parent_node.input[0]) return graph_input -def find_bert_inputs(onnx_model: OnnxModel, - input_ids_name: str = None, - segment_ids_name: str = None, - input_mask_name: str = None - ) -> Tuple[Union[None, np.ndarray], Union[None, np.ndarray], Union[None, np.ndarray]]: +def find_bert_inputs( + onnx_model: OnnxModel, + input_ids_name: str = None, + segment_ids_name: str = None, + input_mask_name: str = None, +) -> Tuple[Union[None, np.ndarray], Union[None, np.ndarray], Union[None, np.ndarray]]: """Find graph inputs for BERT model. First, we will deduce inputs from EmbedLayerNormalization node. If not found, we will guess the meaning of graph inputs based on naming. @@ -254,7 +300,7 @@ def find_bert_inputs(onnx_model: OnnxModel, if len(graph_inputs) != 3: raise ValueError("Expect the graph to have 3 inputs. Got {}".format(len(graph_inputs))) - embed_nodes = onnx_model.get_nodes_by_op_type('EmbedLayerNormalization') + embed_nodes = onnx_model.get_nodes_by_op_type("EmbedLayerNormalization") if len(embed_nodes) == 1: embed_node = embed_nodes[0] input_ids = get_graph_input_from_embed_node(onnx_model, embed_node, 0) @@ -279,7 +325,9 @@ def find_bert_inputs(onnx_model: OnnxModel, input_name_lower = input.name.lower() if "mask" in input_name_lower: # matches input with name like "attention_mask" or "input_mask" input_mask = input - elif "token" in input_name_lower or "segment" in input_name_lower: # matches input with name like "segment_ids" or "token_type_ids" + elif ( + "token" in input_name_lower or "segment" in input_name_lower + ): # matches input with name like "segment_ids" or "token_type_ids" segment_ids = input else: input_ids = input @@ -290,10 +338,12 @@ def find_bert_inputs(onnx_model: OnnxModel, raise ValueError("Fail to assign 3 inputs. You might try rename the graph inputs.") -def get_bert_inputs(onnx_file: str, - input_ids_name: str = None, - segment_ids_name: str = None, - input_mask_name: str = None): +def get_bert_inputs( + onnx_file: str, + input_ids_name: str = None, + segment_ids_name: str = None, + input_mask_name: str = None, +): """Find graph inputs for BERT model. First, we will deduce inputs from EmbedLayerNormalization node. If not found, we will guess the meaning of graph inputs based on naming. @@ -317,54 +367,95 @@ def get_bert_inputs(onnx_file: str, def parse_arguments(): parser = argparse.ArgumentParser() - parser.add_argument('--model', required=True, type=str, help="bert onnx model path.") - - parser.add_argument('--output_dir', - required=False, - type=str, - default=None, - help="output test data path. Default is current directory.") - - parser.add_argument('--batch_size', required=False, type=int, default=1, help="batch size of input") - - parser.add_argument('--sequence_length', - required=False, - type=int, - default=128, - help="maximum sequence length of input") - - parser.add_argument('--input_ids_name', required=False, type=str, default=None, help="input name for input ids") - parser.add_argument('--segment_ids_name', required=False, type=str, default=None, help="input name for segment ids") - parser.add_argument('--input_mask_name', - required=False, - type=str, - default=None, - help="input name for attention mask") - - parser.add_argument('--samples', required=False, type=int, default=1, help="number of test cases to be generated") - - parser.add_argument('--seed', required=False, type=int, default=3, help="random seed") - - parser.add_argument('--verbose', required=False, action='store_true', help="print verbose information") + parser.add_argument("--model", required=True, type=str, help="bert onnx model path.") + + parser.add_argument( + "--output_dir", + required=False, + type=str, + default=None, + help="output test data path. Default is current directory.", + ) + + parser.add_argument("--batch_size", required=False, type=int, default=1, help="batch size of input") + + parser.add_argument( + "--sequence_length", + required=False, + type=int, + default=128, + help="maximum sequence length of input", + ) + + parser.add_argument( + "--input_ids_name", + required=False, + type=str, + default=None, + help="input name for input ids", + ) + parser.add_argument( + "--segment_ids_name", + required=False, + type=str, + default=None, + help="input name for segment ids", + ) + parser.add_argument( + "--input_mask_name", + required=False, + type=str, + default=None, + help="input name for attention mask", + ) + + parser.add_argument( + "--samples", + required=False, + type=int, + default=1, + help="number of test cases to be generated", + ) + + parser.add_argument("--seed", required=False, type=int, default=3, help="random seed") + + parser.add_argument( + "--verbose", + required=False, + action="store_true", + help="print verbose information", + ) parser.set_defaults(verbose=False) - parser.add_argument('--only_input_tensors', - required=False, - action='store_true', - help="only save input tensors and no output tensors") + parser.add_argument( + "--only_input_tensors", + required=False, + action="store_true", + help="only save input tensors and no output tensors", + ) parser.set_defaults(only_input_tensors=False) args = parser.parse_args() return args -def create_and_save_test_data(model: str, output_dir: str, batch_size: int, sequence_length: int, test_cases: int, - seed: int, verbose: bool, input_ids_name: str, segment_ids_name: str, - input_mask_name: str, only_input_tensors: bool): +def create_and_save_test_data( + model: str, + output_dir: str, + batch_size: int, + sequence_length: int, + test_cases: int, + seed: int, + verbose: bool, + input_ids_name: str, + segment_ids_name: str, + input_mask_name: str, + only_input_tensors: bool, +): """Create test data for a model, and save test data to a directory. Args: - model (str): path of ONNX bert model + model (str): path of ONNX bert model output_dir (str): output directory batch_size (int): batch size sequence_length (int): sequence length @@ -378,33 +469,36 @@ def create_and_save_test_data(model: str, output_dir: str, batch_size: int, sequ """ input_ids, segment_ids, input_mask = get_bert_inputs(model, input_ids_name, segment_ids_name, input_mask_name) - all_inputs = generate_test_data(batch_size, - sequence_length, - test_cases, - seed, - verbose, - input_ids, - segment_ids, - input_mask, - random_mask_length=False) + all_inputs = generate_test_data( + batch_size, + sequence_length, + test_cases, + seed, + verbose, + input_ids, + segment_ids, + input_mask, + random_mask_length=False, + ) for i, inputs in enumerate(all_inputs): - dir = os.path.join(output_dir, 'test_data_set_' + str(i)) + dir = os.path.join(output_dir, "test_data_set_" + str(i)) output_test_data(dir, inputs) if only_input_tensors: return import onnxruntime + sess = onnxruntime.InferenceSession(model) output_names = [output.name for output in sess.get_outputs()] for i, inputs in enumerate(all_inputs): - dir = os.path.join(output_dir, 'test_data_set_' + str(i)) + dir = os.path.join(output_dir, "test_data_set_" + str(i)) result = sess.run(output_names, inputs) for i, output_name in enumerate(output_names): tensor_result = numpy_helper.from_array(np.asarray(result[i]), output_names[i]) - with open(os.path.join(dir, 'output_{}.pb'.format(i)), 'wb') as f: + with open(os.path.join(dir, "output_{}.pb".format(i)), "wb") as f: f.write(tensor_result.SerializeToString()) @@ -424,9 +518,19 @@ def main(): else: print("Directory existed. test data files will be overwritten.") - create_and_save_test_data(args.model, output_dir, args.batch_size, args.sequence_length, args.samples, args.seed, - args.verbose, args.input_ids_name, args.segment_ids_name, args.input_mask_name, - args.only_input_tensors) + create_and_save_test_data( + args.model, + output_dir, + args.batch_size, + args.sequence_length, + args.samples, + args.seed, + args.verbose, + args.input_ids_name, + args.segment_ids_name, + args.input_mask_name, + args.only_input_tensors, + ) print("Test data is saved to directory:", output_dir) diff --git a/onnxruntime/python/tools/transformers/compare_bert_results.py b/onnxruntime/python/tools/transformers/compare_bert_results.py index 5837581893b00..337b96b89d510 100644 --- a/onnxruntime/python/tools/transformers/compare_bert_results.py +++ b/onnxruntime/python/tools/transformers/compare_bert_results.py @@ -1,27 +1,28 @@ -#------------------------------------------------------------------------- +# ------------------------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. -#-------------------------------------------------------------------------- +# -------------------------------------------------------------------------- # It is a tool to compare the inference results of the original model and optimized model. -import sys import argparse -import numpy as np +import csv import os import random -from pathlib import Path import statistics +import sys +import timeit +from datetime import datetime +from pathlib import Path + +import numpy as np import onnx import onnx.utils import psutil -import csv -import timeit -from datetime import datetime +from bert_perf_test import create_session, onnxruntime_inference +from bert_test_data import generate_test_data, get_bert_inputs, output_test_data from onnx import ModelProto, TensorProto, numpy_helper from onnx_model import OnnxModel -from bert_test_data import get_bert_inputs, generate_test_data, output_test_data -from bert_perf_test import create_session, onnxruntime_inference def run_model(model_path, all_inputs, use_gpu, disable_optimization): @@ -64,51 +65,75 @@ def compare(baseline_results, treatment_results, verbose, rtol=1e-3, atol=1e-4): print("rel_diff={} abs_diff={}".format(rel_diff, abs_diff)) if diff_count == 0: - print("100% passed for {} random inputs given thresholds (rtol={}, atol={}).".format( - len(baseline_results), rtol, atol)) + print( + "100% passed for {} random inputs given thresholds (rtol={}, atol={}).".format( + len(baseline_results), rtol, atol + ) + ) else: - print("WARNING: {} out of {} results NOT passed for thresholds (rtol={}, atol={}).".format( - diff_count, len(baseline_results), rtol, atol)) + print( + "WARNING: {} out of {} results NOT passed for thresholds (rtol={}, atol={}).".format( + diff_count, len(baseline_results), rtol, atol + ) + ) print("maximum absolute difference={}".format(max_abs_diff)) print("maximum relative difference={}".format(max_rel_diff)) -def run_test(baseline_model, optimized_model, output_dir, batch_size, sequence_length, use_gpu, test_cases, seed, - verbose, rtol, atol, input_ids_name, segment_ids_name, input_mask_name): +def run_test( + baseline_model, + optimized_model, + output_dir, + batch_size, + sequence_length, + use_gpu, + test_cases, + seed, + verbose, + rtol, + atol, + input_ids_name, + segment_ids_name, + input_mask_name, +): # Try deduce input names from optimized model. - input_ids, segment_ids, input_mask = get_bert_inputs(optimized_model, input_ids_name, segment_ids_name, - input_mask_name) + input_ids, segment_ids, input_mask = get_bert_inputs( + optimized_model, input_ids_name, segment_ids_name, input_mask_name + ) # Use random mask length for accuracy test. It might introduce slight inflation in latency reported in this script. - all_inputs = generate_test_data(batch_size, - sequence_length, - test_cases, - seed, - verbose, - input_ids, - segment_ids, - input_mask, - random_mask_length=True) - - baseline_results, baseline_latency, output_names = run_model(baseline_model, - all_inputs, - use_gpu, - disable_optimization=True) + all_inputs = generate_test_data( + batch_size, + sequence_length, + test_cases, + seed, + verbose, + input_ids, + segment_ids, + input_mask, + random_mask_length=True, + ) + + baseline_results, baseline_latency, output_names = run_model( + baseline_model, all_inputs, use_gpu, disable_optimization=True + ) if verbose: - print("baseline average latency (all optimizations disabled): {} ms".format( - statistics.mean(baseline_latency) * 1000)) + print( + "baseline average latency (all optimizations disabled): {} ms".format( + statistics.mean(baseline_latency) * 1000 + ) + ) if output_dir is not None: for i, inputs in enumerate(all_inputs): output_test_data(output_dir, i, inputs) - treatment_results, treatment_latency, treatment_output_names = run_model(optimized_model, - all_inputs, - use_gpu, - disable_optimization=False) + treatment_results, treatment_latency, treatment_output_names = run_model( + optimized_model, all_inputs, use_gpu, disable_optimization=False + ) if verbose: print("treatment average latency: {} ms".format(statistics.mean(treatment_latency) * 1000)) @@ -118,41 +143,79 @@ def run_test(baseline_model, optimized_model, output_dir, batch_size, sequence_l def parse_arguments(): parser = argparse.ArgumentParser() - parser.add_argument('--baseline_model', required=True, type=str, help="baseline onnx model path.") - - parser.add_argument('--optimized_model', - required=True, - type=str, - default=None, - help="path of the optimized model. It shall have same inputs as the baseline model.") - - parser.add_argument('--output_dir', - required=False, - type=str, - default=None, - help="output test data path. If not specified, test data will not be saved.") - - parser.add_argument('--batch_size', required=True, type=int, help="batch size of input") - - parser.add_argument('--sequence_length', required=True, type=int, help="maximum sequence length of input") - - parser.add_argument('--rtol', required=False, type=float, default=1e-3, help="relative tolerance") - - parser.add_argument('--atol', required=False, type=float, default=1e-4, help="absolute tolerance") - - parser.add_argument('--samples', required=False, type=int, default=100, help="number of test cases to be generated") - - parser.add_argument('--seed', required=False, type=int, default=3, help="random seed") - - parser.add_argument('--use_gpu', required=False, action='store_true', help="use GPU") + parser.add_argument("--baseline_model", required=True, type=str, help="baseline onnx model path.") + + parser.add_argument( + "--optimized_model", + required=True, + type=str, + default=None, + help="path of the optimized model. It shall have same inputs as the baseline model.", + ) + + parser.add_argument( + "--output_dir", + required=False, + type=str, + default=None, + help="output test data path. If not specified, test data will not be saved.", + ) + + parser.add_argument("--batch_size", required=True, type=int, help="batch size of input") + + parser.add_argument( + "--sequence_length", + required=True, + type=int, + help="maximum sequence length of input", + ) + + parser.add_argument("--rtol", required=False, type=float, default=1e-3, help="relative tolerance") + + parser.add_argument("--atol", required=False, type=float, default=1e-4, help="absolute tolerance") + + parser.add_argument( + "--samples", + required=False, + type=int, + default=100, + help="number of test cases to be generated", + ) + + parser.add_argument("--seed", required=False, type=int, default=3, help="random seed") + + parser.add_argument("--use_gpu", required=False, action="store_true", help="use GPU") parser.set_defaults(use_gpu=False) - parser.add_argument('--verbose', required=False, action='store_true', help="print verbose information") + parser.add_argument( + "--verbose", + required=False, + action="store_true", + help="print verbose information", + ) parser.set_defaults(verbose=False) - parser.add_argument('--input_ids', required=False, type=str, default=None, help="input name for input ids") - parser.add_argument('--segment_ids', required=False, type=str, default=None, help="input name for segment ids") - parser.add_argument('--input_mask', required=False, type=str, default=None, help="input name for attention mask") + parser.add_argument( + "--input_ids", + required=False, + type=str, + default=None, + help="input name for input ids", + ) + parser.add_argument( + "--segment_ids", + required=False, + type=str, + default=None, + help="input name for segment ids", + ) + parser.add_argument( + "--input_mask", + required=False, + type=str, + default=None, + help="input name for attention mask", + ) args = parser.parse_args() return args @@ -166,9 +229,22 @@ def main(): path = Path(args.output_dir) path.mkdir(parents=True, exist_ok=True) - run_test(args.baseline_model, args.optimized_model, args.output_dir, args.batch_size, args.sequence_length, - args.use_gpu, args.samples, args.seed, args.verbose, args.rtol, args.atol, args.input_ids, - args.segment_ids, args.input_mask) + run_test( + args.baseline_model, + args.optimized_model, + args.output_dir, + args.batch_size, + args.sequence_length, + args.use_gpu, + args.samples, + args.seed, + args.verbose, + args.rtol, + args.atol, + args.input_ids, + args.segment_ids, + args.input_mask, + ) if __name__ == "__main__": diff --git a/onnxruntime/python/tools/transformers/convert_beam_search.py b/onnxruntime/python/tools/transformers/convert_beam_search.py index 2e6f5746de214..ddfc43e3c8ebd 100644 --- a/onnxruntime/python/tools/transformers/convert_beam_search.py +++ b/onnxruntime/python/tools/transformers/convert_beam_search.py @@ -1,7 +1,7 @@ -#------------------------------------------------------------------------- +# ------------------------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. -#------------------------------------------------------------------------- +# ------------------------------------------------------------------------- """ This converts GPT2 or T5 model to onnx with beam search operator. @@ -13,161 +13,203 @@ python convert_beam_search.py -m t5-small --model_type t5 --decoder_onnx ./onnx_models/t5-small_decoder.onnx --encoder_decoder_init_onnx ./onnx_models/t5-small_encoder_decoder_init.onnx --output ./onnx_models/t5_small_beam_search.onnx """ +import argparse +import logging import os +import sys import time -import onnx -import logging -import argparse from pathlib import Path -from onnx import helper -import numpy as np from typing import List, Union + +import numpy as np +import onnx import torch -from packaging import version -from transformers import GPT2Config, T5Config from benchmark_helper import Precision +from onnx import helper from onnx import onnx_pb as onnx_proto +from packaging import version +from transformers import GPT2Config, T5Config -import sys -import os - -sys.path.append(os.path.join(os.path.dirname(__file__), 'models', 'gpt2')) -from gpt2_helper import PRETRAINED_GPT2_MODELS +sys.path.append(os.path.join(os.path.dirname(__file__), "models", "gpt2")) from convert_to_onnx import main as convert_gpt2_to_onnx +from gpt2_helper import PRETRAINED_GPT2_MODELS config: Union[GPT2Config, T5Config] = None -logger = logging.getLogger('') +logger = logging.getLogger("") def parse_arguments(argv=None): parser = argparse.ArgumentParser() - parser.add_argument('-m', - '--model_name_or_path', - required=True, - type=str, - help='Model path, or pretrained model name in the list: ' + ', '.join(PRETRAINED_GPT2_MODELS)) - - parser.add_argument('--model_type', - required=False, - type=str, - default="gpt2", - choices=["gpt2", "t5"], - help='Model type in the list: ' + ', '.join(["gpt2", "t5"])) - - parser.add_argument('--cache_dir', - required=False, - type=str, - default=os.path.join('.', 'cache_models'), - help='Directory to cache pre-trained models') - - parser.add_argument('--decoder_onnx', - required=True, - type=str, - help='Output directory for decoder onnx model, or model path ends with .onnx') - - parser.add_argument('--encoder_decoder_init_onnx', - required=False, - type=str, - default="", - help='path of ONNX model for encoder and decoder initialization. Required for t5 model type.') - - parser.add_argument('--output', - required=False, - type=str, - help='Output directory for beam search model, or model path ends with .onnx') - - parser.add_argument("-p", - "--precision", - required=False, - type=Precision, - default=Precision.FLOAT32, - choices=[Precision.FLOAT32, Precision.FLOAT16], - help="Precision of model to run. fp32 for full precision, fp16 for half or mixed precision") - - parser.add_argument('--use_gpu', required=False, action='store_true', help="use GPU for inference") + parser.add_argument( + "-m", + "--model_name_or_path", + required=True, + type=str, + help="Model path, or pretrained model name in the list: " + ", ".join(PRETRAINED_GPT2_MODELS), + ) + + parser.add_argument( + "--model_type", + required=False, + type=str, + default="gpt2", + choices=["gpt2", "t5"], + help="Model type in the list: " + ", ".join(["gpt2", "t5"]), + ) + + parser.add_argument( + "--cache_dir", + required=False, + type=str, + default=os.path.join(".", "cache_models"), + help="Directory to cache pre-trained models", + ) + + parser.add_argument( + "--decoder_onnx", + required=True, + type=str, + help="Output directory for decoder onnx model, or model path ends with .onnx", + ) + + parser.add_argument( + "--encoder_decoder_init_onnx", + required=False, + type=str, + default="", + help="path of ONNX model for encoder and decoder initialization. Required for t5 model type.", + ) + + parser.add_argument( + "--output", + required=False, + type=str, + help="Output directory for beam search model, or model path ends with .onnx", + ) + + parser.add_argument( + "-p", + "--precision", + required=False, + type=Precision, + default=Precision.FLOAT32, + choices=[Precision.FLOAT32, Precision.FLOAT16], + help="Precision of model to run. fp32 for full precision, fp16 for half or mixed precision", + ) + + parser.add_argument("--use_gpu", required=False, action="store_true", help="use GPU for inference") parser.set_defaults(use_gpu=False) - parser.add_argument('-e', '--use_external_data_format', required=False, action='store_true') + parser.add_argument("-e", "--use_external_data_format", required=False, action="store_true") parser.set_defaults(use_external_data_format=False) - parser.add_argument('--disable_parity', required=False, action='store_true', help="do not run parity test") + parser.add_argument( + "--disable_parity", + required=False, + action="store_true", + help="do not run parity test", + ) parser.set_defaults(disable_parity=False) - parser.add_argument('--torch_performance', required=False, action='store_true', help="test PyTorch performance") + parser.add_argument( + "--torch_performance", + required=False, + action="store_true", + help="test PyTorch performance", + ) parser.set_defaults(torch_performance=False) - parser.add_argument('--total_runs', - required=False, - type=int, - default=1, - help='Number of times of inference for latency measurement') + parser.add_argument( + "--total_runs", + required=False, + type=int, + default=1, + help="Number of times of inference for latency measurement", + ) beam_search_group = parser.add_argument_group("beam search options") - beam_search_group.add_argument('--output_sequences_scores', - required=False, - action='store_true', - help="output sequences scores") + beam_search_group.add_argument( + "--output_sequences_scores", + required=False, + action="store_true", + help="output sequences scores", + ) beam_search_group.set_defaults(output_sequences_scores=False) - beam_search_group.add_argument('--output_token_scores', - required=False, - action='store_true', - help="output token scores") + beam_search_group.add_argument( + "--output_token_scores", + required=False, + action="store_true", + help="output token scores", + ) beam_search_group.set_defaults(output_token_scores=False) - beam_search_group.add_argument('--early_stopping', required=False, action='store_true') + beam_search_group.add_argument("--early_stopping", required=False, action="store_true") beam_search_group.set_defaults(early_stopping=False) - beam_search_group.add_argument('--min_length', type=int, required=False, default=1, help='Min sequence length') + beam_search_group.add_argument("--min_length", type=int, required=False, default=1, help="Min sequence length") - beam_search_group.add_argument('--max_length', type=int, required=False, default=50, help='Max sequence length') + beam_search_group.add_argument("--max_length", type=int, required=False, default=50, help="Max sequence length") - beam_search_group.add_argument('--no_repeat_ngram_size', - type=int, - required=False, - default=0, - help='No repeat ngram size') + beam_search_group.add_argument( + "--no_repeat_ngram_size", + type=int, + required=False, + default=0, + help="No repeat ngram size", + ) - beam_search_group.add_argument('--num_beams', type=int, required=False, default=4, help='Beam size') + beam_search_group.add_argument("--num_beams", type=int, required=False, default=4, help="Beam size") - beam_search_group.add_argument('--num_return_sequences', - type=int, - required=False, - default=1, - help='Number of return sequence <= num_beams') + beam_search_group.add_argument( + "--num_return_sequences", + type=int, + required=False, + default=1, + help="Number of return sequence <= num_beams", + ) - beam_search_group.add_argument('--temperature', - type=float, - required=False, - default=1, - help='Softmax temperature for output logits.') + beam_search_group.add_argument( + "--temperature", + type=float, + required=False, + default=1, + help="Softmax temperature for output logits.", + ) - beam_search_group.add_argument('--length_penalty', - type=float, - required=False, - default=1, - help='Positive. >1 to penalize and <1 to encorage short sentence.') + beam_search_group.add_argument( + "--length_penalty", + type=float, + required=False, + default=1, + help="Positive. >1 to penalize and <1 to encorage short sentence.", + ) - beam_search_group.add_argument('--repetition_penalty', - type=float, - required=False, - default=1, - help='Positive. >1 to penalize and <1 to encorage.') + beam_search_group.add_argument( + "--repetition_penalty", + type=float, + required=False, + default=1, + help="Positive. >1 to penalize and <1 to encorage.", + ) - beam_search_group.add_argument('--vocab_size', - type=int, - required=False, - default=-1, - help="Vocab_size of the underlying model") + beam_search_group.add_argument( + "--vocab_size", + type=int, + required=False, + default=-1, + help="Vocab_size of the underlying model", + ) beam_search_group.add_argument( - '--prefix_vocab_mask', + "--prefix_vocab_mask", required=False, - action='store_true', - help="This vocab mask applies only to first iteration, enable if last word in query might need auto complete") + action="store_true", + help="This vocab mask applies only to first iteration, enable if last word in query might need auto complete", + ) beam_search_group.set_defaults(prefix_vocab_mask=False) args = parser.parse_args(argv) @@ -180,39 +222,40 @@ def gpt2_to_onnx(args): print(f"use convert_to_onnx.py to convert model {model_name} to onnx {args.decoder_onnx} ...") arguments = [ - '--model_name_or_path', + "--model_name_or_path", model_name, - '--output', + "--output", args.decoder_onnx, - '--optimize_onnx', - '--precision', - 'fp32' if args.precision == Precision.FLOAT32 else 'fp16', - '--test_runs', - '1', - '--test_cases', - '10', - '--use_int32_inputs' # BeamSearch requires to use int32 for input_ids, postion_ids and attention_mask + "--optimize_onnx", + "--precision", + "fp32" if args.precision == Precision.FLOAT32 else "fp16", + "--test_runs", + "1", + "--test_cases", + "10", + "--use_int32_inputs", # BeamSearch requires to use int32 for input_ids, postion_ids and attention_mask ] if args.use_gpu: - arguments.append('--use_gpu') + arguments.append("--use_gpu") if args.use_external_data_format: - arguments.append('--use_external_data_format') + arguments.append("--use_external_data_format") if args.precision == Precision.FLOAT16: assert args.use_gpu, "fp16 or mixed precision model cannot run in CPU. Please add --use_gpu" # TODO: Use auto mixed precision for fp16 conversion: arguments.append('--auto_mixed_precision') # Need change cuda kernel to support a combination of fp32 logits and fp16 past state. # Currently logits and past state shall be same data type. - arguments.extend(['--op_block_list', 'Add', 'LayerNormalization', 'FastGelu']) + arguments.extend(["--op_block_list", "Add", "LayerNormalization", "FastGelu"]) convert_gpt2_to_onnx(arguments) def shape_inference(decoder_onnx_path): - if version.parse(onnx.__version__) >= version.parse('1.11.0'): + if version.parse(onnx.__version__) >= version.parse("1.11.0"): logger.warn("SymbolicShapeInference might fail using onnx version 1.11. Please install 1.10.0 for now.") # Run symbolic shape inference to walk around ORT shape inference issue for subgraph. from onnxruntime.tools.symbolic_shape_infer import SymbolicShapeInference + out = SymbolicShapeInference.infer_shapes(onnx.load(decoder_onnx_path), auto_merge=True, guess_output_rank=False) if out: # TODO: Use external format if input has extra data. @@ -222,12 +265,15 @@ def shape_inference(decoder_onnx_path): def create_ort_session(model_path, use_gpu): - from onnxruntime import SessionOptions, InferenceSession, __version__ as ort_version, GraphOptimizationLevel, get_available_providers + from onnxruntime import GraphOptimizationLevel, InferenceSession, SessionOptions + from onnxruntime import __version__ as ort_version + from onnxruntime import get_available_providers + sess_options = SessionOptions() sess_options.graph_optimization_level = GraphOptimizationLevel.ORT_DISABLE_ALL - execution_providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] if use_gpu else ['CPUExecutionProvider'] + execution_providers = ["CUDAExecutionProvider", "CPUExecutionProvider"] if use_gpu else ["CPUExecutionProvider"] if use_gpu: - if 'CUDAExecutionProvider' not in get_available_providers(): + if "CUDAExecutionProvider" not in get_available_providers(): raise RuntimeError("CUDAExecutionProvider is not avaiable for --use_gpu!") else: print("use CUDAExecutionProvider") @@ -237,12 +283,12 @@ def create_ort_session(model_path, use_gpu): def verify_gpt2_subgraph(graph, precision): - is_float16 = (Precision.FLOAT16 == precision) + is_float16 = Precision.FLOAT16 == precision input_count = len(graph.input) layer_count = input_count - 3 - expected_inputs = ['input_ids', 'position_ids', 'attention_mask'] + [f"past_{i}" for i in range(layer_count)] + expected_inputs = ["input_ids", "position_ids", "attention_mask"] + [f"past_{i}" for i in range(layer_count)] if len(graph.input) != len(expected_inputs): raise ValueError(f"Number of inputs expected to be {len(expected_inputs)}. Got {len(graph.input)}") @@ -260,7 +306,7 @@ def verify_gpt2_subgraph(graph, precision): ) print("Verifying GPT-2 graph inputs: name and data type are good.") - expected_outputs = ['logits'] + [f"present_{i}" for i in range(layer_count)] + expected_outputs = ["logits"] + [f"present_{i}" for i in range(layer_count)] if len(graph.output) != len(expected_outputs): raise ValueError(f"Number of outputs expected to be {len(expected_outputs)}. Got {len(graph.output)}") @@ -327,8 +373,15 @@ def convert_model(args): verify_t5_decoder_subgraph(model.graph, args.precision) inputs = [ - "input_ids", "max_length", "min_length", "num_beams", "num_return_sequences", "temperature", "length_penalty", - "repetition_penalty", "vocab_mask" + "input_ids", + "max_length", + "min_length", + "num_beams", + "num_return_sequences", + "temperature", + "length_penalty", + "repetition_penalty", + "vocab_mask", ] if args.prefix_vocab_mask: inputs.append("prefix_vocab_mask") @@ -341,16 +394,23 @@ def convert_model(args): assert args.output_sequences_scores, "--output_token_scores requires --output_sequences_scores" outputs.append("scores") - node = helper.make_node('BeamSearch', inputs=inputs, outputs=outputs, name=f'BeamSearch_{args.model_type}') + node = helper.make_node( + "BeamSearch", + inputs=inputs, + outputs=outputs, + name=f"BeamSearch_{args.model_type}", + ) node.domain = "com.microsoft" - node.attribute.extend([ - helper.make_attribute("eos_token_id", eos_token_id), - helper.make_attribute("pad_token_id", pad_token_id), - helper.make_attribute("no_repeat_ngram_size", args.no_repeat_ngram_size), - helper.make_attribute("early_stopping", 1 if args.early_stopping else 0), - helper.make_attribute("model_type", 0 if args.model_type == "gpt2" else 1), - helper.make_attribute("decoder", model.graph), - ]) + node.attribute.extend( + [ + helper.make_attribute("eos_token_id", eos_token_id), + helper.make_attribute("pad_token_id", pad_token_id), + helper.make_attribute("no_repeat_ngram_size", args.no_repeat_ngram_size), + helper.make_attribute("early_stopping", 1 if args.early_stopping else 0), + helper.make_attribute("model_type", 0 if args.model_type == "gpt2" else 1), + helper.make_attribute("decoder", model.graph), + ] + ) if args.model_type == "t5": if enable_shape_inference: @@ -359,42 +419,59 @@ def convert_model(args): init_model = onnx.load(args.encoder_decoder_init_onnx) init_model.graph.name = f"{args.model_type} encoder decoder init subgraph" verify_t5_encoder_decoder_init_subgraph(init_model.graph, args.precision) - node.attribute.extend([ - helper.make_attribute("encoder_decoder_init", init_model.graph), - ]) + node.attribute.extend( + [ + helper.make_attribute("encoder_decoder_init", init_model.graph), + ] + ) from onnx import TensorProto # graph inputs - input_ids = helper.make_tensor_value_info('input_ids', TensorProto.INT32, ['batch_size', 'sequence_length']) - max_length = helper.make_tensor_value_info('max_length', TensorProto.INT32, [1]) - min_length = helper.make_tensor_value_info('min_length', TensorProto.INT32, [1]) - num_beams = helper.make_tensor_value_info('num_beams', TensorProto.INT32, [1]) - num_return_sequences = helper.make_tensor_value_info('num_return_sequences', TensorProto.INT32, [1]) - temperature = helper.make_tensor_value_info('temperature', TensorProto.FLOAT, [1]) - length_penalty = helper.make_tensor_value_info('length_penalty', TensorProto.FLOAT, [1]) - repetition_penalty = helper.make_tensor_value_info('repetition_penalty', TensorProto.FLOAT, [1]) - vocab_mask = helper.make_tensor_value_info('vocab_mask', TensorProto.INT32, [vocab_size]) + input_ids = helper.make_tensor_value_info("input_ids", TensorProto.INT32, ["batch_size", "sequence_length"]) + max_length = helper.make_tensor_value_info("max_length", TensorProto.INT32, [1]) + min_length = helper.make_tensor_value_info("min_length", TensorProto.INT32, [1]) + num_beams = helper.make_tensor_value_info("num_beams", TensorProto.INT32, [1]) + num_return_sequences = helper.make_tensor_value_info("num_return_sequences", TensorProto.INT32, [1]) + temperature = helper.make_tensor_value_info("temperature", TensorProto.FLOAT, [1]) + length_penalty = helper.make_tensor_value_info("length_penalty", TensorProto.FLOAT, [1]) + repetition_penalty = helper.make_tensor_value_info("repetition_penalty", TensorProto.FLOAT, [1]) + vocab_mask = helper.make_tensor_value_info("vocab_mask", TensorProto.INT32, [vocab_size]) graph_inputs = [ - input_ids, max_length, min_length, num_beams, num_return_sequences, temperature, length_penalty, - repetition_penalty, vocab_mask + input_ids, + max_length, + min_length, + num_beams, + num_return_sequences, + temperature, + length_penalty, + repetition_penalty, + vocab_mask, ] if args.prefix_vocab_mask: - prefix_vocab_mask = helper.make_tensor_value_info('prefix_vocab_mask', TensorProto.INT32, - ['batch_size', vocab_size]) + prefix_vocab_mask = helper.make_tensor_value_info( + "prefix_vocab_mask", TensorProto.INT32, ["batch_size", vocab_size] + ) graph_inputs.append(prefix_vocab_mask) # graph outputs - sequences = helper.make_tensor_value_info('sequences', TensorProto.INT32, - ['batch_size', 'num_return_sequences', 'max_length']) - - sequences_scores = helper.make_tensor_value_info('sequences_scores', TensorProto.FLOAT, - ['batch_size', 'num_return_sequences']) - - scores = helper.make_tensor_value_info('scores', TensorProto.FLOAT, - ['max_length - sequence_length', 'batch_size', 'num_beams', vocab_size]) + sequences = helper.make_tensor_value_info( + "sequences", + TensorProto.INT32, + ["batch_size", "num_return_sequences", "max_length"], + ) + + sequences_scores = helper.make_tensor_value_info( + "sequences_scores", TensorProto.FLOAT, ["batch_size", "num_return_sequences"] + ) + + scores = helper.make_tensor_value_info( + "scores", + TensorProto.FLOAT, + ["max_length - sequence_length", "batch_size", "num_beams", vocab_size], + ) initializers = [] @@ -406,10 +483,20 @@ def convert_model(args): if args.output_token_scores: graph_outputs.append(scores) - new_graph = helper.make_graph([node], f'{args.model_type}-beam-search', graph_inputs, graph_outputs, initializers) + new_graph = helper.make_graph( + [node], + f"{args.model_type}-beam-search", + graph_inputs, + graph_outputs, + initializers, + ) # Create the model - new_model = helper.make_model(new_graph, producer_name='onnxruntime.transformers', opset_imports=model.opset_import) + new_model = helper.make_model( + new_graph, + producer_name="onnxruntime.transformers", + opset_imports=model.opset_import, + ) onnx.save(new_model, args.output) @@ -431,25 +518,28 @@ def test_torch_performance(args, model, input_ids, attention_mask, eos_token_id, torch_latency = [] for _ in range(args.total_runs): start = time.time() - _ = model.generate(input_ids=input_ids, - attention_mask=attention_mask, - max_length=args.max_length, - min_length=args.min_length, - num_beams=args.num_beams, - early_stopping=args.early_stopping, - no_repeat_ngram_size=args.no_repeat_ngram_size, - eos_token_id=eos_token_id, - pad_token_id=pad_token_id, - num_return_sequences=args.num_return_sequences, - temperature=args.temperature, - length_penalty=args.length_penalty, - repetition_penalty=args.repetition_penalty, - bad_words_ids=bad_words_ids, - return_dict_in_generate=True, - output_scores=args.output_sequences_scores or args.output_token_scores) + _ = model.generate( + input_ids=input_ids, + attention_mask=attention_mask, + max_length=args.max_length, + min_length=args.min_length, + num_beams=args.num_beams, + early_stopping=args.early_stopping, + no_repeat_ngram_size=args.no_repeat_ngram_size, + eos_token_id=eos_token_id, + pad_token_id=pad_token_id, + num_return_sequences=args.num_return_sequences, + temperature=args.temperature, + length_penalty=args.length_penalty, + repetition_penalty=args.repetition_penalty, + bad_words_ids=bad_words_ids, + return_dict_in_generate=True, + output_scores=args.output_sequences_scores or args.output_token_scores, + ) torch_latency.append(time.time() - start) batch_size = input_ids.shape[0] from benchmark_helper import get_latency_result + return get_latency_result(torch_latency, batch_size) @@ -469,21 +559,27 @@ def test_model(args, use_vocab_mask: bool = False, sentences: List[str] = None): print("Skipping parity test as prefix vocab mask is not implemented by Hugging Face") return True - from transformers import GPT2Tokenizer, GPT2LMHeadModel + from transformers import GPT2LMHeadModel, GPT2Tokenizer tokenizer = GPT2Tokenizer.from_pretrained(args.model_name_or_path, cache_dir=args.cache_dir) tokenizer.padding_side = "left" tokenizer.pad_token = tokenizer.eos_token - model = GPT2LMHeadModel.from_pretrained(args.model_name_or_path, - cache_dir=args.cache_dir, - pad_token_id=tokenizer.eos_token_id) + model = GPT2LMHeadModel.from_pretrained( + args.model_name_or_path, + cache_dir=args.cache_dir, + pad_token_id=tokenizer.eos_token_id, + ) # Use different length sentences to test batching if sentences is None: - sentences = ["The product is released", "I enjoy walking in the park", "Test best way to invest"] + sentences = [ + "The product is released", + "I enjoy walking in the park", + "Test best way to invest", + ] - inputs = tokenizer(sentences, return_tensors='pt', padding=True) + inputs = tokenizer(sentences, return_tensors="pt", padding=True) input_ids = inputs["input_ids"] attention_mask = inputs["attention_mask"] @@ -503,24 +599,26 @@ def test_model(args, use_vocab_mask: bool = False, sentences: List[str] = None): torch_decoded_sequences = [] if not args.disable_parity: - print('-' * 50) + print("-" * 50) print("Test PyTorch model and beam search with huggingface transformers...") - beam_outputs = model.generate(input_ids=input_ids, - attention_mask=attention_mask, - max_length=args.max_length, - min_length=args.min_length, - num_beams=args.num_beams, - early_stopping=args.early_stopping, - no_repeat_ngram_size=args.no_repeat_ngram_size, - eos_token_id=eos_token_id, - pad_token_id=pad_token_id, - num_return_sequences=args.num_return_sequences, - temperature=args.temperature, - length_penalty=args.length_penalty, - repetition_penalty=args.repetition_penalty, - bad_words_ids=bad_words_ids, - return_dict_in_generate=True, - output_scores=args.output_sequences_scores or args.output_token_scores) + beam_outputs = model.generate( + input_ids=input_ids, + attention_mask=attention_mask, + max_length=args.max_length, + min_length=args.min_length, + num_beams=args.num_beams, + early_stopping=args.early_stopping, + no_repeat_ngram_size=args.no_repeat_ngram_size, + eos_token_id=eos_token_id, + pad_token_id=pad_token_id, + num_return_sequences=args.num_return_sequences, + temperature=args.temperature, + length_penalty=args.length_penalty, + repetition_penalty=args.repetition_penalty, + bad_words_ids=bad_words_ids, + return_dict_in_generate=True, + output_scores=args.output_sequences_scores or args.output_token_scores, + ) print("input_ids", input_ids) print("huggingface transformers outputs:") print("sequences", beam_outputs.sequences) @@ -533,7 +631,7 @@ def test_model(args, use_vocab_mask: bool = False, sentences: List[str] = None): torch_decoded_sequences.append(decoded_sequence) print("{}: {}".format(i, decoded_sequence)) - print('-' * 50) + print("-" * 50) print("Test ONNX model and bream search with onnxruntime...") ort_session = create_ort_session(args.output, args.use_gpu) @@ -552,15 +650,16 @@ def test_model(args, use_vocab_mask: bool = False, sentences: List[str] = None): "temperature": np.array([args.temperature], dtype=np.float32), "length_penalty": np.array([args.length_penalty], dtype=np.float32), "repetition_penalty": np.array([args.repetition_penalty], dtype=np.float32), - "vocab_mask": vocab_mask + "vocab_mask": vocab_mask, } test_data_dir = Path(args.output).parent.as_posix() print("test_data_dir", test_data_dir) from bert_test_data import output_test_data + all_inputs = [inputs] for i, inputs in enumerate(all_inputs): - dir = os.path.join(test_data_dir, 'test_data_set_' + str(i)) + dir = os.path.join(test_data_dir, "test_data_set_" + str(i)) output_test_data(dir, inputs) print("inputs", inputs) @@ -573,6 +672,7 @@ def test_model(args, use_vocab_mask: bool = False, sentences: List[str] = None): latency.append(time.time() - start) batch_size = input_ids.shape[0] from benchmark_helper import get_latency_result + output = get_latency_result(latency, batch_size) print("ORT outputs:") @@ -604,13 +704,20 @@ def test_model(args, use_vocab_mask: bool = False, sentences: List[str] = None): print(ort_decoded_sequences) print("-" * 50) # Compare the generated text instead of word IDs since ORT pads to max sequence length but Torch not. - is_same = (torch_decoded_sequences == ort_decoded_sequences) + is_same = torch_decoded_sequences == ort_decoded_sequences print("Torch and ORT result is ", "same" if is_same else "different") output["parity"] = is_same if args.torch_performance: - torch_latency_output = test_torch_performance(args, model, input_ids, attention_mask, eos_token_id, - pad_token_id, bad_words_ids) + torch_latency_output = test_torch_performance( + args, + model, + input_ids, + attention_mask, + eos_token_id, + pad_token_id, + bad_words_ids, + ) print("Torch Latency", torch_latency_output) print("ORT", output) @@ -630,5 +737,5 @@ def main(argv=None, sentences=None): return test_model(args, use_vocab_mask=True, sentences=sentences) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/onnxruntime/python/tools/transformers/convert_tf_models_to_pytorch.py b/onnxruntime/python/tools/transformers/convert_tf_models_to_pytorch.py index f4ccede13f0c9..a035790b50954 100644 --- a/onnxruntime/python/tools/transformers/convert_tf_models_to_pytorch.py +++ b/onnxruntime/python/tools/transformers/convert_tf_models_to_pytorch.py @@ -1,24 +1,56 @@ -#------------------------------------------------------------------------- +# ------------------------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. -#-------------------------------------------------------------------------- +# -------------------------------------------------------------------------- import glob import os + import requests TFMODELS = { - "bert-base-uncased": - ("bert", "BertConfig", "", "https://storage.googleapis.com/bert_models/2018_10_18/uncased_L-12_H-768_A-12.zip"), - "bert-base-cased": - ("bert", "BertConfig", "", "https://storage.googleapis.com/bert_models/2019_05_30/wwm_cased_L-24_H-1024_A-16.zip"), - "bert-large-uncased": - ("bert", "BertConfig", "", "https://storage.googleapis.com/bert_models/2018_10_18/uncased_L-24_H-1024_A-16.zip"), - "albert-base": ("albert", "AlbertConfig", "", "https://storage.googleapis.com/albert_models/albert_base_v1.tar.gz"), - "albert-large": - ("albert", "AlbertConfig", "", "https://storage.googleapis.com/albert_models/albert_large_v1.tar.gz"), - "gpt-2-117M": ("gpt2", "GPT2Config", "GPT2Model", "https://storage.googleapis.com/gpt-2/models/117M"), - "gpt-2-124M": ("gpt2", "GPT2Config", "GPT2Model", "https://storage.googleapis.com/gpt-2/models/124M") + "bert-base-uncased": ( + "bert", + "BertConfig", + "", + "https://storage.googleapis.com/bert_models/2018_10_18/uncased_L-12_H-768_A-12.zip", + ), + "bert-base-cased": ( + "bert", + "BertConfig", + "", + "https://storage.googleapis.com/bert_models/2019_05_30/wwm_cased_L-24_H-1024_A-16.zip", + ), + "bert-large-uncased": ( + "bert", + "BertConfig", + "", + "https://storage.googleapis.com/bert_models/2018_10_18/uncased_L-24_H-1024_A-16.zip", + ), + "albert-base": ( + "albert", + "AlbertConfig", + "", + "https://storage.googleapis.com/albert_models/albert_base_v1.tar.gz", + ), + "albert-large": ( + "albert", + "AlbertConfig", + "", + "https://storage.googleapis.com/albert_models/albert_large_v1.tar.gz", + ), + "gpt-2-117M": ( + "gpt2", + "GPT2Config", + "GPT2Model", + "https://storage.googleapis.com/gpt-2/models/117M", + ), + "gpt-2-124M": ( + "gpt2", + "GPT2Config", + "GPT2Model", + "https://storage.googleapis.com/gpt-2/models/124M", + ), } @@ -26,7 +58,7 @@ def download_compressed_file(tf_ckpt_url, ckpt_dir): r = requests.get(tf_ckpt_url) compressed_file_name = tf_ckpt_url.split("/")[-1] compressed_file_dir = os.path.join(ckpt_dir, compressed_file_name) - with open(compressed_file_dir, 'wb') as f: + with open(compressed_file_dir, "wb") as f: f.write(r.content) return compressed_file_dir @@ -40,13 +72,14 @@ def get_ckpt_prefix_path(ckpt_dir): if os.path.isfile(sub_folder_dir): sub_folder_dir = ckpt_dir unique_file_name = str(glob.glob(sub_folder_dir + "/*data-00000-of-00001")) - prefix = (unique_file_name.rpartition('.')[0]).split("/")[-1] + prefix = (unique_file_name.rpartition(".")[0]).split("/")[-1] return os.path.join(sub_folder_dir, prefix) def download_tf_checkpoint(model_name, tf_models_dir="tf_models"): import pathlib + base_dir = os.path.join(pathlib.Path(__file__).parent.absolute(), tf_models_dir) ckpt_dir = os.path.join(base_dir, model_name) @@ -56,32 +89,40 @@ def download_tf_checkpoint(model_name, tf_models_dir="tf_models"): tf_ckpt_url = TFMODELS[model_name][3] import re - if (re.search('.zip$', tf_ckpt_url) != None): + + if re.search(".zip$", tf_ckpt_url) != None: zip_dir = download_compressed_file(tf_ckpt_url, ckpt_dir) # unzip file import zipfile - with zipfile.ZipFile(zip_dir, 'r') as zip_ref: + + with zipfile.ZipFile(zip_dir, "r") as zip_ref: zip_ref.extractall(ckpt_dir) os.remove(zip_dir) return get_ckpt_prefix_path(ckpt_dir) - elif (re.search('.tar.gz$', tf_ckpt_url) != None): + elif re.search(".tar.gz$", tf_ckpt_url) != None: tar_dir = download_compressed_file(tf_ckpt_url, ckpt_dir) # untar file import tarfile - with tarfile.open(tar_dir, 'r') as tar_ref: + + with tarfile.open(tar_dir, "r") as tar_ref: tar_ref.extractall(ckpt_dir) os.remove(tar_dir) return get_ckpt_prefix_path(ckpt_dir) else: - for filename in ['checkpoint', 'model.ckpt.data-00000-of-00001', 'model.ckpt.index', 'model.ckpt.meta']: + for filename in [ + "checkpoint", + "model.ckpt.data-00000-of-00001", + "model.ckpt.index", + "model.ckpt.meta", + ]: r = requests.get(tf_ckpt_url + "/" + filename) - with open(os.path.join(ckpt_dir, filename), 'wb') as f: + with open(os.path.join(ckpt_dir, filename), "wb") as f: f.write(r.content) return get_ckpt_prefix_path(ckpt_dir) @@ -92,12 +133,13 @@ def init_pytorch_model(model_name, tf_checkpoint_path): config_module = __import__("transformers", fromlist=[config_name]) model_config = getattr(config_module, config_name) - parent_path = tf_checkpoint_path.rpartition('/')[0] + parent_path = tf_checkpoint_path.rpartition("/")[0] config_path = glob.glob(parent_path + "/*config.json") config = model_config() if len(config_path) == 0 else model_config.from_json_file(str(config_path[0])) if TFMODELS[model_name][2] == "": from transformers import AutoModelForPreTraining + init_model = AutoModelForPreTraining.from_config(config) else: model_categroy_name = TFMODELS[model_name][2] @@ -118,11 +160,15 @@ def convert_tf_checkpoint_to_pytorch(model_name, config, init_model, tf_checkpoi if TFMODELS[model_name][0] != "bert": raise NotImplementedError("Only support tf2 ckeckpoint for Bert model") from transformers import convert_bert_original_tf2_checkpoint_to_pytorch + load_tf_weight_func = convert_bert_original_tf2_checkpoint_to_pytorch.load_tf2_weights_in_bert # Expect transformers team will unify the order of signature in the future - model = load_tf_weight_func(init_model, config, tf_checkpoint_path) if is_tf2 is False else load_tf_weight_func( - init_model, tf_checkpoint_path, config) + model = ( + load_tf_weight_func(init_model, config, tf_checkpoint_path) + if is_tf2 is False + else load_tf_weight_func(init_model, tf_checkpoint_path, config) + ) model.eval() return model @@ -140,11 +186,13 @@ def tf2pt_pipeline(model_name, is_tf2=False): def tf2pt_pipeline_test(): # For test on linux only import logging + import torch - logger = logging.getLogger('') + + logger = logging.getLogger("") for model_name in TFMODELS.keys(): config, model = tf2pt_pipeline(model_name) - assert (config.model_type is TFMODELS[model_name][0]) + assert config.model_type is TFMODELS[model_name][0] input = torch.randint(low=0, high=config.vocab_size - 1, size=(4, 128), dtype=torch.long) try: @@ -153,5 +201,5 @@ def tf2pt_pipeline_test(): logger.exception(e) -if __name__ == '__main__': +if __name__ == "__main__": tf2pt_pipeline_test() diff --git a/onnxruntime/python/tools/transformers/float16.py b/onnxruntime/python/tools/transformers/float16.py index b5b26c0b046a7..634548de0d0b1 100644 --- a/onnxruntime/python/tools/transformers/float16.py +++ b/onnxruntime/python/tools/transformers/float16.py @@ -1,48 +1,48 @@ -#------------------------------------------------------------------------- +# ------------------------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. -#-------------------------------------------------------------------------- +# -------------------------------------------------------------------------- # This file is modified from https://github.com/microsoft/onnxconverter-common/blob/master/onnxconverter_common/float16.py # Modifications: keep_io_types can be list of names; convert initializers if needed to preserve precision; add force_fp16_initializers option. import itertools +import logging +from typing import Dict, List + import numpy as np import onnx from onnx import helper, numpy_helper from onnx import onnx_pb as onnx_proto -from typing import List, Dict - -import logging logger = logging.getLogger(__name__) def _npfloat16_to_int(np_list): - ''' + """ Convert numpy float16 to python int. :param np_list: numpy float16 list :return int_list: python int list - ''' - return [int(bin(_.view('H'))[2:].zfill(16), 2) for _ in np_list] + """ + return [int(bin(_.view("H"))[2:].zfill(16), 2) for _ in np_list] def convert_np_to_float16(np_array, min_positive_val=5.96e-08, max_finite_val=65504.0): - ''' + """ Convert float32 numpy array to float16 without changing sign or finiteness. Positive values less than min_positive_val are mapped to min_positive_val. Positive finite values greater than max_finite_val are mapped to max_finite_val. Similar for negative values. NaN, 0, inf, and -inf are unchanged. - ''' + """ def between(a, b, c): return np.logical_and(a < b, b < c) np_array = np.where(between(0, np_array, min_positive_val), min_positive_val, np_array) np_array = np.where(between(-min_positive_val, np_array, 0), -min_positive_val, np_array) - np_array = np.where(between(max_finite_val, np_array, float('inf')), max_finite_val, np_array) - np_array = np.where(between(float('-inf'), np_array, -max_finite_val), -max_finite_val, np_array) + np_array = np.where(between(max_finite_val, np_array, float("inf")), max_finite_val, np_array) + np_array = np.where(between(float("-inf"), np_array, -max_finite_val), -max_finite_val, np_array) return np.float16(np_array) @@ -62,7 +62,7 @@ def convert_tensor_float_to_float16(tensor, min_positive_val=5.96e-08, max_finit """ if not isinstance(tensor, onnx_proto.TensorProto): - raise ValueError('Expected input type is an ONNX TensorProto but got %s' % type(tensor)) + raise ValueError("Expected input type is an ONNX TensorProto but got %s" % type(tensor)) if tensor.data_type == onnx_proto.TensorProto.FLOAT: tensor.data_type = onnx_proto.TensorProto.FLOAT16 @@ -75,7 +75,7 @@ def convert_tensor_float_to_float16(tensor, min_positive_val=5.96e-08, max_finit # convert raw_data (bytes type) if tensor.raw_data: # convert n.raw_data to float - float32_list = np.fromstring(tensor.raw_data, dtype='float32') + float32_list = np.fromstring(tensor.raw_data, dtype="float32") # convert float to float16 float16_list = convert_np_to_float16(float32_list, min_positive_val, max_finite_val) # convert float16 to bytes and write back to raw_data @@ -89,10 +89,33 @@ def make_value_info_from_tensor(tensor): DEFAULT_OP_BLOCK_LIST = [ - 'ArrayFeatureExtractor', 'Binarizer', 'CastMap', 'CategoryMapper', 'DictVectorizer', 'FeatureVectorizer', 'Imputer', - 'LabelEncoder', 'LinearClassifier', 'LinearRegressor', 'Normalizer', 'OneHotEncoder', 'SVMClassifier', - 'SVMRegressor', 'Scaler', 'TreeEnsembleClassifier', 'TreeEnsembleRegressor', 'ZipMap', 'NonMaxSuppression', 'TopK', - 'RoiAlign', 'Resize', 'Range', 'CumSum', 'Min', 'Max', 'Upsample' + "ArrayFeatureExtractor", + "Binarizer", + "CastMap", + "CategoryMapper", + "DictVectorizer", + "FeatureVectorizer", + "Imputer", + "LabelEncoder", + "LinearClassifier", + "LinearRegressor", + "Normalizer", + "OneHotEncoder", + "SVMClassifier", + "SVMRegressor", + "Scaler", + "TreeEnsembleClassifier", + "TreeEnsembleRegressor", + "ZipMap", + "NonMaxSuppression", + "TopK", + "RoiAlign", + "Resize", + "Range", + "CumSum", + "Min", + "Max", + "Upsample", ] @@ -111,14 +134,16 @@ def add_node(self, node: onnx_proto.NodeProto, is_node_blocked): self.fp16_nodes.append(node) -def convert_float_to_float16(model, - min_positive_val=5.96e-08, - max_finite_val=65504.0, - keep_io_types=False, - disable_shape_infer=False, - op_block_list=None, - node_block_list=None, - force_fp16_initializers=False): +def convert_float_to_float16( + model, + min_positive_val=5.96e-08, + max_finite_val=65504.0, + keep_io_types=False, + disable_shape_infer=False, + op_block_list=None, + node_block_list=None, + force_fp16_initializers=False, +): """Convert model tensor float type in the ONNX ModelProto input to tensor float16. Args: @@ -139,19 +164,22 @@ def convert_float_to_float16(model, Returns: ModelProto: converted model. """ - assert min_positive_val >= 5.96e-08, "invalid min_positive_val. smallest positive float16 value: subnormal 5.96e-08, and normalized 6.104e-05" + assert ( + min_positive_val >= 5.96e-08 + ), "invalid min_positive_val. smallest positive float16 value: subnormal 5.96e-08, and normalized 6.104e-05" assert max_finite_val <= float(np.finfo(np.float16).max), "invalid max_finite_val. largest float16 value: 65504" func_infer_shape = None - if not disable_shape_infer and onnx.__version__ >= '1.2': + if not disable_shape_infer and onnx.__version__ >= "1.2": try: from onnx.shape_inference import infer_shapes + func_infer_shape = infer_shapes finally: pass if not isinstance(model, onnx_proto.ModelProto): - raise ValueError('Expected model type is an ONNX ModelProto but got %s' % type(model)) + raise ValueError("Expected model type is an ONNX ModelProto but got %s" % type(model)) # create blocklists if op_block_list is None: @@ -188,34 +216,34 @@ def convert_float_to_float16(model, for i, n in enumerate(model.graph.input): if n.name in fp32_inputs: - output_name = 'graph_input_cast_' + str(i) + output_name = "graph_input_cast_" + str(i) name_mapping[n.name] = output_name graph_io_to_skip.add(n.name) - node_name = 'graph_input_cast' + str(i) + node_name = "graph_input_cast" + str(i) new_value_info = model.graph.value_info.add() new_value_info.CopyFrom(n) new_value_info.name = output_name new_value_info.type.tensor_type.elem_type = onnx_proto.TensorProto.FLOAT16 # add Cast node (from tensor(float) to tensor(float16) after graph input - new_node = [helper.make_node('Cast', [n.name], [output_name], to=10, name=node_name)] + new_node = [helper.make_node("Cast", [n.name], [output_name], to=10, name=node_name)] model.graph.node.extend(new_node) value_info_list.append(new_value_info) io_casts.add(node_name) for i, n in enumerate(model.graph.output): if n.name in fp32_outputs: - input_name = 'graph_output_cast_' + str(i) + input_name = "graph_output_cast_" + str(i) name_mapping[n.name] = input_name graph_io_to_skip.add(n.name) - node_name = 'graph_output_cast' + str(i) + node_name = "graph_output_cast" + str(i) # add Cast node (from tensor(float16) to tensor(float) before graph output new_value_info = model.graph.value_info.add() new_value_info.CopyFrom(n) new_value_info.name = input_name new_value_info.type.tensor_type.elem_type = onnx_proto.TensorProto.FLOAT16 - new_node = [helper.make_node('Cast', [input_name], [n.name], to=1, name=node_name)] + new_node = [helper.make_node("Cast", [input_name], [n.name], to=1, name=node_name)] model.graph.node.extend(new_node) value_info_list.append(new_value_info) io_casts.add(node_name) @@ -254,9 +282,9 @@ def convert_float_to_float16(model, if is_node_blocked: node_list.append(n) else: - if n.op_type == 'Cast': + if n.op_type == "Cast": for attr in n.attribute: - if attr.name == 'to' and attr.i == 1: + if attr.name == "to" and attr.i == 1: attr.i = 10 break for attr in n.attribute: @@ -280,12 +308,12 @@ def convert_float_to_float16(model, if n.name not in graph_io_to_skip: n.type.tensor_type.elem_type = onnx_proto.TensorProto.FLOAT16 value_info_list.append(n) - if n.type.HasField('sequence_type'): + if n.type.HasField("sequence_type"): if n.type.sequence_type.elem_type.tensor_type.elem_type == onnx_proto.TensorProto.FLOAT: if n.name not in graph_io_to_skip: n.type.sequence_type.elem_type.tensor_type.elem_type = onnx_proto.TensorProto.FLOAT16 value_info_list.append(n) - + queue = next_level for key, value in fp32_initializers.items(): @@ -296,7 +324,9 @@ def convert_float_to_float16(model, if value.fp32_nodes and not force_fp16_initializers: logger.info( "initializer is used by both fp32 and fp16 nodes. Consider add these nodes to block list:{}".format( - value.fp16_nodes)) + value.fp16_nodes + ) + ) # process the nodes in block list that doesn't support tensor(float16) for node in node_list: @@ -310,12 +340,12 @@ def convert_float_to_float16(model, # create new value_info for current node's new input name new_value_info = model.graph.value_info.add() new_value_info.CopyFrom(value_info) - output_name = node.name + '_input_cast_' + str(i) + output_name = node.name + "_input_cast_" + str(i) new_value_info.name = output_name new_value_info.type.tensor_type.elem_type = onnx_proto.TensorProto.FLOAT # add Cast node (from tensor(float16) to tensor(float) before current node - node_name = node.name + '_input_cast' + str(i) - new_node = [helper.make_node('Cast', [input], [output_name], to=1, name=node_name)] + node_name = node.name + "_input_cast" + str(i) + new_node = [helper.make_node("Cast", [input], [output_name], to=1, name=node_name)] model.graph.node.extend(new_node) # change current node's input name node.input[i] = output_name @@ -329,12 +359,12 @@ def convert_float_to_float16(model, # create new value_info for current node's new output new_value_info = model.graph.value_info.add() new_value_info.CopyFrom(value_info) - input_name = node.name + '_output_cast_' + str(i) + input_name = node.name + "_output_cast_" + str(i) new_value_info.name = input_name new_value_info.type.tensor_type.elem_type = onnx_proto.TensorProto.FLOAT # add Cast node (from tensor(float) to tensor(float16) after current node - node_name = node.name + '_output_cast' + str(i) - new_node = [helper.make_node('Cast', [input_name], [output], to=10, name=node_name)] + node_name = node.name + "_output_cast" + str(i) + new_node = [helper.make_node("Cast", [input_name], [output], to=10, name=node_name)] model.graph.node.extend(new_node) # change current node's input name node.output[i] = input_name @@ -345,15 +375,15 @@ def convert_float_to_float16(model, def float_to_float16_max_diff(tensor, min_positive_val=5.96e-08, max_finite_val=65504.0): """Measure the maximum absolute difference after converting a float tensor to float16.""" if not isinstance(tensor, onnx_proto.TensorProto): - raise ValueError('Expected input type is an ONNX TensorProto but got %s' % type(tensor)) + raise ValueError("Expected input type is an ONNX TensorProto but got %s" % type(tensor)) if tensor.data_type != onnx_proto.TensorProto.FLOAT: - raise ValueError('Expected tensor data type is float.') + raise ValueError("Expected tensor data type is float.") if tensor.float_data: float32_data = np.array(tensor.float_data) if tensor.raw_data: - float32_data = np.fromstring(tensor.raw_data, dtype='float32') + float32_data = np.fromstring(tensor.raw_data, dtype="float32") float16_data = convert_np_to_float16(float32_data, min_positive_val, max_finite_val) return np.amax(np.abs(float32_data - np.float32(float16_data))) diff --git a/onnxruntime/python/tools/transformers/fusion_attention.py b/onnxruntime/python/tools/transformers/fusion_attention.py index ecb32b7feebcc..936e83d8bb2f7 100644 --- a/onnxruntime/python/tools/transformers/fusion_attention.py +++ b/onnxruntime/python/tools/transformers/fusion_attention.py @@ -1,27 +1,29 @@ -#------------------------------------------------------------------------- +# ------------------------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. -#-------------------------------------------------------------------------- +# -------------------------------------------------------------------------- +from enum import Enum +from logging import getLogger from os import name from sys import path -import numpy as np -from logging import getLogger -from enum import Enum from typing import Tuple, Union -from onnx import helper, numpy_helper, TensorProto, NodeProto -from onnx_model import OnnxModel + +import numpy as np from fusion_base import Fusion -from fusion_utils import FusionUtils, NumpyHelper from fusion_options import AttentionMaskFormat +from fusion_utils import FusionUtils, NumpyHelper +from onnx import NodeProto, TensorProto, helper, numpy_helper +from onnx_model import OnnxModel from shape_infer_helper import SymbolicShapeInferenceHelper, get_shape_from_type_proto logger = getLogger(__name__) -class AttentionMask(): +class AttentionMask: """ Fuse Attention subgraph into one Attention node. """ + def __init__(self, model: OnnxModel): self.model = model # A lookup table with mask input as key, and mask index output as value @@ -66,11 +68,13 @@ def process_mask(self, input: str) -> str: return input_name # Add a mask processing node to convert attention mask to mask index (1D) - output_name = self.model.create_node_name('mask_index') - mask_index_node = helper.make_node('ReduceSum', - inputs=[input_name], - outputs=[output_name], - name=self.model.create_node_name('ReduceSum', 'MaskReduceSum')) + output_name = self.model.create_node_name("mask_index") + mask_index_node = helper.make_node( + "ReduceSum", + inputs=[input_name], + outputs=[output_name], + name=self.model.create_node_name("ReduceSum", "MaskReduceSum"), + ) mask_index_node.attribute.extend([helper.make_attribute("axes", [1]), helper.make_attribute("keepdims", 0)]) self.model.add_node(mask_index_node) @@ -82,7 +86,14 @@ class FusionAttention(Fusion): """ Fuse Attention subgraph into one Attention node. """ - def __init__(self, model: OnnxModel, hidden_size: int, num_heads: int, attention_mask: AttentionMask): + + def __init__( + self, + model: OnnxModel, + hidden_size: int, + num_heads: int, + attention_mask: AttentionMask, + ): super().__init__(model, "Attention", ["SkipLayerNormalization", "LayerNormalization"]) self.hidden_size = hidden_size self.num_heads = num_heads @@ -93,7 +104,7 @@ def __init__(self, model: OnnxModel, hidden_size: int, num_heads: int, attention self.hidden_size_warning = True def get_num_heads_and_hidden_size(self, reshape_q: NodeProto) -> Tuple[int, int]: - """ Detect num_heads and hidden_size from a reshape node. + """Detect num_heads and hidden_size from a reshape node. Args: reshape_q (NodeProto): reshape node for Q @@ -125,7 +136,8 @@ def get_num_heads_and_hidden_size(self, reshape_q: NodeProto) -> Tuple[int, int] if self.hidden_size > 0 and hidden_size != self.hidden_size: if self.hidden_size_warning: logger.warning( - f"--hidden_size is {self.hidden_size}. Detected value is {hidden_size}. Using detected value.") + f"--hidden_size is {self.hidden_size}. Detected value is {hidden_size}. Using detected value." + ) self.hidden_size_warning = False # Do not show the warning more than once return num_heads, hidden_size @@ -148,10 +160,22 @@ def get_add_qk_str(self, add_qk: NodeProto): return add_qk.input[1] - def create_attention_node(self, mask_index: str, q_matmul: NodeProto, k_matmul: NodeProto, v_matmul: NodeProto, - q_add: NodeProto, k_add: NodeProto, v_add: NodeProto, num_heads: int, hidden_size: int, - input: str, output: str, add_qk_str: str) -> Union[NodeProto, None]: - """ Create an Attention node. + def create_attention_node( + self, + mask_index: str, + q_matmul: NodeProto, + k_matmul: NodeProto, + v_matmul: NodeProto, + q_add: NodeProto, + k_add: NodeProto, + v_add: NodeProto, + num_heads: int, + hidden_size: int, + input: str, + output: str, + add_qk_str: str, + ) -> Union[NodeProto, None]: + """Create an Attention node. Args: mask_index (str): mask input @@ -244,27 +268,35 @@ def create_attention_node(self, mask_index: str, q_matmul: NodeProto, k_matmul: qkv_bias = np.stack((qb, kb, vb), axis=0) qkv_bias_dim = 3 * q_bias_shape - attention_node_name = self.model.create_node_name('Attention') + attention_node_name = self.model.create_node_name("Attention") - weight = helper.make_tensor(name=attention_node_name + '_qkv_weight', - data_type=TensorProto.FLOAT, - dims=[qw_in_size, qkv_weight_dim], - vals=qkv_weight.flatten().tolist()) + weight = helper.make_tensor( + name=attention_node_name + "_qkv_weight", + data_type=TensorProto.FLOAT, + dims=[qw_in_size, qkv_weight_dim], + vals=qkv_weight.flatten().tolist(), + ) # Sometimes weights and bias are stored in fp16 if q_weight.data_type == 10: weight.CopyFrom(numpy_helper.from_array(NumpyHelper.to_array(weight).astype(np.float16), weight.name)) self.model.add_initializer(weight, self.this_graph_name) - bias = helper.make_tensor(name=attention_node_name + '_qkv_bias', - data_type=TensorProto.FLOAT, - dims=[qkv_bias_dim], - vals=qkv_bias.flatten().tolist()) + bias = helper.make_tensor( + name=attention_node_name + "_qkv_bias", + data_type=TensorProto.FLOAT, + dims=[qkv_bias_dim], + vals=qkv_bias.flatten().tolist(), + ) if q_bias.data_type == 10: bias.CopyFrom(numpy_helper.from_array(NumpyHelper.to_array(bias).astype(np.float16), bias.name)) self.model.add_initializer(bias, self.this_graph_name) - attention_inputs = [input, attention_node_name + '_qkv_weight', attention_node_name + '_qkv_bias'] + attention_inputs = [ + input, + attention_node_name + "_qkv_weight", + attention_node_name + "_qkv_bias", + ] if mask_index is not None: attention_inputs.append(mask_index) else: @@ -274,16 +306,19 @@ def create_attention_node(self, mask_index: str, q_matmul: NodeProto, k_matmul: attention_inputs.append("") attention_inputs.append(add_qk_str) - attention_node = helper.make_node('Attention', - inputs=attention_inputs, - outputs=[output], - name=attention_node_name) + attention_node = helper.make_node( + "Attention", + inputs=attention_inputs, + outputs=[output], + name=attention_node_name, + ) attention_node.domain = "com.microsoft" attention_node.attribute.extend([helper.make_attribute("num_heads", num_heads)]) if is_qkv_diff_dims: attention_node.attribute.extend( - [helper.make_attribute("qkv_hidden_sizes", [qw_out_size, kw_out_size, vw_out_size])]) + [helper.make_attribute("qkv_hidden_sizes", [qw_out_size, kw_out_size, vw_out_size])] + ) return attention_node @@ -291,23 +326,27 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): # Sometimes we can not fuse skiplayernormalization since the add before layernorm has an output that used by nodes outside skiplayernorm # Conceptually we treat add before layernorm as skiplayernorm node since they share the same pattern start_node = normalize_node - if normalize_node.op_type == 'LayerNormalization': - add_before_layernorm = self.model.match_parent(normalize_node, 'Add', 0) + if normalize_node.op_type == "LayerNormalization": + add_before_layernorm = self.model.match_parent(normalize_node, "Add", 0) if add_before_layernorm is not None: start_node = add_before_layernorm else: return # SkipLayerNormalization has two inputs, and one of them is the root input for attention. - qkv_nodes = self.model.match_parent_path(start_node, ['Add', 'MatMul', 'Reshape', 'Transpose', 'MatMul'], - [None, None, 0, 0, 0]) + qkv_nodes = self.model.match_parent_path( + start_node, + ["Add", "MatMul", "Reshape", "Transpose", "MatMul"], + [None, None, 0, 0, 0], + ) einsum_node = None if qkv_nodes is not None: (_, matmul_qkv, reshape_qkv, transpose_qkv, matmul_qkv) = qkv_nodes else: # Match Albert - qkv_nodes = self.model.match_parent_path(start_node, ['Add', 'Einsum', 'Transpose', 'MatMul'], - [1, None, 0, 0]) + qkv_nodes = self.model.match_parent_path( + start_node, ["Add", "Einsum", "Transpose", "MatMul"], [1, None, 0, 0] + ) if qkv_nodes is not None: (_, einsum_node, transpose_qkv, matmul_qkv) = qkv_nodes else: @@ -333,12 +372,12 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): | | +--------------------------------------------------------- """ - mul_before_layernorm = self.model.match_parent(start_node, 'Mul', 0) + mul_before_layernorm = self.model.match_parent(start_node, "Mul", 0) if mul_before_layernorm is not None: mul_children = input_name_to_nodes[mul_before_layernorm.output[0]] if mul_children is not None and len(mul_children) == 2: layernorm_node = mul_children[1] - if layernorm_node.op_type == 'LayerNormalization': + if layernorm_node.op_type == "LayerNormalization": root_input = layernorm_node.output[0] else: return @@ -346,7 +385,7 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): root_input = mul_before_layernorm.output[0] else: return - elif normalize_node.op_type == 'LayerNormalization': + elif normalize_node.op_type == "LayerNormalization": children = input_name_to_nodes[root_input] for child in children: if child.op_type == "LayerNormalization": @@ -354,10 +393,10 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): children = input_name_to_nodes[root_input] children_types = [child.op_type for child in children] - if children_types.count('MatMul') != 3: + if children_types.count("MatMul") != 3: return - v_nodes = self.model.match_parent_path(matmul_qkv, ['Transpose', 'Reshape', 'Add', 'MatMul'], [1, 0, 0, None]) + v_nodes = self.model.match_parent_path(matmul_qkv, ["Transpose", "Reshape", "Add", "MatMul"], [1, 0, 0, None]) if v_nodes is None: logger.debug("fuse_attention: failed to match v path") return @@ -366,10 +405,10 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): is_distill = False is_distill_add = False qk_paths = { - "path1": (['Softmax', 'Add', 'Div', 'MatMul'], [0, 0, None, 0]), - "path2": (['Softmax', 'Add', 'Mul', 'MatMul'], [0, 0, None, 0]), - "path3": (['Softmax', 'Where', 'MatMul', 'Div'], [0, 0, 2, 0]), - "path4": (['Softmax', 'Add', 'Where', 'MatMul'], [0, 0, 0, 2]) + "path1": (["Softmax", "Add", "Div", "MatMul"], [0, 0, None, 0]), + "path2": (["Softmax", "Add", "Mul", "MatMul"], [0, 0, None, 0]), + "path3": (["Softmax", "Where", "MatMul", "Div"], [0, 0, 2, 0]), + "path4": (["Softmax", "Add", "Where", "MatMul"], [0, 0, 0, 2]), } qk_nodes = None @@ -397,10 +436,13 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): else: (_, add_qk, _, matmul_qk) = qk_nodes - q_nodes = self.model.match_parent_path(matmul_qk, ['Transpose', 'Reshape', 'Add', 'MatMul'], [0, 0, 0, None]) + q_nodes = self.model.match_parent_path(matmul_qk, ["Transpose", "Reshape", "Add", "MatMul"], [0, 0, 0, None]) if q_nodes is None: - q_nodes = self.model.match_parent_path(matmul_qk, ['Div', 'Transpose', 'Reshape', 'Add', 'MatMul'], - [0, 0, 0, 0, None]) + q_nodes = self.model.match_parent_path( + matmul_qk, + ["Div", "Transpose", "Reshape", "Add", "MatMul"], + [0, 0, 0, 0, None], + ) if q_nodes is None: logger.debug("fuse_attention: failed to match q path") return @@ -408,10 +450,13 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): add_q = q_nodes[-2] matmul_q = q_nodes[-1] - k_nodes = self.model.match_parent_path(matmul_qk, ['Transpose', 'Reshape', 'Add', 'MatMul'], [1, 0, 0, None]) + k_nodes = self.model.match_parent_path(matmul_qk, ["Transpose", "Reshape", "Add", "MatMul"], [1, 0, 0, None]) if k_nodes is None: - k_nodes = self.model.match_parent_path(matmul_qk, ['Transpose', 'Transpose', 'Reshape', 'Add', 'MatMul'], - [1, 0, 0, 0, None]) + k_nodes = self.model.match_parent_path( + matmul_qk, + ["Transpose", "Transpose", "Reshape", "Add", "MatMul"], + [1, 0, 0, 0, None], + ) if k_nodes is None: logger.debug("fuse_attention: failed to match k path") return @@ -422,15 +467,24 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): mask_nodes = None add_qk_str = None if is_distill: - _, mask_nodes, _ = self.model.match_parent_paths(where_qk, - [(['Expand', 'Reshape', 'Equal'], [0, 0, 0]), - (['Equal', 'Unsqueeze', 'Unsqueeze'], [0, 0, 0]), - (['Cast', 'Expand', 'Reshape', 'Equal'], [0, 0, 0, 0])], - output_name_to_node) + _, mask_nodes, _ = self.model.match_parent_paths( + where_qk, + [ + (["Expand", "Reshape", "Equal"], [0, 0, 0]), + (["Equal", "Unsqueeze", "Unsqueeze"], [0, 0, 0]), + (["Cast", "Expand", "Reshape", "Equal"], [0, 0, 0, 0]), + ], + output_name_to_node, + ) elif is_distill_add: _, mask_nodes, _ = self.model.match_parent_paths( - where_qk, [(['Cast', 'Equal', 'Unsqueeze', 'Unsqueeze'], [0, 0, 0, 0]), - (['Equal', 'Unsqueeze', 'Unsqueeze'], [0, 0, 0])], output_name_to_node) + where_qk, + [ + (["Cast", "Equal", "Unsqueeze", "Unsqueeze"], [0, 0, 0, 0]), + (["Equal", "Unsqueeze", "Unsqueeze"], [0, 0, 0]), + ], + output_name_to_node, + ) if add_qk is not None: add_qk_str = self.get_add_qk_str(add_qk) if add_qk_str is None: @@ -438,8 +492,16 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): return else: _, mask_nodes, _ = self.model.match_parent_paths( - add_qk, [(['Mul', 'Sub', 'Cast', 'Unsqueeze', 'Unsqueeze'], [None, 0, 1, 0, 0]), - (['Mul', 'Sub', 'Unsqueeze', 'Unsqueeze'], [None, 0, 1, 0])], output_name_to_node) + add_qk, + [ + ( + ["Mul", "Sub", "Cast", "Unsqueeze", "Unsqueeze"], + [None, 0, 1, 0, 0], + ), + (["Mul", "Sub", "Unsqueeze", "Unsqueeze"], [None, 0, 1, 0]), + ], + output_name_to_node, + ) if mask_nodes is None: logger.debug("fuse_attention: failed to match mask path") return @@ -452,9 +514,20 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): q_num_heads, q_hidden_size = self.get_num_heads_and_hidden_size(reshape_q) # number of heads are same for all the paths, hence to create attention node, we pass the q_num_heads # the input_hidden_size represents the input hidden size, this is used as needed but hidden sizes for Q, K are extracted appropriately - new_node = self.create_attention_node(mask_index, matmul_q, matmul_k, matmul_v, add_q, add_k, add_v, - q_num_heads, q_hidden_size, root_input, attention_last_node.output[0], - add_qk_str) + new_node = self.create_attention_node( + mask_index, + matmul_q, + matmul_k, + matmul_v, + add_q, + add_k, + add_v, + q_num_heads, + q_hidden_size, + root_input, + attention_last_node.output[0], + add_qk_str, + ) if new_node is None: return @@ -464,16 +537,23 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): if einsum_node is not None: unique_index = einsum_node.input[0] new_edge = "edge_modified_" + unique_index - shape_tensor = helper.make_tensor(name="shape_modified_tensor" + unique_index, - data_type=TensorProto.INT64, - dims=[4], - vals=np.int64([0, 0, q_num_heads, - int(q_hidden_size / q_num_heads)]).tobytes(), - raw=True) + shape_tensor = helper.make_tensor( + name="shape_modified_tensor" + unique_index, + data_type=TensorProto.INT64, + dims=[4], + vals=np.int64([0, 0, q_num_heads, int(q_hidden_size / q_num_heads)]).tobytes(), + raw=True, + ) self.model.add_initializer(shape_tensor, self.this_graph_name) self.model.add_node( - helper.make_node("Reshape", [attention_last_node.output[0], shape_tensor.name], [new_edge], - "reshape_modified_" + unique_index), self.this_graph_name) + helper.make_node( + "Reshape", + [attention_last_node.output[0], shape_tensor.name], + [new_edge], + "reshape_modified_" + unique_index, + ), + self.this_graph_name, + ) einsum_node.input[0] = new_edge self.nodes_to_remove.extend([attention_last_node, transpose_qkv, matmul_qkv]) @@ -483,5 +563,5 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): self.nodes_to_remove.extend(v_nodes) # Use prune graph to remove mask nodes since they are shared by all attention nodes. - #self.nodes_to_remove.extend(mask_nodes) + # self.nodes_to_remove.extend(mask_nodes) self.prune_graph = True diff --git a/onnxruntime/python/tools/transformers/fusion_base.py b/onnxruntime/python/tools/transformers/fusion_base.py index 0469e90f48a59..31b3399e27768 100644 --- a/onnxruntime/python/tools/transformers/fusion_base.py +++ b/onnxruntime/python/tools/transformers/fusion_base.py @@ -1,21 +1,24 @@ -#------------------------------------------------------------------------- +# ------------------------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. -#-------------------------------------------------------------------------- +# -------------------------------------------------------------------------- from logging import getLogger -from onnx_model import OnnxModel -from typing import Union, List +from typing import List, Union + from onnx import GraphProto +from onnx_model import OnnxModel logger = getLogger(__name__) class Fusion: - def __init__(self, - model: OnnxModel, - fused_op_type: str, - search_op_types: Union[str, List[str]], - description: str = None): + def __init__( + self, + model: OnnxModel, + fused_op_type: str, + search_op_types: Union[str, List[str]], + description: str = None, + ): self.search_op_types: List[str] = [search_op_types] if isinstance(search_op_types, str) else search_op_types self.fused_op_type: str = fused_op_type self.description: str = f"{fused_op_type}({description})" if description else fused_op_type diff --git a/onnxruntime/python/tools/transformers/fusion_biasgelu.py b/onnxruntime/python/tools/transformers/fusion_biasgelu.py index cfdf00b22fd0d..7fdb1b7d86c52 100644 --- a/onnxruntime/python/tools/transformers/fusion_biasgelu.py +++ b/onnxruntime/python/tools/transformers/fusion_biasgelu.py @@ -1,13 +1,14 @@ -#------------------------------------------------------------------------- +# ------------------------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. -#-------------------------------------------------------------------------- +# -------------------------------------------------------------------------- from logging import getLogger -from onnx import helper -from onnx_model import OnnxModel + from fusion_base import Fusion from fusion_utils import NumpyHelper +from onnx import helper +from onnx_model import OnnxModel logger = getLogger(__name__) @@ -15,18 +16,18 @@ class FusionBiasGelu(Fusion): def __init__(self, model: OnnxModel, is_fastgelu): if is_fastgelu: - super().__init__(model, 'FastGelu', 'FastGelu', 'add bias') + super().__init__(model, "FastGelu", "FastGelu", "add bias") else: - super().__init__(model, 'BiasGelu', 'Gelu') + super().__init__(model, "BiasGelu", "Gelu") def fuse(self, node, input_name_to_nodes, output_name_to_node): gelu_op_type = node.op_type - fuse_op_type = 'BiasGelu' if gelu_op_type == 'Gelu' else 'FastGelu' + fuse_op_type = "BiasGelu" if gelu_op_type == "Gelu" else "FastGelu" if len(node.input) != 1: return - nodes = self.model.match_parent_path(node, ['Add', 'MatMul'], [0, None]) + nodes = self.model.match_parent_path(node, ["Add", "MatMul"], [0, None]) if nodes is None: return (add, matmul) = nodes @@ -47,16 +48,19 @@ def fuse(self, node, input_name_to_nodes, output_name_to_node): return subgraph_nodes = [node, add] - if not self.model.is_safe_to_fuse_nodes(subgraph_nodes, [node.output[0]], input_name_to_nodes, - output_name_to_node): + if not self.model.is_safe_to_fuse_nodes( + subgraph_nodes, [node.output[0]], input_name_to_nodes, output_name_to_node + ): return self.nodes_to_remove.extend(subgraph_nodes) - fused_node = helper.make_node(fuse_op_type, - inputs=[matmul.output[0], add.input[bias_index]], - outputs=node.output, - name=self.model.create_node_name(fuse_op_type, gelu_op_type + "_AddBias_")) + fused_node = helper.make_node( + fuse_op_type, + inputs=[matmul.output[0], add.input[bias_index]], + outputs=node.output, + name=self.model.create_node_name(fuse_op_type, gelu_op_type + "_AddBias_"), + ) fused_node.domain = "com.microsoft" self.nodes_to_add.append(fused_node) self.node_name_to_graph_name[fused_node.name] = self.this_graph_name diff --git a/onnxruntime/python/tools/transformers/fusion_embedlayer.py b/onnxruntime/python/tools/transformers/fusion_embedlayer.py index 25cc133755847..867a151bef6ca 100644 --- a/onnxruntime/python/tools/transformers/fusion_embedlayer.py +++ b/onnxruntime/python/tools/transformers/fusion_embedlayer.py @@ -1,26 +1,32 @@ -#------------------------------------------------------------------------- +# ------------------------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. -#-------------------------------------------------------------------------- +# -------------------------------------------------------------------------- -from typing import Dict, List, Tuple, Union from logging import getLogger -from onnx import helper, TensorProto, NodeProto -from onnx_model import OnnxModel +from typing import Dict, List, Tuple, Union + from fusion_base import Fusion from fusion_utils import FusionUtils +from onnx import NodeProto, TensorProto, helper +from onnx_model import OnnxModel logger = getLogger(__name__) class FusionEmbedLayerNoMask(Fusion): """ - Fuse embedding layer into one node (EmbedLayerNormalization). - It supports the following model types: BERT, DistilBert, ALBert. + Fuse embedding layer into one node (EmbedLayerNormalization). + It supports the following model types: BERT, DistilBert, ALBert. """ - def __init__(self, model: OnnxModel, description: str = 'no mask'): - super().__init__(model, "EmbedLayerNormalization", ["LayerNormalization", "SkipLayerNormalization"], - description) + + def __init__(self, model: OnnxModel, description: str = "no mask"): + super().__init__( + model, + "EmbedLayerNormalization", + ["LayerNormalization", "SkipLayerNormalization"], + description, + ) self.utils = FusionUtils(model) self.shape_infer_helper = self.model.infer_runtime_shape({}, update=True) # The following will be reset in each fuse call of FusionEmbedLayerNormalization @@ -28,18 +34,22 @@ def __init__(self, model: OnnxModel, description: str = 'no mask'): self.embed_node = None def match_two_gather(self, add: NodeProto) -> Union[None, Tuple[NodeProto, NodeProto]]: - gather_0_path = self.model.match_parent_path(add, ['Gather'], [0]) + gather_0_path = self.model.match_parent_path(add, ["Gather"], [0]) if gather_0_path is None: return None - gather_1_path = self.model.match_parent_path(add, ['Gather'], [1]) + gather_1_path = self.model.match_parent_path(add, ["Gather"], [1]) if gather_1_path is None: return None return gather_0_path[0], gather_1_path[0] - def check_attention_subgraph(self, layernorm: NodeProto, input_name_to_nodes: Dict[str, List[NodeProto]], - is_distil_bert: bool) -> bool: + def check_attention_subgraph( + self, + layernorm: NodeProto, + input_name_to_nodes: Dict[str, List[NodeProto]], + is_distil_bert: bool, + ) -> bool: """Check that LayerNormalization has a child of Attention node or subgraph like Attention. Args: @@ -50,10 +60,9 @@ def check_attention_subgraph(self, layernorm: NodeProto, input_name_to_nodes: Di Returns: bool: whether there is Attention node or subgraph like Attention """ - self.attention = self.model.find_first_child_by_type(layernorm, - 'Attention', - input_name_to_nodes, - recursive=False) + self.attention = self.model.find_first_child_by_type( + layernorm, "Attention", input_name_to_nodes, recursive=False + ) if self.attention is None: # In case user disables attention fusion, check whether subgraph looks like Attention. if layernorm.output[0] not in input_name_to_nodes: @@ -63,8 +72,11 @@ def check_attention_subgraph(self, layernorm: NodeProto, input_name_to_nodes: Di # For Albert, there is MatMul+Add after embedding layer before attention. if len(children) == 1 and children[0].op_type == "MatMul" and children[0].output[0] in input_name_to_nodes: grandchildren = input_name_to_nodes[children[0].output[0]] - if len(grandchildren) == 1 and grandchildren[0].op_type == "Add" and grandchildren[0].output[ - 0] in input_name_to_nodes: + if ( + len(grandchildren) == 1 + and grandchildren[0].op_type == "Add" + and grandchildren[0].output[0] in input_name_to_nodes + ): nodes = input_name_to_nodes[grandchildren[0].output[0]] for node in nodes: if node.op_type == "Attention": @@ -77,14 +89,20 @@ def check_attention_subgraph(self, layernorm: NodeProto, input_name_to_nodes: Di # Two Shape nodes might be merged by ORT if is_distil_bert: # SkipLayerNormailization might exist when model has been optimized by ORT first. - if children_types != ['MatMul', 'MatMul', 'MatMul', 'Shape', 'SkipLayerNormalization'] and \ - children_types != ['Add', 'MatMul', 'MatMul', 'MatMul', 'Shape', 'Shape'] and \ - children_types != ['Add', 'MatMul', 'MatMul', 'MatMul', 'Shape']: + if ( + children_types != ["MatMul", "MatMul", "MatMul", "Shape", "SkipLayerNormalization"] + and children_types != ["Add", "MatMul", "MatMul", "MatMul", "Shape", "Shape"] + and children_types != ["Add", "MatMul", "MatMul", "MatMul", "Shape"] + ): logger.debug("No Attention like subgraph in children of LayerNormalization") return False else: - if children_types != ['Add', 'MatMul', 'MatMul', 'MatMul'] and \ - children_types != ['MatMul', 'MatMul', 'MatMul', 'SkipLayerNormalization']: + if children_types != ["Add", "MatMul", "MatMul", "MatMul",] and children_types != [ + "MatMul", + "MatMul", + "MatMul", + "SkipLayerNormalization", + ]: logger.debug("No Attention like subgraph in children of LayerNormalization") return False return True @@ -110,9 +128,13 @@ def match_position_embedding_distilbert(self, position_embedding_gather, input_i Gather """ # remove after tests pass - path1 = self.model.match_parent_path(position_embedding_gather, ['Expand', 'Shape'], [1, 1]) + path1 = self.model.match_parent_path(position_embedding_gather, ["Expand", "Shape"], [1, 1]) if path1 is None: - path1 = self.model.match_parent_path(position_embedding_gather, ['Expand', 'Where', 'Reshape', 'Shape'], [1, 1, 2, 0]) + path1 = self.model.match_parent_path( + position_embedding_gather, + ["Expand", "Where", "Reshape", "Shape"], + [1, 1, 2, 0], + ) if path1 is None: return False @@ -120,14 +142,21 @@ def match_position_embedding_distilbert(self, position_embedding_gather, input_i if shape.input[0] != input_ids: return False - _, path2, _ = self.model.match_parent_paths(expand, [(['Unsqueeze', 'Range', 'Cast', 'Gather', 'Shape'], [0, 0, 1, 0, 0]), \ - (['Unsqueeze', 'Range', 'Gather', 'Shape'], [0, 0, 1, 0])], output_name_to_node) + _, path2, _ = self.model.match_parent_paths( + expand, + [ + (["Unsqueeze", "Range", "Cast", "Gather", "Shape"], [0, 0, 1, 0, 0]), + (["Unsqueeze", "Range", "Gather", "Shape"], [0, 0, 1, 0]), + ], + output_name_to_node, + ) if path2 is None: return False range_node = path2[1] - if not (self.utils.check_node_input_value(range_node, 0, 0) - and self.utils.check_node_input_value(range_node, 2, 1)): + if not ( + self.utils.check_node_input_value(range_node, 0, 0) and self.utils.check_node_input_value(range_node, 2, 1) + ): return False gather_node = path2[-2] @@ -141,19 +170,19 @@ def match_position_embedding_distilbert(self, position_embedding_gather, input_i return True def match_position_embedding_roberta(self, position_embedding_gather, input_ids, output_name_to_node): - """ Match position embedding path from input_ids to Gather for Roberta. + """Match position embedding path from input_ids to Gather for Roberta. - Roberta Embedding Layer Pattern (* is optional since it might be removed by ORT, ? is the padding word id): + Roberta Embedding Layer Pattern (* is optional since it might be removed by ORT, ? is the padding word id): (input_ids) --> Equal(B=?) -- Not -- Cast(to=6) -- CumSum(axis=1) -- Mul -- Cast(to=7) -- Add(B=1) -- Cast(to=7)* --> Gather | ^ V | - +------------------------------+ + +------------------------------+ Roberta new pattern from transformers v4.9: (input_ids) --> Equal(B=?) -- Not -- Cast(to=6) -- CumSum(axis=1) -- Add(B=0) -- Mul -- Cast(to=7) -- Add(B=1) --> Gather | ^ V | - +-------------------------------------------+ + +-------------------------------------------+ start_node = position_embedding_gather start_index = 1 @@ -209,22 +238,30 @@ def match_position_embedding_bert(self, position_embedding_gather, input_ids, ou | LayerNormalization """ - path = self.model.match_parent_path(position_embedding_gather, ['Slice', 'Unsqueeze'], [1, 2], - output_name_to_node) + path = self.model.match_parent_path( + position_embedding_gather, + ["Slice", "Unsqueeze"], + [1, 2], + output_name_to_node, + ) if path is None: return False slice, unsqueeze = path slice_weight = self.model.get_constant_value(slice.input[0]) - if not (slice_weight is not None and len(slice_weight.shape) == 2 and slice_weight.shape[0] == 1 \ - and self.utils.check_node_input_value(slice, 1, [0]) \ - and self.utils.check_node_input_value(slice, 3, [1]) \ - and (len(slice.input) == 4 or self.utils.check_node_input_value(slice, 4, [1]))): + if not ( + slice_weight is not None + and len(slice_weight.shape) == 2 + and slice_weight.shape[0] == 1 + and self.utils.check_node_input_value(slice, 1, [0]) + and self.utils.check_node_input_value(slice, 3, [1]) + and (len(slice.input) == 4 or self.utils.check_node_input_value(slice, 4, [1])) + ): return False opset_version = self.model.get_opset_version() if opset_version < 13: - if not FusionUtils.check_node_attribute(unsqueeze, 'axes', [0]): + if not FusionUtils.check_node_attribute(unsqueeze, "axes", [0]): return False else: if not self.utils.check_node_input_value(unsqueeze, 1, [0]): @@ -257,7 +294,7 @@ def match_position_embedding(self, position_embedding_gather, input_ids, output_ # TODO: Support roberta (position starts from 2 instead of 0) in EmbedLayerNormalization kernel # related: https://github.com/huggingface/transformers/issues/10736 - #if self.match_position_embedding_roberta(position_embedding_gather, input_ids, output_name_to_node): + # if self.match_position_embedding_roberta(position_embedding_gather, input_ids, output_name_to_node): # return True if self.match_position_embedding_distilbert(position_embedding_gather, input_ids, output_name_to_node): @@ -266,8 +303,7 @@ def match_position_embedding(self, position_embedding_gather, input_ids, output_ return False def check_embedding(self, word_embedding_gather, segment_embedding_gather, position_embedding_gather): - """Sanity check of embedding weights, and match hidden_size of weights and shape of inputs. - """ + """Sanity check of embedding weights, and match hidden_size of weights and shape of inputs.""" input_ids = word_embedding_gather.input[1] segment_ids = segment_embedding_gather.input[1] if segment_embedding_gather else None position_ids = position_embedding_gather.input[1] @@ -276,17 +312,25 @@ def check_embedding(self, word_embedding_gather, segment_embedding_gather, posit input_ids_shape = self.shape_infer_helper.get_edge_shape(input_ids) position_ids_shape = self.shape_infer_helper.get_edge_shape(position_ids) assert input_ids_shape and position_ids_shape - if not (len(input_ids_shape) == 2 and len(position_ids_shape) == 2 - and input_ids_shape[1] == position_ids_shape[1]): + if not ( + len(input_ids_shape) == 2 + and len(position_ids_shape) == 2 + and input_ids_shape[1] == position_ids_shape[1] + ): logger.info( - "Cannot fuse EmbedLayerNormalization: input_ids and position_ids not matched in 2nd dimension: {} vs {}" - .format(input_ids_shape, position_ids_shape)) + "Cannot fuse EmbedLayerNormalization: input_ids and position_ids not matched in 2nd dimension: {} vs {}".format( + input_ids_shape, position_ids_shape + ) + ) return False if segment_ids and not self.shape_infer_helper.compare_shape(input_ids, segment_ids): logger.info( - "Cannot fuse EmbedLayerNormalization: input_ids and segment_ids does not have same shape: {} != {}". - format(input_ids_shape, self.shape_infer_helper.get_edge_shape(segment_ids))) + "Cannot fuse EmbedLayerNormalization: input_ids and segment_ids does not have same shape: {} != {}".format( + input_ids_shape, + self.shape_infer_helper.get_edge_shape(segment_ids), + ) + ) return False word_embedding_table = self.model.get_constant_value(word_embedding_gather.input[0]) @@ -295,15 +339,21 @@ def check_embedding(self, word_embedding_gather, segment_embedding_gather, posit return False position_embedding_table = self.model.get_constant_value(position_embedding_gather.input[0]) - if position_embedding_table is None or len(position_embedding_table.shape) != 2 or ( - word_embedding_table.shape[1] != position_embedding_table.shape[1]): + if ( + position_embedding_table is None + or len(position_embedding_table.shape) != 2 + or (word_embedding_table.shape[1] != position_embedding_table.shape[1]) + ): logger.info("Cannot fuse EmbedLayerNormalization: position embedding table is not expected") return False if segment_ids: segment_embedding_table = self.model.get_constant_value(segment_embedding_gather.input[0]) - if segment_embedding_table is None or len(segment_embedding_table.shape) != 2 or ( - word_embedding_table.shape[1] != segment_embedding_table.shape[1]): + if ( + segment_embedding_table is None + or len(segment_embedding_table.shape) != 2 + or (word_embedding_table.shape[1] != segment_embedding_table.shape[1]) + ): logger.info("Cannot fuse EmbedLayerNormalization: segment embedding table is not expected") return False @@ -350,9 +400,16 @@ def cast_to_int32(self, input_name: str) -> Tuple[str, Union[None, NodeProto]]: return int32_output, input_cast_node - def create_fused_node(self, input_ids: str, layernorm: NodeProto, word_embedding_gather: NodeProto, - position_embedding_gather: NodeProto, segment_embedding_gather: Union[None, NodeProto], - position_ids: str = None, embedding_sum_output = False): + def create_fused_node( + self, + input_ids: str, + layernorm: NodeProto, + word_embedding_gather: NodeProto, + position_embedding_gather: NodeProto, + segment_embedding_gather: Union[None, NodeProto], + position_ids: str = None, + embedding_sum_output=False, + ): """Create an EmbedLayerNormalization node. Note that segment embedding is optional. Args: @@ -368,7 +425,7 @@ def create_fused_node(self, input_ids: str, layernorm: NodeProto, word_embedding nodes_to_add = [] input_ids, _ = self.cast_to_int32(input_ids) - node_name = self.model.create_node_name('EmbedLayerNormalization') + node_name = self.model.create_node_name("EmbedLayerNormalization") if layernorm.op_type == "LayerNormalization": gamma = layernorm.input[1] @@ -382,17 +439,28 @@ def create_fused_node(self, input_ids: str, layernorm: NodeProto, word_embedding segment_ids, _ = self.cast_to_int32(segment_embedding_gather.input[1]) embed_node_inputs = [ - input_ids, segment_ids, word_embedding_gather.input[0], position_embedding_gather.input[0], - segment_embedding_gather.input[0], gamma, beta + input_ids, + segment_ids, + word_embedding_gather.input[0], + position_embedding_gather.input[0], + segment_embedding_gather.input[0], + gamma, + beta, ] else: # no segment embedding embed_node_inputs = [ - input_ids, '', word_embedding_gather.input[0], position_embedding_gather.input[0], '', gamma, beta + input_ids, + "", + word_embedding_gather.input[0], + position_embedding_gather.input[0], + "", + gamma, + beta, ] if position_ids is not None: - #Adding an empty input for mask before position_ids - embed_node_inputs.append('') + # Adding an empty input for mask before position_ids + embed_node_inputs.append("") position_ids, _ = self.cast_to_int32(position_ids) embed_node_inputs.append(position_ids) @@ -400,22 +468,24 @@ def create_fused_node(self, input_ids: str, layernorm: NodeProto, word_embedding if embedding_sum_output: embed_node_outputs.append(node_name + "_embedding_sum") - embed_node = helper.make_node('EmbedLayerNormalization', - embed_node_inputs, - outputs=embed_node_outputs, - name=node_name) + embed_node = helper.make_node( + "EmbedLayerNormalization", + embed_node_inputs, + outputs=embed_node_outputs, + name=node_name, + ) embed_node.domain = "com.microsoft" # Pass attribute "epsilon" from normalize node to EmbedLayerNormalization. for att in layernorm.attribute: - if att.name == 'epsilon': + if att.name == "epsilon": embed_node.attribute.extend([att]) # Set default value to 1e-12 if no attribute is found. # OnnxRuntime 1.2.0 or older has no epsilon attribute. The optimized model can only work for 1.3.0 or later. if len(embed_node.attribute) == 0: - embed_node.attribute.extend([helper.make_attribute("epsilon", 1.0E-12)]) + embed_node.attribute.extend([helper.make_attribute("epsilon", 1.0e-12)]) # Make sure new EmbedLayerNormalization node is the last one in self.nodes_to_add. nodes_to_add.append(embed_node) @@ -446,7 +516,7 @@ def is_embedding_sum_needed(self, add_before_layer_norm): return len(nodes) > 1 def fuse_gpt2(self, layernorm, add_before_layernorm, input_name_to_nodes, output_name_to_node): - #graph checks + # graph checks # gpt2 has no segment embedding, subgraph pattern is like # input_ids position_ids # | | @@ -484,8 +554,15 @@ def fuse_gpt2(self, layernorm, add_before_layernorm, input_name_to_nodes, output optional_embedding_sum_output = True # make the fused node - embed_node = self.create_fused_node(input_ids, layernorm, word_embedding_gather, position_embedding_gather, - None, position_ids, optional_embedding_sum_output) + embed_node = self.create_fused_node( + input_ids, + layernorm, + word_embedding_gather, + position_embedding_gather, + None, + position_ids, + optional_embedding_sum_output, + ) # direct the output to another add too self.model.replace_input_of_all_nodes(layernorm.output[0], embed_node.output[0]) @@ -529,8 +606,9 @@ def fuse_distilbert(self, layernorm, add_before_layernorm, input_name_to_nodes, if not self.check_embedding(word_embedding_gather, None, position_embedding_gather): return False - embed_node = self.create_fused_node(input_ids, layernorm, word_embedding_gather, position_embedding_gather, - None) + embed_node = self.create_fused_node( + input_ids, layernorm, word_embedding_gather, position_embedding_gather, None + ) self.finish_fusion(layernorm, embed_node) return True @@ -543,7 +621,7 @@ def fuse_bert(self, layernorm, add_before_layernorm, input_name_to_nodes, output output_name_to_node (Dict[str, List[NodeProto]]): map from output name to nodes """ - add_2_gather = self.model.match_parent_path(add_before_layernorm, ['Add'], [0]) + add_2_gather = self.model.match_parent_path(add_before_layernorm, ["Add"], [0]) if add_2_gather is None: return False @@ -558,7 +636,7 @@ def fuse_bert(self, layernorm, add_before_layernorm, input_name_to_nodes, output if not self.check_attention_subgraph(layernorm, input_name_to_nodes, is_distil_bert=False): return False - position_embedding_path = self.model.match_parent_path(add_before_layernorm, ['Gather'], [1]) + position_embedding_path = self.model.match_parent_path(add_before_layernorm, ["Gather"], [1]) if position_embedding_path is None: return False @@ -574,14 +652,19 @@ def fuse_bert(self, layernorm, add_before_layernorm, input_name_to_nodes, output if not self.check_embedding(word_embedding_gather, segment_embedding_gather, position_embedding_gather): return False - embed_node = self.create_fused_node(input_ids, layernorm, word_embedding_gather, position_embedding_gather, - segment_embedding_gather) + embed_node = self.create_fused_node( + input_ids, + layernorm, + word_embedding_gather, + position_embedding_gather, + segment_embedding_gather, + ) self.finish_fusion(layernorm, embed_node) return True def fuse(self, node, input_name_to_nodes, output_name_to_node): if node.op_type == "LayerNormalization": - first_add_path = self.model.match_parent_path(node, ['Add'], [0]) + first_add_path = self.model.match_parent_path(node, ["Add"], [0]) if first_add_path is None: return add_before_layernorm = first_add_path[0] diff --git a/onnxruntime/python/tools/transformers/fusion_fastgelu.py b/onnxruntime/python/tools/transformers/fusion_fastgelu.py index a9f9d3f7e6879..a9f46585faad7 100644 --- a/onnxruntime/python/tools/transformers/fusion_fastgelu.py +++ b/onnxruntime/python/tools/transformers/fusion_fastgelu.py @@ -1,12 +1,13 @@ -#------------------------------------------------------------------------- +# ------------------------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. -#-------------------------------------------------------------------------- -from typing import Dict, Optional +# -------------------------------------------------------------------------- from logging import getLogger +from typing import Dict, Optional + +from fusion_base import Fusion from onnx import helper from onnx_model import OnnxModel -from fusion_base import Fusion logger = getLogger(__name__) @@ -40,7 +41,7 @@ def fuse_1(self, tanh_node, input_name_to_nodes, output_name_to_node) -> Optiona if tanh_node.output[0] not in input_name_to_nodes: return children = input_name_to_nodes[tanh_node.output[0]] - if len(children) != 1 or children[0].op_type != 'Add': + if len(children) != 1 or children[0].op_type != "Add": return add_after_tanh = children[0] @@ -50,11 +51,11 @@ def fuse_1(self, tanh_node, input_name_to_nodes, output_name_to_node) -> Optiona if add_after_tanh.output[0] not in input_name_to_nodes: return children = input_name_to_nodes[add_after_tanh.output[0]] - if len(children) != 1 or children[0].op_type != 'Mul': + if len(children) != 1 or children[0].op_type != "Mul": return mul_after_tanh = children[0] - mul_half = self.model.match_parent(mul_after_tanh, 'Mul', None, output_name_to_node) + mul_half = self.model.match_parent(mul_after_tanh, "Mul", None, output_name_to_node) if mul_half is None: return @@ -64,10 +65,10 @@ def fuse_1(self, tanh_node, input_name_to_nodes, output_name_to_node) -> Optiona root_input = mul_half.input[0 if i == 1 else 1] - #root_node could be None when root_input is graph input + # root_node could be None when root_input is graph input root_node = self.model.get_parent(mul_half, 0 if i == 1 else 1, output_name_to_node) - mul_before_tanh = self.model.match_parent(tanh_node, 'Mul', 0, output_name_to_node) + mul_before_tanh = self.model.match_parent(tanh_node, "Mul", 0, output_name_to_node) if mul_before_tanh is None: return @@ -75,15 +76,17 @@ def fuse_1(self, tanh_node, input_name_to_nodes, output_name_to_node) -> Optiona if i < 0: return - add_before_tanh = self.model.match_parent(mul_before_tanh, 'Add', 0 if i == 1 else 1, output_name_to_node) + add_before_tanh = self.model.match_parent(mul_before_tanh, "Add", 0 if i == 1 else 1, output_name_to_node) if add_before_tanh is None: return - mul_after_pow = self.model.match_parent(add_before_tanh, - 'Mul', - None, - output_name_to_node, - exclude=[root_node] if root_node else []) + mul_after_pow = self.model.match_parent( + add_before_tanh, + "Mul", + None, + output_name_to_node, + exclude=[root_node] if root_node else [], + ) if mul_after_pow is None: return @@ -91,7 +94,7 @@ def fuse_1(self, tanh_node, input_name_to_nodes, output_name_to_node) -> Optiona if i < 0: return - pow = self.model.match_parent(mul_after_pow, 'Pow', 0 if i == 1 else 1, output_name_to_node) + pow = self.model.match_parent(mul_after_pow, "Pow", 0 if i == 1 else 1, output_name_to_node) if pow is None: return @@ -102,17 +105,30 @@ def fuse_1(self, tanh_node, input_name_to_nodes, output_name_to_node) -> Optiona return subgraph_nodes = [ - mul_after_tanh, mul_half, add_after_tanh, tanh_node, mul_before_tanh, add_before_tanh, mul_after_pow, pow + mul_after_tanh, + mul_half, + add_after_tanh, + tanh_node, + mul_before_tanh, + add_before_tanh, + mul_after_pow, + pow, ] - if not self.model.is_safe_to_fuse_nodes(subgraph_nodes, [mul_after_tanh.output[0]], input_name_to_nodes, - output_name_to_node): + if not self.model.is_safe_to_fuse_nodes( + subgraph_nodes, + [mul_after_tanh.output[0]], + input_name_to_nodes, + output_name_to_node, + ): return self.nodes_to_remove.extend(subgraph_nodes) - fused_node = helper.make_node('FastGelu', - inputs=[root_input], - outputs=mul_after_tanh.output, - name=self.model.create_node_name('FastGelu')) + fused_node = helper.make_node( + "FastGelu", + inputs=[root_input], + outputs=mul_after_tanh.output, + name=self.model.create_node_name("FastGelu"), + ) fused_node.domain = "com.microsoft" self.nodes_to_add.append(fused_node) self.node_name_to_graph_name[fused_node.name] = self.this_graph_name @@ -134,7 +150,7 @@ def fuse_2(self, tanh_node, input_name_to_nodes: Dict, output_name_to_node: Dict if tanh_node.output[0] not in input_name_to_nodes: return children = input_name_to_nodes[tanh_node.output[0]] - if len(children) != 1 or children[0].op_type != 'Add': + if len(children) != 1 or children[0].op_type != "Add": return add_after_tanh = children[0] @@ -144,7 +160,7 @@ def fuse_2(self, tanh_node, input_name_to_nodes: Dict, output_name_to_node: Dict if add_after_tanh.output[0] not in input_name_to_nodes: return children = input_name_to_nodes[add_after_tanh.output[0]] - if len(children) != 1 or children[0].op_type != 'Mul': + if len(children) != 1 or children[0].op_type != "Mul": return mul_half = children[0] @@ -155,17 +171,19 @@ def fuse_2(self, tanh_node, input_name_to_nodes: Dict, output_name_to_node: Dict if mul_half.output[0] not in input_name_to_nodes: return children = input_name_to_nodes[mul_half.output[0]] - if len(children) != 1 or children[0].op_type != 'Mul': + if len(children) != 1 or children[0].op_type != "Mul": return mul_after_mul_half = children[0] - root_node = self.model.get_parent(mul_after_mul_half, - 0 if mul_after_mul_half.input[1] == mul_half.output[0] else 1, - output_name_to_node) + root_node = self.model.get_parent( + mul_after_mul_half, + 0 if mul_after_mul_half.input[1] == mul_half.output[0] else 1, + output_name_to_node, + ) if root_node is None: return - mul_before_tanh = self.model.match_parent(tanh_node, 'Mul', 0, output_name_to_node) + mul_before_tanh = self.model.match_parent(tanh_node, "Mul", 0, output_name_to_node) if mul_before_tanh is None: return @@ -173,11 +191,11 @@ def fuse_2(self, tanh_node, input_name_to_nodes: Dict, output_name_to_node: Dict if i < 0: return - add_before_tanh = self.model.match_parent(mul_before_tanh, 'Add', 0 if i == 1 else 1, output_name_to_node) + add_before_tanh = self.model.match_parent(mul_before_tanh, "Add", 0 if i == 1 else 1, output_name_to_node) if add_before_tanh is None: return - mul_after_pow = self.model.match_parent(add_before_tanh, 'Mul', None, output_name_to_node, exclude=[root_node]) + mul_after_pow = self.model.match_parent(add_before_tanh, "Mul", None, output_name_to_node, exclude=[root_node]) if mul_after_pow is None: return @@ -185,7 +203,7 @@ def fuse_2(self, tanh_node, input_name_to_nodes: Dict, output_name_to_node: Dict if i < 0: return - pow = self.model.match_parent(mul_after_pow, 'Pow', 0 if i == 1 else 1, output_name_to_node) + pow = self.model.match_parent(mul_after_pow, "Pow", 0 if i == 1 else 1, output_name_to_node) if pow is None: return @@ -196,18 +214,30 @@ def fuse_2(self, tanh_node, input_name_to_nodes: Dict, output_name_to_node: Dict return subgraph_nodes = [ - mul_after_mul_half, mul_half, add_after_tanh, tanh_node, mul_before_tanh, add_before_tanh, mul_after_pow, - pow + mul_after_mul_half, + mul_half, + add_after_tanh, + tanh_node, + mul_before_tanh, + add_before_tanh, + mul_after_pow, + pow, ] - if not self.model.is_safe_to_fuse_nodes(subgraph_nodes, [mul_after_mul_half.output[0]], input_name_to_nodes, - output_name_to_node): + if not self.model.is_safe_to_fuse_nodes( + subgraph_nodes, + [mul_after_mul_half.output[0]], + input_name_to_nodes, + output_name_to_node, + ): return self.nodes_to_remove.extend(subgraph_nodes) - fused_node = helper.make_node('FastGelu', - inputs=[root_node.output[0]], - outputs=mul_after_mul_half.output, - name=self.model.create_node_name('FastGelu')) + fused_node = helper.make_node( + "FastGelu", + inputs=[root_node.output[0]], + outputs=mul_after_mul_half.output, + name=self.model.create_node_name("FastGelu"), + ) fused_node.domain = "com.microsoft" self.nodes_to_add.append(fused_node) self.node_name_to_graph_name[fused_node.name] = self.this_graph_name @@ -215,25 +245,25 @@ def fuse_2(self, tanh_node, input_name_to_nodes: Dict, output_name_to_node: Dict def fuse_3(self, tanh_node, input_name_to_nodes: Dict, output_name_to_node: Dict) -> Optional[bool]: """ - OpenAI's gelu implementation, also used in Megatron: - Gelu(x) = x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1.0 + 0.044715 * x * x))) - - Fuse subgraph into a FastGelu node: - +------------ Mul (B=0.79788456) -------------------+ - | | - +-------------------------------+ | - | | | - | v v - [root] --> Mul (B=0.044715) --> Mul --> Add(B=1) --> Mul --> Tanh --> Add(B=1) --> Mul--> - | ^ - | | - +-----------> Mul (B=0.5) --------------------------------------------------------+ - """ + OpenAI's gelu implementation, also used in Megatron: + Gelu(x) = x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1.0 + 0.044715 * x * x))) + + Fuse subgraph into a FastGelu node: + +------------ Mul (B=0.79788456) -------------------+ + | | + +-------------------------------+ | + | | | + | v v + [root] --> Mul (B=0.044715) --> Mul --> Add(B=1) --> Mul --> Tanh --> Add(B=1) --> Mul--> + | ^ + | | + +-----------> Mul (B=0.5) --------------------------------------------------------+ + """ if tanh_node.output[0] not in input_name_to_nodes: return children = input_name_to_nodes[tanh_node.output[0]] - if len(children) != 1 or children[0].op_type != 'Add': + if len(children) != 1 or children[0].op_type != "Add": return add_after_tanh = children[0] @@ -243,11 +273,11 @@ def fuse_3(self, tanh_node, input_name_to_nodes: Dict, output_name_to_node: Dict if add_after_tanh.output[0] not in input_name_to_nodes: return children = input_name_to_nodes[add_after_tanh.output[0]] - if len(children) != 1 or children[0].op_type != 'Mul': + if len(children) != 1 or children[0].op_type != "Mul": return mul_last = children[0] - mul_half = self.model.match_parent(mul_last, 'Mul', None, output_name_to_node) + mul_half = self.model.match_parent(mul_last, "Mul", None, output_name_to_node) if mul_half is None: return @@ -257,18 +287,18 @@ def fuse_3(self, tanh_node, input_name_to_nodes: Dict, output_name_to_node: Dict root_input = mul_half.input[0 if i == 1 else 1] - mul_before_tanh = self.model.match_parent(tanh_node, 'Mul', 0, output_name_to_node) + mul_before_tanh = self.model.match_parent(tanh_node, "Mul", 0, output_name_to_node) if mul_before_tanh is None: return - add_1 = self.model.match_parent(mul_before_tanh, 'Add', None, output_name_to_node) + add_1 = self.model.match_parent(mul_before_tanh, "Add", None, output_name_to_node) if add_1 is None: return j = self.model.find_constant_input(add_1, 1.0) if j < 0: return - mul_7978 = self.model.match_parent(mul_before_tanh, 'Mul', None, output_name_to_node) + mul_7978 = self.model.match_parent(mul_before_tanh, "Mul", None, output_name_to_node) if mul_7978 is None: return k = self.model.find_constant_input(mul_7978, 0.7978, delta=0.0001) @@ -277,7 +307,7 @@ def fuse_3(self, tanh_node, input_name_to_nodes: Dict, output_name_to_node: Dict if mul_7978.input[0 if k == 1 else 1] != root_input: return - mul_before_add_1 = self.model.match_parent(add_1, 'Mul', 0 if j == 1 else 1, output_name_to_node) + mul_before_add_1 = self.model.match_parent(add_1, "Mul", 0 if j == 1 else 1, output_name_to_node) if mul_before_add_1 is None: return @@ -288,7 +318,7 @@ def fuse_3(self, tanh_node, input_name_to_nodes: Dict, output_name_to_node: Dict else: return - mul_0447 = self.model.match_parent(mul_before_add_1, 'Mul', another, output_name_to_node) + mul_0447 = self.model.match_parent(mul_before_add_1, "Mul", another, output_name_to_node) if mul_0447 is None: return m = self.model.find_constant_input(mul_0447, 0.0447, delta=0.0001) @@ -299,17 +329,31 @@ def fuse_3(self, tanh_node, input_name_to_nodes: Dict, output_name_to_node: Dict return subgraph_nodes = [ - mul_0447, mul_before_add_1, add_1, mul_before_tanh, tanh_node, add_after_tanh, mul_7978, mul_half, mul_last + mul_0447, + mul_before_add_1, + add_1, + mul_before_tanh, + tanh_node, + add_after_tanh, + mul_7978, + mul_half, + mul_last, ] - if not self.model.is_safe_to_fuse_nodes(subgraph_nodes, [mul_last.output[0]], input_name_to_nodes, - output_name_to_node): + if not self.model.is_safe_to_fuse_nodes( + subgraph_nodes, + [mul_last.output[0]], + input_name_to_nodes, + output_name_to_node, + ): return self.nodes_to_remove.extend(subgraph_nodes) - fused_node = helper.make_node('FastGelu', - inputs=[root_input], - outputs=mul_last.output, - name=self.model.create_node_name('FastGelu')) + fused_node = helper.make_node( + "FastGelu", + inputs=[root_input], + outputs=mul_last.output, + name=self.model.create_node_name("FastGelu"), + ) fused_node.domain = "com.microsoft" self.nodes_to_add.append(fused_node) self.node_name_to_graph_name[fused_node.name] = self.this_graph_name diff --git a/onnxruntime/python/tools/transformers/fusion_gelu.py b/onnxruntime/python/tools/transformers/fusion_gelu.py index 0a7d657dd3f25..8626ca1482104 100644 --- a/onnxruntime/python/tools/transformers/fusion_gelu.py +++ b/onnxruntime/python/tools/transformers/fusion_gelu.py @@ -1,12 +1,13 @@ -#------------------------------------------------------------------------- +# ------------------------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. -#-------------------------------------------------------------------------- -from typing import Dict, Optional +# -------------------------------------------------------------------------- from logging import getLogger +from typing import Dict, Optional + +from fusion_base import Fusion from onnx import helper from onnx_model import OnnxModel -from fusion_base import Fusion logger = getLogger(__name__) @@ -45,7 +46,7 @@ def fuse_1(self, erf_node, input_name_to_nodes: Dict, output_name_to_node: Dict) if erf_node.output[0] not in input_name_to_nodes: return children = input_name_to_nodes[erf_node.output[0]] - if len(children) != 1 or children[0].op_type != 'Add': + if len(children) != 1 or children[0].op_type != "Add": return add_after_erf = children[0] @@ -55,11 +56,11 @@ def fuse_1(self, erf_node, input_name_to_nodes: Dict, output_name_to_node: Dict) if add_after_erf.output[0] not in input_name_to_nodes: return children = input_name_to_nodes[add_after_erf.output[0]] - if len(children) != 1 or children[0].op_type != 'Mul': + if len(children) != 1 or children[0].op_type != "Mul": return mul_after_erf = children[0] - div = self.model.match_parent(erf_node, 'Div', 0, output_name_to_node) + div = self.model.match_parent(erf_node, "Div", 0, output_name_to_node) if div is None: return @@ -71,14 +72,14 @@ def fuse_1(self, erf_node, input_name_to_nodes: Dict, output_name_to_node: Dict) another = 1 if mul_after_erf.input[0] == add_after_erf.output[0] else 0 if subgraph_input == mul_after_erf.input[another]: # pattern 2 children = input_name_to_nodes[mul_after_erf.output[0]] - if len(children) != 1 or children[0].op_type != 'Mul': + if len(children) != 1 or children[0].op_type != "Mul": return mul_half = children[0] if not self.model.has_constant_input(mul_half, 0.5): return subgraph_output = mul_half.output[0] else: # pattern 1 - mul_half = self.model.match_parent(mul_after_erf, 'Mul', another, output_name_to_node) + mul_half = self.model.match_parent(mul_after_erf, "Mul", another, output_name_to_node) if mul_half is None: return @@ -91,12 +92,13 @@ def fuse_1(self, erf_node, input_name_to_nodes: Dict, output_name_to_node: Dict) subgraph_output = mul_after_erf.output[0] subgraph_nodes = [div, erf_node, add_after_erf, mul_after_erf, mul_half] - if not self.model.is_safe_to_fuse_nodes(subgraph_nodes, [subgraph_output], input_name_to_nodes, - output_name_to_node): + if not self.model.is_safe_to_fuse_nodes( + subgraph_nodes, [subgraph_output], input_name_to_nodes, output_name_to_node + ): return self.nodes_to_remove.extend(subgraph_nodes) - fused_node = helper.make_node('Gelu', inputs=[subgraph_input], outputs=[subgraph_output]) + fused_node = helper.make_node("Gelu", inputs=[subgraph_input], outputs=[subgraph_output]) fused_node.domain = "com.microsoft" self.nodes_to_add.append(fused_node) self.node_name_to_graph_name[fused_node.name] = self.this_graph_name @@ -117,7 +119,7 @@ def fuse_2(self, erf_node, input_name_to_nodes: Dict, output_name_to_node: Dict) if erf_node.output[0] not in input_name_to_nodes: return children = input_name_to_nodes[erf_node.output[0]] - if len(children) != 1 or children[0].op_type != 'Add': + if len(children) != 1 or children[0].op_type != "Add": return add_after_erf = children[0] @@ -127,7 +129,7 @@ def fuse_2(self, erf_node, input_name_to_nodes: Dict, output_name_to_node: Dict) if add_after_erf.output[0] not in input_name_to_nodes: return children = input_name_to_nodes[add_after_erf.output[0]] - if len(children) != 1 or children[0].op_type != 'Mul': + if len(children) != 1 or children[0].op_type != "Mul": return mul_after_erf = children[0] @@ -137,17 +139,17 @@ def fuse_2(self, erf_node, input_name_to_nodes: Dict, output_name_to_node: Dict) if mul_after_erf.output[0] not in input_name_to_nodes: return children = input_name_to_nodes[mul_after_erf.output[0]] - if len(children) != 1 or children[0].op_type != 'Mul': + if len(children) != 1 or children[0].op_type != "Mul": return mul = children[0] - div = self.model.match_parent(erf_node, 'Div', 0, output_name_to_node) + div = self.model.match_parent(erf_node, "Div", 0, output_name_to_node) if div is None: return sqrt_node = None if self.model.find_constant_input(div, 1.4142, delta=0.001) != 1: - sqrt_node = self.model.match_parent(div, 'Sqrt', 1, output_name_to_node) + sqrt_node = self.model.match_parent(div, "Sqrt", 1, output_name_to_node) if sqrt_node is None: return if not self.model.has_constant_input(sqrt_node, 2.0): @@ -164,12 +166,13 @@ def fuse_2(self, erf_node, input_name_to_nodes: Dict, output_name_to_node: Dict) if sqrt_node: subgraph_nodes.append(sqrt_node) - if not self.model.is_safe_to_fuse_nodes(subgraph_nodes, [mul.output[0]], input_name_to_nodes, - output_name_to_node): + if not self.model.is_safe_to_fuse_nodes( + subgraph_nodes, [mul.output[0]], input_name_to_nodes, output_name_to_node + ): return self.nodes_to_remove.extend(subgraph_nodes) - fused_node = helper.make_node('Gelu', inputs=[root_node.output[0]], outputs=[mul.output[0]]) + fused_node = helper.make_node("Gelu", inputs=[root_node.output[0]], outputs=[mul.output[0]]) fused_node.domain = "com.microsoft" self.nodes_to_add.append(fused_node) self.node_name_to_graph_name[fused_node.name] = self.this_graph_name @@ -191,7 +194,7 @@ def fuse_3(self, erf_node, input_name_to_nodes: Dict, output_name_to_node: Dict) if erf_node.output[0] not in input_name_to_nodes: return children = input_name_to_nodes[erf_node.output[0]] - if len(children) != 1 or children[0].op_type != 'Add': + if len(children) != 1 or children[0].op_type != "Add": return add_after_erf = children[0] @@ -201,14 +204,14 @@ def fuse_3(self, erf_node, input_name_to_nodes: Dict, output_name_to_node: Dict) if add_after_erf.output[0] not in input_name_to_nodes: return children = input_name_to_nodes[add_after_erf.output[0]] - if len(children) != 1 or children[0].op_type != 'Mul': + if len(children) != 1 or children[0].op_type != "Mul": return mul_half = children[0] if not self.model.has_constant_input(mul_half, 0.5): return - first_mul = self.model.match_parent(erf_node, 'Mul', 0, output_name_to_node) + first_mul = self.model.match_parent(erf_node, "Mul", 0, output_name_to_node) if first_mul is None: return @@ -223,7 +226,7 @@ def fuse_3(self, erf_node, input_name_to_nodes: Dict, output_name_to_node: Dict) if mul_half.output[0] not in input_name_to_nodes: return children = input_name_to_nodes[mul_half.output[0]] - if len(children) != 1 or children[0].op_type != 'Mul': + if len(children) != 1 or children[0].op_type != "Mul": return last_mul = children[0] @@ -231,12 +234,16 @@ def fuse_3(self, erf_node, input_name_to_nodes: Dict, output_name_to_node: Dict) return subgraph_nodes = [first_mul, erf_node, add_after_erf, mul_half, last_mul] - if not self.model.is_safe_to_fuse_nodes(subgraph_nodes, [last_mul.output[0]], input_name_to_nodes, - output_name_to_node): + if not self.model.is_safe_to_fuse_nodes( + subgraph_nodes, + [last_mul.output[0]], + input_name_to_nodes, + output_name_to_node, + ): return self.nodes_to_remove.extend(subgraph_nodes) - fused_node = helper.make_node('Gelu', inputs=[root_node.output[0]], outputs=[last_mul.output[0]]) + fused_node = helper.make_node("Gelu", inputs=[root_node.output[0]], outputs=[last_mul.output[0]]) fused_node.domain = "com.microsoft" self.nodes_to_add.append(fused_node) self.node_name_to_graph_name[fused_node.name] = self.this_graph_name diff --git a/onnxruntime/python/tools/transformers/fusion_gelu_approximation.py b/onnxruntime/python/tools/transformers/fusion_gelu_approximation.py index 253afabc6ad23..ba231e9e05ea4 100644 --- a/onnxruntime/python/tools/transformers/fusion_gelu_approximation.py +++ b/onnxruntime/python/tools/transformers/fusion_gelu_approximation.py @@ -1,23 +1,26 @@ -#------------------------------------------------------------------------- +# ------------------------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. -#-------------------------------------------------------------------------- +# -------------------------------------------------------------------------- from logging import getLogger + +from fusion_base import Fusion from onnx import helper from onnx_model import OnnxModel -from fusion_base import Fusion class FusionGeluApproximation(Fusion): def __init__(self, model: OnnxModel): - super().__init__(model, 'FastGelu', ['Gelu', 'BiasGelu'], 'GeluApproximation') + super().__init__(model, "FastGelu", ["Gelu", "BiasGelu"], "GeluApproximation") def fuse(self, node, input_name_to_nodes, output_name_to_node): - new_node = helper.make_node("FastGelu", - inputs=node.input, - outputs=node.output, - name=self.model.create_node_name("FastGelu", node.op_type + "_Approximation")) + new_node = helper.make_node( + "FastGelu", + inputs=node.input, + outputs=node.output, + name=self.model.create_node_name("FastGelu", node.op_type + "_Approximation"), + ) new_node.domain = "com.microsoft" self.nodes_to_remove.append(node) self.nodes_to_add.append(new_node) diff --git a/onnxruntime/python/tools/transformers/fusion_gpt_attention.py b/onnxruntime/python/tools/transformers/fusion_gpt_attention.py index 0dea0e385fc90..281b758859626 100644 --- a/onnxruntime/python/tools/transformers/fusion_gpt_attention.py +++ b/onnxruntime/python/tools/transformers/fusion_gpt_attention.py @@ -1,20 +1,21 @@ -#------------------------------------------------------------------------- +# ------------------------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. -#-------------------------------------------------------------------------- -import numpy as np +# -------------------------------------------------------------------------- from logging import getLogger -from onnx import helper, numpy_helper, TensorProto -from onnx_model import OnnxModel + +import numpy as np from fusion_base import Fusion from fusion_utils import FusionUtils +from onnx import TensorProto, helper, numpy_helper +from onnx_model import OnnxModel logger = getLogger(__name__) class FusionGptAttentionPastBase(Fusion): - """Base class for GPT Attention Fusion with past state - """ + """Base class for GPT Attention Fusion with past state""" + def __init__(self, model: OnnxModel, num_heads: int): super().__init__(model, "Attention", "LayerNormalization", "with past") self.num_heads = num_heads @@ -41,7 +42,7 @@ def match_past_pattern_1(self, concat_k, concat_v, output_name_to_node): # | # {present} gather = self.model.get_parent(concat_v, 0, output_name_to_node) - if gather.op_type != 'Gather': + if gather.op_type != "Gather": logger.debug("match_past_pattern_1: expect Gather for past") return None @@ -51,10 +52,10 @@ def match_past_pattern_1(self, concat_k, concat_v, output_name_to_node): past = gather.input[0] parent = self.model.get_parent(concat_k, 0, output_name_to_node) - if parent.op_type == 'Gather': + if parent.op_type == "Gather": gather_past_k = parent else: - past_k_nodes = self.model.match_parent_path(concat_k, ['Transpose', 'Gather'], [0, 0]) + past_k_nodes = self.model.match_parent_path(concat_k, ["Transpose", "Gather"], [0, 0]) if past_k_nodes is None: logger.debug("match_past_pattern_1: failed match Transpose and Gather") return None @@ -93,7 +94,7 @@ def match_past_pattern_2(self, concat_k, concat_v, output_name_to_node): # {present} # squeeze = self.model.get_parent(concat_v, 0, output_name_to_node) - if squeeze.op_type != 'Squeeze': + if squeeze.op_type != "Squeeze": logger.debug("match_past_pattern_2: expect Squeeze as parent of concat_v") return None @@ -104,11 +105,11 @@ def match_past_pattern_2(self, concat_k, concat_v, output_name_to_node): opset_version = self.model.get_opset_version() if opset_version < 13: - if not FusionUtils.check_node_attribute(squeeze, 'axes', [0]): + if not FusionUtils.check_node_attribute(squeeze, "axes", [0]): logger.debug("match_past_pattern_2: axes != [0] for Squeeze in past path") return None - if not FusionUtils.check_node_attribute(split, 'split', [1, 1]): + if not FusionUtils.check_node_attribute(split, "split", [1, 1]): logger.debug("match_past_pattern_2: split != [1, 1] for Split in past path") return None else: @@ -120,12 +121,12 @@ def match_past_pattern_2(self, concat_k, concat_v, output_name_to_node): logger.debug("match_past_pattern_2: split != [1, 1] for Split in past path") return None - if not FusionUtils.check_node_attribute(split, 'axis', 0, default_value=0): + if not FusionUtils.check_node_attribute(split, "axis", 0, default_value=0): logger.debug("match_past_pattern_2: attribute axis of Split are not expected in past path") return None past = split.input[0] - past_k_nodes = self.model.match_parent_path(concat_k, ['Squeeze', 'Split'], [0, 0]) + past_k_nodes = self.model.match_parent_path(concat_k, ["Squeeze", "Split"], [0, 0]) if past_k_nodes is None: logger.debug("match_past_pattern_2: failed to match past_k_nodes path") return None @@ -138,17 +139,15 @@ def match_past_pattern_2(self, concat_k, concat_v, output_name_to_node): return past def match_present(self, concat_v, input_name_to_nodes): - unsqueeze_present_v = self.model.find_first_child_by_type(concat_v, - 'Unsqueeze', - input_name_to_nodes, - recursive=False) + unsqueeze_present_v = self.model.find_first_child_by_type( + concat_v, "Unsqueeze", input_name_to_nodes, recursive=False + ) if not unsqueeze_present_v: logger.info("expect unsqueeze for present") return None - concat_present = self.model.find_first_child_by_type(unsqueeze_present_v, - 'Concat', - input_name_to_nodes, - recursive=False) + concat_present = self.model.find_first_child_by_type( + unsqueeze_present_v, "Concat", input_name_to_nodes, recursive=False + ) if not concat_present: logger.info("expect concat for present") return None @@ -172,31 +171,50 @@ class FusionGptAttention(FusionGptAttentionPastBase): """ Fuse GPT-2 Attention with past state subgraph into one Attention node. """ + def __init__(self, model: OnnxModel, num_heads: int): super().__init__(model, num_heads) - def create_attention_node(self, fc_weight, fc_bias, gemm_qkv, past, present, input, output, mask, - is_unidirectional): - attention_node_name = self.model.create_node_name('GptAttention') - attention_node = helper.make_node('Attention', - inputs=[input, fc_weight, fc_bias, mask, past], - outputs=[attention_node_name + "_output", present], - name=attention_node_name) + def create_attention_node( + self, + fc_weight, + fc_bias, + gemm_qkv, + past, + present, + input, + output, + mask, + is_unidirectional, + ): + attention_node_name = self.model.create_node_name("GptAttention") + attention_node = helper.make_node( + "Attention", + inputs=[input, fc_weight, fc_bias, mask, past], + outputs=[attention_node_name + "_output", present], + name=attention_node_name, + ) attention_node.domain = "com.microsoft" - attention_node.attribute.extend([ - helper.make_attribute("num_heads", self.num_heads), - helper.make_attribute("unidirectional", 1 if is_unidirectional else 0) - ]) - - matmul_node = helper.make_node('MatMul', - inputs=[attention_node_name + "_output", gemm_qkv.input[1]], - outputs=[attention_node_name + "_matmul_output"], - name=attention_node_name + "_matmul") - - add_node = helper.make_node('Add', - inputs=[attention_node_name + "_matmul_output", gemm_qkv.input[2]], - outputs=[output], - name=attention_node_name + "_add") + attention_node.attribute.extend( + [ + helper.make_attribute("num_heads", self.num_heads), + helper.make_attribute("unidirectional", 1 if is_unidirectional else 0), + ] + ) + + matmul_node = helper.make_node( + "MatMul", + inputs=[attention_node_name + "_output", gemm_qkv.input[1]], + outputs=[attention_node_name + "_matmul_output"], + name=attention_node_name + "_matmul", + ) + + add_node = helper.make_node( + "Add", + inputs=[attention_node_name + "_matmul_output", gemm_qkv.input[2]], + outputs=[output], + name=attention_node_name + "_add", + ) self.nodes_to_add.extend([attention_node, matmul_node, add_node]) self.node_name_to_graph_name[attention_node.name] = self.this_graph_name self.node_name_to_graph_name[matmul_node.name] = self.this_graph_name @@ -208,28 +226,44 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): return_indice = [] qkv_nodes = self.model.match_parent_path( normalize_node, - ['Add', 'Reshape', 'Gemm', 'Reshape', 'Reshape', 'Transpose', 'MatMul'], - [0, None, 0, 0, 0, 0, 0], + ["Add", "Reshape", "Gemm", "Reshape", "Reshape", "Transpose", "MatMul"], + [0, None, 0, 0, 0, 0, 0], output_name_to_node=output_name_to_node, - return_indice=return_indice - ) # yapf: disable + return_indice=return_indice, + ) # yapf: disable if qkv_nodes is None: return - (add_qkv, reshape_qkv, gemm_qkv, reshape_1, reshape_2, transpose_qkv, matmul_qkv) = qkv_nodes + ( + add_qkv, + reshape_qkv, + gemm_qkv, + reshape_1, + reshape_2, + transpose_qkv, + matmul_qkv, + ) = qkv_nodes another_input = add_qkv.input[1 - return_indice[0]] - v_nodes = self.model.match_parent_path(matmul_qkv, ['Concat', 'Transpose', 'Reshape', 'Split'], [1, 1, 0, 0]) + v_nodes = self.model.match_parent_path(matmul_qkv, ["Concat", "Transpose", "Reshape", "Split"], [1, 1, 0, 0]) if v_nodes is None: logger.debug("fuse_attention: failed to match v path") return (concat_v, transpose_v, reshape_v, split_fc) = v_nodes - fc_nodes = self.model.match_parent_path(split_fc, ['Reshape', 'Gemm', 'Reshape', 'LayerNormalization'], - [0, 0, 0, 0], output_name_to_node) + fc_nodes = self.model.match_parent_path( + split_fc, + ["Reshape", "Gemm", "Reshape", "LayerNormalization"], + [0, 0, 0, 0], + output_name_to_node, + ) if fc_nodes is None: - fc_nodes = self.model.match_parent_path(split_fc, ['Add', 'MatMul', 'LayerNormalization'], [0, None, 0], - output_name_to_node) + fc_nodes = self.model.match_parent_path( + split_fc, + ["Add", "MatMul", "LayerNormalization"], + [0, None, 0], + output_name_to_node, + ) if fc_nodes is None: logger.debug("fuse_attention: failed to match fc path") return @@ -250,13 +284,25 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): slice_mask = None input_mask_nodes = None concat_k_to_match = None - qk_nodes = self.model.match_parent_path(matmul_qkv, ['Softmax', 'Sub', 'Mul', 'Div', 'MatMul'], [0, 0, 0, 0, 0]) + qk_nodes = self.model.match_parent_path(matmul_qkv, ["Softmax", "Sub", "Mul", "Div", "MatMul"], [0, 0, 0, 0, 0]) if qk_nodes is not None: (softmax_qk, sub_qk, mul_qk, div_qk, matmul_qk) = qk_nodes mask_nodes = self.model.match_parent_path( sub_qk, - ['Mul', 'Sub', 'Slice', 'Slice', 'Unsqueeze', 'Sub', 'Squeeze', 'Slice', 'Shape', 'Div'], - [1, 0, 1, 0, 1, 0, 0, 0, 0, 0]) # yapf: disable + [ + "Mul", + "Sub", + "Slice", + "Slice", + "Unsqueeze", + "Sub", + "Squeeze", + "Slice", + "Shape", + "Div", + ], + [1, 0, 1, 0, 1, 0, 0, 0, 0, 0], + ) # yapf: disable if mask_nodes is None: logger.debug("fuse_attention: failed to match unidirectional mask path") return @@ -269,8 +315,13 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): else: # New pattern for gpt2 from PyTorch 1.5.0 and Transformers 2.9.0. i, qk_nodes, _ = self.model.match_parent_paths( - matmul_qkv, [(['Softmax', 'Where', 'Div', 'MatMul'], [0, 0, 1, 0]), - (['Softmax', 'Add', 'Where', 'Div', 'MatMul'], [0, 0, None, 1, 0])], output_name_to_node) + matmul_qkv, + [ + (["Softmax", "Where", "Div", "MatMul"], [0, 0, 1, 0]), + (["Softmax", "Add", "Where", "Div", "MatMul"], [0, 0, None, 1, 0]), + ], + output_name_to_node, + ) if qk_nodes is None: logger.debug("fuse_attention: failed to match qk nodes") return @@ -284,20 +335,40 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): _, input_mask_nodes, _ = self.model.match_parent_paths( add_qk, [ - (['Mul', 'Sub', 'Cast', 'Unsqueeze', 'Unsqueeze', 'Reshape'], [None, 0, 1, 0, 0, 0]), - (['Mul', 'Sub', 'Unsqueeze', 'Unsqueeze', 'Reshape'], [None, 0, 1, 0, 0]), - (['Mul', 'Sub', 'Unsqueeze', 'Unsqueeze'], [None, 0, 1, 0]), # useless cast and reshape are removed. + ( + ["Mul", "Sub", "Cast", "Unsqueeze", "Unsqueeze", "Reshape"], + [None, 0, 1, 0, 0, 0], + ), + ( + ["Mul", "Sub", "Unsqueeze", "Unsqueeze", "Reshape"], + [None, 0, 1, 0, 0], + ), + ( + ["Mul", "Sub", "Unsqueeze", "Unsqueeze"], + [None, 0, 1, 0], + ), # useless cast and reshape are removed. ], - output_name_to_node) # yapf: disable + output_name_to_node, + ) # yapf: disable if input_mask_nodes is None: logger.debug("fuse_attention: failed to match input attention mask path") return mask_nodes = self.model.match_parent_path( where_qk, - ['Cast', 'Slice', 'Slice', 'Unsqueeze', 'Sub', 'Squeeze', 'Slice', 'Shape'], - [ 0, 0, 0, 1, 0, 0, 0, 0], - output_name_to_node) # yapf: disable + [ + "Cast", + "Slice", + "Slice", + "Unsqueeze", + "Sub", + "Squeeze", + "Slice", + "Shape", + ], + [0, 0, 0, 1, 0, 0, 0, 0], + output_name_to_node, + ) # yapf: disable if mask_nodes is None: # TODO: match mask path for GPT2LMHeadModel_BeamSearchStep. logger.debug("fuse_attention: failed to match mask path") @@ -318,8 +389,9 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): # Validate that the mask data is either lower triangular (unidirectional) or all ones mask_data = numpy_helper.to_array(self.model.get_initializer(slice_mask.input[0])) - if not (len(mask_data.shape) == 4 and mask_data.shape[:2] == (1, 1) - and mask_data.shape[2] == mask_data.shape[3]): + if not ( + len(mask_data.shape) == 4 and mask_data.shape[:2] == (1, 1) and mask_data.shape[2] == mask_data.shape[3] + ): logger.debug("fuse_attention: skip since mask shape is not 1x1xWxW") return if np.allclose(mask_data, np.ones_like(mask_data)): @@ -328,7 +400,7 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): logger.debug("fuse_attention: skip since mask is neither lower triangular nor ones") return - q_nodes = self.model.match_parent_path(matmul_qk, ['Transpose', 'Reshape', 'Split'], [0, 0, 0]) + q_nodes = self.model.match_parent_path(matmul_qk, ["Transpose", "Reshape", "Split"], [0, 0, 0]) if q_nodes is None: logger.debug("fuse_attention: failed to match q path") return @@ -337,11 +409,14 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): logger.debug("fuse_attention: skip since split_fc != split_q") return - k_nodes = self.model.match_parent_path(matmul_qk, ['Concat', 'Transpose', 'Reshape', 'Split'], [1, 1, 0, 0]) + k_nodes = self.model.match_parent_path(matmul_qk, ["Concat", "Transpose", "Reshape", "Split"], [1, 1, 0, 0]) if k_nodes is None: # This pattern is from pytorch 1.7.1 and transformers 4.6.1 - k_nodes = self.model.match_parent_path(matmul_qk, ['Transpose', 'Concat', 'Transpose', 'Reshape', 'Split'], - [1, 0, 1, 0, 0]) + k_nodes = self.model.match_parent_path( + matmul_qk, + ["Transpose", "Concat", "Transpose", "Reshape", "Split"], + [1, 0, 1, 0, 0], + ) if k_nodes is None: logger.debug("fuse_attention: failed to match k path") return @@ -357,14 +432,15 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): logger.debug("fuse_attention: skip since concat_k != concat_k_to_match") return - attention_mask_input_name = '' + attention_mask_input_name = "" if input_mask_nodes is not None: input_name = input_mask_nodes[-1].input[0] attention_mask_input_name = self.cast_attention_mask(input_name) # Match past and present paths - past = self.match_past_pattern_1(concat_k, concat_v, output_name_to_node) or \ - self.match_past_pattern_2(concat_k, concat_v, output_name_to_node) + past = self.match_past_pattern_1(concat_k, concat_v, output_name_to_node) or self.match_past_pattern_2( + concat_k, concat_v, output_name_to_node + ) if past is None: logger.info("fuse_attention: failed to match past path") return @@ -380,8 +456,17 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): logger.info("expect present to be graph output") return - self.create_attention_node(fc_weight, fc_bias, gemm_qkv, past, present, layernorm_before_attention.output[0], - reshape_qkv.output[0], attention_mask_input_name, is_unidirectional) + self.create_attention_node( + fc_weight, + fc_bias, + gemm_qkv, + past, + present, + layernorm_before_attention.output[0], + reshape_qkv.output[0], + attention_mask_input_name, + is_unidirectional, + ) # we rely on prune_graph() to clean old subgraph nodes: # qk_nodes + q_nodes + k_nodes + v_nodes + mask_nodes + [reshape_qkv, transpose_qkv, matmul_qkv] diff --git a/onnxruntime/python/tools/transformers/fusion_gpt_attention_megatron.py b/onnxruntime/python/tools/transformers/fusion_gpt_attention_megatron.py index 5418ccf513c77..21def1eb4e253 100644 --- a/onnxruntime/python/tools/transformers/fusion_gpt_attention_megatron.py +++ b/onnxruntime/python/tools/transformers/fusion_gpt_attention_megatron.py @@ -1,14 +1,15 @@ -#------------------------------------------------------------------------- +# ------------------------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. -#-------------------------------------------------------------------------- -import numpy as np +# -------------------------------------------------------------------------- from logging import getLogger -from onnx import helper, numpy_helper, TensorProto -from onnx_model import OnnxModel + +import numpy as np from fusion_base import Fusion -from fusion_utils import FusionUtils from fusion_gpt_attention import FusionGptAttentionPastBase +from fusion_utils import FusionUtils +from onnx import TensorProto, helper, numpy_helper +from onnx_model import OnnxModel logger = getLogger(__name__) @@ -21,24 +22,43 @@ class FusionGptAttentionMegatron(FusionGptAttentionPastBase): """ Fuse GPT-2 Attention with past state subgraph from Megatron into one Attention node. """ + def __init__(self, model: OnnxModel, num_heads: int): super().__init__(model, num_heads) - def fuse_attention_node(self, matmul_before_split, add_before_split, past, present, input, reshape_qkv, mask): - attention_node_name = self.model.create_node_name('GptAttention') + def fuse_attention_node( + self, + matmul_before_split, + add_before_split, + past, + present, + input, + reshape_qkv, + mask, + ): + attention_node_name = self.model.create_node_name("GptAttention") int32_mask = self.cast_attention_mask(mask) output = reshape_qkv.output[0] i = 1 if (add_before_split.input[0] == matmul_before_split.output[0]) else 0 attention_node = helper.make_node( - 'Attention', - inputs=[input, matmul_before_split.input[1], add_before_split.input[i], int32_mask, past], + "Attention", + inputs=[ + input, + matmul_before_split.input[1], + add_before_split.input[i], + int32_mask, + past, + ], outputs=[output, present], - name=attention_node_name) + name=attention_node_name, + ) attention_node.domain = "com.microsoft" - attention_node.attribute.extend([ - helper.make_attribute("num_heads", self.num_heads), - helper.make_attribute("unidirectional", 0) # unidirectional shall not be ON for 4D attention mask - ]) + attention_node.attribute.extend( + [ + helper.make_attribute("num_heads", self.num_heads), + helper.make_attribute("unidirectional", 0), # unidirectional shall not be ON for 4D attention mask + ] + ) nodes_to_add = [attention_node] self.nodes_to_add.extend(nodes_to_add) @@ -53,9 +73,8 @@ def fuse_attention_node(self, matmul_before_split, add_before_split, past, prese def match_mask(self, sub_qk, mul_qk, matmul_qk, layernorm_before_attention): mask_nodes = self.model.match_parent_path( - sub_qk, - ['Mul', 'Sub', 'Slice', 'Slice'], - [1, 0, 1, 0]) # yapf: disable + sub_qk, ["Mul", "Sub", "Slice", "Slice"], [1, 0, 1, 0] + ) # yapf: disable if mask_nodes is None: logger.debug("fuse_attention: failed to match unidirectional mask path") return None @@ -97,27 +116,34 @@ def match_mask(self, sub_qk, mul_qk, matmul_qk, layernorm_before_attention): logger.debug("fuse_attention failed: slice_mask input 4 (steps) is not constant [1]") return None - last_slice_path = self.model.match_parent_path(last_slice_mask, ['Unsqueeze', 'Gather', 'Shape', 'MatMul'], - [2, 0, 0, 0]) + last_slice_path = self.model.match_parent_path( + last_slice_mask, ["Unsqueeze", "Gather", "Shape", "MatMul"], [2, 0, 0, 0] + ) if last_slice_path is None or last_slice_path[-1] != matmul_qk: logger.debug("fuse_attention: failed to match last slice path") return None - first_slice_path = self.model.match_parent_path(slice_mask, ['Unsqueeze', 'Gather', 'Shape', 'MatMul'], - [2, 0, 0, 0]) + first_slice_path = self.model.match_parent_path( + slice_mask, ["Unsqueeze", "Gather", "Shape", "MatMul"], [2, 0, 0, 0] + ) if first_slice_path is None or first_slice_path[-1] != matmul_qk: logger.debug("fuse_attention: failed to match first slice path") return None - first_slice_sub = self.model.match_parent_path(slice_mask, ['Unsqueeze', 'Sub', 'Gather', 'Shape', 'MatMul'], - [1, 0, 0, 0, 0]) + first_slice_sub = self.model.match_parent_path( + slice_mask, + ["Unsqueeze", "Sub", "Gather", "Shape", "MatMul"], + [1, 0, 0, 0, 0], + ) if first_slice_sub is None or first_slice_sub[-1] != matmul_qk: logger.debug("fuse_attention: failed to match last slice sub path") return None - first_slice_sub_1 = self.model.match_parent_path(slice_mask, - ['Unsqueeze', 'Sub', 'Gather', 'Shape', 'LayerNormalization'], - [1, 0, 1, 0, 0]) + first_slice_sub_1 = self.model.match_parent_path( + slice_mask, + ["Unsqueeze", "Sub", "Gather", "Shape", "LayerNormalization"], + [1, 0, 1, 0, 0], + ) if first_slice_sub_1 is None or first_slice_sub_1[-1] != layernorm_before_attention: logger.debug("fuse_attention: failed to match last slice sub path 1") return None @@ -130,30 +156,53 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): qkv_nodes = self.model.match_parent_path( normalize_node, - ['Add', 'Add', 'MatMul', 'Reshape', 'Transpose', 'MatMul'], - [ 0, 1, None, 0, 0, 0], + ["Add", "Add", "MatMul", "Reshape", "Transpose", "MatMul"], + [0, 1, None, 0, 0, 0], output_name_to_node=output_name_to_node, - ) # yapf: disable + ) # yapf: disable if qkv_nodes is None: return - (add_skip, add_after_attention, matmul_after_attention, reshape_qkv, transpose_qkv, matmul_qkv) = qkv_nodes + ( + add_skip, + add_after_attention, + matmul_after_attention, + reshape_qkv, + transpose_qkv, + matmul_qkv, + ) = qkv_nodes skip_input = add_skip.input[0] v_nodes = self.model.match_parent_path( matmul_qkv, - ['Concat', 'Transpose', 'Reshape', 'Split', 'Add', 'MatMul', 'LayerNormalization'], - [1, 1, 0, 0, 0, None, 0]) # yapf: disable + [ + "Concat", + "Transpose", + "Reshape", + "Split", + "Add", + "MatMul", + "LayerNormalization", + ], + [1, 1, 0, 0, 0, None, 0], + ) # yapf: disable if v_nodes is None: logger.debug("fuse_attention: failed to match v path") return - (concat_v, transpose_v, reshape_v, split_v, add_before_split, matmul_before_split, - layernorm_before_attention) = v_nodes + ( + concat_v, + transpose_v, + reshape_v, + split_v, + add_before_split, + matmul_before_split, + layernorm_before_attention, + ) = v_nodes if skip_input != layernorm_before_attention.input[0]: logger.debug("fuse_attention: skip_input != layernorm_before_attention.input[0]") return - qk_nodes = self.model.match_parent_path(matmul_qkv, ['Softmax', 'Sub', 'Mul', 'MatMul'], [0, 0, 0, 0]) + qk_nodes = self.model.match_parent_path(matmul_qkv, ["Softmax", "Sub", "Mul", "MatMul"], [0, 0, 0, 0]) if qk_nodes is None: logger.debug("fuse_attention: failed to match qk path") return None @@ -164,7 +213,7 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): attention_mask = self.match_mask(sub_qk, mul_qk, matmul_qk, layernorm_before_attention) - q_nodes = self.model.match_parent_path(matmul_qk, ['Div', 'Transpose', 'Reshape', 'Split'], [0, 0, 0, 0]) + q_nodes = self.model.match_parent_path(matmul_qk, ["Div", "Transpose", "Reshape", "Split"], [0, 0, 0, 0]) if q_nodes is None: logger.debug("fuse_attention: failed to match q path") return @@ -173,9 +222,11 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): logger.debug("fuse_attention: skip since split_v != split_q") return - k_nodes = self.model.match_parent_path(matmul_qk, - ['Div', 'Transpose', 'Concat', 'Transpose', 'Reshape', 'Split'], - [1, 0, 0, 1, 0, 0]) + k_nodes = self.model.match_parent_path( + matmul_qk, + ["Div", "Transpose", "Concat", "Transpose", "Reshape", "Split"], + [1, 0, 0, 1, 0, 0], + ) if k_nodes is None: logger.debug("fuse_attention: failed to match k path") return @@ -185,8 +236,14 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): return i, value = self.model.get_constant_input(reshape_k) - if not (isinstance(value, np.ndarray) and list(value.shape) == [4] and value[0] == 0 and value[1] == 0 - and value[2] > 0 and value[3] > 0): + if not ( + isinstance(value, np.ndarray) + and list(value.shape) == [4] + and value[0] == 0 + and value[1] == 0 + and value[2] > 0 + and value[3] > 0 + ): logger.debug("fuse_attention: reshape constant input is not [0, 0, N, H]") return @@ -224,5 +281,12 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): logger.info("fuse_attention: expect present to be graph output") return - self.fuse_attention_node(matmul_before_split, add_before_split, past, present, - layernorm_before_attention.output[0], reshape_qkv, attention_mask) + self.fuse_attention_node( + matmul_before_split, + add_before_split, + past, + present, + layernorm_before_attention.output[0], + reshape_qkv, + attention_mask, + ) diff --git a/onnxruntime/python/tools/transformers/fusion_gpt_attention_no_past.py b/onnxruntime/python/tools/transformers/fusion_gpt_attention_no_past.py index 2c54cfbd811d4..c4e622da2f767 100644 --- a/onnxruntime/python/tools/transformers/fusion_gpt_attention_no_past.py +++ b/onnxruntime/python/tools/transformers/fusion_gpt_attention_no_past.py @@ -1,13 +1,14 @@ -#------------------------------------------------------------------------- +# ------------------------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. -#-------------------------------------------------------------------------- -import numpy as np +# -------------------------------------------------------------------------- from logging import getLogger -from onnx import helper, numpy_helper, TensorProto -from onnx_model import OnnxModel + +import numpy as np from fusion_base import Fusion from fusion_utils import FusionUtils +from onnx import TensorProto, helper, numpy_helper +from onnx_model import OnnxModel logger = getLogger(__name__) @@ -17,31 +18,41 @@ class FusionGptAttentionNoPast(Fusion): Fuse GPT-2 Attention without past state into one Attention node. This does not support attention_mask graph input right now. """ + def __init__(self, model: OnnxModel, num_heads: int): super().__init__(model, "Attention", "LayerNormalization", "without past") # TODO: detect num_heads from graph like FusionAttention self.num_heads = num_heads def create_attention_node(self, gemm, gemm_qkv, input, output): - attention_node_name = self.model.create_node_name('Attention') - attention_node = helper.make_node('Attention', - inputs=[input, gemm.input[1], gemm.input[2]], - outputs=[attention_node_name + "_output"], - name=attention_node_name) + attention_node_name = self.model.create_node_name("Attention") + attention_node = helper.make_node( + "Attention", + inputs=[input, gemm.input[1], gemm.input[2]], + outputs=[attention_node_name + "_output"], + name=attention_node_name, + ) attention_node.domain = "com.microsoft" attention_node.attribute.extend( - [helper.make_attribute("num_heads", self.num_heads), - helper.make_attribute("unidirectional", 1)]) - - matmul_node = helper.make_node('MatMul', - inputs=[attention_node_name + "_output", gemm_qkv.input[1]], - outputs=[attention_node_name + "_matmul_output"], - name=attention_node_name + "_matmul") - - add_node = helper.make_node('Add', - inputs=[attention_node_name + "_matmul_output", gemm_qkv.input[2]], - outputs=[output], - name=attention_node_name + "_add") + [ + helper.make_attribute("num_heads", self.num_heads), + helper.make_attribute("unidirectional", 1), + ] + ) + + matmul_node = helper.make_node( + "MatMul", + inputs=[attention_node_name + "_output", gemm_qkv.input[1]], + outputs=[attention_node_name + "_matmul_output"], + name=attention_node_name + "_matmul", + ) + + add_node = helper.make_node( + "Add", + inputs=[attention_node_name + "_matmul_output", gemm_qkv.input[2]], + outputs=[output], + name=attention_node_name + "_add", + ) self.nodes_to_add.extend([attention_node, matmul_node, add_node]) self.node_name_to_graph_name[attention_node.name] = self.this_graph_name @@ -52,29 +63,45 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): return_indice = [] qkv_nodes = self.model.match_parent_path( normalize_node, - ['Add', 'Reshape', 'Gemm', 'Reshape', 'Reshape', 'Transpose', 'MatMul'], + ["Add", "Reshape", "Gemm", "Reshape", "Reshape", "Transpose", "MatMul"], [0, None, 0, 0, 0, 0, 0], output_name_to_node=output_name_to_node, - return_indice=return_indice - ) # yapf: disable + return_indice=return_indice, + ) # yapf: disable if qkv_nodes is None: return - (add_qkv, reshape_qkv, gemm_qkv, reshape_1, reshape_2, transpose_qkv, matmul_qkv) = qkv_nodes + ( + add_qkv, + reshape_qkv, + gemm_qkv, + reshape_1, + reshape_2, + transpose_qkv, + matmul_qkv, + ) = qkv_nodes another_input = add_qkv.input[1 - return_indice[0]] v_nodes = self.model.match_parent_path( matmul_qkv, - ['Transpose', 'Reshape', 'Split', 'Reshape', 'Gemm', 'Reshape'], - [1, 0, 0, 0, 0, 0]) # yapf: disable + ["Transpose", "Reshape", "Split", "Reshape", "Gemm", "Reshape"], + [1, 0, 0, 0, 0, 0], + ) # yapf: disable if v_nodes is None: logger.debug("fuse_attention: failed to match v path") return - (transpose_v, reshape_v, split_v, reshape_after_gemm, gemm, reshape_before_gemm) = v_nodes + ( + transpose_v, + reshape_v, + split_v, + reshape_after_gemm, + gemm, + reshape_before_gemm, + ) = v_nodes layernorm_before_attention = self.model.get_parent(reshape_before_gemm, 0, output_name_to_node) - if layernorm_before_attention is None or layernorm_before_attention.op_type != 'LayerNormalization': - if layernorm_before_attention.op_type != 'Add': + if layernorm_before_attention is None or layernorm_before_attention.op_type != "LayerNormalization": + if layernorm_before_attention.op_type != "Add": logger.debug(f"failed to get layernorm before gemm. Got {layernorm_before_attention.op_type}") return @@ -84,13 +111,25 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): logger.debug("Add and LayerNormalization shall have one same input") return - qk_nodes = self.model.match_parent_path(matmul_qkv, ['Softmax', 'Sub', 'Mul', 'Div', 'MatMul'], [0, 0, 0, 0, 0]) + qk_nodes = self.model.match_parent_path(matmul_qkv, ["Softmax", "Sub", "Mul", "Div", "MatMul"], [0, 0, 0, 0, 0]) if qk_nodes is not None: (softmax_qk, sub_qk, mul_qk, div_qk, matmul_qk) = qk_nodes mask_nodes = self.model.match_parent_path( sub_qk, - ['Mul', 'Sub', 'Slice', 'Slice', 'Unsqueeze', 'Sub', 'Squeeze', 'Slice', 'Shape', 'Div'], - [1, 0, 1, 0, 1, 0, 0, 0, 0, 0]) # yapf: disable + [ + "Mul", + "Sub", + "Slice", + "Slice", + "Unsqueeze", + "Sub", + "Squeeze", + "Slice", + "Shape", + "Div", + ], + [1, 0, 1, 0, 1, 0, 0, 0, 0, 0], + ) # yapf: disable if mask_nodes is None: logger.debug("fuse_attention: failed to match mask path") return @@ -101,13 +140,24 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): return else: # New pattern for gpt2 from PyTorch 1.5.0 and Transformers 2.9.0. - qk_nodes = self.model.match_parent_path(matmul_qkv, ['Softmax', 'Where', 'Div', 'MatMul'], [0, 0, 1, 0]) + qk_nodes = self.model.match_parent_path(matmul_qkv, ["Softmax", "Where", "Div", "MatMul"], [0, 0, 1, 0]) if qk_nodes is not None: (softmax_qk, where_qk, div_qk, matmul_qk) = qk_nodes mask_nodes = self.model.match_parent_path( where_qk, - ['Cast', 'Slice', 'Slice', 'Unsqueeze', 'Sub', 'Squeeze', 'Slice', 'Shape', 'Div'], - [ 0, 0, 0, 1, 0, 0, 0, 0, 0]) # yapf: disable + [ + "Cast", + "Slice", + "Slice", + "Unsqueeze", + "Sub", + "Squeeze", + "Slice", + "Shape", + "Div", + ], + [0, 0, 0, 1, 0, 0, 0, 0, 0], + ) # yapf: disable if mask_nodes is None: logger.debug("fuse_attention: failed to match mask path") return @@ -118,16 +168,20 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): return else: # match openai-gpt - qk_nodes = self.model.match_parent_path(matmul_qkv, ['Softmax', 'Add', 'Mul', 'Div', 'MatMul'], - [0, 0, 0, 0, 0]) + qk_nodes = self.model.match_parent_path( + matmul_qkv, + ["Softmax", "Add", "Mul", "Div", "MatMul"], + [0, 0, 0, 0, 0], + ) if qk_nodes is None: logger.debug("fuse_attention: failed to match qk path") return (softmax_qk, add_qk, mul_qk, div_qk, matmul_qk) = qk_nodes mask_nodes = self.model.match_parent_path( mul_qk, - ['Slice', 'Slice', 'Unsqueeze', 'Squeeze', 'Slice', 'Shape', 'Div'], - [ 1, 0, 2, 0, 0, 0, 0]) # yapf: disable + ["Slice", "Slice", "Unsqueeze", "Squeeze", "Slice", "Shape", "Div"], + [1, 0, 2, 0, 0, 0, 0], + ) # yapf: disable if mask_nodes is None: logger.debug("fuse_attention: failed to match mask path") return @@ -137,7 +191,7 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): logger.debug("fuse_attention: skip since div_qk != div_mask") return - q_nodes = self.model.match_parent_path(matmul_qk, ['Transpose', 'Reshape', 'Split'], [0, 0, 0]) + q_nodes = self.model.match_parent_path(matmul_qk, ["Transpose", "Reshape", "Split"], [0, 0, 0]) if q_nodes is None: logger.debug("fuse_attention: failed to match q path") return @@ -146,7 +200,7 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): logger.debug("fuse_attention: skip since split_v != split_q") return - k_nodes = self.model.match_parent_path(matmul_qk, ['Transpose', 'Reshape', 'Split'], [1, 0, 0]) + k_nodes = self.model.match_parent_path(matmul_qk, ["Transpose", "Reshape", "Split"], [1, 0, 0]) if k_nodes is None: logger.debug("fuse_attention: failed to match k path") return diff --git a/onnxruntime/python/tools/transformers/fusion_layernorm.py b/onnxruntime/python/tools/transformers/fusion_layernorm.py index e1bc657068ac5..893d3283691be 100644 --- a/onnxruntime/python/tools/transformers/fusion_layernorm.py +++ b/onnxruntime/python/tools/transformers/fusion_layernorm.py @@ -1,12 +1,13 @@ -#------------------------------------------------------------------------- +# ------------------------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. -#-------------------------------------------------------------------------- -from typing import Dict +# -------------------------------------------------------------------------- from logging import getLogger +from typing import Dict + +from fusion_base import Fusion from onnx import helper from onnx_model import OnnxModel -from fusion_base import Fusion logger = getLogger(__name__) @@ -43,24 +44,32 @@ def fuse(self, node, input_name_to_nodes: Dict, output_name_to_node: Dict): root_input = node.input[0] - if children[0].op_type != 'Sub' or children[0].input[0] != root_input: + if children[0].op_type != "Sub" or children[0].input[0] != root_input: return if len(children) == 2: - if children[1].op_type != 'Sub' or children[1].input[0] != root_input: + if children[1].op_type != "Sub" or children[1].input[0] != root_input: return div_node = None for child in children: - div_node = self.model.find_first_child_by_type(child, 'Div', input_name_to_nodes, recursive=False) + div_node = self.model.find_first_child_by_type(child, "Div", input_name_to_nodes, recursive=False) if div_node is not None: break if div_node is None: return path_id, parent_nodes, _ = self.model.match_parent_paths( - div_node, [(['Sqrt', 'Add', 'ReduceMean', 'Pow', 'Sub'], [1, 0, 0, 0, 0]), - (['Sqrt', 'Add', 'ReduceMean', 'Pow', 'Cast', 'Sub'], [1, 0, 0, 0, 0, 0])], output_name_to_node) + div_node, + [ + (["Sqrt", "Add", "ReduceMean", "Pow", "Sub"], [1, 0, 0, 0, 0]), + ( + ["Sqrt", "Add", "ReduceMean", "Pow", "Cast", "Sub"], + [1, 0, 0, 0, 0, 0], + ), + ], + output_name_to_node, + ) if path_id < 0: return @@ -70,7 +79,7 @@ def fuse(self, node, input_name_to_nodes: Dict, output_name_to_node: Dict): second_add_node = parent_nodes[1] i, add_weight = self.model.get_constant_input(second_add_node) - if add_weight is None or add_weight <= 0 or add_weight > 1.0E-4: + if add_weight is None or add_weight <= 0 or add_weight > 1.0e-4: logger.warning(f"epsilon value is not expeced: {add_weight}") return @@ -79,11 +88,11 @@ def fuse(self, node, input_name_to_nodes: Dict, output_name_to_node: Dict): return mul_node = input_name_to_nodes[div_node.output[0]][0] - if mul_node.op_type != 'Mul': + if mul_node.op_type != "Mul": return last_add_node = input_name_to_nodes[mul_node.output[0]][0] - if last_add_node.op_type != 'Add': + if last_add_node.op_type != "Add": return subgraph_nodes = [node] @@ -91,8 +100,12 @@ def fuse(self, node, input_name_to_nodes: Dict, output_name_to_node: Dict): subgraph_nodes.extend(parent_nodes[:-1]) subgraph_nodes.extend([last_add_node, mul_node, div_node]) - if not self.model.is_safe_to_fuse_nodes(subgraph_nodes, last_add_node.output, input_name_to_nodes, - output_name_to_node): + if not self.model.is_safe_to_fuse_nodes( + subgraph_nodes, + last_add_node.output, + input_name_to_nodes, + output_name_to_node, + ): logger.debug(f"It is not safe to fuse LayerNormalization node. Skip") return @@ -106,11 +119,12 @@ def fuse(self, node, input_name_to_nodes: Dict, output_name_to_node: Dict): self.nodes_to_remove.extend(subgraph_nodes) - normalize_node = helper.make_node('LayerNormalization', - inputs=[node.input[0], weight_input, bias_input], - outputs=[last_add_node.output[0]], - name=self.model.create_node_name("LayerNormalization", - name_prefix="LayerNorm")) + normalize_node = helper.make_node( + "LayerNormalization", + inputs=[node.input[0], weight_input, bias_input], + outputs=[last_add_node.output[0]], + name=self.model.create_node_name("LayerNormalization", name_prefix="LayerNorm"), + ) normalize_node.attribute.extend([helper.make_attribute("epsilon", float(add_weight))]) self.nodes_to_add.append(normalize_node) self.node_name_to_graph_name[normalize_node.name] = self.this_graph_name @@ -122,28 +136,58 @@ def __init__(self, model: OnnxModel): def fuse(self, node, input_name_to_nodes: Dict, output_name_to_node: Dict): """ - Layer Norm from Tensorflow model(using keras2onnx or tf2onnx): - +------------------------------------+ - | | - | | - (Cast_1) | - | | - | v (B) (B) (A) - Add --> (Cast_1) --> ReduceMean --> Sub --> Mul --> ReduceMean --> (Cast_3) --> Add --> Sqrt --> Reciprocol --> Mul --> Mul --> Sub --> Add - | | | ^ ^ - | | | | | - | +--------------------------------------------------(Cast_2)-------------------------------|-------+ | - | v | - +---------------------------------------------------------------------------------------------------------------> Mul--------------------+ + Layer Norm from Tensorflow model(using keras2onnx or tf2onnx): + +------------------------------------+ + | | + | | + (Cast_1) | + | | + | v (B) (B) (A) + Add --> (Cast_1) --> ReduceMean --> Sub --> Mul --> ReduceMean --> (Cast_3) --> Add --> Sqrt --> Reciprocol --> Mul --> Mul --> Sub --> Add + | | | ^ ^ + | | | | | + | +--------------------------------------------------(Cast_2)-------------------------------|-------+ | + | v | + +---------------------------------------------------------------------------------------------------------------> Mul--------------------+ """ return_indice = [] _, parent_nodes, return_indice = self.model.match_parent_paths( node, - [(['Sub', 'Mul', 'Mul', 'Reciprocal', 'Sqrt', 'Add', 'ReduceMean', 'Mul', 'Sub', 'ReduceMean'], - [ 1, 1, None, 0, 0, 0, None, 0, 0, None]), - (['Sub', 'Mul', 'Mul', 'Reciprocal', 'Sqrt', 'Add', 'Cast', 'ReduceMean', 'Mul', 'Sub', 'ReduceMean'], - [ 1, 1, None, 0, 0, 0, 0, None, 0, 0, None])], - output_name_to_node) # yapf: disable + [ + ( + [ + "Sub", + "Mul", + "Mul", + "Reciprocal", + "Sqrt", + "Add", + "ReduceMean", + "Mul", + "Sub", + "ReduceMean", + ], + [1, 1, None, 0, 0, 0, None, 0, 0, None], + ), + ( + [ + "Sub", + "Mul", + "Mul", + "Reciprocal", + "Sqrt", + "Add", + "Cast", + "ReduceMean", + "Mul", + "Sub", + "ReduceMean", + ], + [1, 1, None, 0, 0, 0, 0, None, 0, 0, None], + ), + ], + output_name_to_node, + ) # yapf: disable if parent_nodes is None: return @@ -153,38 +197,50 @@ def fuse(self, node, input_name_to_nodes: Dict, output_name_to_node: Dict): logger.debug("return indice is exepected in [0, 1], but got {return_indice}") return - sub_node_0, mul_node_0, mul_node_1, reciprocol_node, sqrt_node, add_node_0 = parent_nodes[:6] + ( + sub_node_0, + mul_node_0, + mul_node_1, + reciprocol_node, + sqrt_node, + add_node_0, + ) = parent_nodes[:6] reduce_mean_node_0, mul_node_2, sub_node_1, reduce_mean_node_1 = parent_nodes[-4:] cast_node_3 = None if len(parent_nodes) == 11: cast_node_3 = parent_nodes[6] - assert (cast_node_3.op_type == 'Cast') + assert cast_node_3.op_type == "Cast" - mul_node_3 = self.model.match_parent(node, 'Mul', 0, output_name_to_node) + mul_node_3 = self.model.match_parent(node, "Mul", 0, output_name_to_node) if mul_node_3 is None: logger.debug("mul_node_3 not found") return node_before_reduce = self.model.get_parent(reduce_mean_node_1, 0, output_name_to_node) - root_node = node_before_reduce if cast_node_3 is None else self.model.get_parent( - node_before_reduce, 0, output_name_to_node) + root_node = ( + node_before_reduce + if cast_node_3 is None + else self.model.get_parent(node_before_reduce, 0, output_name_to_node) + ) if root_node is None: logger.debug("root node is none") return i, epsilon = self.model.get_constant_input(add_node_0) - if epsilon is None or epsilon <= 0 or (epsilon > 1.0E-5 and cast_node_3 is None): + if epsilon is None or epsilon <= 0 or (epsilon > 1.0e-5 and cast_node_3 is None): logger.debug("epsilon is not matched") return - if cast_node_3 is None and (reduce_mean_node_1.input[0] not in mul_node_3.input - or reduce_mean_node_1.input[0] not in sub_node_1.input): + if cast_node_3 is None and ( + reduce_mean_node_1.input[0] not in mul_node_3.input or reduce_mean_node_1.input[0] not in sub_node_1.input + ): logger.debug("reduce_mean_node_1 and mul_node_3 shall link from root node") return - if cast_node_3 is not None and (node_before_reduce.input[0] not in mul_node_3.input - or reduce_mean_node_1.input[0] not in sub_node_1.input): + if cast_node_3 is not None and ( + node_before_reduce.input[0] not in mul_node_3.input or reduce_mean_node_1.input[0] not in sub_node_1.input + ): logger.debug("reduce_mean_node_1 and mul_node_3 shall link from root node") return @@ -193,19 +249,33 @@ def fuse(self, node, input_name_to_nodes: Dict, output_name_to_node: Dict): return subgraph_nodes = [ - node, sub_node_0, mul_node_0, mul_node_1, reciprocol_node, sqrt_node, add_node_0, reduce_mean_node_0, - mul_node_2, sub_node_1, reduce_mean_node_1, mul_node_3 + node, + sub_node_0, + mul_node_0, + mul_node_1, + reciprocol_node, + sqrt_node, + add_node_0, + reduce_mean_node_0, + mul_node_2, + sub_node_1, + reduce_mean_node_1, + mul_node_3, ] if cast_node_3 is not None: - cast_node_2 = self.model.match_parent(mul_node_0, 'Cast', 0, output_name_to_node) + cast_node_2 = self.model.match_parent(mul_node_0, "Cast", 0, output_name_to_node) if cast_node_2 is None: logger.debug("cast_node_2 not found") return subgraph_nodes.extend([node_before_reduce, cast_node_2, cast_node_3]) - if not self.model.is_safe_to_fuse_nodes(subgraph_nodes, node.output, self.model.input_name_to_nodes(), - self.model.output_name_to_node()): + if not self.model.is_safe_to_fuse_nodes( + subgraph_nodes, + node.output, + self.model.input_name_to_nodes(), + self.model.output_name_to_node(), + ): logger.debug("not safe to fuse layer normalization") return @@ -214,11 +284,13 @@ def fuse(self, node, input_name_to_nodes: Dict, output_name_to_node: Dict): weight_input = mul_node_1.input[1] bias_input = sub_node_0.input[0] - #TODO: add epsilon attribute - fused_node = helper.make_node('LayerNormalization', - inputs=[mul_node_3.input[0], weight_input, bias_input], - outputs=[node.output[0]], - name=self.model.create_node_name("LayerNormalization", name_prefix="LayerNorm")) + # TODO: add epsilon attribute + fused_node = helper.make_node( + "LayerNormalization", + inputs=[mul_node_3.input[0], weight_input, bias_input], + outputs=[node.output[0]], + name=self.model.create_node_name("LayerNormalization", name_prefix="LayerNorm"), + ) fused_node.attribute.extend([helper.make_attribute("epsilon", float(epsilon))]) self.nodes_to_add.append(fused_node) self.node_name_to_graph_name[fused_node.name] = self.this_graph_name diff --git a/onnxruntime/python/tools/transformers/fusion_options.py b/onnxruntime/python/tools/transformers/fusion_options.py index c29485d18c4c3..952588f0736cf 100644 --- a/onnxruntime/python/tools/transformers/fusion_options.py +++ b/onnxruntime/python/tools/transformers/fusion_options.py @@ -1,7 +1,7 @@ -#------------------------------------------------------------------------- +# ------------------------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. -#-------------------------------------------------------------------------- +# -------------------------------------------------------------------------- from argparse import ArgumentParser @@ -13,8 +13,8 @@ class AttentionMaskFormat: class FusionOptions: - """ Options of fusion in graph optimization - """ + """Options of fusion in graph optimization""" + def __init__(self, model_type): self.enable_gelu = True self.enable_layer_norm = True @@ -26,7 +26,7 @@ def __init__(self, model_type): self.enable_gelu_approximation = False self.attention_mask_format = AttentionMaskFormat.AttentionMask - if model_type == 'gpt2': + if model_type == "gpt2": self.enable_skip_layer_norm = False def use_raw_attention_mask(self, use_raw_mask=True): @@ -65,56 +65,82 @@ def parse(args): @staticmethod def add_arguments(parser: ArgumentParser): - parser.add_argument('--disable_attention', required=False, action='store_true', help="disable Attention fusion") + parser.add_argument( + "--disable_attention", + required=False, + action="store_true", + help="disable Attention fusion", + ) parser.set_defaults(disable_attention=False) - parser.add_argument('--disable_skip_layer_norm', - required=False, - action='store_true', - help="disable SkipLayerNormalization fusion") + parser.add_argument( + "--disable_skip_layer_norm", + required=False, + action="store_true", + help="disable SkipLayerNormalization fusion", + ) parser.set_defaults(disable_skip_layer_norm=False) - parser.add_argument('--disable_embed_layer_norm', - required=False, - action='store_true', - help="disable EmbedLayerNormalization fusion") + parser.add_argument( + "--disable_embed_layer_norm", + required=False, + action="store_true", + help="disable EmbedLayerNormalization fusion", + ) parser.set_defaults(disable_embed_layer_norm=False) - parser.add_argument('--disable_bias_skip_layer_norm', - required=False, - action='store_true', - help="disable Add Bias and SkipLayerNormalization fusion") + parser.add_argument( + "--disable_bias_skip_layer_norm", + required=False, + action="store_true", + help="disable Add Bias and SkipLayerNormalization fusion", + ) parser.set_defaults(disable_bias_skip_layer_norm=False) - parser.add_argument('--disable_bias_gelu', - required=False, - action='store_true', - help="disable Add Bias and Gelu/FastGelu fusion") + parser.add_argument( + "--disable_bias_gelu", + required=False, + action="store_true", + help="disable Add Bias and Gelu/FastGelu fusion", + ) parser.set_defaults(disable_bias_gelu=False) - parser.add_argument('--disable_layer_norm', - required=False, - action='store_true', - help="disable LayerNormalization fusion") + parser.add_argument( + "--disable_layer_norm", + required=False, + action="store_true", + help="disable LayerNormalization fusion", + ) parser.set_defaults(disable_layer_norm=False) - parser.add_argument('--disable_gelu', required=False, action='store_true', help="disable Gelu fusion") + parser.add_argument( + "--disable_gelu", + required=False, + action="store_true", + help="disable Gelu fusion", + ) parser.set_defaults(disable_gelu=False) - parser.add_argument('--enable_gelu_approximation', - required=False, - action='store_true', - help="enable Gelu/BiasGelu to FastGelu conversion") + parser.add_argument( + "--enable_gelu_approximation", + required=False, + action="store_true", + help="enable Gelu/BiasGelu to FastGelu conversion", + ) parser.set_defaults(enable_gelu_approximation=False) - parser.add_argument('--use_mask_index', - required=False, - action='store_true', - help="use mask index instead of raw attention mask in attention operator") + parser.add_argument( + "--use_mask_index", + required=False, + action="store_true", + help="use mask index instead of raw attention mask in attention operator", + ) parser.set_defaults(use_mask_index=False) - parser.add_argument('--no_attention_mask', - required=False, - action='store_true', - help="no attention mask. Only works for model_type=bert") + parser.add_argument( + "--no_attention_mask", + required=False, + action="store_true", + help="no attention mask. Only works for model_type=bert", + ) parser.set_defaults(no_attention_mask=False) diff --git a/onnxruntime/python/tools/transformers/fusion_reshape.py b/onnxruntime/python/tools/transformers/fusion_reshape.py index 3d336d5ce2996..0439c6958b651 100644 --- a/onnxruntime/python/tools/transformers/fusion_reshape.py +++ b/onnxruntime/python/tools/transformers/fusion_reshape.py @@ -1,12 +1,13 @@ -#------------------------------------------------------------------------- +# ------------------------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. -#-------------------------------------------------------------------------- +# -------------------------------------------------------------------------- -from fusion_base import Fusion from logging import getLogger + import numpy as np -from onnx import helper, numpy_helper, TensorProto +from fusion_base import Fusion +from onnx import TensorProto, helper, numpy_helper from onnx_model import OnnxModel logger = getLogger(__name__) @@ -18,17 +19,21 @@ def __init__(self, model: OnnxModel): def replace_reshape_node(self, shape, reshape_node, concat_node): shape_value = np.asarray(shape, dtype=np.int64) - constant_shape_name = self.model.create_node_name('Constant', 'constant_shape') - new_node = helper.make_node('Constant', - inputs=[], - outputs=[constant_shape_name], - value=helper.make_tensor(name='const_tensor', - data_type=TensorProto.INT64, - dims=shape_value.shape, - vals=bytes(shape_value), - raw=True)) + constant_shape_name = self.model.create_node_name("Constant", "constant_shape") + new_node = helper.make_node( + "Constant", + inputs=[], + outputs=[constant_shape_name], + value=helper.make_tensor( + name="const_tensor", + data_type=TensorProto.INT64, + dims=shape_value.shape, + vals=bytes(shape_value), + raw=True, + ), + ) reshape_node.input[1] = constant_shape_name - reshape_node.name = self.model.create_node_name('Reshape', 'Reshape_Fuse') + reshape_node.name = self.model.create_node_name("Reshape", "Reshape_Fuse") self.nodes_to_remove.extend([concat_node]) self.nodes_to_add.append(new_node) self.node_name_to_graph_name[new_node.name] = self.this_graph_name @@ -38,18 +43,26 @@ def fuse(self, reshape_node, input_name_to_nodes, output_name_to_node): return concat_node = output_name_to_node[reshape_node.input[1]] - if concat_node.op_type != 'Concat' or len(concat_node.input) < 3 or len(concat_node.input) > 4: + if concat_node.op_type != "Concat" or len(concat_node.input) < 3 or len(concat_node.input) > 4: return - path0 = self.model.match_parent_path(concat_node, ['Unsqueeze', 'Gather', 'Shape'], [0, 0, 0], - output_name_to_node) + path0 = self.model.match_parent_path( + concat_node, + ["Unsqueeze", "Gather", "Shape"], + [0, 0, 0], + output_name_to_node, + ) if path0 is None: return (unsqueeze_0, gather_0, shape_0) = path0 - path1 = self.model.match_parent_path(concat_node, ['Unsqueeze', 'Gather', 'Shape'], [1, 0, 0], - output_name_to_node) + path1 = self.model.match_parent_path( + concat_node, + ["Unsqueeze", "Gather", "Shape"], + [1, 0, 0], + output_name_to_node, + ) if path1 is None: return (unsqueeze_1, gather_1, shape_1) = path1 @@ -70,27 +83,41 @@ def fuse(self, reshape_node, input_name_to_nodes, output_name_to_node): path3 = [] shape_nodes = [shape_0, shape_1] if len(concat_node.input) == 3 and self.model.get_initializer(concat_node.input[2]) is None: - path2 = self.model.match_parent_path(concat_node, ['Unsqueeze', 'Mul', 'Gather', 'Shape'], [2, 0, 0, 0], - output_name_to_node) + path2 = self.model.match_parent_path( + concat_node, + ["Unsqueeze", "Mul", "Gather", "Shape"], + [2, 0, 0, 0], + output_name_to_node, + ) if path2 is None: path2 = self.model.match_parent_path( - concat_node, ['Unsqueeze', 'Mul', 'Squeeze', 'Slice', 'Shape'], [2, 0, 0, 0, 0], - output_name_to_node) # GPT2 exported by PyTorch 1.4 with opset_version=11 + concat_node, + ["Unsqueeze", "Mul", "Squeeze", "Slice", "Shape"], + [2, 0, 0, 0, 0], + output_name_to_node, + ) # GPT2 exported by PyTorch 1.4 with opset_version=11 if path2 is None: return - path3 = self.model.match_parent_path(concat_node, ['Unsqueeze', 'Mul', 'Gather', 'Shape'], [2, 0, 1, 0], - output_name_to_node) + path3 = self.model.match_parent_path( + concat_node, + ["Unsqueeze", "Mul", "Gather", "Shape"], + [2, 0, 1, 0], + output_name_to_node, + ) if path3 is None: path3 = self.model.match_parent_path( - concat_node, ['Unsqueeze', 'Mul', 'Squeeze', 'Slice', 'Shape'], [2, 0, 1, 0, 0], - output_name_to_node) # GPT2 exported by PyTorch 1.4 with opset_version=11 + concat_node, + ["Unsqueeze", "Mul", "Squeeze", "Slice", "Shape"], + [2, 0, 1, 0, 0], + output_name_to_node, + ) # GPT2 exported by PyTorch 1.4 with opset_version=11 if path3 is None: return shape_nodes.extend([path2[-1], path3[-1]]) shape.append(-1) - elif (len(concat_node.input) > 2): + elif len(concat_node.input) > 2: concat_2 = self.model.get_initializer(concat_node.input[2]) if concat_2 is None: return @@ -104,17 +131,24 @@ def fuse(self, reshape_node, input_name_to_nodes, output_name_to_node): if -1 in shape: return - path2 = self.model.match_parent_path(concat_node, ['Unsqueeze', 'Div', 'Gather', 'Shape'], [3, 0, 0, 0], - output_name_to_node) + path2 = self.model.match_parent_path( + concat_node, + ["Unsqueeze", "Div", "Gather", "Shape"], + [3, 0, 0, 0], + output_name_to_node, + ) if path2 is None: path2 = self.model.match_parent_path( - concat_node, ['Unsqueeze', 'Div', 'Squeeze', 'Slice', 'Shape'], [3, 0, 0, 0, 0], - output_name_to_node) # GPT2 exported by PyTorch 1.4 with opset_version=11 + concat_node, + ["Unsqueeze", "Div", "Squeeze", "Slice", "Shape"], + [3, 0, 0, 0, 0], + output_name_to_node, + ) # GPT2 exported by PyTorch 1.4 with opset_version=11 if path2 is None: return shape_nodes.extend([path2[-1]]) shape.append(-1) - elif (len(concat_node.input) > 3): + elif len(concat_node.input) > 3: concat_3 = self.model.get_initializer(concat_node.input[3]) if concat_3 is None: return diff --git a/onnxruntime/python/tools/transformers/fusion_shape.py b/onnxruntime/python/tools/transformers/fusion_shape.py index ea2b278bf94fc..8abcfaa90678f 100644 --- a/onnxruntime/python/tools/transformers/fusion_shape.py +++ b/onnxruntime/python/tools/transformers/fusion_shape.py @@ -1,14 +1,15 @@ -#------------------------------------------------------------------------- +# ------------------------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. -#-------------------------------------------------------------------------- +# -------------------------------------------------------------------------- -from fusion_base import Fusion from logging import getLogger -from onnx import TensorProto, NodeProto -from onnx_model import OnnxModel +from typing import Dict, List, Union + +from fusion_base import Fusion from fusion_utils import FusionUtils -from typing import Union, Dict, List +from onnx import NodeProto, TensorProto +from onnx_model import OnnxModel logger = getLogger(__name__) @@ -21,7 +22,7 @@ def __init__(self, model: OnnxModel): self.shape_infer_done = False def get_dimensions_from_tensor_proto(self, tensor_proto: TensorProto) -> Union[int, None]: - if tensor_proto.type.tensor_type.HasField('shape'): + if tensor_proto.type.tensor_type.HasField("shape"): return len(tensor_proto.type.tensor_type.shape.dim) else: return None @@ -40,8 +41,12 @@ def get_dimensions(self, input_name: str) -> Union[int, None]: return None - def fuse(self, concat_node: NodeProto, input_name_to_nodes: Dict[str, List[NodeProto]], - output_name_to_node: Dict[str, NodeProto]): + def fuse( + self, + concat_node: NodeProto, + input_name_to_nodes: Dict[str, List[NodeProto]], + output_name_to_node: Dict[str, NodeProto], + ): """ Smplify subgraph like @@ -64,8 +69,12 @@ def fuse(self, concat_node: NodeProto, input_name_to_nodes: Dict[str, List[NodeP root = None shape_output = None for i in range(inputs): - path = self.model.match_parent_path(concat_node, ['Unsqueeze', 'Gather', 'Shape'], [i, 0, 0], - output_name_to_node) + path = self.model.match_parent_path( + concat_node, + ["Unsqueeze", "Gather", "Shape"], + [i, 0, 0], + output_name_to_node, + ) if path is None: return @@ -79,18 +88,19 @@ def fuse(self, concat_node: NodeProto, input_name_to_nodes: Dict[str, List[NodeP elif shape.input[0] != root: return - if not FusionUtils.check_node_attribute(unsqueeze, 'axis', 0, default_value=0): + if not FusionUtils.check_node_attribute(unsqueeze, "axis", 0, default_value=0): return if opset_version < 13: - if not FusionUtils.check_node_attribute(unsqueeze, 'axes', [0]): + if not FusionUtils.check_node_attribute(unsqueeze, "axes", [0]): return else: if not self.utils.check_node_input_value(unsqueeze, 1, [0]): return value = self.model.get_constant_value(gather.input[1]) - from numpy import ndarray, array_equal + from numpy import array_equal, ndarray + if not (isinstance(value, ndarray) and value.size == 1 and value.item() == i): return diff --git a/onnxruntime/python/tools/transformers/fusion_skiplayernorm.py b/onnxruntime/python/tools/transformers/fusion_skiplayernorm.py index 01cbe9ae83c0b..54037dd6013e6 100644 --- a/onnxruntime/python/tools/transformers/fusion_skiplayernorm.py +++ b/onnxruntime/python/tools/transformers/fusion_skiplayernorm.py @@ -1,13 +1,14 @@ -#------------------------------------------------------------------------- +# ------------------------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. -#-------------------------------------------------------------------------- +# -------------------------------------------------------------------------- from logging import getLogger -from onnx import helper -from onnx_model import OnnxModel + from fusion_base import Fusion from fusion_utils import NumpyHelper +from onnx import helper +from onnx_model import OnnxModel logger = getLogger(__name__) @@ -17,6 +18,7 @@ class FusionSkipLayerNormalization(Fusion): Fuse Add + LayerNormalization into one node: SkipLayerNormalization Note: This fusion does not check the input shape of Add and LayerNormalization. """ + def __init__(self, model: OnnxModel): super().__init__(model, "SkipLayerNormalization", "LayerNormalization") # Update shape inference is needed since other fusions might add new edge which does not have shape info yet. @@ -41,7 +43,8 @@ def fuse(self, node, input_name_to_nodes, output_name_to_node): if self.shape_infer_helper is not None: if not self.shape_infer_helper.compare_shape(add.input[0], add.input[1]): logger.debug( - f"skip skiplayernorm fusion since shape of inputs ({add.input[0]}, {add.input[1]}) are not same") + f"skip skiplayernorm fusion since shape of inputs ({add.input[0]}, {add.input[1]}) are not same" + ) return else: # shape_infer_helper can not handle subgraphs. Current work around is to disable skiplayernorm fusion @@ -50,31 +53,35 @@ def fuse(self, node, input_name_to_nodes, output_name_to_node): "symbolic shape infer failed. it's safe to ignore this message if there is no issue with optimized model" ) - gather_path = self.model.match_parent_path(add, ['Gather'], [None]) + gather_path = self.model.match_parent_path(add, ["Gather"], [None]) if gather_path is not None and self.model.find_graph_input(gather_path[0].input[1]) is None: - if self.model.match_parent_path(gather_path[0], ['ConstantOfShape'], [1]) is None: + if self.model.match_parent_path(gather_path[0], ["ConstantOfShape"], [1]) is None: return - if add is not None and add.op_type == 'Add' and self.model.is_safe_to_fuse_nodes( - [add, node], node.output, input_name_to_nodes, output_name_to_node): + if ( + add is not None + and add.op_type == "Add" + and self.model.is_safe_to_fuse_nodes([add, node], node.output, input_name_to_nodes, output_name_to_node) + ): self.nodes_to_remove.extend([add, node]) inputs = [add.input[0], add.input[1], node.input[1], node.input[2]] - normalize_node = helper.make_node("SkipLayerNormalization", - inputs=inputs, - outputs=[node.output[0]], - name=self.model.create_node_name("SkipLayerNormalization", - name_prefix="SkipLayerNorm")) + normalize_node = helper.make_node( + "SkipLayerNormalization", + inputs=inputs, + outputs=[node.output[0]], + name=self.model.create_node_name("SkipLayerNormalization", name_prefix="SkipLayerNorm"), + ) normalize_node.domain = "com.microsoft" # Pass attribute "epsilon" from layernorm node to SkipLayerNormalization for att in node.attribute: - if att.name == 'epsilon': + if att.name == "epsilon": normalize_node.attribute.extend([att]) # Set default epsilon if no epsilon exists from layernorm if len(normalize_node.attribute) == 0: - normalize_node.attribute.extend([helper.make_attribute("epsilon", 1.0E-12)]) + normalize_node.attribute.extend([helper.make_attribute("epsilon", 1.0e-12)]) self.nodes_to_add.append(normalize_node) self.node_name_to_graph_name[normalize_node.name] = self.this_graph_name @@ -89,7 +96,7 @@ def fuse(self, node, input_name_to_nodes, output_name_to_node): return return_indice = [] - nodes = self.model.match_parent_path(node, ['Add', 'MatMul'], [None, None], None, return_indice) + nodes = self.model.match_parent_path(node, ["Add", "MatMul"], [None, None], None, return_indice) if nodes is None: return assert len(return_indice) == 2 @@ -116,30 +123,36 @@ def fuse(self, node, input_name_to_nodes, output_name_to_node): return subgraph_nodes = [node, add] - if not self.model.is_safe_to_fuse_nodes(subgraph_nodes, [node.output[0]], input_name_to_nodes, - output_name_to_node): + if not self.model.is_safe_to_fuse_nodes( + subgraph_nodes, [node.output[0]], input_name_to_nodes, output_name_to_node + ): logger.debug(f"Skip fusing SkipLayerNormalization with Bias since it is not safe") return self.nodes_to_remove.extend(subgraph_nodes) inputs = [ - node.input[1 - add_input_index], matmul.output[0], node.input[2], node.input[3], add.input[bias_index] + node.input[1 - add_input_index], + matmul.output[0], + node.input[2], + node.input[3], + add.input[bias_index], ] - new_node = helper.make_node("SkipLayerNormalization", - inputs=inputs, - outputs=node.output, - name=self.model.create_node_name("SkipLayerNormalization", - "SkipLayerNorm_AddBias_")) + new_node = helper.make_node( + "SkipLayerNormalization", + inputs=inputs, + outputs=node.output, + name=self.model.create_node_name("SkipLayerNormalization", "SkipLayerNorm_AddBias_"), + ) new_node.domain = "com.microsoft" # Pass attribute "epsilon" from skiplayernorm node to skiplayernorm(add bias) for att in node.attribute: - if att.name == 'epsilon': + if att.name == "epsilon": new_node.attribute.extend([att]) # Set default epsilon if no epsilon exists from skiplayernorm if len(new_node.attribute) == 0: - new_node.attribute.extend([helper.make_attribute("epsilon", 1.0E-12)]) + new_node.attribute.extend([helper.make_attribute("epsilon", 1.0e-12)]) self.nodes_to_add.append(new_node) self.node_name_to_graph_name[new_node.name] = self.this_graph_name diff --git a/onnxruntime/python/tools/transformers/fusion_utils.py b/onnxruntime/python/tools/transformers/fusion_utils.py index 0d070f7768ff9..b682809565ee8 100644 --- a/onnxruntime/python/tools/transformers/fusion_utils.py +++ b/onnxruntime/python/tools/transformers/fusion_utils.py @@ -1,18 +1,18 @@ -#------------------------------------------------------------------------- +# ------------------------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. -#-------------------------------------------------------------------------- +# -------------------------------------------------------------------------- from logging import getLogger from typing import Tuple -from onnx import helper, numpy_helper, TensorProto -from numpy import ndarray, array_equal + +from numpy import array_equal, ndarray +from onnx import TensorProto, helper, numpy_helper from onnx_model import OnnxModel logger = getLogger(__name__) class FusionUtils: - def __init__(self, model: OnnxModel): self.model: OnnxModel = model @@ -27,17 +27,17 @@ def cast_graph_input_to_int32(self, input_name: str) -> Tuple[bool, str]: return False, input_name def cast_input_to_int32(self, input_name: str): - cast_output = input_name + '_int32' + cast_output = input_name + "_int32" # Avoid consequent Cast nodes. inputs = [input_name] output_name_to_node = self.model.output_name_to_node() if input_name in output_name_to_node: parent_node = output_name_to_node[input_name] - if parent_node and parent_node.op_type == 'Cast': + if parent_node and parent_node.op_type == "Cast": inputs = [parent_node.input[0]] - cast_node = helper.make_node('Cast', inputs=inputs, outputs=[cast_output]) + cast_node = helper.make_node("Cast", inputs=inputs, outputs=[cast_output]) cast_node.attribute.extend([helper.make_attribute("to", int(TensorProto.INT32))]) self.model.add_node(cast_node) @@ -50,7 +50,7 @@ def remove_cast_int32(self, input_name: str): if node.op_type == "Cast": is_int32 = False for att in node.attribute: - if att.name == 'to' and att.i == int(TensorProto.INT32): + if att.name == "to" and att.i == int(TensorProto.INT32): is_int32 = True break if is_int32: @@ -78,7 +78,8 @@ def check_node_attribute(node, attribute_name: str, expected_value, default_valu if isinstance(expected_value, list): return (isinstance(value, ndarray) or isinstance(value, list)) and array_equal( - expected_value, value, equal_nan=False) + expected_value, value, equal_nan=False + ) else: return value == expected_value @@ -99,7 +100,8 @@ def check_node_input_value(self, node, input_index: int, expected_value): if isinstance(expected_value, list): return (isinstance(value, ndarray) or isinstance(value, list)) and array_equal( - expected_value, value, equal_nan=False) + expected_value, value, equal_nan=False + ) else: return value == expected_value @@ -119,16 +121,16 @@ def get_dtype(self, shape_infer_helper, input_or_output_name: str) -> int: if shape_infer_helper: tensor_proto = shape_infer_helper.known_vi_[input_or_output_name] - if tensor_proto.type.tensor_type.HasField('elem_type'): + if tensor_proto.type.tensor_type.HasField("elem_type"): return tensor_proto.type.tensor_type.elem_type return None def remove_cascaded_cast_nodes(self): """Remove Cast node that are overrided by another Cast node like --> Cast --> Cast --> - Note that this shall be used carefully since it might introduce semantic change. - For example, float -> int -> float could get different value than the original float value. - So, it is recommended to used only in post-processing of mixed precision conversion. + Note that this shall be used carefully since it might introduce semantic change. + For example, float -> int -> float could get different value than the original float value. + So, it is recommended to used only in post-processing of mixed precision conversion. """ removed_count = 0 for node in self.model.nodes(): @@ -143,15 +145,14 @@ def remove_cascaded_cast_nodes(self): self.model.prune_graph() def remove_useless_cast_nodes(self): - """Remove cast nodes that are not needed: input and output has same data type. - """ + """Remove cast nodes that are not needed: input and output has same data type.""" shape_infer = self.model.infer_runtime_shape(update=True) if shape_infer is None: return nodes_to_remove = [] for node in self.model.nodes(): - if node.op_type == 'Cast': + if node.op_type == "Cast": input_dtype = self.get_dtype(shape_infer, node.input[0]) output_dtype = self.get_dtype(shape_infer, node.output[0]) if input_dtype and input_dtype == output_dtype: @@ -172,20 +173,20 @@ def remove_useless_cast_nodes(self): logger.info(f"Removed {len(nodes_to_remove)} Cast nodes with output type same as input") def remove_useless_reshape_nodes(self): - """Remove reshape node that is not needed based on symbolic shape inference: input and output has same shape - """ + """Remove reshape node that is not needed based on symbolic shape inference: input and output has same shape""" shape_infer = self.model.infer_runtime_shape(update=True) if shape_infer is None: return nodes_to_remove = [] for node in self.model.nodes(): - if node.op_type == 'Reshape': + if node.op_type == "Reshape": input_shape = shape_infer.get_edge_shape(node.input[0]) output_shape = shape_infer.get_edge_shape(node.output[0]) if input_shape and output_shape and input_shape == output_shape: logger.info( - f"Remove reshape node {node.name} since its input shape is same as output: {input_shape}") + f"Remove reshape node {node.name} since its input shape is same as output: {input_shape}" + ) nodes_to_remove.append(node) if nodes_to_remove: @@ -203,13 +204,16 @@ def remove_useless_reshape_nodes(self): class NumpyHelper: - @staticmethod def to_array(tensor: TensorProto, fill_zeros: bool = False) -> ndarray: # When weights are in external data format but not presented, we can still test the optimizer with two changes: # (1) set fill_zeros = True (2) change load_external_data=False in optimizer.py if fill_zeros: from onnx import mapping - return ndarray(shape=tensor.dims, dtype=mapping.TENSOR_TYPE_TO_NP_TYPE[tensor.data_type]) + + return ndarray( + shape=tensor.dims, + dtype=mapping.TENSOR_TYPE_TO_NP_TYPE[tensor.data_type], + ) return numpy_helper.to_array(tensor) diff --git a/onnxruntime/python/tools/transformers/huggingface_models.py b/onnxruntime/python/tools/transformers/huggingface_models.py index 642669156cbb8..cdf75efb1e62d 100644 --- a/onnxruntime/python/tools/transformers/huggingface_models.py +++ b/onnxruntime/python/tools/transformers/huggingface_models.py @@ -6,17 +6,35 @@ # Maps model class name to a tuple of model class MODEL_CLASSES = [ - 'AutoModel', 'AutoModelWithLMHead', 'AutoModelForSequenceClassification', 'AutoModelForQuestionAnswering', - 'AutoModelForCausalLM', + "AutoModel", + "AutoModelWithLMHead", + "AutoModelForSequenceClassification", + "AutoModelForQuestionAnswering", + "AutoModelForCausalLM", ] # List of pretrained models: https://huggingface.co/transformers/pretrained_models.html # Pretrained model name to a tuple of input names, opset_version, use_external_data_format, optimization model type MODELS = { # BERT - "bert-base-uncased": (["input_ids", "attention_mask", "token_type_ids"], 12, False, "bert"), - "bert-large-uncased": (["input_ids", "attention_mask", "token_type_ids"], 12, False, "bert"), - "bert-base-cased": (["input_ids", "attention_mask", "token_type_ids"], 12, False, "bert"), + "bert-base-uncased": ( + ["input_ids", "attention_mask", "token_type_ids"], + 12, + False, + "bert", + ), + "bert-large-uncased": ( + ["input_ids", "attention_mask", "token_type_ids"], + 12, + False, + "bert", + ), + "bert-base-cased": ( + ["input_ids", "attention_mask", "token_type_ids"], + 12, + False, + "bert", + ), # "bert-large-cased": (["input_ids", "attention_mask", "token_type_ids"], 12, False, "bert"), # "bert-base-multilingual-uncased": (["input_ids", "attention_mask", "token_type_ids"], 12, False, "bert"), # "bert-base-multilingual-cased": (["input_ids", "attention_mask", "token_type_ids"], 12, False, "bert"), @@ -55,10 +73,14 @@ "roberta-large-mnli": (["input_ids", "attention_mask"], 12, False, "bert"), "deepset/roberta-base-squad2": (["input_ids", "attention_mask"], 11, False, "bert"), "distilroberta-base": (["input_ids", "attention_mask"], 12, False, "bert"), - # DistilBERT "distilbert-base-uncased": (["input_ids", "attention_mask"], 11, False, "bert"), - "distilbert-base-uncased-distilled-squad": (["input_ids", "attention_mask"], 11, False, "bert"), + "distilbert-base-uncased-distilled-squad": ( + ["input_ids", "attention_mask"], + 11, + False, + "bert", + ), # CTRL "ctrl": (["input_ids"], 11, True, "bert"), # CamemBERT @@ -67,44 +89,43 @@ "albert-base-v1": (["input_ids"], 12, False, "bert"), "albert-large-v1": (["input_ids"], 12, False, "bert"), "albert-xlarge-v1": (["input_ids"], 12, True, "bert"), - #"albert-xxlarge-v1": (["input_ids"], 12, True, "bert"), + # "albert-xxlarge-v1": (["input_ids"], 12, True, "bert"), "albert-base-v2": (["input_ids"], 12, False, "bert"), "albert-large-v2": (["input_ids"], 12, False, "bert"), "albert-xlarge-v2": (["input_ids"], 12, True, "bert"), - #"albert-xxlarge-v2": (["input_ids"], 12, True, "bert"), + # "albert-xxlarge-v2": (["input_ids"], 12, True, "bert"), # T5 (use benchmark_t5.py instead) # "t5-small": (["input_ids", "decoder_input_ids"], 12, False, "bert"), # "t5-base": (["input_ids", "decoder_input_ids"], 12, False, "bert"), # "t5-large": (["input_ids", "decoder_input_ids"], 12, True, "bert"), # "t5-3b": (["input_ids", "decoder_input_ids"], 12, True, "bert"), # "t5-11b": (["input_ids", "decoder_input_ids"], 12, True, "bert"), - #"valhalla/t5-small-qa-qg-hl": (["input_ids"], 12, True, "bert"), + # "valhalla/t5-small-qa-qg-hl": (["input_ids"], 12, True, "bert"), # XLM-RoBERTa "xlm-roberta-base": (["input_ids"], 11, False, "bert"), "xlm-roberta-large": (["input_ids"], 11, True, "bert"), # FlauBERT "flaubert/flaubert_small_cased": (["input_ids"], 11, False, "bert"), - #"flaubert/flaubert_base_uncased": (["input_ids"], 11, False, "bert"), + # "flaubert/flaubert_base_uncased": (["input_ids"], 11, False, "bert"), "flaubert/flaubert_base_cased": (["input_ids"], 11, False, "bert"), - #"flaubert/flaubert_large_cased": (["input_ids"], 11, False, "bert"), + # "flaubert/flaubert_large_cased": (["input_ids"], 11, False, "bert"), # Bart "facebook/bart-large": (["input_ids", "attention_mask"], 11, False, "bart"), "facebook/bart-base": (["input_ids", "attention_mask"], 11, False, "bart"), "facebook/bart-large-mnli": (["input_ids", "attention_mask"], 11, False, "bart"), "facebook/bart-large-cnn": (["input_ids", "attention_mask"], 11, False, "bart"), - # DialoGPT "microsoft/DialoGPT-small": (["input_ids"], 11, False, "gpt2"), "microsoft/DialoGPT-medium": (["input_ids"], 11, False, "gpt2"), - #"microsoft/DialoGPT-large": (["input_ids"], 11, True, "gpt2"), + # "microsoft/DialoGPT-large": (["input_ids"], 11, True, "gpt2"), # Reformer - #"google/reformer-enwik8": (["input_ids"], 11, False, "bert"), - #"google/reformer-crime-and-punishment": (["input_ids"], 11, False, "bert"), + # "google/reformer-enwik8": (["input_ids"], 11, False, "bert"), + # "google/reformer-crime-and-punishment": (["input_ids"], 11, False, "bert"), # MarianMT - #"Helsinki-NLP/opus-mt-ROMANCE-en": (["input_ids"], 12, False, "bert"), + # "Helsinki-NLP/opus-mt-ROMANCE-en": (["input_ids"], 12, False, "bert"), # Longformer (use benchmark_longformer.py instead) - #"allenai/longformer-base-4096": (["input_ids"], 12, False, "bert"), - #"allenai/longformer-large-4096": (["input_ids"], 12, False, "bert"), + # "allenai/longformer-base-4096": (["input_ids"], 12, False, "bert"), + # "allenai/longformer-large-4096": (["input_ids"], 12, False, "bert"), # MBart "facebook/mbart-large-cc25": (["input_ids"], 11, True, "bert"), "facebook/mbart-large-en-ro": (["input_ids"], 11, True, "bert"), @@ -129,7 +150,12 @@ "squeezebert/squeezebert-uncased": (["input_ids"], 11, False, "bert"), "squeezebert/squeezebert-mnli": (["input_ids"], 11, False, "bert"), "squeezebert/squeezebert-mnli-headless": (["input_ids"], 11, False, "bert"), - "unc-nlp/lxmert-base-uncased": (["input_ids", "visual_feats", "visual_pos"], 11, False, "bert"), + "unc-nlp/lxmert-base-uncased": ( + ["input_ids", "visual_feats", "visual_pos"], + 11, + False, + "bert", + ), # "google/pegasus-xsum": (["input_ids"], 11, False, "bert"), # "google/pegasus-large": (["input_ids"], 11, False, "bert"), } diff --git a/onnxruntime/python/tools/transformers/io_binding_helper.py b/onnxruntime/python/tools/transformers/io_binding_helper.py index dbed946192928..31e04216a8352 100644 --- a/onnxruntime/python/tools/transformers/io_binding_helper.py +++ b/onnxruntime/python/tools/transformers/io_binding_helper.py @@ -1,14 +1,15 @@ +import logging +from typing import Dict, List, Union + import numpy import torch -import logging -from typing import List, Dict, Union + from onnxruntime import InferenceSession logger = logging.getLogger(__name__) class TypeHelper: - @staticmethod def get_input_type(ort_session: InferenceSession, name: str) -> str: for i, input in enumerate(ort_session.get_inputs()): @@ -94,11 +95,9 @@ def get_io_numpy_type_map(ort_session: InferenceSession) -> Dict[str, numpy.dtyp class IOBindingHelper: - @staticmethod def get_output_buffers(ort_session: InferenceSession, output_shapes, device): - """ Returns a dictionary of output name as key, and 1D tensor as value. The tensor has enough space for given shape. - """ + """Returns a dictionary of output name as key, and 1D tensor as value. The tensor has enough space for given shape.""" output_buffers = {} for name, shape in output_shapes.items(): ort_type = TypeHelper.get_output_type(ort_session, name) @@ -107,16 +106,17 @@ def get_output_buffers(ort_session: InferenceSession, output_shapes, device): return output_buffers @staticmethod - def prepare_io_binding(ort_session, - input_ids: torch.Tensor, - position_ids: torch.Tensor, - attention_mask: torch.Tensor, - past: List[torch.Tensor], - output_buffers, - output_shapes, - name_to_np_type=None): - """ Returnas IO binding object for a session. - """ + def prepare_io_binding( + ort_session, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + attention_mask: torch.Tensor, + past: List[torch.Tensor], + output_buffers, + output_shapes, + name_to_np_type=None, + ): + """Returnas IO binding object for a session.""" if name_to_np_type is None: name_to_np_type = TypeHelper.get_io_numpy_type_map(ort_session) @@ -125,8 +125,14 @@ def prepare_io_binding(ort_session, # Bind inputs assert input_ids.is_contiguous() - io_binding.bind_input('input_ids', input_ids.device.type, 0, name_to_np_type['input_ids'], - list(input_ids.size()), input_ids.data_ptr()) + io_binding.bind_input( + "input_ids", + input_ids.device.type, + 0, + name_to_np_type["input_ids"], + list(input_ids.size()), + input_ids.data_ptr(), + ) if past is not None: for i, past_i in enumerate(past): @@ -138,39 +144,62 @@ def prepare_io_binding(ort_session, # Here we workaround and pass data pointer of input_ids. Actual data is not used for past so it does not matter. data_ptr = input_ids.data_ptr() - io_binding.bind_input(f'past_{i}', past_i.device.type, 0, name_to_np_type[f'past_{i}'], - list(past_i.size()), data_ptr) + io_binding.bind_input( + f"past_{i}", + past_i.device.type, + 0, + name_to_np_type[f"past_{i}"], + list(past_i.size()), + data_ptr, + ) if attention_mask is not None: assert attention_mask.is_contiguous() - io_binding.bind_input('attention_mask', attention_mask.device.type, 0, name_to_np_type['attention_mask'], - list(attention_mask.size()), attention_mask.data_ptr()) + io_binding.bind_input( + "attention_mask", + attention_mask.device.type, + 0, + name_to_np_type["attention_mask"], + list(attention_mask.size()), + attention_mask.data_ptr(), + ) if position_ids is not None: assert position_ids.is_contiguous() - io_binding.bind_input('position_ids', position_ids.device.type, 0, name_to_np_type['position_ids'], - list(position_ids.size()), position_ids.data_ptr()) + io_binding.bind_input( + "position_ids", + position_ids.device.type, + 0, + name_to_np_type["position_ids"], + list(position_ids.size()), + position_ids.data_ptr(), + ) # Bind outputs for output in ort_session.get_outputs(): output_name = output.name output_buffer = output_buffers[output_name] logger.debug(f"{output_name} device type={output_buffer.device.type} shape={list(output_buffer.size())}") - io_binding.bind_output(output_name, output_buffer.device.type, 0, name_to_np_type[output_name], - output_shapes[output_name], output_buffer.data_ptr()) + io_binding.bind_output( + output_name, + output_buffer.device.type, + 0, + name_to_np_type[output_name], + output_shapes[output_name], + output_buffer.data_ptr(), + ) return io_binding @staticmethod def get_outputs_from_io_binding_buffer(ort_session, output_buffers, output_shapes, return_numpy=True): - """ Copy results to cpu. Returns a list of numpy array. - """ + """Copy results to cpu. Returns a list of numpy array.""" ort_outputs = [] for output in ort_session.get_outputs(): output_name = output.name buffer = output_buffers[output_name] shape = output_shapes[output_name] - copy_tensor = buffer[0:numpy.prod(shape)].reshape(shape).clone().detach() + copy_tensor = buffer[0 : numpy.prod(shape)].reshape(shape).clone().detach() if return_numpy: ort_outputs.append(copy_tensor.cpu().numpy()) else: diff --git a/onnxruntime/python/tools/transformers/machine_info.py b/onnxruntime/python/tools/transformers/machine_info.py index 8af24644f9f42..e872e2a6c00c6 100644 --- a/onnxruntime/python/tools/transformers/machine_info.py +++ b/onnxruntime/python/tools/transformers/machine_info.py @@ -1,30 +1,43 @@ -#------------------------------------------------------------------------- +# ------------------------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. -#-------------------------------------------------------------------------- +# -------------------------------------------------------------------------- # It is used to dump machine information for Notebooks import argparse -import logging -from typing import List, Dict, Union, Tuple -import cpuinfo -import psutil import json -import sys +import logging import platform +import sys from os import environ -from py3nvml.py3nvml import nvmlInit, nvmlSystemGetDriverVersion, nvmlDeviceGetCount, nvmlDeviceGetHandleByIndex, \ - nvmlDeviceGetMemoryInfo, nvmlDeviceGetName, nvmlShutdown, NVMLError +from typing import Dict, List, Tuple, Union + +import cpuinfo +import psutil +from py3nvml.py3nvml import ( + NVMLError, + nvmlDeviceGetCount, + nvmlDeviceGetHandleByIndex, + nvmlDeviceGetMemoryInfo, + nvmlDeviceGetName, + nvmlInit, + nvmlShutdown, + nvmlSystemGetDriverVersion, +) + +class MachineInfo: + """Class encapsulating Machine Info logic.""" -class MachineInfo(): - """ Class encapsulating Machine Info logic. """ def __init__(self, silent=False, logger=None): self.silent = silent if logger is None: - logging.basicConfig(format="%(asctime)s - %(name)s - %(levelname)s: %(message)s", level=logging.INFO) + logging.basicConfig( + format="%(asctime)s - %(name)s - %(levelname)s: %(message)s", + level=logging.INFO, + ) self.logger = logging.getLogger(__name__) else: self.logger = logger @@ -50,7 +63,7 @@ def get_machine_info(self): "packages": self.get_related_packages(), "onnxruntime": self.get_onnxruntime_info(), "pytorch": self.get_pytorch_info(), - "tensorflow": self.get_tensorflow_info() + "tensorflow": self.get_tensorflow_info(), } return machine_info @@ -79,7 +92,7 @@ def get_cpu_info(self) -> Dict: "hz": self._try_get(cpu_info, ["hz_actual"]), "l2_cache": self._try_get(cpu_info, ["l2_cache_size"]), "flags": self._try_get(cpu_info, ["flags"]), - "processor": platform.uname().processor + "processor": platform.uname().processor, } def get_gpu_info_by_nvml(self) -> Dict: @@ -106,16 +119,28 @@ def get_gpu_info_by_nvml(self) -> Dict: result = {"driver_version": driver_version, "devices": gpu_info_list} - if 'CUDA_VISIBLE_DEVICES' in environ: - result["cuda_visible"] = environ['CUDA_VISIBLE_DEVICES'] + if "CUDA_VISIBLE_DEVICES" in environ: + result["cuda_visible"] = environ["CUDA_VISIBLE_DEVICES"] return result def get_related_packages(self) -> List[str]: import pkg_resources + installed_packages = pkg_resources.working_set related_packages = [ - 'onnxruntime-gpu', 'onnxruntime', 'ort-nightly-gpu', 'ort-nightly', 'onnx', 'transformers', 'protobuf', - 'sympy', 'torch', 'tensorflow', 'flatbuffers', 'numpy', 'onnxconverter-common' + "onnxruntime-gpu", + "onnxruntime", + "ort-nightly-gpu", + "ort-nightly", + "onnx", + "transformers", + "protobuf", + "sympy", + "torch", + "tensorflow", + "flatbuffers", + "numpy", + "onnxconverter-common", ] related_packages_list = {i.key: i.version for i in installed_packages if i.key in related_packages} return related_packages_list @@ -123,9 +148,10 @@ def get_related_packages(self) -> List[str]: def get_onnxruntime_info(self) -> Dict: try: import onnxruntime + return { "version": onnxruntime.__version__, - "support_gpu": 'CUDAExecutionProvider' in onnxruntime.get_available_providers() + "support_gpu": "CUDAExecutionProvider" in onnxruntime.get_available_providers(), } except ImportError as error: if not self.silent: @@ -139,7 +165,12 @@ def get_onnxruntime_info(self) -> Dict: def get_pytorch_info(self) -> Dict: try: import torch - return {"version": torch.__version__, "support_gpu": torch.cuda.is_available(), "cuda": torch.version.cuda} + + return { + "version": torch.__version__, + "support_gpu": torch.cuda.is_available(), + "cuda": torch.version.cuda, + } except ImportError as error: if not self.silent: self.logger.exception(error) @@ -152,10 +183,11 @@ def get_pytorch_info(self) -> Dict: def get_tensorflow_info(self) -> Dict: try: import tensorflow as tf + return { "version": tf.version.VERSION, "git_version": tf.version.GIT_VERSION, - "support_gpu": tf.test.is_built_with_cuda() + "support_gpu": tf.test.is_built_with_cuda(), } except ImportError as error: if not self.silent: @@ -170,7 +202,12 @@ def get_tensorflow_info(self) -> Dict: def parse_arguments(): parser = argparse.ArgumentParser() - parser.add_argument('--silent', required=False, action='store_true', help="Do not print error message") + parser.add_argument( + "--silent", + required=False, + action="store_true", + help="Do not print error message", + ) parser.set_defaults(silent=False) args = parser.parse_args() @@ -182,6 +219,6 @@ def get_machine_info(silent=True) -> str: return json.dumps(machine.machine_info, indent=2) -if __name__ == '__main__': +if __name__ == "__main__": args = parse_arguments() print(get_machine_info(args.silent)) diff --git a/onnxruntime/python/tools/transformers/models/__init__.py b/onnxruntime/python/tools/transformers/models/__init__.py index 7c2a88f4d9554..cc667396a2622 100644 --- a/onnxruntime/python/tools/transformers/models/__init__.py +++ b/onnxruntime/python/tools/transformers/models/__init__.py @@ -1,4 +1,4 @@ -#------------------------------------------------------------------------- +# ------------------------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. -#-------------------------------------------------------------------------- +# -------------------------------------------------------------------------- diff --git a/onnxruntime/python/tools/transformers/models/gpt2/__init__.py b/onnxruntime/python/tools/transformers/models/gpt2/__init__.py index 7c2a88f4d9554..cc667396a2622 100644 --- a/onnxruntime/python/tools/transformers/models/gpt2/__init__.py +++ b/onnxruntime/python/tools/transformers/models/gpt2/__init__.py @@ -1,4 +1,4 @@ -#------------------------------------------------------------------------- +# ------------------------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. -#-------------------------------------------------------------------------- +# -------------------------------------------------------------------------- diff --git a/onnxruntime/python/tools/transformers/models/gpt2/benchmark_gpt2.py b/onnxruntime/python/tools/transformers/models/gpt2/benchmark_gpt2.py index f1600017ca4e7..f8cbdaaa32e70 100644 --- a/onnxruntime/python/tools/transformers/models/gpt2/benchmark_gpt2.py +++ b/onnxruntime/python/tools/transformers/models/gpt2/benchmark_gpt2.py @@ -6,74 +6,90 @@ # This script benchmarks gpt2 model with past state. # For gpt2 model without past state, use benchmark.py to measure performance. +import argparse import csv +import logging +import os +import sys from datetime import datetime + import psutil -import argparse -import logging import torch +from gpt2_beamsearch_helper import MODEL_CLASSES, Gpt2HelperFactory +from gpt2_helper import DEFAULT_TOLERANCE, PRETRAINED_GPT2_MODELS, Gpt2Helper from packaging import version from transformers import AutoConfig -from gpt2_helper import Gpt2Helper, DEFAULT_TOLERANCE, PRETRAINED_GPT2_MODELS -from gpt2_beamsearch_helper import Gpt2HelperFactory, MODEL_CLASSES - -import sys -import os - -sys.path.append(os.path.join(os.path.dirname(__file__), '..', '..')) +sys.path.append(os.path.join(os.path.dirname(__file__), "..", "..")) +from benchmark_helper import Precision, create_onnxruntime_session, prepare_environment, setup_logger from quantize_helper import QuantizeHelper -from benchmark_helper import create_onnxruntime_session, setup_logger, prepare_environment, Precision -logger = logging.getLogger('') +logger = logging.getLogger("") def parse_arguments(argv=None): parser = argparse.ArgumentParser() - parser.add_argument('-m', - '--model_name_or_path', - required=True, - type=str, - help='Model path, or pretrained model name selected in the list: ' + - ', '.join(PRETRAINED_GPT2_MODELS)) - - parser.add_argument('--model_class', - required=False, - type=str, - default='GPT2LMHeadModel', - choices=list(MODEL_CLASSES.keys()), - help='Model type selected in the list: ' + ', '.join(MODEL_CLASSES.keys())) - - parser.add_argument('--cache_dir', - required=False, - type=str, - default=os.path.join('.', 'cache_models'), - help='Directory to cache pre-trained models') - - parser.add_argument('--onnx_dir', - required=False, - type=str, - default=os.path.join('.', 'onnx_models'), - help='Directory to store onnx models') - - parser.add_argument('--test_times', - required=False, - default=100, - type=int, - help='Number of repeat times to get average inference latency.') - - parser.add_argument('-v', '--validate_onnx', required=False, action='store_true', help='Validate ONNX model') - - parser.add_argument('-o', - '--optimize_onnx', - required=False, - action='store_true', - help='Use optimizer.py to optimize onnx model') + parser.add_argument( + "-m", + "--model_name_or_path", + required=True, + type=str, + help="Model path, or pretrained model name selected in the list: " + ", ".join(PRETRAINED_GPT2_MODELS), + ) + + parser.add_argument( + "--model_class", + required=False, + type=str, + default="GPT2LMHeadModel", + choices=list(MODEL_CLASSES.keys()), + help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()), + ) + + parser.add_argument( + "--cache_dir", + required=False, + type=str, + default=os.path.join(".", "cache_models"), + help="Directory to cache pre-trained models", + ) + + parser.add_argument( + "--onnx_dir", + required=False, + type=str, + default=os.path.join(".", "onnx_models"), + help="Directory to store onnx models", + ) + + parser.add_argument( + "--test_times", + required=False, + default=100, + type=int, + help="Number of repeat times to get average inference latency.", + ) + + parser.add_argument( + "-v", + "--validate_onnx", + required=False, + action="store_true", + help="Validate ONNX model", + ) + + parser.add_argument( + "-o", + "--optimize_onnx", + required=False, + action="store_true", + help="Use optimizer.py to optimize onnx model", + ) parser.set_defaults(optimize_onnx=False) - parser.add_argument('--use_gpu', required=False, action='store_true', help="use GPU for inference") + parser.add_argument("--use_gpu", required=False, action="store_true", help="use GPU for inference") parser.set_defaults(use_gpu=False) parser.add_argument( @@ -82,70 +98,100 @@ def parse_arguments(argv=None): type=Precision, default=Precision.FLOAT32, choices=list(Precision), - help="Precision of model to run. fp32 for full precision, fp16 for half precision, and int8 for quantization") + help="Precision of model to run. fp32 for full precision, fp16 for half precision, and int8 for quantization", + ) - parser.add_argument('--torchscript', required=False, action='store_true', help="use Torchscript") + parser.add_argument("--torchscript", required=False, action="store_true", help="use Torchscript") parser.set_defaults(torchscript=False) - parser.add_argument('-b', '--batch_sizes', nargs='+', type=int, default=[1], help="batch size") - parser.add_argument('--beam_size', type=int, default=4, help='Beam size if greedy/top-p/top-k sampling is needed') + parser.add_argument("-b", "--batch_sizes", nargs="+", type=int, default=[1], help="batch size") + parser.add_argument( + "--beam_size", + type=int, + default=4, + help="Beam size if greedy/top-p/top-k sampling is needed", + ) - parser.add_argument('--sequence_lengths', - nargs='+', - type=int, - default=[1], - help="sequence lengths (excluding past)") + parser.add_argument( + "--sequence_lengths", + nargs="+", + type=int, + default=[1], + help="sequence lengths (excluding past)", + ) - parser.add_argument('-s', - '--past_sequence_lengths', - nargs='+', - type=int, - default=[8, 16, 32, 64, 128, 256], - help="past sequence lengths") + parser.add_argument( + "-s", + "--past_sequence_lengths", + nargs="+", + type=int, + default=[8, 16, 32, 64, 128, 256], + help="past sequence lengths", + ) - parser.add_argument("-r", "--result_csv", required=False, default=None, help="CSV file for saving summary results.") + parser.add_argument( + "-r", + "--result_csv", + required=False, + default=None, + help="CSV file for saving summary results.", + ) parser.add_argument("--thread_num", required=False, type=int, default=-1, help="Threads to use") - parser.add_argument('--include_copy_output_latency', required=False, action='store_true') + parser.add_argument("--include_copy_output_latency", required=False, action="store_true") parser.set_defaults(include_copy_output_latency=False) - parser.add_argument('--verbose', required=False, action='store_true') + parser.add_argument("--verbose", required=False, action="store_true") parser.set_defaults(verbose=False) search_option_group = parser.add_argument_group("configurable one step search options") - search_option_group.add_argument('--ignore_eos', - type=bool, - default=False, - help='If ignore end of sentence token in model inference.') - search_option_group.add_argument('--repetition_penalty', - type=float, - default=1, - help='Positive. >1 to penalize and <1 to encorage.') - search_option_group.add_argument('--temperature', - type=float, - default=1, - help='Softmax temperature for output logits.') - search_option_group.add_argument('--excluded_token_ids', - required=False, - nargs='+', - type=float, - help='A list of token ids to be excluded in inference.') - search_option_group.add_argument('--length_penalty', - type=float, - default=1, - help='Positive. >1 to penalize and <1 to encorage short sentence.') + search_option_group.add_argument( + "--ignore_eos", + type=bool, + default=False, + help="If ignore end of sentence token in model inference.", + ) + search_option_group.add_argument( + "--repetition_penalty", + type=float, + default=1, + help="Positive. >1 to penalize and <1 to encorage.", + ) + search_option_group.add_argument( + "--temperature", + type=float, + default=1, + help="Softmax temperature for output logits.", + ) + search_option_group.add_argument( + "--excluded_token_ids", + required=False, + nargs="+", + type=float, + help="A list of token ids to be excluded in inference.", + ) + search_option_group.add_argument( + "--length_penalty", + type=float, + default=1, + help="Positive. >1 to penalize and <1 to encorage short sentence.", + ) sampling_option_group = parser.add_argument_group("one step sampling options") - sampling_option_group.add_argument('--do_sample', - action='store_true', - help='If to do sampling instead of beam search or greedy.') - sampling_option_group.add_argument('--do_sample_top_p', - type=float, - default=0.95, - help='Nuclear/top-p sampling accumulation probability.') - sampling_option_group.add_argument('--do_sample_top_k', type=int, default=0, help='Use top-k if non-zero.') + sampling_option_group.add_argument( + "--do_sample", + action="store_true", + help="If to do sampling instead of beam search or greedy.", + ) + sampling_option_group.add_argument( + "--do_sample_top_p", + type=float, + default=0.95, + help="Nuclear/top-p sampling accumulation probability.", + ) + sampling_option_group.add_argument("--do_sample_top_k", type=int, default=0, help="Use top-k if non-zero.") args = parser.parse_args(argv) @@ -154,8 +200,10 @@ def parse_arguments(argv=None): def main(args): from transformers import __version__ as transformers_version + if version.parse(transformers_version) < version.parse( - "3.1.0"): # past_key_values name does not exist in 3.0.2 or older + "3.1.0" + ): # past_key_values name does not exist in 3.0.2 or older raise RuntimeError("This tool requires transformers 3.1.0 or later.") logger.info(f"Arguments:{args}") @@ -182,61 +230,71 @@ def main(args): gpt2helper = Gpt2HelperFactory.create_helper(model_type) config = AutoConfig.from_pretrained(args.model_name_or_path, torchscript=args.torchscript, cache_dir=cache_dir) - if model_type == 'beam_search_step': - model = model_class.from_pretrained(args.model_name_or_path, - config=config, - batch_size=1, - beam_size=args.beam_size, - cache_dir=cache_dir) - elif model_type == 'configurable_one_step_search': - model = model_class.from_pretrained(args.model_name_or_path, - config=config, - batch_size=1, - beam_size=args.beam_size, - ignore_eos=args.ignore_eos, - temperature=args.temperature, - repetition_penalty=args.repetition_penalty, - excluded_token_ids=args.excluded_token_ids, - length_penalty=args.length_penalty, - do_sample=args.do_sample, - do_sample_top_p=args.do_sample_top_p, - do_sample_top_k=args.do_sample_top_k, - cache_dir=cache_dir) + if model_type == "beam_search_step": + model = model_class.from_pretrained( + args.model_name_or_path, + config=config, + batch_size=1, + beam_size=args.beam_size, + cache_dir=cache_dir, + ) + elif model_type == "configurable_one_step_search": + model = model_class.from_pretrained( + args.model_name_or_path, + config=config, + batch_size=1, + beam_size=args.beam_size, + ignore_eos=args.ignore_eos, + temperature=args.temperature, + repetition_penalty=args.repetition_penalty, + excluded_token_ids=args.excluded_token_ids, + length_penalty=args.length_penalty, + do_sample=args.do_sample, + do_sample_top_p=args.do_sample_top_p, + do_sample_top_k=args.do_sample_top_k, + cache_dir=cache_dir, + ) else: model = model_class.from_pretrained(args.model_name_or_path, config=config, cache_dir=cache_dir) # This scirpt does not support float16 for PyTorch. - #if args.float16: + # if args.float16: # model.half() device = torch.device("cuda:0" if args.use_gpu else "cpu") model.to(device) - use_external_data_format = (config.n_layer > 24) #TODO: find a way to check model size > 2GB - onnx_model_paths = gpt2helper.get_onnx_paths(output_dir, - args.model_name_or_path, - args.model_class, - has_past=True, - new_folder=use_external_data_format) + use_external_data_format = config.n_layer > 24 # TODO: find a way to check model size > 2GB + onnx_model_paths = gpt2helper.get_onnx_paths( + output_dir, + args.model_name_or_path, + args.model_class, + has_past=True, + new_folder=use_external_data_format, + ) onnx_model_path = onnx_model_paths["raw"] use_padding = MODEL_CLASSES[args.model_class][2] - gpt2helper.export_onnx(model, - device, - onnx_model_path, - args.verbose, - use_external_data_format, - has_position_ids=use_padding, - has_attention_mask=use_padding) + gpt2helper.export_onnx( + model, + device, + onnx_model_path, + args.verbose, + use_external_data_format, + has_position_ids=use_padding, + has_attention_mask=use_padding, + ) if args.optimize_onnx or args.precision != Precision.FLOAT32: - onnx_model_path = onnx_model_paths[str(args.precision) if args.precision != Precision.INT8 else 'fp32'] - gpt2helper.optimize_onnx(onnx_model_paths["raw"], - onnx_model_path, - args.precision == Precision.FLOAT16, - model.config.num_attention_heads, - model.config.hidden_size, - use_external_data_format, - auto_mixed_precision=True) + onnx_model_path = onnx_model_paths[str(args.precision) if args.precision != Precision.INT8 else "fp32"] + gpt2helper.optimize_onnx( + onnx_model_paths["raw"], + onnx_model_path, + args.precision == Precision.FLOAT16, + model.config.num_attention_heads, + model.config.hidden_size, + use_external_data_format, + auto_mixed_precision=True, + ) if args.precision == Precision.INT8: logger.info("quantizing model...") @@ -246,44 +304,64 @@ def main(args): onnx_model_path = onnx_model_paths["int8"] if args.torchscript: - model = gpt2helper.torchscript(model, - config, - device, - has_position_ids=use_padding, - has_attention_mask=use_padding) - - session = create_onnxruntime_session(onnx_model_path, - args.use_gpu, - enable_all_optimization=False, - num_threads=args.thread_num, - verbose=args.verbose) + model = gpt2helper.torchscript( + model, + config, + device, + has_position_ids=use_padding, + has_attention_mask=use_padding, + ) + + session = create_onnxruntime_session( + onnx_model_path, + args.use_gpu, + enable_all_optimization=False, + num_threads=args.thread_num, + verbose=args.verbose, + ) if session is None: return # Allocate output buffers for IO Binding - if model_type == 'beam_search_step' or model_type == 'configurable_one_step_search': - max_output_shapes = gpt2helper.get_output_shapes(max(args.batch_sizes), - context_len=max(args.past_sequence_lengths), - past_sequence_length=max(args.past_sequence_lengths), - sequence_length=max(args.sequence_lengths), - beam_size=args.beam_size, - step=0, - config=config, - model_class=args.model_class) + if model_type == "beam_search_step" or model_type == "configurable_one_step_search": + max_output_shapes = gpt2helper.get_output_shapes( + max(args.batch_sizes), + context_len=max(args.past_sequence_lengths), + past_sequence_length=max(args.past_sequence_lengths), + sequence_length=max(args.sequence_lengths), + beam_size=args.beam_size, + step=0, + config=config, + model_class=args.model_class, + ) output_buffers = gpt2helper.get_output_buffers(max_output_shapes, device, args.precision == Precision.FLOAT16) else: - max_output_shapes = gpt2helper.get_output_shapes(max(args.batch_sizes), max(args.past_sequence_lengths), - max(args.sequence_lengths), config, args.model_class) + max_output_shapes = gpt2helper.get_output_shapes( + max(args.batch_sizes), + max(args.past_sequence_lengths), + max(args.sequence_lengths), + config, + args.model_class, + ) output_buffers = gpt2helper.get_output_buffers(max_output_shapes, device, args.precision == Precision.FLOAT16) csv_filename = args.result_csv or "benchmark_result_{}.csv".format(datetime.now().strftime("%Y%m%d-%H%M%S")) - with open(csv_filename, mode="a", newline='') as csv_file: + with open(csv_filename, mode="a", newline="") as csv_file: column_names = [ - "model_name", "model_class", "gpu", "precision", "optimizer", "torchscript", "batch_size", - "sequence_length", "past_sequence_length", "torch_latency", "onnxruntime_latency", - "onnxruntime_io_binding_latency" + "model_name", + "model_class", + "gpu", + "precision", + "optimizer", + "torchscript", + "batch_size", + "sequence_length", + "past_sequence_length", + "torch_latency", + "onnxruntime_latency", + "onnxruntime_io_binding_latency", ] csv_writer = csv.DictWriter(csv_file, fieldnames=column_names) csv_writer.writeheader() @@ -295,35 +373,51 @@ def main(args): logger.debug( f"Running test for batch_size={batch_size} sequence_length={sequence_length} past_sequence_length={past_sequence_length}..." ) - if model_type == 'beam_search_step' or model_type == 'configurable_one_step_search': - dummy_inputs = gpt2helper.get_dummy_inputs(batch_size, - past_sequence_length, - sequence_length, - config.num_attention_heads, - config.hidden_size, - config.n_layer, - config.vocab_size, - device, - float16=(args.precision == Precision.FLOAT16), - has_position_ids=use_padding, - has_attention_mask=use_padding) - output_shapes = gpt2helper.get_output_shapes(batch_size, past_sequence_length, - past_sequence_length, sequence_length, - args.beam_size, 0, config, args.model_class) + if model_type == "beam_search_step" or model_type == "configurable_one_step_search": + dummy_inputs = gpt2helper.get_dummy_inputs( + batch_size, + past_sequence_length, + sequence_length, + config.num_attention_heads, + config.hidden_size, + config.n_layer, + config.vocab_size, + device, + float16=(args.precision == Precision.FLOAT16), + has_position_ids=use_padding, + has_attention_mask=use_padding, + ) + output_shapes = gpt2helper.get_output_shapes( + batch_size, + past_sequence_length, + past_sequence_length, + sequence_length, + args.beam_size, + 0, + config, + args.model_class, + ) else: - dummy_inputs = gpt2helper.get_dummy_inputs(batch_size, - past_sequence_length, - sequence_length, - config.num_attention_heads, - config.hidden_size, - config.n_layer, - config.vocab_size, - device, - float16=(args.precision == Precision.FLOAT16), - has_position_ids=use_padding, - has_attention_mask=use_padding) - output_shapes = gpt2helper.get_output_shapes(batch_size, past_sequence_length, sequence_length, - config, args.model_class) + dummy_inputs = gpt2helper.get_dummy_inputs( + batch_size, + past_sequence_length, + sequence_length, + config.num_attention_heads, + config.hidden_size, + config.n_layer, + config.vocab_size, + device, + float16=(args.precision == Precision.FLOAT16), + has_position_ids=use_padding, + has_attention_mask=use_padding, + ) + output_shapes = gpt2helper.get_output_shapes( + batch_size, + past_sequence_length, + sequence_length, + config, + args.model_class, + ) try: outputs, torch_latency = gpt2helper.pytorch_inference(model, dummy_inputs, args.test_times) @@ -335,26 +429,30 @@ def main(args): else: logger.debug(f"torch output {i} shape {value.shape}") - ort_outputs, ort_latency = gpt2helper.onnxruntime_inference(session, dummy_inputs, - args.test_times) + ort_outputs, ort_latency = gpt2helper.onnxruntime_inference( + session, dummy_inputs, args.test_times + ) - ort_io_outputs, ort_io_latency = gpt2helper.onnxruntime_inference_with_binded_io( + (ort_io_outputs, ort_io_latency,) = gpt2helper.onnxruntime_inference_with_binded_io( session, dummy_inputs, output_buffers, output_shapes, args.test_times, return_numpy=False, - include_copy_output_latency=args.include_copy_output_latency) + include_copy_output_latency=args.include_copy_output_latency, + ) if args.validate_onnx: - if gpt2helper.compare_outputs(outputs, - ort_outputs, - model_class=args.model_class, - rtol=DEFAULT_TOLERANCE[args.precision], - atol=DEFAULT_TOLERANCE[args.precision]): + if gpt2helper.compare_outputs( + outputs, + ort_outputs, + model_class=args.model_class, + rtol=DEFAULT_TOLERANCE[args.precision], + atol=DEFAULT_TOLERANCE[args.precision], + ): logger.info( - f'Pytorch and ONNX Runtime outputs are all close (tolerance={DEFAULT_TOLERANCE[args.precision]}).' + f"Pytorch and ONNX Runtime outputs are all close (tolerance={DEFAULT_TOLERANCE[args.precision]})." ) # Results of IO binding might be in GPU. Copy outputs to CPU for comparison. @@ -362,13 +460,15 @@ def main(args): for output in ort_io_outputs: copy_outputs.append(output.cpu().numpy()) - if gpt2helper.compare_outputs(outputs, - copy_outputs, - model_class=args.model_class, - rtol=DEFAULT_TOLERANCE[args.precision], - atol=DEFAULT_TOLERANCE[args.precision]): + if gpt2helper.compare_outputs( + outputs, + copy_outputs, + model_class=args.model_class, + rtol=DEFAULT_TOLERANCE[args.precision], + atol=DEFAULT_TOLERANCE[args.precision], + ): logger.info( - f'Pytorch and ONNX Runtime IO Binding outputs are all close (tolerance={DEFAULT_TOLERANCE[args.precision]}).' + f"Pytorch and ONNX Runtime IO Binding outputs are all close (tolerance={DEFAULT_TOLERANCE[args.precision]})." ) logger.info( @@ -387,7 +487,7 @@ def main(args): "past_sequence_length": past_sequence_length, "torch_latency": f"{torch_latency:.2f}", "onnxruntime_latency": f"{ort_latency:.2f}", - "onnxruntime_io_binding_latency": f"{ort_io_latency:.2f}" + "onnxruntime_io_binding_latency": f"{ort_io_latency:.2f}", } csv_writer.writerow(row) except: @@ -398,7 +498,7 @@ def main(args): return csv_filename -if __name__ == '__main__': +if __name__ == "__main__": args = parse_arguments() setup_logger(args.verbose) main(args) diff --git a/onnxruntime/python/tools/transformers/models/gpt2/convert_to_onnx.py b/onnxruntime/python/tools/transformers/models/gpt2/convert_to_onnx.py index 88999b0c385d6..231f08cfc4862 100644 --- a/onnxruntime/python/tools/transformers/models/gpt2/convert_to_onnx.py +++ b/onnxruntime/python/tools/transformers/models/gpt2/convert_to_onnx.py @@ -16,78 +16,92 @@ """ import argparse -import logging -import torch -import numpy import json +import logging +import os +import sys from pathlib import Path + +import numpy +import torch +from gpt2_beamsearch_helper import MODEL_CLASSES, Gpt2HelperFactory +from gpt2_beamsearch_tester import Gpt2TesterFactory +from gpt2_helper import DEFAULT_TOLERANCE, PRETRAINED_GPT2_MODELS from packaging import version from transformers import AutoConfig -from gpt2_helper import DEFAULT_TOLERANCE, PRETRAINED_GPT2_MODELS -from gpt2_beamsearch_helper import Gpt2HelperFactory, MODEL_CLASSES -from gpt2_beamsearch_tester import Gpt2TesterFactory - -import sys -import os -sys.path.append(os.path.join(os.path.dirname(__file__), '..', '..')) +sys.path.append(os.path.join(os.path.dirname(__file__), "..", "..")) +from benchmark_helper import Precision, create_onnxruntime_session, prepare_environment, setup_logger from quantize_helper import QuantizeHelper -from benchmark_helper import create_onnxruntime_session, setup_logger, prepare_environment, Precision -logger = logging.getLogger('') +logger = logging.getLogger("") def parse_arguments(argv=None): parser = argparse.ArgumentParser() - parser.add_argument('-m', - '--model_name_or_path', - required=True, - type=str, - help='Model path, or pretrained model name in the list: ' + ', '.join(PRETRAINED_GPT2_MODELS)) - - parser.add_argument('--model_class', - required=False, - type=str, - default='GPT2LMHeadModel', - choices=list(MODEL_CLASSES.keys()), - help='Model type selected in the list: ' + ', '.join(MODEL_CLASSES.keys())) - - parser.add_argument('--cache_dir', - required=False, - type=str, - default=os.path.join('.', 'cache_models'), - help='Directory to cache pre-trained models') - - parser.add_argument('--output', - required=False, - type=str, - default=os.path.join('.', 'onnx_models'), - help='Output directory, or model path ends with .onnx') - - parser.add_argument('-o', - '--optimize_onnx', - required=False, - action='store_true', - help='Use optimizer.py to optimize onnx model') + parser.add_argument( + "-m", + "--model_name_or_path", + required=True, + type=str, + help="Model path, or pretrained model name in the list: " + ", ".join(PRETRAINED_GPT2_MODELS), + ) + + parser.add_argument( + "--model_class", + required=False, + type=str, + default="GPT2LMHeadModel", + choices=list(MODEL_CLASSES.keys()), + help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()), + ) + + parser.add_argument( + "--cache_dir", + required=False, + type=str, + default=os.path.join(".", "cache_models"), + help="Directory to cache pre-trained models", + ) + + parser.add_argument( + "--output", + required=False, + type=str, + default=os.path.join(".", "onnx_models"), + help="Output directory, or model path ends with .onnx", + ) + + parser.add_argument( + "-o", + "--optimize_onnx", + required=False, + action="store_true", + help="Use optimizer.py to optimize onnx model", + ) parser.set_defaults(optimize_onnx=False) - parser.add_argument('--use_gpu', required=False, action='store_true', help="use GPU for inference") + parser.add_argument("--use_gpu", required=False, action="store_true", help="use GPU for inference") parser.set_defaults(use_gpu=False) - parser.add_argument('--tolerance', - required=False, - type=float, - default=0, - help="the aboslute and relative tolerance for parity verification") + parser.add_argument( + "--tolerance", + required=False, + type=float, + default=0, + help="the aboslute and relative tolerance for parity verification", + ) - parser.add_argument('--input_test_file', - '-i', - required=False, - type=str, - default='', - help='Path to the file with inputs to test with') + parser.add_argument( + "--input_test_file", + "-i", + required=False, + type=str, + default="", + help="Path to the file with inputs to test with", + ) parser.add_argument( "-p", @@ -96,110 +110,143 @@ def parse_arguments(argv=None): type=Precision, default=Precision.FLOAT32, choices=list(Precision), - help= - "Precision of model to run. fp32 for full precision, fp16 for half or mixed precision, and int8 for quantization" + help="Precision of model to run. fp32 for full precision, fp16 for half or mixed precision, and int8 for quantization", + ) + + parser.add_argument( + "-t", + "--test_cases", + required=False, + type=int, + default=1000, + help="Number of test cases per run for parity", + ) + parser.add_argument( + "-r", + "--test_runs", + required=False, + type=int, + default=10, + help="Number of runs for parity. It is used for significance test.", ) - parser.add_argument("-t", - "--test_cases", - required=False, - type=int, - default=1000, - help="Number of test cases per run for parity") - parser.add_argument("-r", - "--test_runs", - required=False, - type=int, - default=10, - help="Number of runs for parity. It is used for significance test.") - - parser.add_argument('--verbose', required=False, action='store_true') + parser.add_argument("--verbose", required=False, action="store_true") parser.set_defaults(verbose=False) - parser.add_argument('-e', '--use_external_data_format', required=False, action='store_true') + parser.add_argument("-e", "--use_external_data_format", required=False, action="store_true") parser.set_defaults(use_external_data_format=False) - parser.add_argument('--use_int32_inputs', - required=False, - action='store_true', - help='Use int32 instead of int64 for input_ids, position_ids and attention_mask.') + parser.add_argument( + "--use_int32_inputs", + required=False, + action="store_true", + help="Use int32 instead of int64 for input_ids, position_ids and attention_mask.", + ) parser.set_defaults(use_int32_inputs=False) - parser.add_argument('--beam_size', type=int, default=4, help='Beam size if greedy/top-p/top-k sampling is needed') + parser.add_argument( + "--beam_size", + type=int, + default=4, + help="Beam size if greedy/top-p/top-k sampling is needed", + ) search_option_group = parser.add_argument_group("configurable one step search options") - search_option_group.add_argument('--ignore_eos', - type=bool, - default=False, - help='If ignore end of sentence token in model inference.') - search_option_group.add_argument('--repetition_penalty', - type=float, - default=1, - help='Positive. >1 to penalize and <1 to encorage.') - search_option_group.add_argument('--temperature', - type=float, - default=1, - help='Softmax temperature for output logits.') - search_option_group.add_argument('--excluded_token_ids', - required=False, - nargs='+', - type=float, - help='A list of token ids to be excluded in inference.') - search_option_group.add_argument('--length_penalty', - type=float, - default=1, - help='Positive. >1 to penalize and <1 to encorage short sentence.') + search_option_group.add_argument( + "--ignore_eos", + type=bool, + default=False, + help="If ignore end of sentence token in model inference.", + ) + search_option_group.add_argument( + "--repetition_penalty", + type=float, + default=1, + help="Positive. >1 to penalize and <1 to encorage.", + ) + search_option_group.add_argument( + "--temperature", + type=float, + default=1, + help="Softmax temperature for output logits.", + ) + search_option_group.add_argument( + "--excluded_token_ids", + required=False, + nargs="+", + type=float, + help="A list of token ids to be excluded in inference.", + ) + search_option_group.add_argument( + "--length_penalty", + type=float, + default=1, + help="Positive. >1 to penalize and <1 to encorage short sentence.", + ) sampling_option_group = parser.add_argument_group("one step sampling options") - sampling_option_group.add_argument('--do_sample', - action='store_true', - help='If to do sampling instead of beam search or greedy.') - sampling_option_group.add_argument('--do_sample_top_p', - type=float, - default=0.95, - help='Nuclear/top-p sampling accumulation probability.') - sampling_option_group.add_argument('--do_sample_top_k', type=int, default=0, help='Use top-k if non-zero.') + sampling_option_group.add_argument( + "--do_sample", + action="store_true", + help="If to do sampling instead of beam search or greedy.", + ) + sampling_option_group.add_argument( + "--do_sample_top_p", + type=float, + default=0.95, + help="Nuclear/top-p sampling accumulation probability.", + ) + sampling_option_group.add_argument("--do_sample_top_k", type=int, default=0, help="Use top-k if non-zero.") fp16_option_group = parser.add_argument_group( - "float to float16 conversion parameters that works when \"--precision fp16\" is specified") + 'float to float16 conversion parameters that works when "--precision fp16" is specified' + ) fp16_option_group.add_argument( - '-a', - '--auto_mixed_precision', + "-a", + "--auto_mixed_precision", required=False, - action='store_true', - help='Convert to mixed precision automatically. Other float16 conversion parameters will be ignored.') + action="store_true", + help="Convert to mixed precision automatically. Other float16 conversion parameters will be ignored.", + ) fp16_option_group.set_defaults(auto_mixed_precision=False) - fp16_option_group.add_argument('--keep_io_types', - required=False, - action='store_true', - help='Use float32 for past inputs, present and logits outputs.') + fp16_option_group.add_argument( + "--keep_io_types", + required=False, + action="store_true", + help="Use float32 for past inputs, present and logits outputs.", + ) fp16_option_group.set_defaults(keep_io_types=False) - fp16_option_group.add_argument('--io_block_list', - nargs='+', - default=[], - help='List of inputs or outputs in float32 instead of float16') + fp16_option_group.add_argument( + "--io_block_list", + nargs="+", + default=[], + help="List of inputs or outputs in float32 instead of float16", + ) fp16_option_group.add_argument( - '--op_block_list', - nargs='+', + "--op_block_list", + nargs="+", default=[], - help= - 'List of operators (like Attention Gather Add LayerNormalization FastGelu MatMul) to compute in float32 instead of float16.' + help="List of operators (like Attention Gather Add LayerNormalization FastGelu MatMul) to compute in float32 instead of float16.", ) - fp16_option_group.add_argument('--node_block_list', - nargs='+', - default=[], - help='List of node names to compute in float32 instead of float16.') + fp16_option_group.add_argument( + "--node_block_list", + nargs="+", + default=[], + help="List of node names to compute in float32 instead of float16.", + ) - fp16_option_group.add_argument('--force_fp16_initializers', - required=False, - action='store_true', - help='Convert all float initializers to float16.') + fp16_option_group.add_argument( + "--force_fp16_initializers", + required=False, + action="store_true", + help="Convert all float initializers to float16.", + ) fp16_option_group.set_defaults(force_fp16_initializers=False) args = parser.parse_args(argv) @@ -211,7 +258,7 @@ def get_onnx_model_size(onnx_path: str, use_external_data_format: bool): if not use_external_data_format: return os.path.getsize(onnx_path) else: - return sum([f.stat().st_size for f in Path(onnx_path).parent.rglob('*')]) + return sum([f.stat().st_size for f in Path(onnx_path).parent.rglob("*")]) def get_latency_name(): @@ -221,8 +268,10 @@ def get_latency_name(): def main(argv=None, experiment_name="", run_id=0, csv_filename="gpt2_parity_results.csv"): result = {} from transformers import __version__ as transformers_version + if version.parse(transformers_version) < version.parse( - "3.1.0"): # past_key_values name does not exist in 3.0.2 or older + "3.1.0" + ): # past_key_values name does not exist in 3.0.2 or older raise RuntimeError("This tool requires transformers 3.1.0 or later.") args = parse_arguments(argv) @@ -230,6 +279,7 @@ def main(argv=None, experiment_name="", run_id=0, csv_filename="gpt2_parity_resu if not experiment_name: import sys + experiment_name = " ".join(argv if argv else sys.argv[1:]) if args.tolerance == 0: @@ -251,7 +301,7 @@ def main(argv=None, experiment_name="", run_id=0, csv_filename="gpt2_parity_resu assert not args.use_gpu, "quantization only supports CPU" if args.use_external_data_format: - assert not args.output.endswith('.onnx'), "output shall be a directory for --use_external_data_format" + assert not args.output.endswith(".onnx"), "output shall be a directory for --use_external_data_format" model_class = MODEL_CLASSES[args.model_class][0] use_padding = MODEL_CLASSES[args.model_class][2] @@ -266,26 +316,30 @@ def main(argv=None, experiment_name="", run_id=0, csv_filename="gpt2_parity_resu gpt2helper = Gpt2HelperFactory.create_helper(model_type) gpt2tester = Gpt2TesterFactory.create_tester(model_type) config = AutoConfig.from_pretrained(args.model_name_or_path, cache_dir=cache_dir) - if model_type == 'beam_search_step': - model = model_class.from_pretrained(args.model_name_or_path, - config=config, - batch_size=1, - beam_size=args.beam_size, - cache_dir=cache_dir) - elif model_type == 'configurable_one_step_search': - model = model_class.from_pretrained(args.model_name_or_path, - config=config, - batch_size=1, - beam_size=args.beam_size, - ignore_eos=args.ignore_eos, - temperature=args.temperature, - repetition_penalty=args.repetition_penalty, - excluded_token_ids=args.excluded_token_ids, - length_penalty=args.length_penalty, - do_sample=args.do_sample, - do_sample_top_p=args.do_sample_top_p, - do_sample_top_k=args.do_sample_top_k, - cache_dir=cache_dir) + if model_type == "beam_search_step": + model = model_class.from_pretrained( + args.model_name_or_path, + config=config, + batch_size=1, + beam_size=args.beam_size, + cache_dir=cache_dir, + ) + elif model_type == "configurable_one_step_search": + model = model_class.from_pretrained( + args.model_name_or_path, + config=config, + batch_size=1, + beam_size=args.beam_size, + ignore_eos=args.ignore_eos, + temperature=args.temperature, + repetition_penalty=args.repetition_penalty, + excluded_token_ids=args.excluded_token_ids, + length_penalty=args.length_penalty, + do_sample=args.do_sample, + do_sample_top_p=args.do_sample_top_p, + do_sample_top_k=args.do_sample_top_k, + cache_dir=cache_dir, + ) else: model = model_class.from_pretrained(args.model_name_or_path, config=config, cache_dir=cache_dir) @@ -300,7 +354,8 @@ def main(argv=None, experiment_name="", run_id=0, csv_filename="gpt2_parity_resu args.model_name_or_path, args.model_class, new_folder=args.use_external_data_format, - remove_existing=["fp32", "fp16", "int8"]) # Do not remove raw model to save time in parity test + remove_existing=["fp32", "fp16", "int8"], + ) # Do not remove raw model to save time in parity test raw_onnx_model = onnx_model_paths["raw"] @@ -308,16 +363,18 @@ def main(argv=None, experiment_name="", run_id=0, csv_filename="gpt2_parity_resu logger.warning(f"Skip exporting ONNX model since it existed: {raw_onnx_model}") else: logger.info(f"Exporting ONNX model to {raw_onnx_model}") - gpt2helper.export_onnx(model, - device, - raw_onnx_model, - args.verbose, - args.use_external_data_format, - has_position_ids=use_padding, - has_attention_mask=use_padding, - input_ids_dtype=torch.int32 if args.use_int32_inputs else torch.int64, - position_ids_dtype=torch.int32 if args.use_int32_inputs else torch.int64, - attention_mask_dtype=torch.int32 if args.use_int32_inputs else torch.int64) + gpt2helper.export_onnx( + model, + device, + raw_onnx_model, + args.verbose, + args.use_external_data_format, + has_position_ids=use_padding, + has_attention_mask=use_padding, + input_ids_dtype=torch.int32 if args.use_int32_inputs else torch.int64, + position_ids_dtype=torch.int32 if args.use_int32_inputs else torch.int64, + attention_mask_dtype=torch.int32 if args.use_int32_inputs else torch.int64, + ) fp16_params = {"keep_io_types": args.keep_io_types} if args.io_block_list: @@ -329,32 +386,35 @@ def main(argv=None, experiment_name="", run_id=0, csv_filename="gpt2_parity_resu if args.force_fp16_initializers: fp16_params["force_fp16_initializers"] = args.force_fp16_initializers - is_io_float16 = (args.precision == Precision.FLOAT16 and not args.keep_io_types) + is_io_float16 = args.precision == Precision.FLOAT16 and not args.keep_io_types if args.optimize_onnx or args.precision != Precision.FLOAT32: - output_path = onnx_model_paths[str(args.precision) if args.precision != Precision.INT8 else 'fp32'] + output_path = onnx_model_paths[str(args.precision) if args.precision != Precision.INT8 else "fp32"] logger.info(f"Optimizing model to {output_path}") - gpt2helper.optimize_onnx(raw_onnx_model, - output_path, - args.precision == Precision.FLOAT16, - model.config.num_attention_heads, - model.config.hidden_size, - args.use_external_data_format, - auto_mixed_precision=args.auto_mixed_precision, - **fp16_params) + gpt2helper.optimize_onnx( + raw_onnx_model, + output_path, + args.precision == Precision.FLOAT16, + model.config.num_attention_heads, + model.config.hidden_size, + args.use_external_data_format, + auto_mixed_precision=args.auto_mixed_precision, + **fp16_params, + ) else: output_path = raw_onnx_model if args.precision == Precision.INT8: logger.info("quantizing model...") - QuantizeHelper.quantize_onnx_model(output_path, onnx_model_paths['int8'], args.use_external_data_format) + QuantizeHelper.quantize_onnx_model(output_path, onnx_model_paths["int8"], args.use_external_data_format) model = QuantizeHelper.quantize_torch_model(model) logger.info("finished quantizing model") - output_path = onnx_model_paths['int8'] + output_path = onnx_model_paths["int8"] - if args.output.endswith('.onnx') and output_path != args.output and not args.use_external_data_format: + if args.output.endswith(".onnx") and output_path != args.output and not args.use_external_data_format: import shutil + shutil.move(output_path, args.output) output_path = args.output @@ -378,7 +438,8 @@ def main(argv=None, experiment_name="", run_id=0, csv_filename="gpt2_parity_resu attention_mask_dtype=torch.int32 if args.use_int32_inputs else torch.int64, test_cases_per_run=args.test_cases, total_runs=args.test_runs, - verbose=args.verbose) + verbose=args.verbose, + ) latency = gpt2helper.test_performance( session, @@ -395,23 +456,49 @@ def main(argv=None, experiment_name="", run_id=0, csv_filename="gpt2_parity_resu attention_mask_dtype=torch.int32 if args.use_int32_inputs else torch.int64, batch_size=8, sequence_length=1, - past_sequence_length=32) + past_sequence_length=32, + ) if args.precision == Precision.FLOAT16: logger.info(f"fp16 conversion parameters:{fp16_params}") # Write results to file import csv + from onnxruntime import __version__ as ort_version + latency_name = get_latency_name() csv_file_existed = os.path.exists(csv_filename) - with open(csv_filename, mode="a", newline='') as csv_file: + with open(csv_filename, mode="a", newline="") as csv_file: column_names = [ - "experiment", "run_id", "model_name", "model_class", "gpu", "precision", "optimizer", "test_cases", - "runs", "keep_io_types", "io_block_list", "op_block_list", "node_block_list", "force_fp16_initializers", - "auto_mixed_precision", "ORT_TRANSFORMER_OPTIONS", "ORT_CUDA_GEMM_OPTIONS", "onnxruntime", latency_name, - "top1_match_rate", "onnx_size_in_MB", "diff_50_percentile", "diff_90_percentile", "diff_95_percentile", - "diff_99_percentile", "diff_pass_rate", "nan_rate", "top1_match_rate_per_run" + "experiment", + "run_id", + "model_name", + "model_class", + "gpu", + "precision", + "optimizer", + "test_cases", + "runs", + "keep_io_types", + "io_block_list", + "op_block_list", + "node_block_list", + "force_fp16_initializers", + "auto_mixed_precision", + "ORT_TRANSFORMER_OPTIONS", + "ORT_CUDA_GEMM_OPTIONS", + "onnxruntime", + latency_name, + "top1_match_rate", + "onnx_size_in_MB", + "diff_50_percentile", + "diff_90_percentile", + "diff_95_percentile", + "diff_99_percentile", + "diff_pass_rate", + "nan_rate", + "top1_match_rate_per_run", ] csv_writer = csv.DictWriter(csv_file, fieldnames=column_names) if not csv_file_existed: @@ -432,8 +519,8 @@ def main(argv=None, experiment_name="", run_id=0, csv_filename="gpt2_parity_resu "node_block_list": args.node_block_list, "force_fp16_initializers": args.force_fp16_initializers, "auto_mixed_precision": args.auto_mixed_precision, - "ORT_TRANSFORMER_OPTIONS": os.getenv('ORT_TRANSFORMER_OPTIONS'), - "ORT_CUDA_GEMM_OPTIONS": os.getenv('ORT_CUDA_GEMM_OPTIONS'), + "ORT_TRANSFORMER_OPTIONS": os.getenv("ORT_TRANSFORMER_OPTIONS"), + "ORT_CUDA_GEMM_OPTIONS": os.getenv("ORT_CUDA_GEMM_OPTIONS"), "onnxruntime": ort_version, latency_name: f"{latency:.2f}", "diff_50_percentile": parity_result["max_diff_percentile_50"], @@ -463,24 +550,26 @@ def main(argv=None, experiment_name="", run_id=0, csv_filename="gpt2_parity_resu if use_padding: if "attention_mask" in data: numpy_float = numpy.float16 if is_io_float16 else numpy.float32 - attention_mask = torch.from_numpy(numpy.asarray(data["attention_mask"], - dtype=numpy_float)).to(device) + attention_mask = torch.from_numpy(numpy.asarray(data["attention_mask"], dtype=numpy_float)).to( + device + ) else: padding = -1 attention_mask = (input_ids != padding).type(torch.float16 if is_io_float16 else torch.float32) input_ids.masked_fill_(input_ids == padding, 0) if "position_ids" in data: - position_ids = torch.from_numpy(numpy.asarray(data["position_ids"], - dtype=numpy.int64)).to(device) + position_ids = torch.from_numpy(numpy.asarray(data["position_ids"], dtype=numpy.int64)).to( + device + ) else: - position_ids = (attention_mask.long().cumsum(-1) - 1) + position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(position_ids < 0, 0) inputs = { "input_ids": input_ids.to(torch.int32) if args.use_int32_inputs else input_ids, "position_ids": position_ids.to(torch.int32) if args.use_int32_inputs else position_ids, - "attention_mask": attention_mask.to(torch.int32) if args.use_int32_inputs else attention_mask + "attention_mask": attention_mask.to(torch.int32) if args.use_int32_inputs else attention_mask, } else: inputs = {"input_ids": input_ids.to(torch.int32) if args.use_int32_inputs else input_ids} @@ -490,31 +579,35 @@ def main(argv=None, experiment_name="", run_id=0, csv_filename="gpt2_parity_resu input_log_probs = torch.zeros([input_ids.shape[0], 1]) input_unfinished_sents = torch.ones([input_ids.shape[0], 1], dtype=torch.bool) - inputs.update({ - "beam_select_idx": beam_select_idx, - "input_log_probs": input_log_probs, - "input_unfinished_sents": input_unfinished_sents, - }) + inputs.update( + { + "beam_select_idx": beam_select_idx, + "input_log_probs": input_log_probs, + "input_unfinished_sents": input_unfinished_sents, + } + ) test_inputs.append(inputs) - gpt2tester.test_generation(session, - model, - device, - test_inputs, - precision=args.precision, - model_class=args.model_class, - top_k=20, - top_k_no_order=True, - max_steps=24, - max_inputs=0, - verbose=args.verbose, - save_test_data=3, - save_test_data_dir=Path(output_path).parent) + gpt2tester.test_generation( + session, + model, + device, + test_inputs, + precision=args.precision, + model_class=args.model_class, + top_k=20, + top_k_no_order=True, + max_steps=24, + max_inputs=0, + verbose=args.verbose, + save_test_data=3, + save_test_data_dir=Path(output_path).parent, + ) logger.info(f"Done. Output model: {output_path}") return result -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/onnxruntime/python/tools/transformers/models/gpt2/gpt2_beamsearch_helper.py b/onnxruntime/python/tools/transformers/models/gpt2/gpt2_beamsearch_helper.py index 463b2e7829ac6..65729cf068d3e 100644 --- a/onnxruntime/python/tools/transformers/models/gpt2/gpt2_beamsearch_helper.py +++ b/onnxruntime/python/tools/transformers/models/gpt2/gpt2_beamsearch_helper.py @@ -5,19 +5,19 @@ # -------------------------------------------------------------------------- # This script helps onnx conversion and validation for GPT2 model with past state. import logging -import torch +import os import random -import numpy +import sys import time from pathlib import Path -from typing import List, Dict, Union -from transformers import GPT2LMHeadModel, GPT2Config -from gpt2_helper import Gpt2Helper, Gpt2Inputs, MyGPT2Model, MyGPT2LMHeadModel, MyGPT2LMHeadModel_NoPadding +from typing import Dict, List, Union -import sys -import os +import numpy +import torch +from gpt2_helper import Gpt2Helper, Gpt2Inputs, MyGPT2LMHeadModel, MyGPT2LMHeadModel_NoPadding, MyGPT2Model +from transformers import GPT2Config, GPT2LMHeadModel -sys.path.append(os.path.join(os.path.dirname(__file__), '..', '..')) +sys.path.append(os.path.join(os.path.dirname(__file__), "..", "..")) from io_binding_helper import TypeHelper from torch_onnx_export_helper import torch_onnx_export @@ -27,7 +27,6 @@ class Gpt2HelperFactory: - @staticmethod def create_helper(helper_type="default"): helpers = { @@ -72,11 +71,13 @@ def forward( logits_flat, present_flat = MyGPT2Model.post_process(result, self.config.n_layer) next_token_logits = logits_flat[:, -1].view(self.config.batch_size, -1, logits_flat.size(-1)) next_token_log_probs = torch.log_softmax(next_token_logits, dim=-1) - next_token_log_probs, next_token_ids = torch.topk(next_token_log_probs, - self.config.beam_size, - dim=-1, - largest=True, - sorted=True) + next_token_log_probs, next_token_ids = torch.topk( + next_token_log_probs, + self.config.beam_size, + dim=-1, + largest=True, + sorted=True, + ) # finished sentences is always with EOS, and all but the first one has -inf, so that they will be automatically dropped in the round of beam search. finished_sents = ~input_unfinished_sents @@ -87,34 +88,35 @@ def forward( # select N sequences from beams of each input, sorted by sequence probability output_log_probs = output_log_probs.view(self.config.batch_size, -1) # shape=(batch, beam_size^2) - output_log_probs, selected_index_flat = output_log_probs.topk(self.config.beam_size, - dim=-1, - largest=True, - sorted=True) # output shape=(batch, beam_size) + output_log_probs, selected_index_flat = output_log_probs.topk( + self.config.beam_size, dim=-1, largest=True, sorted=True + ) # output shape=(batch, beam_size) # select the correspondent sentences/next tokens - selected_input_seq = torch.div(selected_index_flat, self.config.beam_size, - rounding_mode='trunc') # selected_index_flat // self.config.beam_size + selected_input_seq = torch.div( + selected_index_flat, self.config.beam_size, rounding_mode="trunc" + ) # selected_index_flat // self.config.beam_size next_token_ids = next_token_ids.view(self.config.batch_size, -1).gather(-1, selected_index_flat) prev_step_results = prev_step_results.view(self.config.batch_size, -1, prev_step_results.size(-1)) prev_step_results = prev_step_results.gather( - 1, - selected_input_seq.unsqueeze(-1).repeat(1, 1, prev_step_results.size(-1))) + 1, selected_input_seq.unsqueeze(-1).repeat(1, 1, prev_step_results.size(-1)) + ) output_unfinished_sents = input_unfinished_sents.gather(1, selected_input_seq) # Add ones_like to walkaround error like Shape mismatch attempting to re-use buffer. {1,1} != {1,4} - output_unfinished_sents = (output_unfinished_sents & next_token_ids.ne( - torch.ones_like(next_token_ids, dtype=torch.int) * self.config.eos_token_id)) + output_unfinished_sents = output_unfinished_sents & next_token_ids.ne( + torch.ones_like(next_token_ids, dtype=torch.int) * self.config.eos_token_id + ) # get the next full input_ids current_step_results = torch.cat([prev_step_results, next_token_ids.unsqueeze(-1)], dim=-1).contiguous() prev_step_scores = prev_step_scores.view(self.config.batch_size, -1, prev_step_scores.size(-1)) prev_step_scores = prev_step_scores.gather( - 1, - selected_input_seq.unsqueeze(-1).repeat(1, 1, prev_step_scores.size(-1))) + 1, selected_input_seq.unsqueeze(-1).repeat(1, 1, prev_step_scores.size(-1)) + ) current_step_scores = torch.cat([prev_step_scores, output_log_probs.unsqueeze(-1)], dim=-1).contiguous() return ( @@ -132,18 +134,20 @@ class GPT2LMHeadModel_ConfigurableOneStepSearch(GPT2LMHeadModel): """Here we wrap a class for Onnx model conversion for GPT2LMHeadModel with past state and one step beam search with configuration support.""" - def __init__(self, - config, - batch_size, - beam_size, - ignore_eos=False, - temperature=1.0, - repetition_penalty=1.0, - excluded_token_ids=None, - length_penalty=1.0, - do_sample=False, - do_sample_top_p=1, - do_sample_top_k=0): + def __init__( + self, + config, + batch_size, + beam_size, + ignore_eos=False, + temperature=1.0, + repetition_penalty=1.0, + excluded_token_ids=None, + length_penalty=1.0, + do_sample=False, + do_sample_top_p=1, + do_sample_top_k=0, + ): super().__init__(config) self.config.batch_size = batch_size self.config.beam_size = beam_size @@ -162,16 +166,25 @@ def collapse_first_two_dims(tensor): @staticmethod def top_k_top_p_filtering(log_probs, top_p=1.0, top_k=0): - '''Set tail event (out of top_p) to a big negative number''' + """Set tail event (out of top_p) to a big negative number""" sorted_log_probs, sorted_indices = torch.sort(log_probs, descending=True) cumulative_probs = torch.cumsum(sorted_log_probs.exp(), dim=-1) sorted_indices_to_remove = cumulative_probs >= top_p sorted_indices_to_remove = torch.cat( - [torch.zeros_like(sorted_indices_to_remove[..., :1]), sorted_indices_to_remove[..., :-1]], dim=-1) + [ + torch.zeros_like(sorted_indices_to_remove[..., :1]), + sorted_indices_to_remove[..., :-1], + ], + dim=-1, + ) if top_k > 0: sorted_indices_to_remove = torch.cat( - [sorted_indices_to_remove[..., :top_k], - torch.ones_like(sorted_indices_to_remove[..., top_k:])], dim=-1) + [ + sorted_indices_to_remove[..., :top_k], + torch.ones_like(sorted_indices_to_remove[..., top_k:]), + ], + dim=-1, + ) sorted_log_probs.masked_fill_(sorted_indices_to_remove, BIG_NEG) return log_probs.scatter(-1, sorted_indices, sorted_log_probs) @@ -188,8 +201,8 @@ def forward( input_num_seq_per_sample = input_ids.size(1) input_ids_unfinished_flat = self.collapse_first_two_dims(input_ids).index_select( - 0, - input_unfinished_sents.view(-1).nonzero(as_tuple=False).view(-1)) + 0, input_unfinished_sents.view(-1).nonzero(as_tuple=False).view(-1) + ) if self.config.ignore_eos: attention_mask = (input_ids_unfinished_flat != self.config.eos_token_id).float() @@ -203,8 +216,9 @@ def forward( input_ids_unfinished_flat = input_ids_unfinished_flat[:, last_seq_len:] position_ids = position_ids[:, last_seq_len:] - unfinished_index_relative_to_last_unfinished = beam_select_idx.view(-1)[input_unfinished_sents.view( - -1).nonzero(as_tuple=False).view(-1)] + unfinished_index_relative_to_last_unfinished = beam_select_idx.view(-1)[ + input_unfinished_sents.view(-1).nonzero(as_tuple=False).view(-1) + ] past = tuple([p.index_select(1, unfinished_index_relative_to_last_unfinished) for p in past]) @@ -218,23 +232,37 @@ def forward( logits_flat, present_flat = MyGPT2Model.post_process(result, self.config.n_layer) # insert finished sequence back to form a square shape of (batch_size, beam_size) - next_token_logits = logits_flat.new_zeros(input_ids.size()[:2] + (logits_flat.size(-1), )) - next_token_logits.index_fill_(2, torch.LongTensor([self.config.eos_token_id]).to(input_ids.device), -BIG_NEG) + next_token_logits = logits_flat.new_zeros(input_ids.size()[:2] + (logits_flat.size(-1),)) + next_token_logits.index_fill_( + 2, + torch.LongTensor([self.config.eos_token_id]).to(input_ids.device), + -BIG_NEG, + ) next_token_logits.masked_scatter_( - input_unfinished_sents.unsqueeze(-1).expand_as(next_token_logits), logits_flat[:, -1]) + input_unfinished_sents.unsqueeze(-1).expand_as(next_token_logits), + logits_flat[:, -1], + ) # repetition penalty from CTRL paper (https://arxiv.org/abs/1909.05858) if self.config.repetition_penalty != 1.0: _pen = next_token_logits.gather(2, input_ids) - _pen = torch.where(_pen > 0, _pen / self.config.repetition_penalty, _pen * self.config.repetition_penalty) + _pen = torch.where( + _pen > 0, + _pen / self.config.repetition_penalty, + _pen * self.config.repetition_penalty, + ) next_token_logits.scatter_(2, input_ids, _pen) # similar way to encourage short sentence if self.config.length_penalty != 1.0: _pen = next_token_logits[..., self.config.eos_token_id] # if eos > 0, increase it, else, decrease it. - _pen = torch.where(_pen > 0, _pen * self.config.length_penalty, _pen / self.config.length_penalty) + _pen = torch.where( + _pen > 0, + _pen * self.config.length_penalty, + _pen / self.config.length_penalty, + ) next_token_logits[..., self.config.eos_token_id] = _pen if self.config.temperature != 1.0: @@ -242,62 +270,76 @@ def forward( # exclude excluded_token_ids if self.config.excluded_token_ids is not None: - next_token_logits.index_fill_(2, self.config.excluded_token_ids.to(next_token_logits.device), - BIG_NEG) # batch x beams/sequences x vocab_size + next_token_logits.index_fill_( + 2, self.config.excluded_token_ids.to(next_token_logits.device), BIG_NEG + ) # batch x beams/sequences x vocab_size next_token_log_probs = torch.log_softmax(next_token_logits, dim=-1) if self.config.do_sample: vocab_size = next_token_log_probs.size(-1) - _next_token_log_probs = self.top_k_top_p_filtering(next_token_log_probs.view(-1, vocab_size), - top_k=self.config.do_sample_top_k, - top_p=self.config.do_sample_top_p) - next_token_ids = torch.multinomial(_next_token_log_probs.exp(), - num_samples=self.config.beam_size, - replacement=False) + _next_token_log_probs = self.top_k_top_p_filtering( + next_token_log_probs.view(-1, vocab_size), + top_k=self.config.do_sample_top_k, + top_p=self.config.do_sample_top_p, + ) + next_token_ids = torch.multinomial( + _next_token_log_probs.exp(), + num_samples=self.config.beam_size, + replacement=False, + ) next_token_ids = next_token_ids.view(self.config.batch_size, input_num_seq_per_sample, -1) next_token_log_probs = next_token_log_probs.gather(-1, next_token_ids) else: - next_token_log_probs, next_token_ids = torch.topk(next_token_log_probs, - self.config.beam_size, - dim=-1, - largest=True, - sorted=True) + next_token_log_probs, next_token_ids = torch.topk( + next_token_log_probs, + self.config.beam_size, + dim=-1, + largest=True, + sorted=True, + ) output_log_probs = input_log_probs.unsqueeze(-1) + next_token_log_probs # select N sequences from beams of each input, sorted by sequence probability output_log_probs = output_log_probs.view(self.config.batch_size, -1) # shape=(batch, beam_size^2) - output_log_probs, selected_index_flat = output_log_probs.topk(self.config.beam_size, - dim=-1, - largest=True, - sorted=True) # output shape=(batch, beam_size) + output_log_probs, selected_index_flat = output_log_probs.topk( + self.config.beam_size, dim=-1, largest=True, sorted=True + ) # output shape=(batch, beam_size) # select the correspondent sentences/next tokens - selected_input_seq = torch.div(selected_index_flat, self.config.beam_size, - rounding_mode='trunc') # selected_index_flat // self.config.beam_size + selected_input_seq = torch.div( + selected_index_flat, self.config.beam_size, rounding_mode="trunc" + ) # selected_index_flat // self.config.beam_size next_token_ids = next_token_ids.view(self.config.batch_size, -1).gather(-1, selected_index_flat) prev_step_results = input_ids.view(self.config.batch_size, -1, input_ids.size(-1)).contiguous() prev_step_results = prev_step_results.gather( 1, - selected_input_seq.unsqueeze(-1).expand(selected_input_seq.shape + (prev_step_results.size(-1), ))) + selected_input_seq.unsqueeze(-1).expand(selected_input_seq.shape + (prev_step_results.size(-1),)), + ) output_unfinished_sents = input_unfinished_sents.gather(1, selected_input_seq) - output_unfinished_sents = (output_unfinished_sents & next_token_ids.ne(self.config.eos_token_id)) + output_unfinished_sents = output_unfinished_sents & next_token_ids.ne(self.config.eos_token_id) current_step_results = torch.cat([prev_step_results, next_token_ids.unsqueeze(-1)], dim=-1).contiguous() prev_step_scores = prev_step_scores.view(self.config.batch_size, -1, prev_step_scores.size(-1)) prev_step_scores = prev_step_scores.gather( 1, - selected_input_seq.unsqueeze(-1).expand(selected_input_seq.shape + (prev_step_scores.size(-1), ))) + selected_input_seq.unsqueeze(-1).expand(selected_input_seq.shape + (prev_step_scores.size(-1),)), + ) current_step_scores = torch.cat([prev_step_scores, output_log_probs.unsqueeze(-1)], dim=-1).contiguous() # For next past state - index_relative_to_last_unfinished = (input_unfinished_sents.view(-1).float().cumsum(-1) - 1).clamp( - min=0).long().reshape_as(input_unfinished_sents).gather(1, selected_input_seq) + index_relative_to_last_unfinished = ( + (input_unfinished_sents.view(-1).float().cumsum(-1) - 1) + .clamp(min=0) + .long() + .reshape_as(input_unfinished_sents) + .gather(1, selected_input_seq) + ) return ( current_step_results.view(self.config.batch_size * self.config.beam_size, -1), @@ -311,16 +353,23 @@ def forward( # Maps model class name to a tuple of model class, name of first output and use padding or not MODEL_CLASSES = { - 'GPT2LMHeadModel': (MyGPT2LMHeadModel, 'logits', True), - 'GPT2LMHeadModel_NoPadding': (MyGPT2LMHeadModel_NoPadding, 'logits', False), - 'GPT2Model': (MyGPT2Model, 'last_state', True), - "GPT2LMHeadModel_BeamSearchStep": (GPT2LMHeadModel_BeamSearchStep, "last_state", True), - "GPT2LMHeadModel_ConfigurableOneStepSearch": (GPT2LMHeadModel_ConfigurableOneStepSearch, "last_state", False), + "GPT2LMHeadModel": (MyGPT2LMHeadModel, "logits", True), + "GPT2LMHeadModel_NoPadding": (MyGPT2LMHeadModel_NoPadding, "logits", False), + "GPT2Model": (MyGPT2Model, "last_state", True), + "GPT2LMHeadModel_BeamSearchStep": ( + GPT2LMHeadModel_BeamSearchStep, + "last_state", + True, + ), + "GPT2LMHeadModel_ConfigurableOneStepSearch": ( + GPT2LMHeadModel_ConfigurableOneStepSearch, + "last_state", + False, + ), } class Gpt2BeamSearchInputs(Gpt2Inputs): - def __init__( self, input_ids, @@ -345,10 +394,18 @@ def __init__( def to_list(self) -> List: input_list = [ - v for v in [ - self.input_ids, self.position_ids, self.attention_mask, self.beam_select_idx, self.input_log_probs, - self.input_unfinished_sents, self.prev_step_results, self.prev_step_scores - ] if v is not None + v + for v in [ + self.input_ids, + self.position_ids, + self.attention_mask, + self.beam_select_idx, + self.input_log_probs, + self.input_unfinished_sents, + self.prev_step_results, + self.prev_step_scores, + ] + if v is not None ] if self.past: input_list.extend(self.past) @@ -356,8 +413,9 @@ def to_list(self) -> List: def to_fp32(self): past = [p.to(dtype=torch.float32) for p in self.past] - attention_mask = self.attention_mask.to( - dtype=torch.float32) if self.attention_mask is not None else self.attention_mask + attention_mask = ( + self.attention_mask.to(dtype=torch.float32) if self.attention_mask is not None else self.attention_mask + ) return Gpt2BeamSearchInputs( self.input_ids, past, @@ -375,36 +433,39 @@ class Gpt2BeamSearchHelper(Gpt2Helper): """A helper class for Gpt2 model conversion, inference and verification.""" @staticmethod - def get_dummy_inputs(batch_size: int, - past_sequence_length: int, - sequence_length: int, - num_attention_heads: int, - hidden_size: int, - num_layer: int, - vocab_size: int, - device: torch.device, - float16: bool = False, - has_position_ids: bool = True, - has_attention_mask: bool = True, - input_ids_dtype: torch.dtype = torch.int64, - position_ids_dtype: torch.dtype = torch.int64, - attention_mask_dtype: torch.dtype = torch.int64) -> Gpt2BeamSearchInputs: - """Create random inputs for GPT2 beam search. - """ - gpt2_dummy_inputs = Gpt2Helper.get_dummy_inputs(batch_size, - past_sequence_length, - sequence_length, - num_attention_heads, - hidden_size, - num_layer, - vocab_size, - device, - float16, - has_position_ids, - has_attention_mask, - input_ids_dtype=input_ids_dtype, - position_ids_dtype=position_ids_dtype, - attention_mask_dtype=attention_mask_dtype) + def get_dummy_inputs( + batch_size: int, + past_sequence_length: int, + sequence_length: int, + num_attention_heads: int, + hidden_size: int, + num_layer: int, + vocab_size: int, + device: torch.device, + float16: bool = False, + has_position_ids: bool = True, + has_attention_mask: bool = True, + input_ids_dtype: torch.dtype = torch.int64, + position_ids_dtype: torch.dtype = torch.int64, + attention_mask_dtype: torch.dtype = torch.int64, + ) -> Gpt2BeamSearchInputs: + """Create random inputs for GPT2 beam search.""" + gpt2_dummy_inputs = Gpt2Helper.get_dummy_inputs( + batch_size, + past_sequence_length, + sequence_length, + num_attention_heads, + hidden_size, + num_layer, + vocab_size, + device, + float16, + has_position_ids, + has_attention_mask, + input_ids_dtype=input_ids_dtype, + position_ids_dtype=position_ids_dtype, + attention_mask_dtype=attention_mask_dtype, + ) float_type = torch.float16 if float16 else torch.float32 beam_select_idx = torch.zeros([1, batch_size], device=device).long() @@ -436,15 +497,17 @@ def get_dummy_inputs(batch_size: int, ) @staticmethod - def get_output_shapes(batch_size: int, - context_len: int, - past_sequence_length: int, - sequence_length: int, - beam_size: int, - step: int, - config: GPT2Config, - model_class: str = "GPT2LMHeadModel_BeamSearchStep", - num_seq: int = 0) -> Dict[str, List[int]]: + def get_output_shapes( + batch_size: int, + context_len: int, + past_sequence_length: int, + sequence_length: int, + beam_size: int, + step: int, + config: GPT2Config, + model_class: str = "GPT2LMHeadModel_BeamSearchStep", + num_seq: int = 0, + ) -> Dict[str, List[int]]: """Returns a dictionary with output name as key, and shape as value.""" num_attention_heads = config.num_attention_heads hidden_size = config.hidden_size @@ -456,7 +519,10 @@ def get_output_shapes(batch_size: int, if model_class == "GPT2LMHeadModel_BeamSearchStep": last_state_shape = [batch_size, beam_size] else: - last_state_shape = [batch_size * beam_size, past_sequence_length - context_len + sequence_length + 1] + last_state_shape = [ + batch_size * beam_size, + past_sequence_length - context_len + sequence_length + 1, + ] if model_class == "GPT2LMHeadModel_BeamSearchStep": if step == 0: @@ -497,9 +563,13 @@ def get_output_shapes(batch_size: int, output_shapes["output_unfinished_sents"] = [batch_size, beam_size] if model_class == "GPT2LMHeadModel_BeamSearchStep": output_shapes["current_step_results"] = [ - batch_size * beam_size, past_sequence_length - context_len + sequence_length + 1 + batch_size * beam_size, + past_sequence_length - context_len + sequence_length + 1, ] - output_shapes["current_step_scores"] = [batch_size * beam_size, past_sequence_length - context_len + 2] + output_shapes["current_step_scores"] = [ + batch_size * beam_size, + past_sequence_length - context_len + 2, + ] print("output_shapes", output_shapes) return output_shapes @@ -510,7 +580,7 @@ def get_output_buffers(output_shapes, device, is_float16=False): output_buffers = {} for name, shape in output_shapes.items(): - if (name == "output_selected_indices" or name == "current_step_results" or name == "last_state"): + if name == "output_selected_indices" or name == "current_step_results" or name == "last_state": output_buffers[name] = torch.empty(numpy.prod(shape), dtype=torch.long, device=device) elif name == "output_unfinished_sents": output_buffers[name] = torch.empty(numpy.prod(shape), dtype=torch.bool, device=device) @@ -519,11 +589,13 @@ def get_output_buffers(output_shapes, device, is_float16=False): return output_buffers @staticmethod - def compare_outputs(torch_outputs, - ort_outputs, - model_class="GPT2LMHeadModel_BeamSearchStep", - rtol=1e-03, - atol=1e-03): + def compare_outputs( + torch_outputs, + ort_outputs, + model_class="GPT2LMHeadModel_BeamSearchStep", + rtol=1e-03, + atol=1e-03, + ): """Returns True if torch and ORT outputs are close for given thresholds, and False otherwise.""" if model_class == "GPT2LMHeadModel_BeamSearchStep": results_id = -4 @@ -532,10 +604,12 @@ def compare_outputs(torch_outputs, results_id = 0 num_layers = len(ort_outputs) - 5 - is_close = numpy.allclose(ort_outputs[results_id], - torch_outputs[results_id].cpu().numpy(), - rtol=rtol, - atol=atol) + is_close = numpy.allclose( + ort_outputs[results_id], + torch_outputs[results_id].cpu().numpy(), + rtol=rtol, + atol=atol, + ) logger.debug(f"PyTorch and OnnxRuntime output 0 (last_state) are close: {is_close}") is_all_close = is_close @@ -556,29 +630,36 @@ def compare_outputs(torch_outputs, return is_all_close @staticmethod - def export_onnx(model, - device, - onnx_model_path: str, - verbose: bool = False, - use_external_data_format: bool = False, - has_position_ids: bool = True, - has_attention_mask: bool = True): + def export_onnx( + model, + device, + onnx_model_path: str, + verbose: bool = False, + use_external_data_format: bool = False, + has_position_ids: bool = True, + has_attention_mask: bool = True, + ): """Export GPT-2 model with past state to ONNX model.""" - assert isinstance(model, (GPT2LMHeadModel_BeamSearchStep, GPT2LMHeadModel_ConfigurableOneStepSearch)) + assert isinstance( + model, + (GPT2LMHeadModel_BeamSearchStep, GPT2LMHeadModel_ConfigurableOneStepSearch), + ) config: GPT2Config = model.config num_layer = config.n_layer - dummy_inputs = Gpt2BeamSearchHelper.get_dummy_inputs(batch_size=1, - past_sequence_length=1, - sequence_length=2, - num_attention_heads=config.num_attention_heads, - hidden_size=config.hidden_size, - num_layer=num_layer, - vocab_size=config.vocab_size, - device=device, - float16=False, - has_position_ids=has_position_ids, - has_attention_mask=has_attention_mask) + dummy_inputs = Gpt2BeamSearchHelper.get_dummy_inputs( + batch_size=1, + past_sequence_length=1, + sequence_length=2, + num_attention_heads=config.num_attention_heads, + hidden_size=config.hidden_size, + num_layer=num_layer, + vocab_size=config.vocab_size, + device=device, + float16=False, + has_position_ids=has_position_ids, + has_attention_mask=has_attention_mask, + ) input_list = dummy_inputs.to_list() with torch.no_grad(): @@ -606,14 +687,8 @@ def export_onnx(model, ] dynamic_axes = { - "input_ids": { - 0: "batch_size", - 1: "seq_len" - }, - output_names[0]: { - 0: "batch_size", - 1: "seq_len" - }, + "input_ids": {0: "batch_size", 1: "seq_len"}, + output_names[0]: {0: "batch_size", 1: "seq_len"}, } for name in past_names: dynamic_axes[name] = {1: "batch_size", 3: "past_seq_len"} @@ -641,12 +716,15 @@ def export_onnx(model, input_names.extend(past_names) # add dynamic output axes - present_axes = {1: 'batch_size', 3: 'cur_seq_len'} + present_axes = {1: "batch_size", 3: "cur_seq_len"} if isinstance(model, GPT2LMHeadModel_BeamSearchStep): dynamic_axes["last_state"] = {0: "batch_size", 1: "beam_size"} else: - dynamic_axes["last_state"] = {0: "batch_size * beam_size", 1: "total_seq_len"} + dynamic_axes["last_state"] = { + 0: "batch_size * beam_size", + 1: "total_seq_len", + } for i in range(num_layer): dynamic_axes["present_" + str(i)] = present_axes @@ -656,7 +734,10 @@ def export_onnx(model, dynamic_axes["output_unfinished_sents"] = {0: "batch_size", 1: "beam_size"} if "current_step_results" in output_names: - dynamic_axes["current_step_results"] = {0: "batch_size * beam_size", 1: "total_seq_len"} + dynamic_axes["current_step_results"] = { + 0: "batch_size * beam_size", + 1: "total_seq_len", + } dynamic_axes["current_step_scores"] = {0: "batch_size * beam_size"} @@ -720,28 +801,32 @@ def onnxruntime_inference(ort_session, inputs: Gpt2BeamSearchInputs, total_runs: return ort_outputs, average_latency @staticmethod - def prepare_io_binding(ort_session, - input_ids, - position_ids, - attention_mask, - past, - output_buffers, - output_shapes, - beam_select_idx=None, - input_log_probs=None, - input_unfinished_sents=None, - prev_step_results=None, - prev_step_scores=None): + def prepare_io_binding( + ort_session, + input_ids, + position_ids, + attention_mask, + past, + output_buffers, + output_shapes, + beam_select_idx=None, + input_log_probs=None, + input_unfinished_sents=None, + prev_step_results=None, + prev_step_scores=None, + ): """Returnas IO binding object for a session.""" # Bind (input_ids, position_ids, attention_mask and past_*) and all outputs - io_binding = Gpt2Helper.prepare_io_binding(ort_session, - input_ids, - position_ids, - attention_mask, - past=past, - output_buffers=output_buffers, - output_shapes=output_shapes) + io_binding = Gpt2Helper.prepare_io_binding( + ort_session, + input_ids, + position_ids, + attention_mask, + past=past, + output_buffers=output_buffers, + output_shapes=output_shapes, + ) # Bind the remaining inputs other_inputs = { @@ -767,15 +852,16 @@ def prepare_io_binding(ort_session, return io_binding @staticmethod - def onnxruntime_inference_with_binded_io(ort_session, - inputs: Gpt2BeamSearchInputs, - output_buffers: Dict[str, torch.Tensor], - output_shapes: Dict[str, List[int]], - total_runs: int = 0, - return_numpy: bool = True, - include_copy_output_latency: bool = False): - """Inference with IO binding. Returns outputs, and optional latency when total_runs > 0. - """ + def onnxruntime_inference_with_binded_io( + ort_session, + inputs: Gpt2BeamSearchInputs, + output_buffers: Dict[str, torch.Tensor], + output_shapes: Dict[str, List[int]], + total_runs: int = 0, + return_numpy: bool = True, + include_copy_output_latency: bool = False, + ): + """Inference with IO binding. Returns outputs, and optional latency when total_runs > 0.""" logger.debug(f"start onnxruntime_inference_with_binded_io") # Bind inputs and outputs to onnxruntime session @@ -798,8 +884,9 @@ def onnxruntime_inference_with_binded_io(ort_session, ort_session.run_with_iobinding(io_binding) # Copy results to cpu for verification - ort_outputs = Gpt2BeamSearchHelper.get_outputs_from_io_binding_buffer(ort_session, output_buffers, - output_shapes, return_numpy) + ort_outputs = Gpt2BeamSearchHelper.get_outputs_from_io_binding_buffer( + ort_session, output_buffers, output_shapes, return_numpy + ) if total_runs == 0: return ort_outputs @@ -810,8 +897,9 @@ def onnxruntime_inference_with_binded_io(ort_session, # Run onnxruntime with io binding ort_session.run_with_iobinding(io_binding) if include_copy_output_latency: - _ = Gpt2BeamSearchHelper.get_outputs_from_io_binding_buffer(ort_session, output_buffers, output_shapes, - return_numpy) + _ = Gpt2BeamSearchHelper.get_outputs_from_io_binding_buffer( + ort_session, output_buffers, output_shapes, return_numpy + ) latency.append(time.time() - start) average_latency = sum(latency) * 1000 / len(latency) @@ -820,17 +908,19 @@ def onnxruntime_inference_with_binded_io(ort_session, return ort_outputs, average_latency @staticmethod - def test_parity(ort_session, - model, - device, - is_float16=False, - rtol=5e-4, - atol=5e-4, - total_test_cases=100, - use_io_binding=True, - model_class="GPT2LMHeadModel_BeamSearchStep", - has_position_ids=True, - has_attention_mask=True): + def test_parity( + ort_session, + model, + device, + is_float16=False, + rtol=5e-4, + atol=5e-4, + total_test_cases=100, + use_io_binding=True, + model_class="GPT2LMHeadModel_BeamSearchStep", + has_position_ids=True, + has_attention_mask=True, + ): """Generate random inputs and compare the results of PyTorch and Onnx Runtime.""" config: GPT2Config = model.config @@ -865,11 +955,21 @@ def test_parity(ort_session, batch_size = random.randint(1, max_batch_size) logger.debug( - f"Running parity test for batch_size={batch_size} past_sequence_length={past_sequence_length}...") - dummy_inputs = Gpt2BeamSearchHelper.get_dummy_inputs(batch_size, past_sequence_length, sequence_length, - config.num_attention_heads, config.hidden_size, - config.n_layer, config.vocab_size, device, is_float16, - has_position_ids, has_attention_mask) + f"Running parity test for batch_size={batch_size} past_sequence_length={past_sequence_length}..." + ) + dummy_inputs = Gpt2BeamSearchHelper.get_dummy_inputs( + batch_size, + past_sequence_length, + sequence_length, + config.num_attention_heads, + config.hidden_size, + config.n_layer, + config.vocab_size, + device, + is_float16, + has_position_ids, + has_attention_mask, + ) outputs = Gpt2BeamSearchHelper.pytorch_inference(model, dummy_inputs) if use_io_binding: @@ -886,13 +986,12 @@ def test_parity(ort_session, model_class, ) ort_outputs = Gpt2BeamSearchHelper.onnxruntime_inference_with_binded_io( - ort_session, dummy_inputs, output_buffers, output_shapes) + ort_session, dummy_inputs, output_buffers, output_shapes + ) - is_all_close = Gpt2BeamSearchHelper.compare_outputs(outputs, - ort_outputs, - model_class=model_class, - rtol=rtol, - atol=atol) + is_all_close = Gpt2BeamSearchHelper.compare_outputs( + outputs, ort_outputs, model_class=model_class, rtol=rtol, atol=atol + ) if is_all_close: passed_test_cases += 1 logger.info(f"Parity Test Cases={total_test_cases}; Passed={passed_test_cases}") diff --git a/onnxruntime/python/tools/transformers/models/gpt2/gpt2_beamsearch_tester.py b/onnxruntime/python/tools/transformers/models/gpt2/gpt2_beamsearch_tester.py index 540b96189c023..3a8c17a3b7c7e 100644 --- a/onnxruntime/python/tools/transformers/models/gpt2/gpt2_beamsearch_tester.py +++ b/onnxruntime/python/tools/transformers/models/gpt2/gpt2_beamsearch_tester.py @@ -5,16 +5,16 @@ # -------------------------------------------------------------------------- # This script helps evaluation of GPT-2 model. import logging -import torch -import numpy +import os +import sys import timeit -from gpt2_tester import Gpt2Tester, Gpt2Metric -from gpt2_beamsearch_helper import Gpt2BeamSearchHelper, Gpt2BeamSearchInputs -import sys -import os +import numpy +import torch +from gpt2_beamsearch_helper import Gpt2BeamSearchHelper, Gpt2BeamSearchInputs +from gpt2_tester import Gpt2Metric, Gpt2Tester -sys.path.append(os.path.join(os.path.dirname(__file__), '..', '..')) +sys.path.append(os.path.join(os.path.dirname(__file__), "..", "..")) from benchmark_helper import Precision @@ -22,7 +22,6 @@ class Gpt2TesterFactory: - @staticmethod def create_tester(tester_type="default"): testers = { @@ -35,7 +34,6 @@ def create_tester(tester_type="default"): class Gpt2BeamSearchTester(Gpt2Tester): - def __init__( self, input_ids, @@ -55,16 +53,18 @@ def __init__( top_k=20, top_k_required_order=False, ): - super().__init__(input_ids, - position_ids, - attention_mask, - num_attention_heads=num_attention_heads, - hidden_size=hidden_size, - num_layer=num_layer, - device=device, - is_fp16=is_fp16, - top_k=top_k, - top_k_required_order=top_k_required_order) + super().__init__( + input_ids, + position_ids, + attention_mask, + num_attention_heads=num_attention_heads, + hidden_size=hidden_size, + num_layer=num_layer, + device=device, + is_fp16=is_fp16, + top_k=top_k, + top_k_required_order=top_k_required_order, + ) self.input_length = input_ids.shape[-1] self.n_layer = num_layer self.beam_size = beam_size @@ -97,17 +97,27 @@ def update(self, output, step, device): """ Update the inputs for next inference. """ - self.last_state = (torch.from_numpy(output[0]).to(device) - if isinstance(output[0], numpy.ndarray) else output[0].clone().detach().cpu()) + self.last_state = ( + torch.from_numpy(output[0]).to(device) + if isinstance(output[0], numpy.ndarray) + else output[0].clone().detach().cpu() + ) self.input_ids = self.last_state.view(self.batch_size * self.beam_size, -1).to(device) if self.position_ids is not None: input_unfinished_sents_id = -3 - self.prev_step_results = (torch.from_numpy(output[-2]).to(device) if isinstance(output[-2], numpy.ndarray) - else output[-2].clone().detach().to(device)) - self.position_ids = (torch.tensor([self.input_length + step - 1 - ]).unsqueeze(0).repeat(self.batch_size * self.beam_size, 1).to(device)) + self.prev_step_results = ( + torch.from_numpy(output[-2]).to(device) + if isinstance(output[-2], numpy.ndarray) + else output[-2].clone().detach().to(device) + ) + self.position_ids = ( + torch.tensor([self.input_length + step - 1]) + .unsqueeze(0) + .repeat(self.batch_size * self.beam_size, 1) + .to(device) + ) if self.attention_mask.size(0) != (self.batch_size * self.beam_size): self.attention_mask = self.attention_mask.repeat(self.batch_size * self.beam_size, 1) @@ -121,17 +131,26 @@ def update(self, output, step, device): else: input_unfinished_sents_id = -2 - self.beam_select_idx = (torch.from_numpy(output[input_unfinished_sents_id - 2]).to(device) if isinstance( - output[input_unfinished_sents_id - 2], numpy.ndarray) else output[input_unfinished_sents_id - - 2].clone().detach().to(device)) - self.input_log_probs = (torch.from_numpy(output[input_unfinished_sents_id - 1]).to(device) if isinstance( - output[input_unfinished_sents_id - 1], numpy.ndarray) else output[input_unfinished_sents_id - - 1].clone().detach().to(device)) - self.input_unfinished_sents = (torch.from_numpy(output[input_unfinished_sents_id]).to(device) if isinstance( - output[input_unfinished_sents_id], numpy.ndarray) else - output[input_unfinished_sents_id].clone().detach().to(device)) - self.prev_step_scores = (torch.from_numpy(output[-1]).to(device) - if isinstance(output[-1], numpy.ndarray) else output[-1].clone().detach().to(device)) + self.beam_select_idx = ( + torch.from_numpy(output[input_unfinished_sents_id - 2]).to(device) + if isinstance(output[input_unfinished_sents_id - 2], numpy.ndarray) + else output[input_unfinished_sents_id - 2].clone().detach().to(device) + ) + self.input_log_probs = ( + torch.from_numpy(output[input_unfinished_sents_id - 1]).to(device) + if isinstance(output[input_unfinished_sents_id - 1], numpy.ndarray) + else output[input_unfinished_sents_id - 1].clone().detach().to(device) + ) + self.input_unfinished_sents = ( + torch.from_numpy(output[input_unfinished_sents_id]).to(device) + if isinstance(output[input_unfinished_sents_id], numpy.ndarray) + else output[input_unfinished_sents_id].clone().detach().to(device) + ) + self.prev_step_scores = ( + torch.from_numpy(output[-1]).to(device) + if isinstance(output[-1], numpy.ndarray) + else output[-1].clone().detach().to(device) + ) self.top_1_tokens = self.input_ids[0] self.top_k_tokens = self.last_state @@ -141,24 +160,29 @@ def update(self, output, step, device): self.past = list(output[1]) else: for i in range(self.n_layer): - past_i = (torch.from_numpy(output[i + 1]) - if isinstance(output[i + 1], numpy.ndarray) else output[i + 1].clone().detach()) + past_i = ( + torch.from_numpy(output[i + 1]) + if isinstance(output[i + 1], numpy.ndarray) + else output[i + 1].clone().detach() + ) self.past.append(past_i.to(device)) @staticmethod - def test_generation(session, - model, - device, - test_inputs, - precision=Precision.FLOAT32, - model_class="GPT2LMHeadModel_BeamSearchStep", - top_k=20, - top_k_no_order=True, - max_steps=24, - max_inputs=0, - verbose=False, - save_test_data=0, - save_test_data_dir="."): + def test_generation( + session, + model, + device, + test_inputs, + precision=Precision.FLOAT32, + model_class="GPT2LMHeadModel_BeamSearchStep", + top_k=20, + top_k_no_order=True, + max_steps=24, + max_inputs=0, + verbose=False, + save_test_data=0, + save_test_data_dir=".", + ): """ Test Generation using beam search to compare PyTorch and ONNX model. It will print top 1 and top k errors on the given test inputs. @@ -208,9 +232,9 @@ def test_generation(session, print(f"{i}") input_ids = inputs["input_ids"] position_ids = inputs["position_ids"] if "position_ids" in inputs else None - attention_mask = (inputs["attention_mask"] if "attention_mask" in inputs else None) - beam_select_idx = (inputs["beam_select_idx"] if "beam_select_idx" in inputs else None) - input_log_probs = (inputs["input_log_probs"] if "input_log_probs" in inputs else None) + attention_mask = inputs["attention_mask"] if "attention_mask" in inputs else None + beam_select_idx = inputs["beam_select_idx"] if "beam_select_idx" in inputs else None + input_log_probs = inputs["input_log_probs"] if "input_log_probs" in inputs else None input_unfinished_sents = inputs["input_unfinished_sents"] if model_class == "GPT2LMHeadModel_BeamSearchStep": prev_step_results = inputs["input_ids"] @@ -324,10 +348,7 @@ def test_generation(session, Gpt2BeamSearchHelper.auto_increase_buffer_size(output_buffers, output_shapes) - ( - onnx_io_output, - avg_latency_ms, - ) = Gpt2BeamSearchHelper.onnxruntime_inference_with_binded_io( + (onnx_io_output, avg_latency_ms,) = Gpt2BeamSearchHelper.onnxruntime_inference_with_binded_io( session, onnx_io_runner.get_inputs(), output_buffers, @@ -345,8 +366,9 @@ def test_generation(session, onnx_io_runner.update(onnx_io_output, step, device) - if ((not onnx_runner.input_unfinished_sents.any()) - or (not torch_runner.input_unfinished_sents.any())): + if (not onnx_runner.input_unfinished_sents.any()) or ( + not torch_runner.input_unfinished_sents.any() + ): print("break at step: ", step) break @@ -407,7 +429,7 @@ def pprint_results( # remove EOS for k, t in enumerate(seq): if t == eos_token_id: - seq = seq[:k + 1] + seq = seq[: k + 1] break print("-" * 40) result = ",".join([str(token_id) for token_id in sample]) diff --git a/onnxruntime/python/tools/transformers/models/gpt2/gpt2_helper.py b/onnxruntime/python/tools/transformers/models/gpt2/gpt2_helper.py index a7a42d7da816f..51f8bfb174d5b 100644 --- a/onnxruntime/python/tools/transformers/models/gpt2/gpt2_helper.py +++ b/onnxruntime/python/tools/transformers/models/gpt2/gpt2_helper.py @@ -5,38 +5,41 @@ # -------------------------------------------------------------------------- # This script helps onnx conversion and validation for GPT2 model with past state. import logging -import torch -import shutil +import os +import pickle import random -import numpy +import shutil +import sys import time -import pickle from pathlib import Path -from typing import List, Dict, Tuple, Union -from transformers import GPT2Model, GPT2LMHeadModel, GPT2Config, TFGPT2Model +from typing import Dict, List, Tuple, Union -import sys -import os +import numpy +import torch +from transformers import GPT2Config, GPT2LMHeadModel, GPT2Model, TFGPT2Model -sys.path.append(os.path.join(os.path.dirname(__file__), '..', '..')) +sys.path.append(os.path.join(os.path.dirname(__file__), "..", "..")) +from benchmark_helper import Precision from float16 import float_to_float16_max_diff -from onnx_model import OnnxModel from fusion_utils import FusionUtils -from benchmark_helper import Precision from io_binding_helper import IOBindingHelper +from onnx_model import OnnxModel from torch_onnx_export_helper import torch_onnx_export logger = logging.getLogger(__name__) -PRETRAINED_GPT2_MODELS = ['distilgpt2', 'gpt2', 'gpt2-medium', 'gpt2-large', 'gpt2-xl'] +PRETRAINED_GPT2_MODELS = ["distilgpt2", "gpt2", "gpt2-medium", "gpt2-large", "gpt2-xl"] -DEFAULT_TOLERANCE = {Precision.FLOAT32: 0.0005, Precision.FLOAT16: 0.2, Precision.INT8: 3.0} +DEFAULT_TOLERANCE = { + Precision.FLOAT32: 0.0005, + Precision.FLOAT16: 0.2, + Precision.INT8: 3.0, +} class GPT2ModelNoPastState(GPT2Model): - """ Here we wrap a class to disable past state output. - """ + """Here we wrap a class to disable past state output.""" def __init__(self, config): super().__init__(config) @@ -46,8 +49,7 @@ def forward(self, input_ids): class TFGPT2ModelNoPastState(TFGPT2Model): - """ Here we wrap a class to disable past state output. - """ + """Here we wrap a class to disable past state output.""" def __init__(self, config): config.use_cache = False @@ -58,8 +60,7 @@ def forward(self, input_ids): class MyGPT2Model(GPT2Model): - """ Here we wrap a class for Onnx model conversion for GPT2Model with past state. - """ + """Here we wrap a class for Onnx model conversion for GPT2Model with past state.""" def __init__(self, config): super().__init__(config) @@ -68,46 +69,54 @@ def __init__(self, config): def post_process(result, num_layer): if isinstance(result[1][0], tuple) or isinstance(result[1][0], list): assert len(result[1]) == num_layer and len(result[1][0]) == 2 - #assert len(result[1][0][0].shape) == 4 and result[1][0][0].shape == result[1][0][1].shape + # assert len(result[1][0][0].shape) == 4 and result[1][0][0].shape == result[1][0][1].shape present = [] for i in range(num_layer): # Since transformers v4.*, past key and values are separated outputs. # Here we concate them into one tensor to be compatible with Attention operator. - present.append(torch.cat((result[1][i][0].unsqueeze(0), result[1][i][1].unsqueeze(0)), dim=0)) + present.append( + torch.cat( + (result[1][i][0].unsqueeze(0), result[1][i][1].unsqueeze(0)), + dim=0, + ) + ) return (result[0], tuple(present)) return result def forward(self, input_ids, position_ids, attention_mask, *past): - result = super().forward(input_ids, - position_ids=position_ids, - attention_mask=attention_mask, - past_key_values=past, - return_dict=False) + result = super().forward( + input_ids, + position_ids=position_ids, + attention_mask=attention_mask, + past_key_values=past, + return_dict=False, + ) return MyGPT2Model.post_process(result, self.config.n_layer) class MyGPT2LMHeadModel(GPT2LMHeadModel): - """ Here we wrap a class for Onnx model conversion for GPT2LMHeadModel with past state. - """ + """Here we wrap a class for Onnx model conversion for GPT2LMHeadModel with past state.""" def __init__(self, config): super().__init__(config) def forward(self, input_ids, position_ids, attention_mask, *past): - result = super().forward(input_ids, - position_ids=position_ids, - attention_mask=attention_mask, - past_key_values=past, - return_dict=False) + result = super().forward( + input_ids, + position_ids=position_ids, + attention_mask=attention_mask, + past_key_values=past, + return_dict=False, + ) return MyGPT2Model.post_process(result, self.config.n_layer) class MyGPT2LMHeadModel_NoPadding(GPT2LMHeadModel): - """ Here we wrap a class for Onnx model conversion for GPT2LMHeadModel with past state and no padding. - When you always use batch_size=1 in inference, there is no padding in inputs. In such case, position_ids - and attention_mask need no be in inputs. + """Here we wrap a class for Onnx model conversion for GPT2LMHeadModel with past state and no padding. + When you always use batch_size=1 in inference, there is no padding in inputs. In such case, position_ids + and attention_mask need no be in inputs. """ def __init__(self, config): @@ -121,14 +130,13 @@ def forward(self, input_ids, *past): # Maps model class name to a tuple of model class, name of first output and use padding or not MODEL_CLASSES = { - 'GPT2LMHeadModel': (MyGPT2LMHeadModel, 'logits', True), - 'GPT2LMHeadModel_NoPadding': (MyGPT2LMHeadModel_NoPadding, 'logits', False), - 'GPT2Model': (MyGPT2Model, 'last_state', True), + "GPT2LMHeadModel": (MyGPT2LMHeadModel, "logits", True), + "GPT2LMHeadModel_NoPadding": (MyGPT2LMHeadModel_NoPadding, "logits", False), + "GPT2Model": (MyGPT2Model, "last_state", True), } class Gpt2Inputs: - def __init__(self, input_ids, position_ids, attention_mask, past): self.input_ids: torch.LongTensor = input_ids self.position_ids: torch.LongTensor = position_ids @@ -149,49 +157,65 @@ def to_fp32(self): # For attention mask, only convert fp16 to fp32, and keep the original type if it is integer. attention_mask = None if self.attention_mask is not None: - attention_mask = self.attention_mask.to( - dtype=torch.float32) if (self.attention_mask.dtype == torch.float16) else self.attention_mask + attention_mask = ( + self.attention_mask.to(dtype=torch.float32) + if (self.attention_mask.dtype == torch.float16) + else self.attention_mask + ) past = [p.to(dtype=torch.float32) for p in self.past] return Gpt2Inputs(self.input_ids, self.position_ids, attention_mask, past) class Gpt2Helper: - """ A helper class for Gpt2 model conversion, inference and verification. - """ + """A helper class for Gpt2 model conversion, inference and verification.""" @staticmethod - def get_dummy_inputs(batch_size: int, - past_sequence_length: int, - sequence_length: int, - num_attention_heads: int, - hidden_size: int, - num_layer: int, - vocab_size: int, - device: torch.device, - float16: bool = False, - has_position_ids: bool = True, - has_attention_mask: bool = True, - input_ids_dtype: torch.dtype = torch.int32, - position_ids_dtype: torch.dtype = torch.int32, - attention_mask_dtype: torch.dtype = torch.int32) -> Gpt2Inputs: - """ Create random inputs for GPT2 model. + def get_dummy_inputs( + batch_size: int, + past_sequence_length: int, + sequence_length: int, + num_attention_heads: int, + hidden_size: int, + num_layer: int, + vocab_size: int, + device: torch.device, + float16: bool = False, + has_position_ids: bool = True, + has_attention_mask: bool = True, + input_ids_dtype: torch.dtype = torch.int32, + position_ids_dtype: torch.dtype = torch.int32, + attention_mask_dtype: torch.dtype = torch.int32, + ) -> Gpt2Inputs: + """Create random inputs for GPT2 model. Returns torch tensors of input_ids, position_ids, attention_mask and a list of past state tensors. """ float_type = torch.float16 if float16 else torch.float32 - past_shape = [2, batch_size, num_attention_heads, past_sequence_length, int(hidden_size / num_attention_heads)] + past_shape = [ + 2, + batch_size, + num_attention_heads, + past_sequence_length, + int(hidden_size / num_attention_heads), + ] past = [(torch.rand(past_shape, dtype=float_type, device=device) * 2.0 - 1.0) for _ in range(num_layer)] - input_ids = torch.randint(low=0, - high=vocab_size - 1, - size=(batch_size, sequence_length), - dtype=input_ids_dtype, - device=device) + input_ids = torch.randint( + low=0, + high=vocab_size - 1, + size=(batch_size, sequence_length), + dtype=input_ids_dtype, + device=device, + ) attention_mask = None if has_attention_mask: total_sequence_length = past_sequence_length + sequence_length - attention_mask = torch.ones([batch_size, total_sequence_length], dtype=attention_mask_dtype, device=device) + attention_mask = torch.ones( + [batch_size, total_sequence_length], + dtype=attention_mask_dtype, + device=device, + ) if total_sequence_length >= 2: padding_position = random.randint(0, total_sequence_length - 1) # test input with padding. attention_mask[:, padding_position] = 0 @@ -199,20 +223,21 @@ def get_dummy_inputs(batch_size: int, # Deduce position_ids from attention mask position_ids = None if has_position_ids: - position_ids = (attention_mask.long().cumsum(-1) - 1) + position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(position_ids < 0, 0) position_ids = position_ids[:, past_sequence_length:].to(position_ids_dtype) return Gpt2Inputs(input_ids, position_ids, attention_mask, past) @staticmethod - def get_output_shapes(batch_size: int, - past_sequence_length: int, - sequence_length: int, - config: GPT2Config, - model_class: str = "GPT2LMHeadModel") -> Dict[str, List[int]]: - """ Returns a dictionary with output name as key, and shape as value. - """ + def get_output_shapes( + batch_size: int, + past_sequence_length: int, + sequence_length: int, + config: GPT2Config, + model_class: str = "GPT2LMHeadModel", + ) -> Dict[str, List[int]]: + """Returns a dictionary with output name as key, and shape as value.""" num_attention_heads = config.num_attention_heads hidden_size = config.hidden_size num_layer = config.num_hidden_layers @@ -220,10 +245,17 @@ def get_output_shapes(batch_size: int, output_name = MODEL_CLASSES[model_class][1] - last_state_shape = [batch_size, sequence_length, vocab_size if output_name == "logits" else hidden_size] + last_state_shape = [ + batch_size, + sequence_length, + vocab_size if output_name == "logits" else hidden_size, + ] present_state_shape = [ - 2, batch_size, num_attention_heads, past_sequence_length + sequence_length, - int(hidden_size / num_attention_heads) + 2, + batch_size, + num_attention_heads, + past_sequence_length + sequence_length, + int(hidden_size / num_attention_heads), ] output_shapes = {output_name: last_state_shape} @@ -238,14 +270,15 @@ def auto_increase_buffer_size(output_buffers, output_shapes): assert key in output_buffers buffer = output_buffers[key] if numpy.prod(output_shapes[key]) > buffer.nelement(): - output_buffers[key] = torch.empty(numpy.prod(output_shapes[key]), - dtype=buffer.dtype, - device=buffer.device) + output_buffers[key] = torch.empty( + numpy.prod(output_shapes[key]), + dtype=buffer.dtype, + device=buffer.device, + ) @staticmethod def get_output_buffers(output_shapes, device, is_float16=False): - """ Returns a dictionary of output name as key, and 1D tensor as value. The tensor has enough space for given shape. - """ + """Returns a dictionary of output name as key, and 1D tensor as value. The tensor has enough space for given shape.""" data_type = torch.float16 if is_float16 else torch.float32 output_buffers = {} @@ -255,8 +288,7 @@ def get_output_buffers(output_shapes, device, is_float16=False): @staticmethod def diff_outputs(torch_outputs, ort_outputs, relative=False): - """ Returns the maximum difference between PyTorch and OnnxRuntime outputs. - """ + """Returns the maximum difference between PyTorch and OnnxRuntime outputs.""" expected_outputs = torch_outputs[0].cpu().numpy() diff = numpy.abs(expected_outputs - ort_outputs[0]) if relative: @@ -266,26 +298,28 @@ def diff_outputs(torch_outputs, ort_outputs, relative=False): @staticmethod def compare_outputs(torch_outputs, ort_outputs, rtol=1e-03, atol=1e-03, **kwargs): - """ Returns True if torch and ORT outputs are close for given thresholds, and False otherwise. - Note: need kwargs since Gpt2BeamSearchHelper.compare_outputs has an extra parameter model_class + """Returns True if torch and ORT outputs are close for given thresholds, and False otherwise. + Note: need kwargs since Gpt2BeamSearchHelper.compare_outputs has an extra parameter model_class """ is_close = numpy.allclose(ort_outputs[0], torch_outputs[0].cpu().numpy(), rtol=rtol, atol=atol) - logger.debug(f'PyTorch and OnnxRuntime output 0 (last_state) are close: {is_close}') + logger.debug(f"PyTorch and OnnxRuntime output 0 (last_state) are close: {is_close}") is_all_close = is_close num_layers = len(ort_outputs) - 1 for layer in range(num_layers): - is_close = numpy.allclose(ort_outputs[1 + layer], - torch_outputs[1][layer].cpu().numpy(), - rtol=rtol, - atol=atol) - logger.debug(f'PyTorch and OnnxRuntime layer {layer} state (present_{layer}) are close:{is_close}') + is_close = numpy.allclose( + ort_outputs[1 + layer], + torch_outputs[1][layer].cpu().numpy(), + rtol=rtol, + atol=atol, + ) + logger.debug(f"PyTorch and OnnxRuntime layer {layer} state (present_{layer}) are close:{is_close}") is_all_close = is_all_close and is_close if not is_all_close: max_abs_diff = Gpt2Helper.diff_outputs(torch_outputs, ort_outputs) - logger.info(f'PyTorch and OnnxRuntime results are not all close: max_abs_diff={max_abs_diff:.5f}') + logger.info(f"PyTorch and OnnxRuntime results are not all close: max_abs_diff={max_abs_diff:.5f}") return is_all_close @@ -315,18 +349,19 @@ def compare_outputs_v2(torch_outputs, ort_outputs, atol=1e-06): is_all_close = is_all_close and is_close if numpy.isnan(torch_output).any(): - logger.debug(f'PyTorch output {i} has nan') + logger.debug(f"PyTorch output {i} has nan") if numpy.isinf(torch_output).any(): - logger.debug(f'PyTorch output {i} has inf') + logger.debug(f"PyTorch output {i} has inf") if numpy.isnan(ort_output).any(): - logger.debug(f'ORT output {i} has nan') + logger.debug(f"ORT output {i} has nan") if numpy.isinf(ort_output).any(): - logger.debug(f'ORT output {i} has inf') + logger.debug(f"ORT output {i} has inf") diff = numpy.fabs(ort_output - torch_output) idx = numpy.unravel_index(diff.argmax(), diff.shape) messages.append( - f'diff={diff[idx]:.9f} index={idx} ort={ort_output[idx]:.9f} torch={float(torch_output[idx]):.9f}') + f"diff={diff[idx]:.9f} index={idx} ort={ort_output[idx]:.9f} torch={float(torch_output[idx]):.9f}" + ) if i == 0: # logits ort_max_index = numpy.unravel_index(numpy.argmax(ort_output, axis=None), ort_output.shape) @@ -334,44 +369,53 @@ def compare_outputs_v2(torch_outputs, ort_outputs, atol=1e-06): is_top1_matched = numpy.array_equal(ort_max_index, torch_max_index) max_diff_output_index = max_diffs.index(max(max_diffs)) - return is_all_close, max(max_diffs), max_diff_output_index, messages, is_top1_matched + return ( + is_all_close, + max(max_diffs), + max_diff_output_index, + messages, + is_top1_matched, + ) @staticmethod - def export_onnx(model, - device, - onnx_model_path: str, - verbose: bool = False, - use_external_data_format: bool = False, - has_position_ids: bool = True, - has_attention_mask: bool = True, - input_ids_dtype: torch.dtype = torch.int32, - position_ids_dtype: torch.dtype = torch.int32, - attention_mask_dtype: torch.dtype = torch.int32): - """ Export GPT-2 model with past state to ONNX model. - """ + def export_onnx( + model, + device, + onnx_model_path: str, + verbose: bool = False, + use_external_data_format: bool = False, + has_position_ids: bool = True, + has_attention_mask: bool = True, + input_ids_dtype: torch.dtype = torch.int32, + position_ids_dtype: torch.dtype = torch.int32, + attention_mask_dtype: torch.dtype = torch.int32, + ): + """Export GPT-2 model with past state to ONNX model.""" config: GPT2Config = model.config num_layer = config.n_layer - dummy_inputs = Gpt2Helper.get_dummy_inputs(batch_size=1, - past_sequence_length=1, - sequence_length=1, - num_attention_heads=config.num_attention_heads, - hidden_size=config.hidden_size, - num_layer=num_layer, - vocab_size=config.vocab_size, - device=device, - float16=False, - has_position_ids=has_position_ids, - has_attention_mask=has_attention_mask, - input_ids_dtype=input_ids_dtype, - position_ids_dtype=position_ids_dtype, - attention_mask_dtype=attention_mask_dtype) + dummy_inputs = Gpt2Helper.get_dummy_inputs( + batch_size=1, + past_sequence_length=1, + sequence_length=1, + num_attention_heads=config.num_attention_heads, + hidden_size=config.hidden_size, + num_layer=num_layer, + vocab_size=config.vocab_size, + device=device, + float16=False, + has_position_ids=has_position_ids, + has_attention_mask=has_attention_mask, + input_ids_dtype=input_ids_dtype, + position_ids_dtype=position_ids_dtype, + attention_mask_dtype=attention_mask_dtype, + ) input_list = dummy_inputs.to_list() with torch.no_grad(): outputs = model(*input_list) - past_names = [f'past_{i}' for i in range(num_layer)] - present_names = [f'present_{i}' for i in range(num_layer)] + past_names = [f"past_{i}" for i in range(num_layer)] + present_names = [f"present_{i}" for i in range(num_layer)] # GPT2Model outputs last_state; GPT2LMHeadModel outputs logits (prediction_scores) assert outputs[0].shape[2] == config.vocab_size or outputs[0].shape[2] == config.hidden_size @@ -385,19 +429,22 @@ def export_onnx(model, # last_state: (batch_size, seq_len, hidden_size) # or logits: (batch_size, seq_len, vocab_size) # present_{i}: (2, batch_size, num_heads, past_seq_len + seq_len, hidden_size/num_heads) - dynamic_axes = {'input_ids': {0: 'batch_size', 1: 'seq_len'}, output_names[0]: {0: 'batch_size', 1: 'seq_len'}} + dynamic_axes = { + "input_ids": {0: "batch_size", 1: "seq_len"}, + output_names[0]: {0: "batch_size", 1: "seq_len"}, + } for name in past_names: - dynamic_axes[name] = {1: 'batch_size', 3: 'past_seq_len'} + dynamic_axes[name] = {1: "batch_size", 3: "past_seq_len"} for name in present_names: - dynamic_axes[name] = {1: 'batch_size', 3: 'total_seq_len'} + dynamic_axes[name] = {1: "batch_size", 3: "total_seq_len"} - input_names = ['input_ids'] + input_names = ["input_ids"] if has_position_ids: - dynamic_axes['position_ids'] = {0: 'batch_size', 1: 'seq_len'} - input_names.append('position_ids') + dynamic_axes["position_ids"] = {0: "batch_size", 1: "seq_len"} + input_names.append("position_ids") if has_attention_mask: - dynamic_axes['attention_mask'] = {0: 'batch_size', 1: 'total_seq_len'} - input_names.append('attention_mask') + dynamic_axes["attention_mask"] = {0: "batch_size", 1: "total_seq_len"} + input_names.append("attention_mask") input_names.extend(past_names) assert len(outputs) == 2 and len(outputs[1]) == num_layer @@ -408,42 +455,47 @@ def export_onnx(model, Path(onnx_model_path).parent.mkdir(parents=True, exist_ok=True) - torch_onnx_export(model, - args=tuple(input_list), - f=onnx_model_path, - input_names=input_names, - output_names=output_names, - dynamic_axes=dynamic_axes, - opset_version=11, - do_constant_folding=True, - use_external_data_format=use_external_data_format, - verbose=verbose) + torch_onnx_export( + model, + args=tuple(input_list), + f=onnx_model_path, + input_names=input_names, + output_names=output_names, + dynamic_axes=dynamic_axes, + opset_version=11, + do_constant_folding=True, + use_external_data_format=use_external_data_format, + verbose=verbose, + ) @staticmethod - def optimize_onnx(onnx_model_path, - optimized_model_path, - is_float16, - num_attention_heads, - hidden_size, - use_external_data_format=False, - auto_mixed_precision=False, - **kwargs): - """ Optimize ONNX model with an option to convert it to use mixed precision. - """ + def optimize_onnx( + onnx_model_path, + optimized_model_path, + is_float16, + num_attention_heads, + hidden_size, + use_external_data_format=False, + auto_mixed_precision=False, + **kwargs, + ): + """Optimize ONNX model with an option to convert it to use mixed precision.""" + from fusion_options import FusionOptions from optimizer import optimize_model - from fusion_options import FusionOptions - optimization_options = FusionOptions('gpt2') - #optimization_options.enable_gelu = False - #optimization_options.enable_layer_norm = False - #optimization_options.enable_attention = False - m = optimize_model(onnx_model_path, - model_type='gpt2', - num_heads=num_attention_heads, - hidden_size=hidden_size, - opt_level=0, - optimization_options=optimization_options, - use_gpu=False) + optimization_options = FusionOptions("gpt2") + # optimization_options.enable_gelu = False + # optimization_options.enable_layer_norm = False + # optimization_options.enable_attention = False + m = optimize_model( + onnx_model_path, + model_type="gpt2", + num_heads=num_attention_heads, + hidden_size=hidden_size, + opt_level=0, + optimization_options=optimization_options, + use_gpu=False, + ) if is_float16: if auto_mixed_precision: @@ -456,8 +508,10 @@ def optimize_onnx(onnx_model_path, m.save_model_to_file(optimized_model_path, use_external_data_format) @staticmethod - def auto_mixed_precision(onnx_model: OnnxModel, - op_block_list: List[str] = ['Add', 'LayerNormalization', 'FastGelu']): + def auto_mixed_precision( + onnx_model: OnnxModel, + op_block_list: List[str] = ["Add", "LayerNormalization", "FastGelu"], + ): """Convert GPT-2 model to mixed precision. It detects whether original model has fp16 precision weights, and set parameters for float16 conversion automatically. Args: @@ -493,7 +547,7 @@ def auto_mixed_precision(onnx_model: OnnxModel, # we can deduce that the weights are stored in float16 precision. max_diff = float_to_float16_max_diff(initializer) logger.debug(f"max diff of converting weights in last MatMul node {node.name}: {max_diff}") - is_weight_fp16_precision = (max_diff < 1E-6) + is_weight_fp16_precision = max_diff < 1e-6 else: logger.warning(f"Failed to find MatMul node for logits. Found {node.op_type} of node {node.name}") @@ -508,7 +562,7 @@ def auto_mixed_precision(onnx_model: OnnxModel, "keep_io_types": keep_io_types, "op_block_list": op_block_list, "node_block_list": node_block_list, - "force_fp16_initializers": is_weight_fp16_precision + "force_fp16_initializers": is_weight_fp16_precision, } logger.info(f"auto_mixed_precision parameters: {parameters}") @@ -522,8 +576,7 @@ def auto_mixed_precision(onnx_model: OnnxModel, @staticmethod def pytorch_inference(model, inputs: Gpt2Inputs, total_runs: int = 0): - """ Run inference of PyTorch model, and returns average latency in ms when total_runs > 0 besides outputs. - """ + """Run inference of PyTorch model, and returns average latency in ms when total_runs > 0 besides outputs.""" logger.debug("start pytorch_inference") # Convert it to fp32 as the PyTroch model cannot deal with half input. @@ -543,27 +596,26 @@ def pytorch_inference(model, inputs: Gpt2Inputs, total_runs: int = 0): latency.append(time.time() - start) average_latency = sum(latency) * 1000 / len(latency) - logger.debug("PyTorch inference time = {} ms".format(format(average_latency, '.2f'))) + logger.debug("PyTorch inference time = {} ms".format(format(average_latency, ".2f"))) return outputs, average_latency @staticmethod def onnxruntime_inference(ort_session, inputs: Gpt2Inputs, total_runs: int = 0): - """ Run inference of ONNX model, and returns average latency in ms when total_runs > 0 besides outputs. - """ + """Run inference of ONNX model, and returns average latency in ms when total_runs > 0 besides outputs.""" logger.debug(f"start onnxruntime_inference") - ort_inputs = {'input_ids': numpy.ascontiguousarray(inputs.input_ids.cpu().numpy())} + ort_inputs = {"input_ids": numpy.ascontiguousarray(inputs.input_ids.cpu().numpy())} if inputs.past is not None: for i, past_i in enumerate(inputs.past): - ort_inputs[f'past_{i}'] = numpy.ascontiguousarray(past_i.cpu().numpy()) + ort_inputs[f"past_{i}"] = numpy.ascontiguousarray(past_i.cpu().numpy()) if inputs.attention_mask is not None: - ort_inputs['attention_mask'] = numpy.ascontiguousarray(inputs.attention_mask.cpu().numpy()) + ort_inputs["attention_mask"] = numpy.ascontiguousarray(inputs.attention_mask.cpu().numpy()) if inputs.position_ids is not None: - ort_inputs['position_ids'] = numpy.ascontiguousarray(inputs.position_ids.cpu().numpy()) + ort_inputs["position_ids"] = numpy.ascontiguousarray(inputs.position_ids.cpu().numpy()) ort_outputs = ort_session.run(None, ort_inputs) if total_runs == 0: @@ -576,46 +628,69 @@ def onnxruntime_inference(ort_session, inputs: Gpt2Inputs, total_runs: int = 0): latency.append(time.time() - start) average_latency = sum(latency) * 1000 / len(latency) - logger.debug("OnnxRuntime Inference time = {} ms".format(format(average_latency, '.2f'))) + logger.debug("OnnxRuntime Inference time = {} ms".format(format(average_latency, ".2f"))) return ort_outputs, average_latency @staticmethod - def prepare_io_binding(ort_session, input_ids, position_ids, attention_mask, past, output_buffers, output_shapes): - """ Returnas IO binding object for a session. - """ - return IOBindingHelper.prepare_io_binding(ort_session, input_ids, position_ids, attention_mask, past, - output_buffers, output_shapes) + def prepare_io_binding( + ort_session, + input_ids, + position_ids, + attention_mask, + past, + output_buffers, + output_shapes, + ): + """Returnas IO binding object for a session.""" + return IOBindingHelper.prepare_io_binding( + ort_session, + input_ids, + position_ids, + attention_mask, + past, + output_buffers, + output_shapes, + ) @staticmethod def get_outputs_from_io_binding_buffer(ort_session, output_buffers, output_shapes, return_numpy=True): - """ Copy results to cpu. Returns a list of numpy array. - """ - return IOBindingHelper.get_outputs_from_io_binding_buffer(ort_session, output_buffers, output_shapes, - return_numpy) + """Copy results to cpu. Returns a list of numpy array.""" + return IOBindingHelper.get_outputs_from_io_binding_buffer( + ort_session, output_buffers, output_shapes, return_numpy + ) @staticmethod - def onnxruntime_inference_with_binded_io(ort_session, - inputs: Gpt2Inputs, - output_buffers: Dict[str, torch.Tensor], - output_shapes: Dict[str, List[int]], - total_runs: int = 0, - return_numpy: bool = True, - include_copy_output_latency: bool = False): - """ Inference with IO binding. Returns outputs, and optional latency when total_runs > 0. - """ + def onnxruntime_inference_with_binded_io( + ort_session, + inputs: Gpt2Inputs, + output_buffers: Dict[str, torch.Tensor], + output_shapes: Dict[str, List[int]], + total_runs: int = 0, + return_numpy: bool = True, + include_copy_output_latency: bool = False, + ): + """Inference with IO binding. Returns outputs, and optional latency when total_runs > 0.""" logger.debug(f"start onnxruntime_inference_with_binded_io") # Bind inputs and outputs to onnxruntime session - io_binding = Gpt2Helper.prepare_io_binding(ort_session, inputs.input_ids, inputs.position_ids, - inputs.attention_mask, inputs.past, output_buffers, output_shapes) + io_binding = Gpt2Helper.prepare_io_binding( + ort_session, + inputs.input_ids, + inputs.position_ids, + inputs.attention_mask, + inputs.past, + output_buffers, + output_shapes, + ) # Run onnxruntime with io binding ort_session.run_with_iobinding(io_binding) # Copy results to cpu for verification - ort_outputs = Gpt2Helper.get_outputs_from_io_binding_buffer(ort_session, output_buffers, output_shapes, - return_numpy) + ort_outputs = Gpt2Helper.get_outputs_from_io_binding_buffer( + ort_session, output_buffers, output_shapes, return_numpy + ) if total_runs == 0: return ort_outputs @@ -626,51 +701,53 @@ def onnxruntime_inference_with_binded_io(ort_session, # Run onnxruntime with io binding ort_session.run_with_iobinding(io_binding) if include_copy_output_latency: - _ = Gpt2Helper.get_outputs_from_io_binding_buffer(ort_session, output_buffers, output_shapes, - return_numpy) + _ = Gpt2Helper.get_outputs_from_io_binding_buffer( + ort_session, output_buffers, output_shapes, return_numpy + ) latency.append(time.time() - start) average_latency = sum(latency) * 1000 / len(latency) - logger.debug("OnnxRuntime with IO binding inference time = {} ms".format(format(average_latency, '.2f'))) + logger.debug("OnnxRuntime with IO binding inference time = {} ms".format(format(average_latency, ".2f"))) return ort_outputs, average_latency @staticmethod def save_outputs(i, ort_outputs, torch_outputs): - with open(f'ort_outputs_{i}.pickle', 'wb') as f: + with open(f"ort_outputs_{i}.pickle", "wb") as f: pickle.dump(ort_outputs, f) logger.info(f"ORT output are saved to ort_outputs_{i}.pickle") - with open(f'torch_outputs_{i}.pickle', 'wb') as f: + with open(f"torch_outputs_{i}.pickle", "wb") as f: pickle.dump(torch_outputs, f) logger.info(f"Torch output are saved to torch_outputs_{i}.pickle") @staticmethod def save_inputs(i, dummy_inputs, ort_outputs, torch_outputs): - with open(f'dummy_inputs_{i}.pickle', 'wb') as f: + with open(f"dummy_inputs_{i}.pickle", "wb") as f: pickle.dump(dummy_inputs, f) logger.info(f"inputs are saved to dummy_inputs_{i}.pickle") @staticmethod - def test_parity(ort_session, - model, - device, - is_float16=False, - rtol=5e-4, - atol=5e-4, - test_cases_per_run=10000, - total_runs=1, - use_io_binding=True, - model_class="GPT2LMHeadModel", - has_position_ids=True, - has_attention_mask=True, - input_ids_dtype=torch.int32, - position_ids_dtype=torch.int32, - attention_mask_dtype=torch.int32, - verbose=False, - enable_pickle_output=False): - """ Generate random inputs and compare the results of PyTorch and Onnx Runtime. - """ + def test_parity( + ort_session, + model, + device, + is_float16=False, + rtol=5e-4, + atol=5e-4, + test_cases_per_run=10000, + total_runs=1, + use_io_binding=True, + model_class="GPT2LMHeadModel", + has_position_ids=True, + has_attention_mask=True, + input_ids_dtype=torch.int32, + position_ids_dtype=torch.int32, + attention_mask_dtype=torch.int32, + verbose=False, + enable_pickle_output=False, + ): + """Generate random inputs and compare the results of PyTorch and Onnx Runtime.""" config: GPT2Config = model.config @@ -684,8 +761,9 @@ def test_parity(ort_session, output_buffers = None if use_io_binding: - max_output_shapes = Gpt2Helper.get_output_shapes(max_batch_size, max_past_seq_len, max_seq_len, config, - model_class) + max_output_shapes = Gpt2Helper.get_output_shapes( + max_batch_size, max_past_seq_len, max_seq_len, config, model_class + ) output_buffers = Gpt2Helper.get_output_buffers(max_output_shapes, device, is_float16) passed_test_cases = 0 @@ -701,32 +779,46 @@ def test_parity(ort_session, batch_size = random.randint(1, max_batch_size) logger.debug( - f"Running parity test for batch_size={batch_size} past_sequence_length={past_sequence_length}...") - dummy_inputs = Gpt2Helper.get_dummy_inputs(batch_size, - past_sequence_length, - sequence_length, - config.num_attention_heads, - config.hidden_size, - config.n_layer, - config.vocab_size, - device, - is_float16, - has_position_ids, - has_attention_mask, - input_ids_dtype=input_ids_dtype, - position_ids_dtype=position_ids_dtype, - attention_mask_dtype=attention_mask_dtype) + f"Running parity test for batch_size={batch_size} past_sequence_length={past_sequence_length}..." + ) + dummy_inputs = Gpt2Helper.get_dummy_inputs( + batch_size, + past_sequence_length, + sequence_length, + config.num_attention_heads, + config.hidden_size, + config.n_layer, + config.vocab_size, + device, + is_float16, + has_position_ids, + has_attention_mask, + input_ids_dtype=input_ids_dtype, + position_ids_dtype=position_ids_dtype, + attention_mask_dtype=attention_mask_dtype, + ) outputs = Gpt2Helper.pytorch_inference(model, dummy_inputs) if use_io_binding: ort_outputs = Gpt2Helper.onnxruntime_inference(ort_session, dummy_inputs) else: - output_shapes = Gpt2Helper.get_output_shapes(batch_size, past_sequence_length, sequence_length, config, - model_class) - ort_outputs = Gpt2Helper.onnxruntime_inference_with_binded_io(ort_session, dummy_inputs, output_buffers, - output_shapes) + output_shapes = Gpt2Helper.get_output_shapes( + batch_size, + past_sequence_length, + sequence_length, + config, + model_class, + ) + ort_outputs = Gpt2Helper.onnxruntime_inference_with_binded_io( + ort_session, dummy_inputs, output_buffers, output_shapes + ) - is_all_close, max_abs_diff, max_diff_output_index, messages, is_top1_matched = Gpt2Helper.compare_outputs_v2( - outputs, ort_outputs, atol=atol) + ( + is_all_close, + max_abs_diff, + max_diff_output_index, + messages, + is_top1_matched, + ) = Gpt2Helper.compare_outputs_v2(outputs, ort_outputs, atol=atol) if not numpy.isnan(max_abs_diff): max_abs_diff_list.append(max_abs_diff) if is_all_close: @@ -770,88 +862,95 @@ def test_parity(ort_session, return result @staticmethod - def test_performance(ort_session, - model, - device, - is_float16=False, - total_runs=100, - use_io_binding=True, - model_class="GPT2LMHeadModel", - has_position_ids=True, - has_attention_mask=True, - input_ids_dtype=torch.int32, - position_ids_dtype=torch.int32, - attention_mask_dtype=torch.int32, - batch_size=8, - sequence_length=1, - past_sequence_length=32): - """ Generate random inputs and measure average latency of Onnx Runtime. - """ + def test_performance( + ort_session, + model, + device, + is_float16=False, + total_runs=100, + use_io_binding=True, + model_class="GPT2LMHeadModel", + has_position_ids=True, + has_attention_mask=True, + input_ids_dtype=torch.int32, + position_ids_dtype=torch.int32, + attention_mask_dtype=torch.int32, + batch_size=8, + sequence_length=1, + past_sequence_length=32, + ): + """Generate random inputs and measure average latency of Onnx Runtime.""" config: GPT2Config = model.config output_buffers = None if use_io_binding: - output_shapes = Gpt2Helper.get_output_shapes(batch_size, past_sequence_length, sequence_length, config, - model_class) + output_shapes = Gpt2Helper.get_output_shapes( + batch_size, past_sequence_length, sequence_length, config, model_class + ) output_buffers = Gpt2Helper.get_output_buffers(output_shapes, device, is_float16) - dummy_inputs = Gpt2Helper.get_dummy_inputs(batch_size, - past_sequence_length, - sequence_length, - config.num_attention_heads, - config.hidden_size, - config.n_layer, - config.vocab_size, - device, - is_float16, - has_position_ids, - has_attention_mask, - input_ids_dtype=input_ids_dtype, - position_ids_dtype=position_ids_dtype, - attention_mask_dtype=attention_mask_dtype) + dummy_inputs = Gpt2Helper.get_dummy_inputs( + batch_size, + past_sequence_length, + sequence_length, + config.num_attention_heads, + config.hidden_size, + config.n_layer, + config.vocab_size, + device, + is_float16, + has_position_ids, + has_attention_mask, + input_ids_dtype=input_ids_dtype, + position_ids_dtype=position_ids_dtype, + attention_mask_dtype=attention_mask_dtype, + ) if use_io_binding: _, latency = Gpt2Helper.onnxruntime_inference(ort_session, dummy_inputs, total_runs) else: - _, latency = Gpt2Helper.onnxruntime_inference_with_binded_io(ort_session, dummy_inputs, output_buffers, - output_shapes, total_runs) + _, latency = Gpt2Helper.onnxruntime_inference_with_binded_io( + ort_session, dummy_inputs, output_buffers, output_shapes, total_runs + ) return latency @staticmethod def torchscript(model, config, device, has_position_ids=True, has_attention_mask=True): - """ JIT trace for TorchScript. - """ - input_list = Gpt2Helper.get_dummy_inputs(batch_size=1, - past_sequence_length=1, - sequence_length=1, - num_attention_heads=config.num_attention_heads, - hidden_size=config.hidden_size, - num_layer=config.n_layer, - vocab_size=config.vocab_size, - device=device, - float16=False, - has_position_ids=has_position_ids, - has_attention_mask=has_attention_mask).to_list() + """JIT trace for TorchScript.""" + input_list = Gpt2Helper.get_dummy_inputs( + batch_size=1, + past_sequence_length=1, + sequence_length=1, + num_attention_heads=config.num_attention_heads, + hidden_size=config.hidden_size, + num_layer=config.n_layer, + vocab_size=config.vocab_size, + device=device, + float16=False, + has_position_ids=has_position_ids, + has_attention_mask=has_attention_mask, + ).to_list() return torch.jit.trace(model, input_list) @staticmethod - def get_onnx_paths(output_dir, - model_name_or_path, - model_class: str = 'GPT2LMHeadModel', - has_past=True, - new_folder=False, - remove_existing=["raw", "fp32", "fp16", "int8"]): - """ Build a path name for given model based on given attributes. - """ + def get_onnx_paths( + output_dir, + model_name_or_path, + model_class: str = "GPT2LMHeadModel", + has_past=True, + new_folder=False, + remove_existing=["raw", "fp32", "fp16", "int8"], + ): + """Build a path name for given model based on given attributes.""" model_name = model_name_or_path if os.path.isdir(model_name_or_path): model_name = Path(model_name_or_path).parts[-1] else: - model_name.split('/')[-1] + model_name.split("/")[-1] - if model_class != 'GPT2LMHeadModel': + if model_class != "GPT2LMHeadModel": model_name += "_" + model_class if has_past: @@ -863,7 +962,7 @@ def get_onnx_paths(output_dir, for model_type in ["raw", "fp32", "fp16", "int8"]: new_dir = os.path.join(output_dir, model_name + suffix[model_type]) if os.path.exists(new_dir): - if (model_type in remove_existing): + if model_type in remove_existing: try: shutil.rmtree(new_dir) logger.info(f"Removed the existed directory: {new_dir}") @@ -875,14 +974,23 @@ def get_onnx_paths(output_dir, # store each model to its own directory (for external data format). return { "raw": os.path.join(os.path.join(output_dir, model_name), model_name + ".onnx"), - "fp32": os.path.join(os.path.join(output_dir, model_name + "_fp32"), model_name + "_fp32.onnx"), - "fp16": os.path.join(os.path.join(output_dir, model_name + "_fp16"), model_name + "_fp16.onnx"), - "int8": os.path.join(os.path.join(output_dir, model_name + "_int8"), model_name + "_int8.onnx") + "fp32": os.path.join( + os.path.join(output_dir, model_name + "_fp32"), + model_name + "_fp32.onnx", + ), + "fp16": os.path.join( + os.path.join(output_dir, model_name + "_fp16"), + model_name + "_fp16.onnx", + ), + "int8": os.path.join( + os.path.join(output_dir, model_name + "_int8"), + model_name + "_int8.onnx", + ), } return { "raw": os.path.join(output_dir, model_name + ".onnx"), "fp32": os.path.join(output_dir, model_name + "_fp32.onnx"), "fp16": os.path.join(output_dir, model_name + "_fp16.onnx"), - "int8": os.path.join(output_dir, model_name + "_int8.onnx") + "int8": os.path.join(output_dir, model_name + "_int8.onnx"), } diff --git a/onnxruntime/python/tools/transformers/models/gpt2/gpt2_parity.py b/onnxruntime/python/tools/transformers/models/gpt2/gpt2_parity.py index 7147e36d6edc0..4c52aaaf2e84a 100644 --- a/onnxruntime/python/tools/transformers/models/gpt2/gpt2_parity.py +++ b/onnxruntime/python/tools/transformers/models/gpt2/gpt2_parity.py @@ -11,62 +11,78 @@ # User could use this script to select the best mixed precision model according to these metrics. import argparse -import logging -from onnx_model import OnnxModel -import onnx import csv import datetime +import logging +import os +import sys + +import onnx import scipy.stats import torch - +from convert_to_onnx import get_latency_name, main from gpt2_helper import PRETRAINED_GPT2_MODELS, Gpt2Helper -from convert_to_onnx import main, get_latency_name - -import sys -import os +from onnx_model import OnnxModel -sys.path.append(os.path.join(os.path.dirname(__file__), '..', '..')) +sys.path.append(os.path.join(os.path.dirname(__file__), "..", "..")) from benchmark_helper import setup_logger -logger = logging.getLogger('') +logger = logging.getLogger("") def parse_arguments(argv=None): parser = argparse.ArgumentParser() - parser.add_argument('-m', - '--model_name_or_path', - required=True, - type=str, - help='Model path, or pretrained model name in the list: ' + ', '.join(PRETRAINED_GPT2_MODELS)) - - parser.add_argument('--csv', - required=False, - type=str, - default='gpt2_parity_results.csv', - help='path of csv file to save the result') - - parser.add_argument('--test_cases', required=False, type=int, default=500, help="number of test cases per run") - - parser.add_argument('--runs', required=False, type=int, default=40, help="number of repeated runs") - - parser.add_argument('--use_gpu', required=False, action='store_true', help="use GPU for inference") + parser.add_argument( + "-m", + "--model_name_or_path", + required=True, + type=str, + help="Model path, or pretrained model name in the list: " + ", ".join(PRETRAINED_GPT2_MODELS), + ) + + parser.add_argument( + "--csv", + required=False, + type=str, + default="gpt2_parity_results.csv", + help="path of csv file to save the result", + ) + + parser.add_argument( + "--test_cases", + required=False, + type=int, + default=500, + help="number of test cases per run", + ) + + parser.add_argument("--runs", required=False, type=int, default=40, help="number of repeated runs") + + parser.add_argument("--use_gpu", required=False, action="store_true", help="use GPU for inference") parser.set_defaults(use_gpu=False) - parser.add_argument('--all', required=False, action='store_true', help="run all combinations of mixed precision") + parser.add_argument( + "--all", + required=False, + action="store_true", + help="run all combinations of mixed precision", + ) parser.set_defaults(all=False) - parser.add_argument('-e', '--use_external_data_format', required=False, action='store_true') + parser.add_argument("-e", "--use_external_data_format", required=False, action="store_true") parser.set_defaults(use_external_data_format=False) - parser.add_argument('--verbose', required=False, action='store_true') + parser.add_argument("--verbose", required=False, action="store_true") parser.set_defaults(verbose=False) - parser.add_argument('--skip_test', - required=False, - action='store_true', - help="do not run test, and only rank experiments based on existing csv file") + parser.add_argument( + "--skip_test", + required=False, + action="store_true", + help="do not run test, and only rank experiments based on existing csv file", + ) parser.set_defaults(skip_test=False) args = parser.parse_args(argv) @@ -75,7 +91,6 @@ def parse_arguments(argv=None): class ParityTask: - def __init__(self, test_cases, total_runs, csv_path): self.total_runs = total_runs self.test_cases = test_cases @@ -84,15 +99,17 @@ def __init__(self, test_cases, total_runs, csv_path): self.run_id = 0 def run(self, argv, experiment_name): - start_time = datetime.datetime.now().strftime('%Y%m%d%H%M%S') + start_time = datetime.datetime.now().strftime("%Y%m%d%H%M%S") run_id = f"{start_time}_{self.run_id}" self.run_id += 1 try: - result = main(argv + ["-t", f"{self.test_cases}", "-r", f"{self.total_runs}"], - experiment_name=experiment_name, - run_id=run_id, - csv_filename=self.csv_path) + result = main( + argv + ["-t", f"{self.test_cases}", "-r", f"{self.total_runs}"], + experiment_name=experiment_name, + run_id=run_id, + csv_filename=self.csv_path, + ) except: logger.exception(f"Failed to run experiment {experiment_name}") @@ -103,7 +120,8 @@ def run(self, argv, experiment_name): def load_results_from_csv(csv_path): rows = [] import csv - with open(csv_path, newline='') as csvfile: + + with open(csv_path, newline="") as csvfile: reader = csv.DictReader(csvfile) for row in reader: rows.append(row) @@ -116,7 +134,7 @@ def score(row): top1_match_rate = float(row["top1_match_rate"]) onnx_size_in_MB = float(row["onnx_size_in_MB"]) # A simple scoring function: cost of 0.1ms latency ~ 0.1% match rate ~ 100MB size - return (top1_match_rate * 1000 - latency_in_ms * 10 - onnx_size_in_MB / 100) + return top1_match_rate * 1000 - latency_in_ms * 10 - onnx_size_in_MB / 100 def print_wins(wins, rows, test_name): @@ -127,7 +145,13 @@ def print_wins(wins, rows, test_name): for row in rows: row_map[row["run_id"]] = row - sorted_wins = dict(sorted(wins.items(), key=lambda item: (item[1], score(row_map[item[0]])), reverse=True)) + sorted_wins = dict( + sorted( + wins.items(), + key=lambda item: (item[1], score(row_map[item[0]])), + reverse=True, + ) + ) logger.debug(f"{test_name} Wins:{sorted_wins}") logger.info(f"Based on {test_name} wins and a scoring function, the ranking:") @@ -143,17 +167,24 @@ def print_wins(wins, rows, test_name): for row in rows: if row["run_id"] == key: logger.info( - "{:02d}: WINs={:02d}, run_id={}, latency={:5.2f} top1_match={:.4f} size={}_MB experiment={} {}". - format( - rank, value, key, float(row[get_latency_name()]), float(row["top1_match_rate"]), - row["onnx_size_in_MB"], row["experiment"], " (Half2 Disabled)" if - (row['ORT_CUDA_GEMM_OPTIONS'] == "4" and "Half2" not in row["experiment"]) else "")) + "{:02d}: WINs={:02d}, run_id={}, latency={:5.2f} top1_match={:.4f} size={}_MB experiment={} {}".format( + rank, + value, + key, + float(row[get_latency_name()]), + float(row["top1_match_rate"]), + row["onnx_size_in_MB"], + row["experiment"], + " (Half2 Disabled)" + if (row["ORT_CUDA_GEMM_OPTIONS"] == "4" and "Half2" not in row["experiment"]) + else "", + ) + ) break def run_significance_test(rows, output_csv_path): - """Run U test and T test. - """ + """Run U test and T test.""" utest_wins = {} ttest_wins = {} for row in rows: @@ -161,10 +192,19 @@ def run_significance_test(rows, output_csv_path): utest_wins[run_id] = 0 ttest_wins[run_id] = 0 - with open(output_csv_path, 'w', newline='') as csvfile: + with open(output_csv_path, "w", newline="") as csvfile: column_names = [ - 'model_name', 'run_id_1', 'experiment_1', 'top1_match_rate_1', 'run_id_2', 'experiment_2', - 'top1_match_rate_2', 'U_statistic', 'U_pvalue', "T_statistic", "T_pvalue" + "model_name", + "run_id_1", + "experiment_1", + "top1_match_rate_1", + "run_id_2", + "experiment_2", + "top1_match_rate_2", + "U_statistic", + "U_pvalue", + "T_statistic", + "T_pvalue", ] writer = csv.DictWriter(csvfile, fieldnames=column_names) @@ -180,7 +220,7 @@ def run_significance_test(rows, output_csv_path): all_matched = True for column in required_match_columns: - if (result1[column] != result2[column]): + if result1[column] != result2[column]: all_matched = False break if not all_matched: @@ -188,6 +228,7 @@ def run_significance_test(rows, output_csv_path): if isinstance(result1["top1_match_rate_per_run"], str): import json + a = json.loads(result1["top1_match_rate_per_run"]) b = json.loads(result2["top1_match_rate_per_run"]) else: @@ -197,8 +238,8 @@ def run_significance_test(rows, output_csv_path): try: utest_statistic, utest_pvalue = scipy.stats.mannwhitneyu( a, b, use_continuity=True, alternative="two-sided" - ) #TODO: shall we use one-sided: less or greater according to "top1_match_rate" - except ValueError: #ValueError: All numbers are identical in mannwhitneyu + ) # TODO: shall we use one-sided: less or greater according to "top1_match_rate" + except ValueError: # ValueError: All numbers are identical in mannwhitneyu utest_statistic = None utest_pvalue = None ttest_statistic, ttest_pvalue = scipy.stats.ttest_ind(a, b, axis=None, equal_var=True) @@ -216,17 +257,17 @@ def run_significance_test(rows, output_csv_path): ttest_wins[result2["run_id"]] += 1 row = { - 'model_name': result1["model_name"], - 'run_id_1': result1["run_id"], - 'experiment_1': result1["experiment"], - 'top1_match_rate_1': float(result1["top1_match_rate"]), + "model_name": result1["model_name"], + "run_id_1": result1["run_id"], + "experiment_1": result1["experiment"], + "top1_match_rate_1": float(result1["top1_match_rate"]), "run_id_2": result2["run_id"], "experiment_2": result2["experiment"], - 'top1_match_rate_2': float(result2["top1_match_rate"]), - 'U_statistic': utest_statistic, - 'U_pvalue': utest_pvalue, - 'T_statistic': ttest_statistic, - 'T_pvalue': ttest_pvalue + "top1_match_rate_2": float(result2["top1_match_rate"]), + "U_statistic": utest_statistic, + "U_pvalue": utest_pvalue, + "T_statistic": ttest_statistic, + "T_pvalue": ttest_pvalue, } writer.writerow(row) @@ -255,7 +296,12 @@ def get_mixed_precision_parameters(args, last_matmul_node_name, op_block_list): parameters = f"-m {model} -o --use_gpu -p fp16".split() if args.use_external_data_format: parameters.append("--use_external_data_format") - parameters += ["--io_block_list", "logits", "--node_block_list", last_matmul_node_name] + parameters += [ + "--io_block_list", + "logits", + "--node_block_list", + last_matmul_node_name, + ] if op_block_list: parameters.extend(["--op_block_list"] + op_block_list) @@ -263,10 +309,15 @@ def get_mixed_precision_parameters(args, last_matmul_node_name, op_block_list): return parameters -def run_candidate(task: ParityTask, args, last_matmul_node_name, op_block_list=["FastGelu", "LayerNormalization"]): +def run_candidate( + task: ParityTask, + args, + last_matmul_node_name, + op_block_list=["FastGelu", "LayerNormalization"], +): parameters = get_mixed_precision_parameters(args, last_matmul_node_name, op_block_list) - op_block_list_str = ','.join(sorted(op_block_list)) - name_suffix = " (Half2 Disabled)" if os.getenv('ORT_CUDA_GEMM_OPTIONS') == "4" else "" + op_block_list_str = ",".join(sorted(op_block_list)) + name_suffix = " (Half2 Disabled)" if os.getenv("ORT_CUDA_GEMM_OPTIONS") == "4" else "" if op_block_list: name = f"Mixed precision baseline + {op_block_list_str} in FP32{name_suffix}" else: @@ -303,11 +354,13 @@ def run_tuning_step0(task, fp16_baseline): task.run(fp16_baseline + fp32_io, "Graph I/O FP32, Other FP16") op_list = get_all_operators() - #task.run(fp16_baseline + fp32_io + ["--op_block_list"] + [o for o in op_list], "Everthing in FP32") + # task.run(fp16_baseline + fp32_io + ["--op_block_list"] + [o for o in op_list], "Everthing in FP32") # Only weights in FP16 - task.run(fp16_baseline + fp32_io + ["--op_block_list"] + [o for o in op_list] + ['--force_fp16_initializers'], - "FP32 except weights in FP16") + task.run( + fp16_baseline + fp32_io + ["--op_block_list"] + [o for o in op_list] + ["--force_fp16_initializers"], + "FP32 except weights in FP16", + ) for op in op_list: op_block_list = ["--op_block_list"] + [o for o in op_list if o != op] @@ -318,7 +371,10 @@ def run_tuning_step1(task, mixed_precision_baseline): """Step 1 is to figure out which operator in FP32 could benefit most""" for op in get_all_operators(): op_block_list = ["--op_block_list", op] - task.run(mixed_precision_baseline + op_block_list, f"Mixed precision baseline + {op} in FP32") + task.run( + mixed_precision_baseline + op_block_list, + f"Mixed precision baseline + {op} in FP32", + ) def run_tuning_step2(task, mixed_precision_baseline): @@ -326,16 +382,21 @@ def run_tuning_step2(task, mixed_precision_baseline): Step 2 is to figure out a combination of two operators (one is Add from step one) to get better result """ for op in get_all_operators(): - if op not in ['Add']: - op_block_list = ["--op_block_list", 'Add', op] - task.run(mixed_precision_baseline + op_block_list, f"Mixed precision baseline + Add,{op} in FP32") + if op not in ["Add"]: + op_block_list = ["--op_block_list", "Add", op] + task.run( + mixed_precision_baseline + op_block_list, + f"Mixed precision baseline + Add,{op} in FP32", + ) def run_parity_disable_half2(task: ParityTask, args): - onnx_model_paths = Gpt2Helper.get_onnx_paths('onnx_models', - args.model_name_or_path, - new_folder=args.use_external_data_format, - remove_existing=[]) + onnx_model_paths = Gpt2Helper.get_onnx_paths( + "onnx_models", + args.model_name_or_path, + new_folder=args.use_external_data_format, + remove_existing=[], + ) last_matmul_node_name = get_last_matmul_node_name(onnx_model_paths["raw"]) run_candidate(task, args, last_matmul_node_name, op_block_list=[]) run_candidate(task, args, last_matmul_node_name, op_block_list=["Add"]) @@ -343,10 +404,12 @@ def run_parity_disable_half2(task: ParityTask, args): def run_parity(task: ParityTask, args): - onnx_model_paths = Gpt2Helper.get_onnx_paths('onnx_models', - args.model_name_or_path, - new_folder=args.use_external_data_format, - remove_existing=[]) + onnx_model_paths = Gpt2Helper.get_onnx_paths( + "onnx_models", + args.model_name_or_path, + new_folder=args.use_external_data_format, + remove_existing=[], + ) fp32_baseline, fp16_baseline = get_baselines(args) @@ -373,17 +436,36 @@ def run_parity(task: ParityTask, args): run_tuning_step1(task, mixed_precision_baseline) run_tuning_step2(task, mixed_precision_baseline) else: - run_candidate(task, args, last_matmul_node_name, op_block_list=["LayerNormalization", "Add"]) + run_candidate( + task, + args, + last_matmul_node_name, + op_block_list=["LayerNormalization", "Add"], + ) run_candidate(task, args, last_matmul_node_name, op_block_list=["FastGelu", "Add"]) # Run a few good candidates - run_candidate(task, args, last_matmul_node_name, op_block_list=["FastGelu", "LayerNormalization", "Add"]) - run_candidate(task, args, last_matmul_node_name, op_block_list=["FastGelu", "LayerNormalization", "Add", "Gather"]) - run_candidate(task, args, last_matmul_node_name, \ - op_block_list=["FastGelu", "LayerNormalization", "Add", "Gather", "MatMul"]) - - -if __name__ == '__main__': + run_candidate( + task, + args, + last_matmul_node_name, + op_block_list=["FastGelu", "LayerNormalization", "Add"], + ) + run_candidate( + task, + args, + last_matmul_node_name, + op_block_list=["FastGelu", "LayerNormalization", "Add", "Gather"], + ) + run_candidate( + task, + args, + last_matmul_node_name, + op_block_list=["FastGelu", "LayerNormalization", "Add", "Gather", "MatMul"], + ) + + +if __name__ == "__main__": args = parse_arguments() setup_logger(args.verbose) @@ -395,9 +477,10 @@ def run_parity(task: ParityTask, args): task = ParityTask(args.test_cases, args.runs, args.csv) if not args.skip_test: - if (os.getenv('ORT_CUDA_GEMM_OPTIONS') == "4" and args.use_gpu): - assert torch.cuda.get_device_capability( - )[0] >= 7, "half2 kernel is not avaiable in current GPU device. Please set environment variable ORT_CUDA_GEMM_OPTIONS=0 or use supported GPU like V100 or T4" + if os.getenv("ORT_CUDA_GEMM_OPTIONS") == "4" and args.use_gpu: + assert ( + torch.cuda.get_device_capability()[0] >= 7 + ), "half2 kernel is not avaiable in current GPU device. Please set environment variable ORT_CUDA_GEMM_OPTIONS=0 or use supported GPU like V100 or T4" run_parity_disable_half2(task, args) else: run_parity(task, args) @@ -409,5 +492,5 @@ def run_parity(task: ParityTask, args): rows = task.results logger.info("Start running significance tests...") - summary_csv = task.csv_path.replace('.csv', ".stats.csv") + summary_csv = task.csv_path.replace(".csv", ".stats.csv") run_significance_test(rows, summary_csv) diff --git a/onnxruntime/python/tools/transformers/models/gpt2/gpt2_tester.py b/onnxruntime/python/tools/transformers/models/gpt2/gpt2_tester.py index b2b7532786cf9..be303b4e188bf 100644 --- a/onnxruntime/python/tools/transformers/models/gpt2/gpt2_tester.py +++ b/onnxruntime/python/tools/transformers/models/gpt2/gpt2_tester.py @@ -5,17 +5,17 @@ # -------------------------------------------------------------------------- # This script helps evaluation of GPT-2 model. import logging -import torch -import numpy -import timeit import math +import os import statistics -from gpt2_helper import Gpt2Helper, Gpt2Inputs - import sys -import os +import timeit + +import numpy +import torch +from gpt2_helper import Gpt2Helper, Gpt2Inputs -sys.path.append(os.path.join(os.path.dirname(__file__), '..', '..')) +sys.path.append(os.path.join(os.path.dirname(__file__), "..", "..")) from benchmark_helper import Precision @@ -23,8 +23,7 @@ class Gpt2Metric: - - def __init__(self, treatment_name, baseline_name='Torch', top_k=20): + def __init__(self, treatment_name, baseline_name="Torch", top_k=20): assert top_k > 1 and top_k <= 100 self.baseline = baseline_name self.treatment = treatment_name @@ -64,7 +63,7 @@ def print(self): if key == 0: print("\t{}: \t{:.2f} ms".format(key, average)) else: - print("\t[{}, {}]:\t{:.2f} ms".format(2**key, 2**(key + 1) - 1, average)) + print("\t[{}, {}]:\t{:.2f} ms".format(2**key, 2 ** (key + 1) - 1, average)) total += average * len(self.seq_len_latency[key]) count += len(self.seq_len_latency[key]) print("Average Latency: {:.2f} ms".format(total / count)) @@ -100,9 +99,11 @@ def _eval_topk(self, baseline_topk, treatment_topk, top_k, verbose=True): else: if verbose: print( - f"Top {top_k} tokens not matched for {self.name}. This will lead to wrong beam search results") - self.batch_topk_error |= (torch.eq(baseline_topk, treatment_topk).logical_not().sum(1).unsqueeze(dim=1) - > 0) + f"Top {top_k} tokens not matched for {self.name}. This will lead to wrong beam search results" + ) + self.batch_topk_error |= ( + torch.eq(baseline_topk, treatment_topk).logical_not().sum(1).unsqueeze(dim=1) > 0 + ) def end_batch(self): self.top_1_error += self.batch_top1_error.sum() @@ -116,18 +117,19 @@ def add_latency(self, past_seq_len, latency): class Gpt2Tester: - - def __init__(self, - input_ids, - position_ids, - attention_mask, - num_attention_heads, - hidden_size, - num_layer, - device, - is_fp16=False, - top_k=20, - top_k_required_order=False): + def __init__( + self, + input_ids, + position_ids, + attention_mask, + num_attention_heads, + hidden_size, + num_layer, + device, + is_fp16=False, + top_k=20, + top_k_required_order=False, + ): self.batch_size = input_ids.shape[0] self.input_length = input_ids.shape[1] @@ -142,7 +144,13 @@ def __init__(self, # Emtpy past state for first inference self.past = [] - past_shape = [2, self.batch_size, num_attention_heads, 0, hidden_size // num_attention_heads] + past_shape = [ + 2, + self.batch_size, + num_attention_heads, + 0, + hidden_size // num_attention_heads, + ] for i in range(num_layer): empty_past = torch.empty(past_shape).type(torch.float16 if is_fp16 else torch.float32) self.past.append(empty_past.to(device)) @@ -159,7 +167,7 @@ def get_inputs(self) -> Gpt2Inputs: def save_test_data(self, session, output, save_test_data_dir, test_case_id): from onnx import numpy_helper - path = os.path.join(save_test_data_dir, 'test_data_set_' + str(test_case_id)) + path = os.path.join(save_test_data_dir, "test_data_set_" + str(test_case_id)) if os.path.exists(path): print(f"Directory {path} existed. Skip saving test data") return @@ -179,17 +187,18 @@ def add_tensor(input_tensors, torch_tensor, name): add_tensor(input_tensors, self.attention_mask, "attention_mask") for i in range(self.n_layer): - add_tensor(input_tensors, self.past[i], 'past_' + str(i)) + add_tensor(input_tensors, self.past[i], "past_" + str(i)) for i, tensor in enumerate(input_tensors): - with open(os.path.join(path, 'input_{}.pb'.format(i)), 'wb') as f: + with open(os.path.join(path, "input_{}.pb".format(i)), "wb") as f: f.write(tensor.SerializeToString()) output_names = [output.name for output in session.get_outputs()] for i, name in enumerate(output_names): tensor = numpy_helper.from_array( - output[i] if isinstance(output[i], numpy.ndarray) else output[i].clone().cpu().numpy()) - with open(os.path.join(path, 'output_{}.pb'.format(i)), 'wb') as f: + output[i] if isinstance(output[i], numpy.ndarray) else output[i].clone().cpu().numpy() + ) + with open(os.path.join(path, "output_{}.pb".format(i)), "wb") as f: f.write(tensor.SerializeToString()) print(f"Test data saved to directory {path}") @@ -198,8 +207,9 @@ def update(self, output, step, device): """ Update the inputs for next inference. """ - self.logits = torch.from_numpy(output[0]) if isinstance(output[0], - numpy.ndarray) else output[0].clone().detach().cpu() + self.logits = ( + torch.from_numpy(output[0]) if isinstance(output[0], numpy.ndarray) else output[0].clone().detach().cpu() + ) self.top_1_tokens = Gpt2Tester.predict_next_token(self.logits) self.top_k_tokens = Gpt2Tester.predict_next_token(self.logits, self.top_k, self.top_k_required_order) @@ -207,13 +217,18 @@ def update(self, output, step, device): self.input_ids = self.top_1_tokens.clone().detach().reshape([self.batch_size, 1]).to(device) if self.has_position_ids: - self.position_ids = torch.tensor([self.input_length + step - 1]).unsqueeze(0).repeat(self.batch_size, - 1).to(device) + self.position_ids = ( + torch.tensor([self.input_length + step - 1]).unsqueeze(0).repeat(self.batch_size, 1).to(device) + ) if self.has_attention_mask: self.attention_mask = torch.cat( - [self.attention_mask, - torch.ones([self.batch_size, 1]).type_as(self.attention_mask)], 1).to(device) + [ + self.attention_mask, + torch.ones([self.batch_size, 1]).type_as(self.attention_mask), + ], + 1, + ).to(device) self.past = [] @@ -221,8 +236,11 @@ def update(self, output, step, device): self.past = list(output[1]) else: for i in range(self.n_layer): - past_i = torch.from_numpy(output[i + 1]) if isinstance( - output[i + 1], numpy.ndarray) else output[i + 1].clone().detach() + past_i = ( + torch.from_numpy(output[i + 1]) + if isinstance(output[i + 1], numpy.ndarray) + else output[i + 1].clone().detach() + ) self.past.append(past_i.to(device)) def diff(self, baseline): @@ -234,18 +252,26 @@ def diff(self, baseline): if self.logits is not None: max_io_diff = (self.logits - baseline.logits).abs().max() if max_io_diff > 1e-4: - print(f'Max logits difference is too large: {max_io_diff}') + print(f"Max logits difference is too large: {max_io_diff}") if not torch.all(self.input_ids == baseline.input_ids): - print('Input_ids is different', self.input_ids, baseline.input_ids) + print("Input_ids is different", self.input_ids, baseline.input_ids) if self.has_position_ids: if not torch.all(self.position_ids == baseline.position_ids): - print('position_ids is different', self.position_ids, baseline.position_ids) + print( + "position_ids is different", + self.position_ids, + baseline.position_ids, + ) if self.has_attention_mask: if not torch.all(self.attention_mask == baseline.attention_mask): - print('attention_mask is different', self.attention_mask, baseline.attention_mask) + print( + "attention_mask is different", + self.attention_mask, + baseline.attention_mask, + ) assert len(self.past) == len(baseline.past) @@ -282,10 +308,16 @@ def diff_present(onnx_output, onnx_io_output, n_layer): """ present_diff_max = [] for i in range(n_layer): - onnx_present_i = torch.from_numpy(onnx_output[i + 1]) if isinstance(onnx_output[i + 1], - numpy.ndarray) else onnx_output[i + 1] - onnx_io_present_i = torch.from_numpy(onnx_io_output[i + 1]) if isinstance( - onnx_io_output[i + 1], numpy.ndarray) else onnx_io_output[i + 1] + onnx_present_i = ( + torch.from_numpy(onnx_output[i + 1]) + if isinstance(onnx_output[i + 1], numpy.ndarray) + else onnx_output[i + 1] + ) + onnx_io_present_i = ( + torch.from_numpy(onnx_io_output[i + 1]) + if isinstance(onnx_io_output[i + 1], numpy.ndarray) + else onnx_io_output[i + 1] + ) max_diff = (onnx_present_i - onnx_io_present_i).abs().max() present_diff_max.append(max_diff) print(f"present_diff_max={present_diff_max}") @@ -296,24 +328,28 @@ def is_quantized_onnx_model(onnx_model_path): Returns True if the ONNX model is quantized. """ from onnx import load + model = load(onnx_model_path) from onnxruntime.quantization.quantize import __producer__ as quantize_producer + return model.producer_name == quantize_producer @staticmethod - def test_generation(session, - model, - device, - test_inputs, - precision=Precision.FLOAT32, - model_class='Gpt2LMHeadModel', - top_k=20, - top_k_no_order=True, - max_steps=24, - max_inputs=0, - verbose=False, - save_test_data=0, - save_test_data_dir='.'): + def test_generation( + session, + model, + device, + test_inputs, + precision=Precision.FLOAT32, + model_class="Gpt2LMHeadModel", + top_k=20, + top_k_no_order=True, + max_steps=24, + max_inputs=0, + verbose=False, + save_test_data=0, + save_test_data_dir=".", + ): """ Test Generation using greedy beam search (without sampling) to compare PyTorch and ONNX model. It will print top 1 and top k errors on the given test inputs. @@ -327,29 +363,31 @@ def test_generation(session, eos_token_id = model.config.eos_token_id test_data_saved = 0 - is_float16 = (precision == Precision.FLOAT16) + is_float16 = precision == Precision.FLOAT16 if is_float16: - assert 'float16' in session.get_outputs()[0].type + assert "float16" in session.get_outputs()[0].type # We will still use fp32 torch model as baseline when onnx model if fp16 model.eval().to(device) # Allocate initial buffers for IO Binding of ONNX Runtimne. The buffer size will automatically increase later. - init_output_shapes = Gpt2Helper.get_output_shapes(batch_size=4, - past_sequence_length=128, - sequence_length=32, - config=model.config, - model_class=model_class) + init_output_shapes = Gpt2Helper.get_output_shapes( + batch_size=4, + past_sequence_length=128, + sequence_length=32, + config=model.config, + model_class=model_class, + ) output_buffers = Gpt2Helper.get_output_buffers(init_output_shapes, device, is_float16=is_float16) - baseline_name = 'Torch' - treatment_name = 'Quantized Onnx' if precision == Precision.INT8 else "Onnx" + baseline_name = "Torch" + treatment_name = "Quantized Onnx" if precision == Precision.INT8 else "Onnx" torch_metric = Gpt2Metric(baseline_name, baseline_name, top_k) onnx_metric = Gpt2Metric(treatment_name, baseline_name, top_k) - onnx_io_metric = Gpt2Metric(treatment_name + ' with IO Binding', baseline_name, top_k) + onnx_io_metric = Gpt2Metric(treatment_name + " with IO Binding", baseline_name, top_k) for i, inputs in enumerate(test_inputs): - if (max_inputs > 0 and i == max_inputs): + if max_inputs > 0 and i == max_inputs: break if i % 10 == 0: print(f"{i}") @@ -357,12 +395,42 @@ def test_generation(session, position_ids = inputs["position_ids"] if "position_ids" in inputs else None attention_mask = inputs["attention_mask"] if "attention_mask" in inputs else None - onnx_runner = Gpt2Tester(input_ids, position_ids, attention_mask, n_head, n_embd, n_layer, device, - is_float16, top_k, not top_k_no_order) - onnx_io_runner = Gpt2Tester(input_ids, position_ids, attention_mask, n_head, n_embd, n_layer, device, - is_float16, top_k, not top_k_no_order) - torch_runner = Gpt2Tester(input_ids, position_ids, attention_mask, n_head, n_embd, n_layer, device, False, - top_k, not top_k_no_order) # Torch model baseline is fp32 + onnx_runner = Gpt2Tester( + input_ids, + position_ids, + attention_mask, + n_head, + n_embd, + n_layer, + device, + is_float16, + top_k, + not top_k_no_order, + ) + onnx_io_runner = Gpt2Tester( + input_ids, + position_ids, + attention_mask, + n_head, + n_embd, + n_layer, + device, + is_float16, + top_k, + not top_k_no_order, + ) + torch_runner = Gpt2Tester( + input_ids, + position_ids, + attention_mask, + n_head, + n_embd, + n_layer, + device, + False, + top_k, + not top_k_no_order, + ) # Torch model baseline is fp32 batch_size = torch_runner.batch_size onnx_metric.start_batch(batch_size) @@ -379,27 +447,30 @@ def test_generation(session, torch_metric.add_latency(past_seq_len, timeit.default_timer() - start_time) torch_runner.update(pytorch_output, step, device) - onnx_output, avg_latency_ms = Gpt2Helper.onnxruntime_inference(session, - onnx_runner.get_inputs(), - total_runs=1) + onnx_output, avg_latency_ms = Gpt2Helper.onnxruntime_inference( + session, onnx_runner.get_inputs(), total_runs=1 + ) onnx_metric.add_latency(past_seq_len, avg_latency_ms / 1000.0) onnx_runner.update(onnx_output, step, device) - output_shapes = Gpt2Helper.get_output_shapes(batch_size, - past_seq_len, - seq_len, - model.config, - model_class=model_class) + output_shapes = Gpt2Helper.get_output_shapes( + batch_size, + past_seq_len, + seq_len, + model.config, + model_class=model_class, + ) Gpt2Helper.auto_increase_buffer_size(output_buffers, output_shapes) - onnx_io_output, avg_latency_ms = Gpt2Helper.onnxruntime_inference_with_binded_io( + (onnx_io_output, avg_latency_ms,) = Gpt2Helper.onnxruntime_inference_with_binded_io( session, onnx_io_runner.get_inputs(), output_buffers, output_shapes, total_runs=1, return_numpy=False, - include_copy_output_latency=True) + include_copy_output_latency=True, + ) onnx_io_metric.add_latency(past_seq_len, avg_latency_ms / 1000.0) if test_data_saved < save_test_data: diff --git a/onnxruntime/python/tools/transformers/models/gpt2/parity_check_helper.py b/onnxruntime/python/tools/transformers/models/gpt2/parity_check_helper.py index 2f878fd3cedac..c122e243293aa 100644 --- a/onnxruntime/python/tools/transformers/models/gpt2/parity_check_helper.py +++ b/onnxruntime/python/tools/transformers/models/gpt2/parity_check_helper.py @@ -8,17 +8,16 @@ import math import multiprocessing -import numpy import os -import torch +import sys from pathlib import Path -from onnx import numpy_helper, TensorProto -from gpt2_helper import Gpt2Helper -import sys -import os +import numpy +import torch +from gpt2_helper import Gpt2Helper +from onnx import TensorProto, numpy_helper -sys.path.append(os.path.join(os.path.dirname(__file__), '..', '..')) +sys.path.append(os.path.join(os.path.dirname(__file__), "..", "..")) from benchmark_helper import create_onnxruntime_session @@ -47,10 +46,14 @@ def environ_setting_paths(output_path): def environ_reset(): for flag in [ - "ORT_DEBUG_NODE_IO_DUMP_SHAPE_DATA", "ORT_DEBUG_NODE_IO_DUMP_INPUT_DATA", - "ORT_DEBUG_NODE_IO_DUMP_OUTPUT_DATA", "ORT_DEBUG_NODE_IO_NAME_FILTER", "ORT_DEBUG_NODE_IO_OP_TYPE_FILTER", - "ORT_DEBUG_NODE_IO_DUMP_DATA_TO_FILES", "ORT_DEBUG_NODE_IO_OUTPUT_DIR", - "ORT_DEBUG_NODE_IO_DUMPING_DATA_TO_FILES_FOR_ALL_NODES_IS_OK" + "ORT_DEBUG_NODE_IO_DUMP_SHAPE_DATA", + "ORT_DEBUG_NODE_IO_DUMP_INPUT_DATA", + "ORT_DEBUG_NODE_IO_DUMP_OUTPUT_DATA", + "ORT_DEBUG_NODE_IO_NAME_FILTER", + "ORT_DEBUG_NODE_IO_OP_TYPE_FILTER", + "ORT_DEBUG_NODE_IO_DUMP_DATA_TO_FILES", + "ORT_DEBUG_NODE_IO_OUTPUT_DIR", + "ORT_DEBUG_NODE_IO_DUMPING_DATA_TO_FILES_FOR_ALL_NODES_IS_OK", ]: if flag in os.environ: del os.environ[flag] @@ -68,6 +71,7 @@ def generate_outputs_files(model_path, dummy_inputs, outputs_path, use_gpu): dir_path = Path(outputs_path) if dir_path.exists() and dir_path.is_dir(): import shutil + shutil.rmtree(outputs_path) dir_path.mkdir(parents=True, exist_ok=True) @@ -82,15 +86,16 @@ def post_processing(outputs_path, outputs_path_other): if_close = {} import glob - for filename in glob.glob(os.path.join(outputs_path, '*.tensorproto')): + + for filename in glob.glob(os.path.join(outputs_path, "*.tensorproto")): filename_other = os.path.join(outputs_path_other, Path(filename).name) if not os.path.exists(filename_other): continue - with open(filename, 'rb') as f: + with open(filename, "rb") as f: tensor = TensorProto() tensor.ParseFromString(f.read()) array = numpy_helper.to_array(tensor) - with open(filename_other, 'rb') as f: + with open(filename_other, "rb") as f: tensor_other = TensorProto() tensor_other.ParseFromString(f.read()) array_other = numpy_helper.to_array(tensor_other) @@ -109,29 +114,31 @@ def post_processing(outputs_path, outputs_path_other): print(line) -if __name__ == '__main__': +if __name__ == "__main__": # Below example shows how to use this helper to investigate parity issue of gpt-2 fp32 and fp16 onnx model # Please build ORT with --cmake_extra_defines onnxruntime_DEBUG_NODE_INPUTS_OUTPUTS=ON !! - multiprocessing.set_start_method('spawn') + multiprocessing.set_start_method("spawn") # Generate Inputs sequence_length = 8 past_sequence_length = 8 batch_size = 5 - dummy_inputs_fp16 = Gpt2Helper.get_dummy_inputs(batch_size, - past_sequence_length, - sequence_length, - 12, - 768, - 12, - 50257, - device=torch.device("cpu"), - float16=True) + dummy_inputs_fp16 = Gpt2Helper.get_dummy_inputs( + batch_size, + past_sequence_length, + sequence_length, + 12, + 768, + 12, + 50257, + device=torch.device("cpu"), + float16=True, + ) dummy_inputs_fp32 = dummy_inputs_fp16.to_fp32() # Get GPT-2 model from huggingface using convert_to_onnx.py - os.system('python convert_to_onnx.py -m gpt2 --output gpt2_fp32.onnx -o -p fp32 --use_gpu') - os.system('python convert_to_onnx.py -m gpt2 --output gpt2_fp16.onnx -o -p fp16 --use_gpu') + os.system("python convert_to_onnx.py -m gpt2 --output gpt2_fp32.onnx -o -p fp32 --use_gpu") + os.system("python convert_to_onnx.py -m gpt2 --output gpt2_fp16.onnx -o -p fp16 --use_gpu") # Specify the directory to dump the node's I/O outputs_path_fp32_gpu = "./fp32_gpu" diff --git a/onnxruntime/python/tools/transformers/models/longformer/__init__.py b/onnxruntime/python/tools/transformers/models/longformer/__init__.py index 7c2a88f4d9554..cc667396a2622 100644 --- a/onnxruntime/python/tools/transformers/models/longformer/__init__.py +++ b/onnxruntime/python/tools/transformers/models/longformer/__init__.py @@ -1,4 +1,4 @@ -#------------------------------------------------------------------------- +# ------------------------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. -#-------------------------------------------------------------------------- +# -------------------------------------------------------------------------- diff --git a/onnxruntime/python/tools/transformers/models/longformer/benchmark_longformer.py b/onnxruntime/python/tools/transformers/models/longformer/benchmark_longformer.py index 5223521fe345a..5fdd9c9b119b6 100644 --- a/onnxruntime/python/tools/transformers/models/longformer/benchmark_longformer.py +++ b/onnxruntime/python/tools/transformers/models/longformer/benchmark_longformer.py @@ -26,25 +26,35 @@ # python benchmark_longformer.py --model longformer-base-4096 --batch_sizes 1 --sequence_lengths 4096 --global_lengths 8 --onnx_dir . --memory -t 10 # By default, compact memory kernel is not enabled. You need set an environment variable ORT_LONGFORMER_COMPACT_MEMORY=1 to enable it. -import timeit -from datetime import datetime -import csv import argparse +import csv +import math import os import sys -import torch -import onnxruntime +import timeit +from datetime import datetime + import numpy as np -import math +import torch +from longformer_helper import PRETRAINED_LONGFORMER_MODELS, LongformerHelper -from longformer_helper import LongformerHelper, PRETRAINED_LONGFORMER_MODELS +import onnxruntime -sys.path.append(os.path.join(os.path.dirname(__file__), '..', '..')) +sys.path.append(os.path.join(os.path.dirname(__file__), "..", "..")) import benchmark_helper -def test_torch_latency(device, model, model_name, batch_sizes, sequence_lengths, global_lengths, test_times, - num_threads, verbose): +def test_torch_latency( + device, + model, + model_name, + batch_sizes, + sequence_lengths, + global_lengths, + test_times, + num_threads, + verbose, +): if num_threads > 0: torch.set_num_threads(num_threads) @@ -53,14 +63,15 @@ def test_torch_latency(device, model, model_name, batch_sizes, sequence_lengths, for sequence_length in sequence_lengths: for global_length in global_lengths: print(f"batch_size={batch_size} sequence_length={sequence_length} global_length={global_length}...") - inputs: LongforerInputs = LongformerHelper.get_dummy_inputs(batch_size, sequence_length, global_length, - device) + inputs: LongforerInputs = LongformerHelper.get_dummy_inputs( + batch_size, sequence_length, global_length, device + ) input_list = inputs.to_list() _ = model(*input_list) runtimes = timeit.repeat(lambda: model(*input_list), repeat=test_times, number=1) result = { - "engine": "torch", #TODO: test torchscript + "engine": "torch", # TODO: test torchscript "version": torch.__version__, "device": "cuda", "optimizer": "", @@ -87,8 +98,9 @@ def test_parity(device, model, ort_session, batch_size, sequence_length, global_ print( f"Comparing Torch and ORT outputs for batch_size={batch_size} sequence_length={sequence_length} global_length={global_length}..." ) - dummy_inputs: LongforerInputs = LongformerHelper.get_dummy_inputs(batch_size, sequence_length, global_length, - device) + dummy_inputs: LongforerInputs = LongformerHelper.get_dummy_inputs( + batch_size, sequence_length, global_length, device + ) ort_inputs = dummy_inputs.get_ort_inputs() ort_outputs = ort_session.run(None, ort_inputs) input_list = dummy_inputs.to_list() @@ -101,32 +113,36 @@ def test_parity(device, model, ort_session, batch_size, sequence_length, global_ return max_diff -def test_ort_latency(device, - model, - model_name, - description, - ort_session, - batch_sizes, - sequence_lengths, - global_lengths, - test_times, - num_threads, - optimizer=False, - precision='fp32', - validate_onnx=True, - disable_io_binding=False, - verbose=True): +def test_ort_latency( + device, + model, + model_name, + description, + ort_session, + batch_sizes, + sequence_lengths, + global_lengths, + test_times, + num_threads, + optimizer=False, + precision="fp32", + validate_onnx=True, + disable_io_binding=False, + verbose=True, +): results = [] for batch_size in batch_sizes: for sequence_length in sequence_lengths: for global_length in global_lengths: - assert global_length <= model.config.attention_window[ - 0], "Limitation of current implementation: number of global token <= attention_window" + assert ( + global_length <= model.config.attention_window[0] + ), "Limitation of current implementation: number of global token <= attention_window" print( f"Testing batch_size={batch_size} sequence_length={sequence_length} global_length={global_length} optimizer={optimizer}, precision={precision} io_binding={not disable_io_binding}..." ) - dummy_inputs: LongforerInputs = LongformerHelper.get_dummy_inputs(batch_size, sequence_length, - global_length, device) + dummy_inputs: LongforerInputs = LongformerHelper.get_dummy_inputs( + batch_size, sequence_length, global_length, device + ) # Run OnnxRuntime ort_inputs = dummy_inputs.get_ort_inputs() @@ -169,37 +185,57 @@ def test_ort_latency(device, output_buffer_max_sizes=[max_last_state_size, max_pooler_size], batch_size=batch_size, device=device, - data_type=np.longlong, #input data type + data_type=np.longlong, # input data type ) else: - result = benchmark_helper.inference_ort(ort_session, - ort_inputs, - result_template=result_template, - repeat_times=test_times, - batch_size=batch_size) + result = benchmark_helper.inference_ort( + ort_session, + ort_inputs, + result_template=result_template, + repeat_times=test_times, + batch_size=batch_size, + ) if validate_onnx: - max_diff = test_parity(device, model, ort_session, batch_size, sequence_length, global_length, - verbose) + max_diff = test_parity( + device, + model, + ort_session, + batch_size, + sequence_length, + global_length, + verbose, + ) result["description"] += f"(max_diff={max_diff})" results.append(result) return results -def test_ort_memory(device, onnx_model_path, batch_size, sequence_length, global_length, test_times, num_threads): +def test_ort_memory( + device, + onnx_model_path, + batch_size, + sequence_length, + global_length, + test_times, + num_threads, +): print( f"Testing memory for model={onnx_model_path}, batch_size={batch_size}, sequence_length={sequence_length}, global_length={global_length}, test_times={test_times}, num_threads={num_threads}" ) def inference(): - session = benchmark_helper.create_onnxruntime_session(onnx_model_path, - use_gpu=True, - enable_all_optimization=True, - num_threads=num_threads) - - dummy_inputs: LongforerInputs = LongformerHelper.get_dummy_inputs(batch_size, sequence_length, global_length, - device) + session = benchmark_helper.create_onnxruntime_session( + onnx_model_path, + use_gpu=True, + enable_all_optimization=True, + num_threads=num_threads, + ) + + dummy_inputs: LongforerInputs = LongformerHelper.get_dummy_inputs( + batch_size, sequence_length, global_length, device + ) ort_inputs = dummy_inputs.get_ort_inputs() for _ in range(test_times): ort_outputs = session.run(None, ort_inputs) @@ -213,24 +249,27 @@ def inference(): "global_length": global_length, "test_times": test_times, "num_threads": num_threads, - "memory": memory_used + "memory": memory_used, } def load_torch_model(model_name, device): - torch_model_name_or_dir = PRETRAINED_LONGFORMER_MODELS[ - model_name] if model_name in PRETRAINED_LONGFORMER_MODELS else model_name + torch_model_name_or_dir = ( + PRETRAINED_LONGFORMER_MODELS[model_name] if model_name in PRETRAINED_LONGFORMER_MODELS else model_name + ) from transformers import LongformerModel + model = LongformerModel.from_pretrained(torch_model_name_or_dir) model.to(device) return model -def find_onnx_model(model_name, onnx_dir='.'): +def find_onnx_model(model_name, onnx_dir="."): # Search onnx model in the following order: optimized fp16 model, optimized fp32 model, raw model # TODO: call convert_longformer_to_onnx to export onnx instead. import os.path + onnx_model_path = os.path.join(onnx_dir, model_name + ".onnx") optimized_fp32_model = os.path.join(onnx_dir, model_name + "_fp32.onnx") optimized_fp16_model = os.path.join(onnx_dir, model_name + "_fp16.onnx") @@ -253,8 +292,15 @@ def test_memory(args, device): onnx_model_path = find_onnx_model(model_name) if not args.onnx else args.onnx torch.cuda.empty_cache() - return test_ort_memory(device, onnx_model_path, args.batch_sizes[0], args.sequence_lengths[0], - args.global_lengths[0], args.test_times, args.num_threads) + return test_ort_memory( + device, + onnx_model_path, + args.batch_sizes[0], + args.sequence_lengths[0], + args.global_lengths[0], + args.test_times, + args.num_threads, + ) def test_ort(args, device): @@ -263,32 +309,57 @@ def test_ort(args, device): onnx_model_path = find_onnx_model(model_name) if not args.onnx else args.onnx optimized = onnx_model_path.endswith("_fp16.onnx") or onnx_model_path.endswith("_fp32.onnx") - precision = 'fp32' if not onnx_model_path.endswith("_fp16.onnx") else 'fp16' + precision = "fp32" if not onnx_model_path.endswith("_fp16.onnx") else "fp16" model = load_torch_model(model_name, device) num_threads = args.num_threads - session = benchmark_helper.create_onnxruntime_session(onnx_model_path, - use_gpu=True, - enable_all_optimization=True, - num_threads=num_threads) + session = benchmark_helper.create_onnxruntime_session( + onnx_model_path, + use_gpu=True, + enable_all_optimization=True, + num_threads=num_threads, + ) if session is None: raise RuntimeError(f"Failed to create ORT session from ONNX file {onnx_model_path}") description = onnx_model_path - if (os.environ.get('ORT_LONGFORMER_COMPACT_MEMORY', '0') == "1"): + if os.environ.get("ORT_LONGFORMER_COMPACT_MEMORY", "0") == "1": description += "[compact_memory]" - return test_ort_latency(device, model, model_name, description, session, args.batch_sizes, args.sequence_lengths, - args.global_lengths, args.test_times, num_threads, optimized, precision, args.validate_onnx, - args.disable_io_binding, args.verbose) + return test_ort_latency( + device, + model, + model_name, + description, + session, + args.batch_sizes, + args.sequence_lengths, + args.global_lengths, + args.test_times, + num_threads, + optimized, + precision, + args.validate_onnx, + args.disable_io_binding, + args.verbose, + ) def test_torch(args, device): model = load_torch_model(args.model, device) - return test_torch_latency(device, model, args.model, args.batch_sizes, args.sequence_lengths, args.global_lengths, - args.test_times, args.num_threads, args.verbose) + return test_torch_latency( + device, + model, + args.model, + args.batch_sizes, + args.sequence_lengths, + args.global_lengths, + args.test_times, + args.num_threads, + args.verbose, + ) def test_latency(args, device): @@ -303,28 +374,34 @@ def test_latency(args, device): def parse_arguments(argv=None): parser = argparse.ArgumentParser() - parser.add_argument("-m", - "--model", - required=False, - type=str, - default="longformer-base-4096", - help="Checkpoint directory or pre-trained model names in the list: " + - ", ".join(PRETRAINED_LONGFORMER_MODELS.keys())) - - parser.add_argument("-e", - "--engine", - required=False, - type=str, - default='onnxruntime', - choices=['onnxruntime', 'torch'], - help="Engine to benchmark.") - - parser.add_argument("-t", - "--test_times", - required=False, - default=1000, - type=int, - help="Number of repeat times to get average inference latency.") + parser.add_argument( + "-m", + "--model", + required=False, + type=str, + default="longformer-base-4096", + help="Checkpoint directory or pre-trained model names in the list: " + + ", ".join(PRETRAINED_LONGFORMER_MODELS.keys()), + ) + + parser.add_argument( + "-e", + "--engine", + required=False, + type=str, + default="onnxruntime", + choices=["onnxruntime", "torch"], + help="Engine to benchmark.", + ) + + parser.add_argument( + "-t", + "--test_times", + required=False, + default=1000, + type=int, + help="Number of repeat times to get average inference latency.", + ) parser.add_argument("-b", "--batch_sizes", nargs="+", type=int, default=[1]) @@ -336,30 +413,50 @@ def parse_arguments(argv=None): nargs="+", type=int, default=[512, 1024, 2048, 4096], - help= - "Sequence lengths. It could have multiple values in latency test. If --export_padding is not used in exporting onnx model, sequence length shall be multiple of window size." + help="Sequence lengths. It could have multiple values in latency test. If --export_padding is not used in exporting onnx model, sequence length shall be multiple of window size.", ) parser.add_argument("--onnx", required=False, type=str, default=None, help="Onnx model path") - parser.add_argument("-g", - "--global_lengths", - nargs="+", - type=int, - default=[0], - help="Number of global tokens. It could have multiple values in latency test.") + parser.add_argument( + "-g", + "--global_lengths", + nargs="+", + type=int, + default=[0], + help="Number of global tokens. It could have multiple values in latency test.", + ) - parser.add_argument("-n", "--num_threads", required=False, type=int, default=0, help="Threads to use.") + parser.add_argument( + "-n", + "--num_threads", + required=False, + type=int, + default=0, + help="Threads to use.", + ) - parser.add_argument("-v", - "--validate_onnx", - required=False, - action="store_true", - help="Validate that ONNX model generates same output as PyTorch model.") + parser.add_argument( + "-v", + "--validate_onnx", + required=False, + action="store_true", + help="Validate that ONNX model generates same output as PyTorch model.", + ) - parser.add_argument("--disable_io_binding", required=False, action="store_true", help="Do not use IO Binding.") + parser.add_argument( + "--disable_io_binding", + required=False, + action="store_true", + help="Do not use IO Binding.", + ) - parser.add_argument("--memory", required=False, action="store_true", help="Test memory usage instead of latency.") + parser.add_argument( + "--memory", + required=False, + action="store_true", + help="Test memory usage instead of latency.", + ) parser.add_argument("--verbose", required=False, action="store_true", help="Print more information.") @@ -369,17 +466,35 @@ def parse_arguments(argv=None): def output_details(results, csv_filename): - latency_results = [result for result in results if 'average_latency_ms' in result] + latency_results = [result for result in results if "average_latency_ms" in result] if len(latency_results) == 0: print("No latency results for output.") return - with open(csv_filename, mode="a", newline='') as csv_file: + with open(csv_filename, mode="a", newline="") as csv_file: column_names = [ - "engine", "version", "device", "precision", "optimizer", "io_binding", "model_name", "inputs", "threads", - "datetime", "test_times", "description", "batch_size", "sequence_length", "global_length", "memory", "QPS", - "average_latency_ms", "latency_variance", "latency_90_percentile", "latency_95_percentile", - "latency_99_percentile" + "engine", + "version", + "device", + "precision", + "optimizer", + "io_binding", + "model_name", + "inputs", + "threads", + "datetime", + "test_times", + "description", + "batch_size", + "sequence_length", + "global_length", + "memory", + "QPS", + "average_latency_ms", + "latency_variance", + "latency_90_percentile", + "latency_95_percentile", + "latency_99_percentile", ] csv_writer = csv.DictWriter(csv_file, fieldnames=column_names) @@ -399,10 +514,10 @@ def run(args): torch.set_grad_enabled(False) # set random seed manully to get deterministic results - #benchmark_helper.set_random_seed(123) + # benchmark_helper.set_random_seed(123) # Currently, the longformer attention operator could only run in GPU (no CPU implementation yet). - device = torch.device('cuda:0') + device = torch.device("cuda:0") if args.memory: return test_memory(args, device) @@ -414,18 +529,23 @@ def test_all(): results = [] test_times = 100 sequence_lengths = [512, 1024, 2048, 4096] - for model_name in ['longformer-base-4096']: + for model_name in ["longformer-base-4096"]: for batch_size in [1]: for sequence_length in sequence_lengths: for global_length in [8]: - engine_name = 'torch' + engine_name = "torch" args = parse_arguments( - f"-e {engine_name} -t {test_times} -b {batch_size} -s {sequence_length} -g {global_length} -t {test_times} -m {model_name}" - .split(' ')) + f"-e {engine_name} -t {test_times} -b {batch_size} -s {sequence_length} -g {global_length} -t {test_times} -m {model_name}".split( + " " + ) + ) results += run(args) - engine_name = 'onnxruntime' - onnx_paths = [f"{model_name}_fp32.onnx", f"{model_name}_fp16.onnx"] # optimized models + engine_name = "onnxruntime" + onnx_paths = [ + f"{model_name}_fp32.onnx", + f"{model_name}_fp16.onnx", + ] # optimized models for onnx_path in onnx_paths: if os.path.exists(onnx_path): for compact_memory in ["0", "1"]: @@ -433,14 +553,18 @@ def test_all(): print("ORT_LONGFORMER_COMPACT_MEMORY=", compact_memory) args = parse_arguments( - f"--disable_io_binding -e {engine_name} --onnx {onnx_path} -t {test_times} -b {batch_size} -s {sequence_length} -g {global_length} -t 10 -m {model_name} --memory" - .split(' ')) + f"--disable_io_binding -e {engine_name} --onnx {onnx_path} -t {test_times} -b {batch_size} -s {sequence_length} -g {global_length} -t 10 -m {model_name} --memory".split( + " " + ) + ) memory_results = run(args) print(memory_results) args = parse_arguments( - f"--disable_io_binding -e {engine_name} --onnx {onnx_path} -t {test_times} -b {batch_size} -s {sequence_length} -g {global_length} -t {test_times} -m {model_name} --validate_onnx" - .split(' ')) + f"--disable_io_binding -e {engine_name} --onnx {onnx_path} -t {test_times} -b {batch_size} -s {sequence_length} -g {global_length} -t {test_times} -m {model_name} --validate_onnx".split( + " " + ) + ) latency_results = run(args) if len(latency_results) == 1: latency_results[0]["memory"] = memory_results["memory"] diff --git a/onnxruntime/python/tools/transformers/models/longformer/convert_to_onnx.py b/onnxruntime/python/tools/transformers/models/longformer/convert_to_onnx.py index 02e9473ef42e8..f94a4154240c0 100644 --- a/onnxruntime/python/tools/transformers/models/longformer/convert_to_onnx.py +++ b/onnxruntime/python/tools/transformers/models/longformer/convert_to_onnx.py @@ -12,7 +12,7 @@ # # It is tested in Ubuntu 18.04 with python 3.8, onnxruntime-gpu 1.11.0, PyTorch 1.9.0, transformers 4.18.0. # Warning: Using newer version (1.10 or 1.11) of PyTorch might encounter issue in exporting, but they are fine for benchmarking. -# +# # Example commands for exporting longformer base model in Linux: # cd ./torch_extensions # python setup.py install @@ -23,78 +23,98 @@ # # For inference of the onnx model, you will need onnxruntime-gpu 1.7.0 or newer version. -import sys +import argparse import os -import torch +import sys +from pathlib import Path + import numpy as np -import argparse +import torch import transformers +from longformer_helper import PRETRAINED_LONGFORMER_MODELS, LongformerHelper +from packaging import version from torch.onnx import register_custom_op_symbolic from torch.onnx.symbolic_helper import parse_args -from packaging import version -from pathlib import Path -from longformer_helper import LongformerHelper, PRETRAINED_LONGFORMER_MODELS -sys.path.append(os.path.join(os.path.dirname(__file__), '..', '..')) +sys.path.append(os.path.join(os.path.dirname(__file__), "..", "..")) from torch_onnx_export_helper import torch_onnx_export -@parse_args('v', 'v', 'v', 'v', 'v', 'v', 'v', 'i', 'i') -def my_longformer_attention(g, input, weight, bias, mask, global_weight, global_bias, global_mask, num_heads, window): - return g.op("com.microsoft::LongformerAttention", - input, - weight, - bias, - mask, - global_weight, - global_bias, - global_mask, - num_heads_i=num_heads, - window_i=window) +@parse_args("v", "v", "v", "v", "v", "v", "v", "i", "i") +def my_longformer_attention( + g, + input, + weight, + bias, + mask, + global_weight, + global_bias, + global_mask, + num_heads, + window, +): + return g.op( + "com.microsoft::LongformerAttention", + input, + weight, + bias, + mask, + global_weight, + global_bias, + global_mask, + num_heads_i=num_heads, + window_i=window, + ) # namespace is onnxruntime which is registered in longformer_attention.cpp -register_custom_op_symbolic('onnxruntime::LongformerAttention', my_longformer_attention, 9) +register_custom_op_symbolic("onnxruntime::LongformerAttention", my_longformer_attention, 9) # TODO: search the directory to find correct output filename of "python setup.py install" when python version is not 3.8 torch.ops.load_library( - r'./torch_extensions/build/lib.linux-x86_64-3.8/longformer_attention.cpython-38-x86_64-linux-gnu.so') + r"./torch_extensions/build/lib.linux-x86_64-3.8/longformer_attention.cpython-38-x86_64-linux-gnu.so" +) def parse_arguments(): parser = argparse.ArgumentParser() - parser.add_argument("-m", - "--model", - required=False, - type=str, - default="longformer-base-4096", - help="Checkpoint directory or pre-trained model names in the list: " + - ", ".join(PRETRAINED_LONGFORMER_MODELS.keys())) + parser.add_argument( + "-m", + "--model", + required=False, + type=str, + default="longformer-base-4096", + help="Checkpoint directory or pre-trained model names in the list: " + + ", ".join(PRETRAINED_LONGFORMER_MODELS.keys()), + ) parser.add_argument( - '--export_padding', + "--export_padding", required=False, - action='store_true', - help= - 'Export padding logic to ONNX graph. If not enabled, user need pad input so that sequence length is multiple of window size.' + action="store_true", + help="Export padding logic to ONNX graph. If not enabled, user need pad input so that sequence length is multiple of window size.", ) parser.set_defaults(export_padding=False) - parser.add_argument('-o', - '--optimize_onnx', - required=False, - action='store_true', - help='Use optimizer.py to optimize onnx model.') + parser.add_argument( + "-o", + "--optimize_onnx", + required=False, + action="store_true", + help="Use optimizer.py to optimize onnx model.", + ) parser.set_defaults(optimize_onnx=False) - parser.add_argument("-p", - "--precision", - required=False, - type=str, - default='fp32', - choices=['fp32', 'fp16'], - help="Precision of model to run: fp32 for full precision, fp16 for mixed precision") + parser.add_argument( + "-p", + "--precision", + required=False, + type=str, + default="fp32", + choices=["fp32", "fp16"], + help="Precision of model to run: fp32 for full precision, fp16 for mixed precision", + ) args = parser.parse_args() return args @@ -120,16 +140,22 @@ def get_dummy_inputs(config, export_padding, device): # A new function to replace LongformerSelfAttention.forward # For transformers 4.0.0 -def my_longformer_self_attention_forward_4(self, - hidden_states, - attention_mask=None, - is_index_masked=None, - is_index_global_attn=None, - is_global_attn=None): +def my_longformer_self_attention_forward_4( + self, + hidden_states, + attention_mask=None, + is_index_masked=None, + is_index_global_attn=None, + is_global_attn=None, +): global_mask = is_index_global_attn.int() # The following check is based on the dummy inputs (only the first token is global). - assert len(global_mask.shape) == 2 and global_mask.shape[0] == 1 and global_mask.count_nonzero().item( - ) == 1 and global_mask.tolist()[0][0] == 1 + assert ( + len(global_mask.shape) == 2 + and global_mask.shape[0] == 1 + and global_mask.count_nonzero().item() == 1 + and global_mask.tolist()[0][0] == 1 + ) input_mask = is_index_masked.float() input_mask = input_mask.masked_fill(is_index_masked, -10000.0) @@ -139,90 +165,132 @@ def my_longformer_self_attention_forward_4(self, # TODO: add postprocess of ONNX model to use graph input directly: glboal_mask = global_attention_mask # The following check is based on the dummy inputs (only the last token is masked). - assert len(input_mask.shape) == 2 and input_mask.shape[0] == 1 and input_mask.count_nonzero().item( - ) == 1 and input_mask.tolist()[0][-1] == -10000.0 + assert ( + len(input_mask.shape) == 2 + and input_mask.shape[0] == 1 + and input_mask.count_nonzero().item() == 1 + and input_mask.tolist()[0][-1] == -10000.0 + ) weight = torch.stack( - (self.query.weight.transpose(0, 1), \ - self.key.weight.transpose(0, 1), \ - self.value.weight.transpose(0, 1)), dim=1) + ( + self.query.weight.transpose(0, 1), + self.key.weight.transpose(0, 1), + self.value.weight.transpose(0, 1), + ), + dim=1, + ) weight = weight.reshape(self.embed_dim, 3 * self.embed_dim) bias = torch.stack((self.query.bias, self.key.bias, self.value.bias), dim=0) bias = bias.reshape(3 * self.embed_dim) - global_weight = torch.stack((self.query_global.weight.transpose(0, 1), \ - self.key_global.weight.transpose(0, 1), \ - self.value_global.weight.transpose(0, 1)), - dim=1) + global_weight = torch.stack( + ( + self.query_global.weight.transpose(0, 1), + self.key_global.weight.transpose(0, 1), + self.value_global.weight.transpose(0, 1), + ), + dim=1, + ) global_weight = global_weight.reshape(self.embed_dim, 3 * self.embed_dim) global_bias = torch.stack((self.query_global.bias, self.key_global.bias, self.value_global.bias), dim=0) global_bias = global_bias.reshape(3 * self.embed_dim) - attn_output = torch.ops.onnxruntime.LongformerAttention(hidden_states, weight, bias, input_mask, global_weight, - global_bias, global_mask, self.num_heads, - self.one_sided_attn_window_size) + attn_output = torch.ops.onnxruntime.LongformerAttention( + hidden_states, + weight, + bias, + input_mask, + global_weight, + global_bias, + global_mask, + self.num_heads, + self.one_sided_attn_window_size, + ) assert attn_output.size() == hidden_states.size(), "Unexpected size" - outputs = (attn_output, ) + outputs = (attn_output,) return outputs # For transformers 4.3.0 -def my_longformer_self_attention_forward_4_3(self, - hidden_states, - attention_mask=None, - is_index_masked=None, - is_index_global_attn=None, - is_global_attn=None, - output_attentions=False): +def my_longformer_self_attention_forward_4_3( + self, + hidden_states, + attention_mask=None, + is_index_masked=None, + is_index_global_attn=None, + is_global_attn=None, + output_attentions=False, +): assert output_attentions == False - return my_longformer_self_attention_forward_4(self, hidden_states, attention_mask, is_index_masked, - is_index_global_attn, is_global_attn) + return my_longformer_self_attention_forward_4( + self, + hidden_states, + attention_mask, + is_index_masked, + is_index_global_attn, + is_global_attn, + ) # For transformers 4.3.2 -def my_longformer_self_attention_forward_4_3_2(self, - hidden_states, - attention_mask=None, - layer_head_mask=None, - is_index_masked=None, - is_index_global_attn=None, - is_global_attn=None, - output_attentions=False): +def my_longformer_self_attention_forward_4_3_2( + self, + hidden_states, + attention_mask=None, + layer_head_mask=None, + is_index_masked=None, + is_index_global_attn=None, + is_global_attn=None, + output_attentions=False, +): assert output_attentions == False assert layer_head_mask is None - return my_longformer_self_attention_forward_4(self, hidden_states, attention_mask, is_index_masked, - is_index_global_attn, is_global_attn) + return my_longformer_self_attention_forward_4( + self, + hidden_states, + attention_mask, + is_index_masked, + is_index_global_attn, + is_global_attn, + ) def export_longformer(model, onnx_model_path, export_padding): - input_ids, attention_mask, global_attention_mask = get_dummy_inputs(model.config, - export_padding, - device=torch.device('cpu')) + input_ids, attention_mask, global_attention_mask = get_dummy_inputs( + model.config, export_padding, device=torch.device("cpu") + ) - example_outputs = model(input_ids, attention_mask=attention_mask, global_attention_mask=global_attention_mask) + example_outputs = model( + input_ids, + attention_mask=attention_mask, + global_attention_mask=global_attention_mask, + ) if version.parse(transformers.__version__) < version.parse("4.0.0"): raise RuntimeError("This tool requires transformers 4.0.0 or later.") # Here we replace LongformerSelfAttention.forward using our implmentation for exporting ONNX model - from transformers import LongformerSelfAttention import inspect - key = ' '.join(inspect.getfullargspec(LongformerSelfAttention.forward).args) + + from transformers import LongformerSelfAttention + + key = " ".join(inspect.getfullargspec(LongformerSelfAttention.forward).args) args_to_func = { - 'self hidden_states attention_mask layer_head_mask is_index_masked is_index_global_attn is_global_attn output_attentions': - my_longformer_self_attention_forward_4_3_2, - 'self hidden_states attention_mask is_index_masked is_index_global_attn is_global_attn output_attentions': - my_longformer_self_attention_forward_4_3, - 'self hidden_states attention_mask is_index_masked is_index_global_attn is_global_attn': - my_longformer_self_attention_forward_4, + "self hidden_states attention_mask layer_head_mask is_index_masked is_index_global_attn is_global_attn output_attentions": my_longformer_self_attention_forward_4_3_2, + "self hidden_states attention_mask is_index_masked is_index_global_attn is_global_attn output_attentions": my_longformer_self_attention_forward_4_3, + "self hidden_states attention_mask is_index_masked is_index_global_attn is_global_attn": my_longformer_self_attention_forward_4, } if key not in args_to_func: - print("Current arguments", inspect.getfullargspec(LongformerSelfAttention.forward).args) + print( + "Current arguments", + inspect.getfullargspec(LongformerSelfAttention.forward).args, + ) raise RuntimeError( "LongformerSelfAttention.forward arguments are different. Please install supported version (like transformers 4.3.0)." ) @@ -236,35 +304,22 @@ def export_longformer(model, onnx_model_path, export_padding): Path(onnx_model_path).parent.mkdir(parents=True, exist_ok=True) - torch_onnx_export(model, - example_inputs, - onnx_model_path, - opset_version=12, - input_names=["input_ids", "attention_mask", "global_attention_mask"], - output_names=["last_state", "pooler"], - dynamic_axes={ - 'input_ids': { - 0: 'batch_size', - 1: 'sequence_length' - }, - 'attention_mask': { - 0: 'batch_size', - 1: 'sequence_length' - }, - 'global_attention_mask': { - 0: 'batch_size', - 1: 'sequence_length' - }, - 'last_state': { - 0: 'batch_size', - 1: 'sequence_length' - }, - 'pooler': { - 0: 'batch_size', - 1: 'sequence_length' - } - }, - custom_opsets={"com.microsoft": 1}) + torch_onnx_export( + model, + example_inputs, + onnx_model_path, + opset_version=12, + input_names=["input_ids", "attention_mask", "global_attention_mask"], + output_names=["last_state", "pooler"], + dynamic_axes={ + "input_ids": {0: "batch_size", 1: "sequence_length"}, + "attention_mask": {0: "batch_size", 1: "sequence_length"}, + "global_attention_mask": {0: "batch_size", 1: "sequence_length"}, + "last_state": {0: "batch_size", 1: "sequence_length"}, + "pooler": {0: "batch_size", 1: "sequence_length"}, + }, + custom_opsets={"com.microsoft": 1}, + ) print(f"ONNX model exported to {onnx_model_path}") # Restore original implementaiton: @@ -273,7 +328,9 @@ def export_longformer(model, onnx_model_path, export_padding): def optimize_longformer(onnx_model_path, fp32_model_path, fp16_model_path=None): from onnx import load_model + from onnxruntime.transformers.onnx_model_bert import BertOnnxModel + model = load_model(onnx_model_path, format=None, load_external_data=True) optimizer = BertOnnxModel(model) optimizer.optimize() @@ -294,13 +351,14 @@ def main(args): onnx_model_path = model_name + ".onnx" from transformers import LongformerModel + model = LongformerModel.from_pretrained(PRETRAINED_LONGFORMER_MODELS[model_name]) export_longformer(model, onnx_model_path, args.export_padding) - if args.optimize_onnx or args.precision != 'fp32': + if args.optimize_onnx or args.precision != "fp32": fp32_model_path = model_name + "_fp32.onnx" - fp16_model_path = model_name + "_fp16.onnx" if args.precision == 'fp16' else None + fp16_model_path = model_name + "_fp16.onnx" if args.precision == "fp16" else None optimize_longformer(onnx_model_path, fp32_model_path, fp16_model_path) diff --git a/onnxruntime/python/tools/transformers/models/longformer/generate_test_data.py b/onnxruntime/python/tools/transformers/models/longformer/generate_test_data.py index f0f185b8ed93c..379efce27b27a 100644 --- a/onnxruntime/python/tools/transformers/models/longformer/generate_test_data.py +++ b/onnxruntime/python/tools/transformers/models/longformer/generate_test_data.py @@ -1,63 +1,95 @@ -#------------------------------------------------------------------------- +# ------------------------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. -#-------------------------------------------------------------------------- +# -------------------------------------------------------------------------- # Generate test data for a longformer model, so that we can use onnxruntime_perf_test.exe to evaluate the inference latency. -import sys import argparse -import numpy as np import os import random +import sys from pathlib import Path + +import numpy as np from onnx import ModelProto, TensorProto, numpy_helper -sys.path.append(os.path.join(os.path.dirname(__file__), '..', '..')) -from onnx_model import OnnxModel +sys.path.append(os.path.join(os.path.dirname(__file__), "..", "..")) from bert_test_data import fake_input_ids_data, fake_input_mask_data, output_test_data +from onnx_model import OnnxModel def parse_arguments(): parser = argparse.ArgumentParser() - parser.add_argument('--model', required=True, type=str, help="bert onnx model path.") - - parser.add_argument('--output_dir', - required=False, - type=str, - default=None, - help="output test data path. If not specified, .") - - parser.add_argument('--batch_size', required=False, type=int, default=1, help="batch size of input") - - parser.add_argument('--sequence_length', - required=False, - type=int, - default=128, - help="maximum sequence length of input") - - parser.add_argument('--global_tokens', required=False, type=int, default=10, help="number of global tokens") - - parser.add_argument('--input_ids_name', required=False, type=str, default=None, help="input name for input ids") - - parser.add_argument('--input_mask_name', - required=False, - type=str, - default=None, - help="input name for attention mask") - - parser.add_argument('--global_mask_name', - required=False, - type=str, - default=None, - help="input name for global attention mask") - - parser.add_argument('--samples', required=False, type=int, default=1, help="number of test cases to be generated") - - parser.add_argument('--seed', required=False, type=int, default=3, help="random seed") - - parser.add_argument('--verbose', required=False, action='store_true', help="print verbose information") + parser.add_argument("--model", required=True, type=str, help="bert onnx model path.") + + parser.add_argument( + "--output_dir", + required=False, + type=str, + default=None, + help="output test data path. If not specified, .", + ) + + parser.add_argument("--batch_size", required=False, type=int, default=1, help="batch size of input") + + parser.add_argument( + "--sequence_length", + required=False, + type=int, + default=128, + help="maximum sequence length of input", + ) + + parser.add_argument( + "--global_tokens", + required=False, + type=int, + default=10, + help="number of global tokens", + ) + + parser.add_argument( + "--input_ids_name", + required=False, + type=str, + default=None, + help="input name for input ids", + ) + + parser.add_argument( + "--input_mask_name", + required=False, + type=str, + default=None, + help="input name for attention mask", + ) + + parser.add_argument( + "--global_mask_name", + required=False, + type=str, + default=None, + help="input name for global attention mask", + ) + + parser.add_argument( + "--samples", + required=False, + type=int, + default=1, + help="number of test cases to be generated", + ) + + parser.add_argument("--seed", required=False, type=int, default=3, help="random seed") + + parser.add_argument( + "--verbose", + required=False, + action="store_true", + help="print verbose information", + ) parser.set_defaults(verbose=False) args = parser.parse_args() @@ -135,7 +167,7 @@ def fake_global_mask_data(global_mask, batch_size, sequence_length, num_global_t assert num_global_tokens <= sequence_length data = np.zeros((batch_size, sequence_length), dtype=np.int32) temp = np.ones((batch_size, num_global_tokens), dtype=np.int32) - data[:temp.shape[0], :temp.shape[1]] = temp + data[: temp.shape[0], : temp.shape[1]] = temp else: data = np.zeros((batch_size, sequence_length), dtype=np.int32) @@ -147,17 +179,19 @@ def fake_global_mask_data(global_mask, batch_size, sequence_length, num_global_t return data -def fake_test_data(batch_size, - sequence_length, - test_cases, - dictionary_size, - verbose, - random_seed, - input_ids, - input_mask, - global_mask, - num_global_tokens, - random_mask_length=False): +def fake_test_data( + batch_size, + sequence_length, + test_cases, + dictionary_size, + verbose, + random_seed, + input_ids, + input_mask, + global_mask, + num_global_tokens, + random_mask_length=False, +): """ Generate fake input data for test. """ @@ -175,8 +209,9 @@ def fake_test_data(batch_size, inputs[input_mask.name] = fake_input_mask_data(input_mask, batch_size, sequence_length, random_mask_length) if global_mask: - inputs[global_mask.name] = fake_global_mask_data(global_mask, batch_size, sequence_length, - num_global_tokens) + inputs[global_mask.name] = fake_global_mask_data( + global_mask, batch_size, sequence_length, num_global_tokens + ) if verbose and len(all_inputs) == 0: print("Example inputs", inputs) @@ -185,30 +220,63 @@ def fake_test_data(batch_size, return all_inputs -def generate_test_data(batch_size, - sequence_length, - test_cases, - seed, - verbose, - input_ids, - input_mask, - global_mask, - num_global_tokens, - random_mask_length=False): +def generate_test_data( + batch_size, + sequence_length, + test_cases, + seed, + verbose, + input_ids, + input_mask, + global_mask, + num_global_tokens, + random_mask_length=False, +): dictionary_size = 10000 - all_inputs = fake_test_data(batch_size, sequence_length, test_cases, dictionary_size, verbose, seed, input_ids, - input_mask, global_mask, num_global_tokens, random_mask_length) + all_inputs = fake_test_data( + batch_size, + sequence_length, + test_cases, + dictionary_size, + verbose, + seed, + input_ids, + input_mask, + global_mask, + num_global_tokens, + random_mask_length, + ) if len(all_inputs) != test_cases: print("Failed to create test data for test.") return all_inputs -def create_longformer_test_data(model, output_dir, batch_size, sequence_length, test_cases, seed, verbose, - input_ids_name, input_mask_name, global_mask_name, num_global_tokens): +def create_longformer_test_data( + model, + output_dir, + batch_size, + sequence_length, + test_cases, + seed, + verbose, + input_ids_name, + input_mask_name, + global_mask_name, + num_global_tokens, +): input_ids, input_mask, global_mask = get_longformer_inputs(model, input_ids_name, input_mask_name, global_mask_name) - all_inputs = generate_test_data(batch_size, sequence_length, test_cases, seed, verbose, input_ids, input_mask, - global_mask, num_global_tokens) + all_inputs = generate_test_data( + batch_size, + sequence_length, + test_cases, + seed, + verbose, + input_ids, + input_mask, + global_mask, + num_global_tokens, + ) for i, inputs in enumerate(all_inputs): output_test_data(output_dir, i, inputs) @@ -221,7 +289,9 @@ def main(): if output_dir is None: # Default output directory is a sub-directory under the directory of model. output_dir = os.path.join( - Path(args.model).parent, "b{}_s{}_g{}".format(args.batch_size, args.sequence_length, args.global_tokens)) + Path(args.model).parent, + "b{}_s{}_g{}".format(args.batch_size, args.sequence_length, args.global_tokens), + ) if output_dir is not None: # create the output directory if not existed @@ -230,9 +300,19 @@ def main(): else: print("Directory existed. test data files will be overwritten.") - create_longformer_test_data(args.model, output_dir, args.batch_size, args.sequence_length, args.samples, args.seed, - args.verbose, args.input_ids_name, args.input_mask_name, args.global_mask_name, - args.global_tokens) + create_longformer_test_data( + args.model, + output_dir, + args.batch_size, + args.sequence_length, + args.samples, + args.seed, + args.verbose, + args.input_ids_name, + args.input_mask_name, + args.global_mask_name, + args.global_tokens, + ) print("Test data is saved to directory:", output_dir) diff --git a/onnxruntime/python/tools/transformers/models/longformer/longformer_helper.py b/onnxruntime/python/tools/transformers/models/longformer/longformer_helper.py index 7a7ea17f544b5..6ee547a4754a6 100644 --- a/onnxruntime/python/tools/transformers/models/longformer/longformer_helper.py +++ b/onnxruntime/python/tools/transformers/models/longformer/longformer_helper.py @@ -5,23 +5,24 @@ # -------------------------------------------------------------------------- # This script helps creating dummy inputs for Longformer model. -import os import logging -import torch -import onnx +import os import random -import numpy -import time import re +import time from pathlib import Path -from typing import List, Dict, Tuple, Union +from typing import Dict, List, Tuple, Union + +import numpy +import onnx +import torch logger = logging.getLogger(__name__) PRETRAINED_LONGFORMER_MODELS = { "longformer-base-4096": "allenai/longformer-base-4096", "longformer-large-4096": "allenai/longformer-large-4096", - "longformer-random-tiny": "patrickvonplaten/longformer-random-tiny" # A tiny model for debugging + "longformer-random-tiny": "patrickvonplaten/longformer-random-tiny", # A tiny model for debugging } @@ -46,23 +47,27 @@ def get_ort_inputs(self) -> Dict: class LongformerHelper: - """ A helper class for Longformer model conversion, inference and verification. - """ + """A helper class for Longformer model conversion, inference and verification.""" + @staticmethod - def get_dummy_inputs(batch_size: int, - sequence_length: int, - num_global_tokens: int, - device: torch.device, - vocab_size: int = 100) -> LongformerInputs: - """ Create random inputs for Longformer model. + def get_dummy_inputs( + batch_size: int, + sequence_length: int, + num_global_tokens: int, + device: torch.device, + vocab_size: int = 100, + ) -> LongformerInputs: + """Create random inputs for Longformer model. Returns torch tensors of input_ids, attention_mask and global_attention_mask tensors. """ - input_ids = torch.randint(low=0, - high=vocab_size - 1, - size=(batch_size, sequence_length), - dtype=torch.long, - device=device) + input_ids = torch.randint( + low=0, + high=vocab_size - 1, + size=(batch_size, sequence_length), + dtype=torch.long, + device=device, + ) attention_mask = torch.ones(input_ids.shape, dtype=torch.long, device=device) global_attention_mask = torch.zeros(input_ids.shape, dtype=torch.long, device=device) global_token_index = list(range(num_global_tokens)) @@ -71,6 +76,8 @@ def get_dummy_inputs(batch_size: int, @staticmethod def get_output_shapes(batch_size: int, sequence_length: int, hidden_size: int) -> Dict[str, List[int]]: - """ Returns a dictionary with output name as key, and shape as value. - """ - return {"last_state": [batch_size, sequence_length, hidden_size], "pooler": [batch_size, sequence_length]} + """Returns a dictionary with output name as key, and shape as value.""" + return { + "last_state": [batch_size, sequence_length, hidden_size], + "pooler": [batch_size, sequence_length], + } diff --git a/onnxruntime/python/tools/transformers/models/longformer/torch_extensions/setup.py b/onnxruntime/python/tools/transformers/models/longformer/torch_extensions/setup.py index 37004c29dd241..2be48dd369be3 100644 --- a/onnxruntime/python/tools/transformers/models/longformer/torch_extensions/setup.py +++ b/onnxruntime/python/tools/transformers/models/longformer/torch_extensions/setup.py @@ -1,11 +1,15 @@ from setuptools import setup from torch.utils.cpp_extension import BuildExtension, CppExtension -setup(name='longformer_attention', - ext_modules=[ - CppExtension(name='longformer_attention', - sources=['longformer_attention.cpp'], - include_dirs=[], - extra_compile_args=['-g']) - ], - cmdclass={'build_ext': BuildExtension}) +setup( + name="longformer_attention", + ext_modules=[ + CppExtension( + name="longformer_attention", + sources=["longformer_attention.cpp"], + include_dirs=[], + extra_compile_args=["-g"], + ) + ], + cmdclass={"build_ext": BuildExtension}, +) diff --git a/onnxruntime/python/tools/transformers/models/t5/__init__.py b/onnxruntime/python/tools/transformers/models/t5/__init__.py index 7c2a88f4d9554..cc667396a2622 100644 --- a/onnxruntime/python/tools/transformers/models/t5/__init__.py +++ b/onnxruntime/python/tools/transformers/models/t5/__init__.py @@ -1,4 +1,4 @@ -#------------------------------------------------------------------------- +# ------------------------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. -#-------------------------------------------------------------------------- +# -------------------------------------------------------------------------- diff --git a/onnxruntime/python/tools/transformers/models/t5/convert_to_onnx.py b/onnxruntime/python/tools/transformers/models/t5/convert_to_onnx.py index efce064ac5bc0..72be075089440 100644 --- a/onnxruntime/python/tools/transformers/models/t5/convert_to_onnx.py +++ b/onnxruntime/python/tools/transformers/models/t5/convert_to_onnx.py @@ -4,82 +4,101 @@ # license information. # -------------------------------------------------------------------------- -import os -import sys import argparse +import copy import logging +import os +import sys + import torch -import copy from t5_helper import PRETRAINED_T5_MODELS, T5Helper -sys.path.append(os.path.join(os.path.dirname(__file__), '..', '..')) -from benchmark_helper import setup_logger, prepare_environment, create_onnxruntime_session, Precision +sys.path.append(os.path.join(os.path.dirname(__file__), "..", "..")) +from benchmark_helper import Precision, create_onnxruntime_session, prepare_environment, setup_logger -logger = logging.getLogger('') +logger = logging.getLogger("") def parse_arguments(): parser = argparse.ArgumentParser() - parser.add_argument('-m', - '--model_name_or_path', - required=False, - default=PRETRAINED_T5_MODELS[0], - type=str, - help='Model path, or pretrained model name in the list: ' + ', '.join(PRETRAINED_T5_MODELS)) - - parser.add_argument('--cache_dir', - required=False, - type=str, - default=os.path.join('.', 'cache_models'), - help='Directory to cache pre-trained models') - - parser.add_argument('--output', - required=False, - type=str, - default=os.path.join('.', 'onnx_models'), - help='Output directory') - - parser.add_argument('-o', - '--optimize_onnx', - required=False, - action='store_true', - help='Use optimizer.py to optimize onnx model') + parser.add_argument( + "-m", + "--model_name_or_path", + required=False, + default=PRETRAINED_T5_MODELS[0], + type=str, + help="Model path, or pretrained model name in the list: " + ", ".join(PRETRAINED_T5_MODELS), + ) + + parser.add_argument( + "--cache_dir", + required=False, + type=str, + default=os.path.join(".", "cache_models"), + help="Directory to cache pre-trained models", + ) + + parser.add_argument( + "--output", + required=False, + type=str, + default=os.path.join(".", "onnx_models"), + help="Output directory", + ) + + parser.add_argument( + "-o", + "--optimize_onnx", + required=False, + action="store_true", + help="Use optimizer.py to optimize onnx model", + ) parser.set_defaults(optimize_onnx=False) - parser.add_argument('--use_gpu', required=False, action='store_true', help="use GPU for inference") + parser.add_argument("--use_gpu", required=False, action="store_true", help="use GPU for inference") parser.set_defaults(use_gpu=False) - parser.add_argument("-p", - "--precision", - required=False, - type=Precision, - default=Precision.FLOAT32, - choices=[Precision.FLOAT32, Precision.FLOAT16], - help="Precision of model to run. fp32 for full precision, fp16 for half precision") + parser.add_argument( + "-p", + "--precision", + required=False, + type=Precision, + default=Precision.FLOAT32, + choices=[Precision.FLOAT32, Precision.FLOAT16], + help="Precision of model to run. fp32 for full precision, fp16 for half precision", + ) - parser.add_argument('--verbose', required=False, action='store_true') + parser.add_argument("--verbose", required=False, action="store_true") parser.set_defaults(verbose=False) - parser.add_argument('-e', '--use_external_data_format', required=False, action='store_true') + parser.add_argument("-e", "--use_external_data_format", required=False, action="store_true") parser.set_defaults(use_external_data_format=False) parser.add_argument( - '-s', - '--use_decoder_start_token', + "-s", + "--use_decoder_start_token", required=False, - action='store_true', - help="Use config.decoder_start_token_id in decoding. Otherwise, add an extra graph input for decoder_input_ids." + action="store_true", + help="Use config.decoder_start_token_id in decoding. Otherwise, add an extra graph input for decoder_input_ids.", ) parser.set_defaults(use_decoder_start_token=False) - parser.add_argument('-w', '--overwrite', required=False, action='store_true', help="overwrite existing ONNX model") + parser.add_argument( + "-w", + "--overwrite", + required=False, + action="store_true", + help="overwrite existing ONNX model", + ) parser.set_defaults(overwrite=False) - parser.add_argument('--disable_auto_mixed_precision', - required=False, - action='store_true', - help="use pure fp16 instead of mixed precision") + parser.add_argument( + "--disable_auto_mixed_precision", + required=False, + action="store_true", + help="use pure fp16 instead of mixed precision", + ) parser.set_defaults(disable_auto_mixed_precision=False) args = parser.parse_args() @@ -87,18 +106,20 @@ def parse_arguments(): return args -def export_onnx_models(model_name_or_path, - cache_dir, - output_dir, - use_gpu, - use_external_data_format, - optimize_onnx, - precision, - verbose, - use_decoder_start_token: bool = True, - merge_encoder_and_decoder_init: bool = True, - overwrite: bool = False, - disable_auto_mixed_precision: bool = False): +def export_onnx_models( + model_name_or_path, + cache_dir, + output_dir, + use_gpu, + use_external_data_format, + optimize_onnx, + precision, + verbose, + use_decoder_start_token: bool = True, + merge_encoder_and_decoder_init: bool = True, + overwrite: bool = False, + disable_auto_mixed_precision: bool = False, +): device = torch.device("cuda:0" if use_gpu else "cpu") models = T5Helper.load_model(model_name_or_path, cache_dir, device, merge_encoder_and_decoder_init) @@ -112,40 +133,48 @@ def export_onnx_models(model_name_or_path, model.to(device) filename_suffix = "_" + name - onnx_path = T5Helper.get_onnx_path(output_dir, - model_name_or_path, - suffix=filename_suffix, - new_folder=use_external_data_format) + onnx_path = T5Helper.get_onnx_path( + output_dir, + model_name_or_path, + suffix=filename_suffix, + new_folder=use_external_data_format, + ) if overwrite or not os.path.exists(onnx_path): logger.info(f"Exporting ONNX model to {onnx_path}") # We have to clone model before exporting onnx, otherwise verify_onnx will report large difference. cloned_model = copy.deepcopy(model).to(device) - T5Helper.export_onnx(cloned_model, - device, - onnx_path, - verbose, - use_external_data_format, - use_decoder_input_ids=not use_decoder_start_token) + T5Helper.export_onnx( + cloned_model, + device, + onnx_path, + verbose, + use_external_data_format, + use_decoder_input_ids=not use_decoder_start_token, + ) else: logger.info(f"Skip exporting: existed ONNX model {onnx_path}") # Optimize ONNX graph. Note that we have not implemented graph optimization for T5 yet. if optimize_onnx or precision != Precision.FLOAT32: - output_path = T5Helper.get_onnx_path(output_dir, - model_name_or_path, - suffix=filename_suffix + "_" + str(precision), - new_folder=use_external_data_format) + output_path = T5Helper.get_onnx_path( + output_dir, + model_name_or_path, + suffix=filename_suffix + "_" + str(precision), + new_folder=use_external_data_format, + ) if overwrite or not os.path.exists(output_path): logger.info(f"Optimizing model to {output_path}") - T5Helper.optimize_onnx(onnx_path, - output_path, - precision == Precision.FLOAT16, - config.num_heads, - config.hidden_size, - use_external_data_format, - auto_mixed_precision=not disable_auto_mixed_precision) + T5Helper.optimize_onnx( + onnx_path, + output_path, + precision == Precision.FLOAT16, + config.num_heads, + config.hidden_size, + use_external_data_format, + auto_mixed_precision=not disable_auto_mixed_precision, + ) else: logger.info(f"Skip optimizing: existed ONNX model {onnx_path}") else: @@ -154,11 +183,12 @@ def export_onnx_models(model_name_or_path, ort_session = create_onnxruntime_session( output_path, use_gpu=use_gpu, - provider=['CUDAExecutionProvider', 'CPUExecutionProvider'] if use_gpu else ['CPUExecutionProvider']) + provider=["CUDAExecutionProvider", "CPUExecutionProvider"] if use_gpu else ["CPUExecutionProvider"], + ) max_diff = T5Helper.verify_onnx(model, ort_session, device) - logger.info(f'PyTorch and OnnxRuntime results max difference = {max_diff}') + logger.info(f"PyTorch and OnnxRuntime results max difference = {max_diff}") if max_diff > 1e-4: - logger.warn(f'PyTorch and OnnxRuntime results are NOT close') + logger.warn(f"PyTorch and OnnxRuntime results are NOT close") output_paths.append(output_path) @@ -182,17 +212,27 @@ def main(): assert args.use_gpu, "fp16 requires --use_gpu" if args.optimize_onnx: - logger.warn(f'Graph optimization for T5 is not implemented yet.') + logger.warn(f"Graph optimization for T5 is not implemented yet.") with torch.no_grad(): merge_encoder_and_decoder_init = True # Merge encoder and decoder initialization into one model is recommended. - output_paths = export_onnx_models(args.model_name_or_path, cache_dir, output_dir, args.use_gpu, - args.use_external_data_format, args.optimize_onnx, args.precision, - args.verbose, args.use_decoder_start_token, merge_encoder_and_decoder_init, - args.overwrite, args.disable_auto_mixed_precision) + output_paths = export_onnx_models( + args.model_name_or_path, + cache_dir, + output_dir, + args.use_gpu, + args.use_external_data_format, + args.optimize_onnx, + args.precision, + args.verbose, + args.use_decoder_start_token, + merge_encoder_and_decoder_init, + args.overwrite, + args.disable_auto_mixed_precision, + ) logger.info(f"Done! Outputs: {output_paths}") -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/onnxruntime/python/tools/transformers/models/t5/past_helper.py b/onnxruntime/python/tools/transformers/models/t5/past_helper.py index 0a9eb37be9443..fe113491067fd 100644 --- a/onnxruntime/python/tools/transformers/models/t5/past_helper.py +++ b/onnxruntime/python/tools/transformers/models/t5/past_helper.py @@ -10,30 +10,42 @@ class PastKeyValuesHelper: - """ Helper functions to process past key values for encoder-decoder model""" + """Helper functions to process past key values for encoder-decoder model""" + @staticmethod def get_past_names(num_layers, present: bool = False): past_self_names = [] past_cross_names = [] for i in range(num_layers): - past_self_names.extend([f'present_key_self_{i}', f'present_value_self_{i}'] - if present else [f'past_key_self_{i}', f'past_value_self_{i}']) - past_cross_names.extend([f'present_key_cross_{i}', f'present_value_cross_{i}'] - if present else [f'past_key_cross_{i}', f'past_value_cross_{i}']) + past_self_names.extend( + [f"present_key_self_{i}", f"present_value_self_{i}"] + if present + else [f"past_key_self_{i}", f"past_value_self_{i}"] + ) + past_cross_names.extend( + [f"present_key_cross_{i}", f"present_value_cross_{i}"] + if present + else [f"past_key_cross_{i}", f"past_value_cross_{i}"] + ) return past_self_names + past_cross_names @staticmethod def group_by_self_or_cross(present_key_values): """Split present state from grouped by layer to grouped by self/cross attention. - Before: (past_key_self_0, past_value_self_0, past_key_cross_0, past_value_cross_0), (past_key_self_1, past_value_self_1, past_key_cross_1, past_value_cross_1), ... - After: (past_key_self_0, past_value_self_0, past_key_self_1, past_value_self_1, ...), (past_key_cross_0, past_value_cross_0, past_key_cross_1, past_value_cross_1, ...) - + Before: (past_key_self_0, past_value_self_0, past_key_cross_0, past_value_cross_0), (past_key_self_1, past_value_self_1, past_key_cross_1, past_value_cross_1), ... + After: (past_key_self_0, past_value_self_0, past_key_self_1, past_value_self_1, ...), (past_key_cross_0, past_value_cross_0, past_key_cross_1, past_value_cross_1, ...) + """ present_self = [] present_cross = [] for i, present_layer_i in enumerate(present_key_values): assert len(present_layer_i) == 4, f"Expected to have four items. Got {len(present_layer_i)}" - present_key_self, present_value_self, present_key_cross, present_value_cross = present_layer_i + ( + present_key_self, + present_value_self, + present_key_cross, + present_value_cross, + ) = present_layer_i present_self.extend([present_key_self, present_value_self]) present_cross.extend([present_key_cross, present_value_cross]) return present_self, present_cross @@ -41,9 +53,16 @@ def group_by_self_or_cross(present_key_values): @staticmethod def group_by_layer(past, num_layers): """Reorder past state from grouped by self/cross attention to grouped by layer. - Before: past_key_self_0, past_value_self_0, past_key_self_1, past_value_self_1, ..., past_key_cross_0, past_value_cross_0, past_key_cross_1, past_value_cross_1, ... - After: (past_key_self_0, past_value_self_0, past_key_cross_0, past_value_cross_0), (past_key_self_1, past_value_self_1, past_key_cross_1, past_value_cross_1), + Before: past_key_self_0, past_value_self_0, past_key_self_1, past_value_self_1, ..., past_key_cross_0, past_value_cross_0, past_key_cross_1, past_value_cross_1, ... + After: (past_key_self_0, past_value_self_0, past_key_cross_0, past_value_cross_0), (past_key_self_1, past_value_self_1, past_key_cross_1, past_value_cross_1), """ assert len(past) == 4 * num_layers - return tuple([past[2 * i], past[2 * i + 1], past[2 * num_layers + 2 * i], past[2 * num_layers + 2 * i + 1]] - for i in range(num_layers)) + return tuple( + [ + past[2 * i], + past[2 * i + 1], + past[2 * num_layers + 2 * i], + past[2 * num_layers + 2 * i + 1], + ] + for i in range(num_layers) + ) diff --git a/onnxruntime/python/tools/transformers/models/t5/t5_decoder.py b/onnxruntime/python/tools/transformers/models/t5/t5_decoder.py index b4c44fc10e8a5..0f98b9f7ce1c3 100644 --- a/onnxruntime/python/tools/transformers/models/t5/t5_decoder.py +++ b/onnxruntime/python/tools/transformers/models/t5/t5_decoder.py @@ -4,53 +4,71 @@ # license information. # -------------------------------------------------------------------------- +import logging +import os +import sys from pathlib import Path from typing import List, Union -import sys -import os -import logging + import numpy import torch +from past_helper import PastKeyValuesHelper +from t5_encoder import T5EncoderInputs from transformers import T5Config + from onnxruntime import InferenceSession -from t5_encoder import T5EncoderInputs -from past_helper import PastKeyValuesHelper -sys.path.append(os.path.join(os.path.dirname(__file__), '..', '..')) -from torch_onnx_export_helper import torch_onnx_export +sys.path.append(os.path.join(os.path.dirname(__file__), "..", "..")) from io_binding_helper import TypeHelper +from torch_onnx_export_helper import torch_onnx_export logger = logging.getLogger(__name__) class T5DecoderInit(torch.nn.Module): - """ A T5 decoder with LM head to create initial past key values. - This model is only called once during starting decoding. + """A T5 decoder with LM head to create initial past key values. + This model is only called once during starting decoding. """ - def __init__(self, - decoder: torch.nn.Module, - lm_head: torch.nn.Module, - config: T5Config, - decoder_start_token_id: int = None): + def __init__( + self, + decoder: torch.nn.Module, + lm_head: torch.nn.Module, + config: T5Config, + decoder_start_token_id: int = None, + ): super().__init__() self.decoder = decoder self.lm_head = lm_head self.config = config - self.decoder_start_token_id = decoder_start_token_id if decoder_start_token_id is not None else self.config.decoder_start_token_id - - def forward(self, decoder_input_ids: torch.Tensor, encoder_attention_mask: torch.Tensor, - encoder_hidden_states: torch.FloatTensor): + self.decoder_start_token_id = ( + decoder_start_token_id if decoder_start_token_id is not None else self.config.decoder_start_token_id + ) + + def forward( + self, + decoder_input_ids: torch.Tensor, + encoder_attention_mask: torch.Tensor, + encoder_hidden_states: torch.FloatTensor, + ): if decoder_input_ids is None: batch_size = encoder_attention_mask.shape[0] - decoder_input_ids = torch.ones( - (batch_size, 1), dtype=torch.long, device=encoder_attention_mask.device) * self.decoder_start_token_id + decoder_input_ids = ( + torch.ones( + (batch_size, 1), + dtype=torch.long, + device=encoder_attention_mask.device, + ) + * self.decoder_start_token_id + ) - decoder_outputs = self.decoder(input_ids=decoder_input_ids, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - use_cache=True, - return_dict=True) + decoder_outputs = self.decoder( + input_ids=decoder_input_ids, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + use_cache=True, + return_dict=True, + ) sequence_output = decoder_outputs.last_hidden_state present_key_values = decoder_outputs.past_key_values @@ -63,7 +81,7 @@ def forward(self, decoder_input_ids: torch.Tensor, encoder_attention_mask: torch class T5Decoder(torch.nn.Module): - """ A T5 decoder with LM head and past key values""" + """A T5 decoder with LM head and past key values""" def __init__(self, decoder, lm_head, config): super().__init__() @@ -75,12 +93,14 @@ def forward(self, decoder_input_ids, encoder_attention_mask, encoder_hidden_stat past_key_values = PastKeyValuesHelper.group_by_layer(past, self.config.num_layers) - decoder_outputs = self.decoder(input_ids=decoder_input_ids, - past_key_values=past_key_values, - encoder_hidden_states=encoder_hidden_states, - encoder_attention_mask=encoder_attention_mask, - use_cache=True, - return_dict=True) + decoder_outputs = self.decoder( + input_ids=decoder_input_ids, + past_key_values=past_key_values, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + use_cache=True, + return_dict=True, + ) sequence_output = decoder_outputs.last_hidden_state present_key_values = decoder_outputs.past_key_values @@ -95,21 +115,28 @@ def forward(self, decoder_input_ids, encoder_attention_mask, encoder_hidden_stat class T5DecoderInputs: - - def __init__(self, decoder_input_ids, encoder_attention_mask, encoder_hidden_states, past_key_values=None): + def __init__( + self, + decoder_input_ids, + encoder_attention_mask, + encoder_hidden_states, + past_key_values=None, + ): self.decoder_input_ids: torch.LongTensor = decoder_input_ids self.encoder_attention_mask: torch.LongTensor = encoder_attention_mask self.encoder_hidden_states: Union[torch.FloatTensor, torch.HalfTensor] = encoder_hidden_states self.past_key_values: Union[List[torch.FloatTensor], List[torch.HalfTensor], None] = past_key_values @staticmethod - def create_dummy(config: T5Config, - batch_size: int, - encode_sequence_length: int, - past_decode_sequence_length: int, - device: torch.device, - float16: bool = False): # -> T5DecoderInputs: - """ Create dummy inputs for T5Decoder. + def create_dummy( + config: T5Config, + batch_size: int, + encode_sequence_length: int, + past_decode_sequence_length: int, + device: torch.device, + float16: bool = False, + ): # -> T5DecoderInputs: + """Create dummy inputs for T5Decoder. Args: decoder: decoder @@ -128,29 +155,37 @@ def create_dummy(config: T5Config, vocab_size: int = config.vocab_size sequence_length: int = 1 # fixed for decoding - decoder_input_ids = torch.randint(low=0, - high=vocab_size - 1, - size=(batch_size, sequence_length), - dtype=torch.int64, - device=device) + decoder_input_ids = torch.randint( + low=0, + high=vocab_size - 1, + size=(batch_size, sequence_length), + dtype=torch.int64, + device=device, + ) encoder_inputs = T5EncoderInputs.create_dummy(batch_size, encode_sequence_length, vocab_size, device) float_type = torch.float16 if float16 else torch.float32 - encoder_hidden_state = torch.rand(batch_size, - encode_sequence_length, - hidden_size, - dtype=float_type, - device=device) + encoder_hidden_state = torch.rand( + batch_size, + encode_sequence_length, + hidden_size, + dtype=float_type, + device=device, + ) if past_decode_sequence_length > 0: self_attention_past_shape = [ - batch_size, num_attention_heads, past_decode_sequence_length, - int(hidden_size / num_attention_heads) + batch_size, + num_attention_heads, + past_decode_sequence_length, + int(hidden_size / num_attention_heads), ] cross_attention_past_shape = [ - batch_size, num_attention_heads, encode_sequence_length, - int(hidden_size / num_attention_heads) + batch_size, + num_attention_heads, + encode_sequence_length, + int(hidden_size / num_attention_heads), ] past = [] @@ -165,7 +200,11 @@ def create_dummy(config: T5Config, return T5DecoderInputs(decoder_input_ids, encoder_inputs.attention_mask, encoder_hidden_state, past) def to_list(self) -> List: - input_list = [self.decoder_input_ids, self.encoder_attention_mask, self.encoder_hidden_states] + input_list = [ + self.decoder_input_ids, + self.encoder_attention_mask, + self.encoder_hidden_states, + ] if self.past_key_values: input_list.extend(self.past_key_values) return input_list @@ -173,18 +212,23 @@ def to_list(self) -> List: def to_fp32(self): encoder_hidden_state = self.encoder_hidden_states.to(dtype=torch.float32) past = [p.to(dtype=torch.float32) for p in self.past_key_values] - return T5DecoderInputs(self.decoder_input_ids.clone(), self.encoder_attention_mask.clone(), - encoder_hidden_state, past) + return T5DecoderInputs( + self.decoder_input_ids.clone(), + self.encoder_attention_mask.clone(), + encoder_hidden_state, + past, + ) class T5DecoderHelper: - @staticmethod - def export_onnx(decoder: Union[T5Decoder, T5DecoderInit], - device: torch.device, - onnx_model_path: str, - verbose: bool = True, - use_external_data_format: bool = False): + def export_onnx( + decoder: Union[T5Decoder, T5DecoderInit], + device: torch.device, + onnx_model_path: str, + verbose: bool = True, + use_external_data_format: bool = False, + ): """Export decoder to ONNX Args: @@ -196,18 +240,20 @@ def export_onnx(decoder: Union[T5Decoder, T5DecoderInit], """ assert isinstance(decoder, (T5Decoder, T5DecoderInit)) - inputs = T5DecoderInputs.create_dummy(decoder.config, - batch_size=2, - encode_sequence_length=3, - past_decode_sequence_length=5 if isinstance(decoder, T5Decoder) else 0, - device=device) + inputs = T5DecoderInputs.create_dummy( + decoder.config, + batch_size=2, + encode_sequence_length=3, + past_decode_sequence_length=5 if isinstance(decoder, T5Decoder) else 0, + device=device, + ) input_list = inputs.to_list() with torch.no_grad(): outputs = decoder(*input_list) past_names = PastKeyValuesHelper.get_past_names(decoder.config.num_layers, present=False) present_names = PastKeyValuesHelper.get_past_names(decoder.config.num_layers, present=True) - present_self_names = present_names[:2 * decoder.config.num_layers] + present_self_names = present_names[: 2 * decoder.config.num_layers] input_past_names = past_names if isinstance(decoder, T5Decoder) else [] output_present_names = present_self_names if isinstance(decoder, T5Decoder) else present_names @@ -231,65 +277,63 @@ def export_onnx(decoder: Union[T5Decoder, T5DecoderInit], input_names.extend(input_past_names) dynamic_axes = { - 'input_ids': { - 0: 'batch_size', - #1: 'sequence_length' - }, - 'encoder_attention_mask': { - 0: 'batch_size', - 1: 'encode_sequence_length' - }, - 'encoder_hidden_states': { - 0: 'batch_size', - 1: 'encode_sequence_length' + "input_ids": { + 0: "batch_size", + # 1: 'sequence_length' }, + "encoder_attention_mask": {0: "batch_size", 1: "encode_sequence_length"}, + "encoder_hidden_states": {0: "batch_size", 1: "encode_sequence_length"}, "logits": { - 0: 'batch_size', - #1: 'sequence_length' - } + 0: "batch_size", + # 1: 'sequence_length' + }, } for name in input_past_names: dynamic_axes[name] = { - 0: 'batch_size', - 2: 'past_decode_sequence_length' if "self" in name else "encode_sequence_length" + 0: "batch_size", + 2: "past_decode_sequence_length" if "self" in name else "encode_sequence_length", } for name in output_present_names: if "cross" in name: - dynamic_axes[name] = {0: 'batch_size', 2: "encode_sequence_length"} + dynamic_axes[name] = {0: "batch_size", 2: "encode_sequence_length"} else: # self attention past state if isinstance(decoder, T5Decoder): - dynamic_axes[name] = {0: 'batch_size', 2: 'past_decode_sequence_length + 1'} + dynamic_axes[name] = { + 0: "batch_size", + 2: "past_decode_sequence_length + 1", + } else: dynamic_axes[name] = { - 0: 'batch_size', - #2: 'sequence_length' + 0: "batch_size", + # 2: 'sequence_length' } Path(onnx_model_path).parent.mkdir(parents=True, exist_ok=True) - torch_onnx_export(decoder, - args=tuple(input_list), - f=onnx_model_path, - export_params=True, - input_names=input_names, - output_names=output_names, - dynamic_axes=dynamic_axes, - opset_version=12, - do_constant_folding=True, - use_external_data_format=use_external_data_format, - verbose=verbose) + torch_onnx_export( + decoder, + args=tuple(input_list), + f=onnx_model_path, + export_params=True, + input_names=input_names, + output_names=output_names, + dynamic_axes=dynamic_axes, + opset_version=12, + do_constant_folding=True, + use_external_data_format=use_external_data_format, + verbose=verbose, + ) @staticmethod def onnxruntime_inference(ort_session, inputs: T5DecoderInputs): - """ Run inference of ONNX model. - """ + """Run inference of ONNX model.""" logger.debug(f"start onnxruntime_inference") ort_inputs = { - 'input_ids': numpy.ascontiguousarray(inputs.decoder_input_ids.cpu().numpy()), - 'encoder_attention_mask': numpy.ascontiguousarray(inputs.encoder_attention_mask.cpu().numpy()), - 'encoder_hidden_states': numpy.ascontiguousarray(inputs.encoder_hidden_states.cpu().numpy()) + "input_ids": numpy.ascontiguousarray(inputs.decoder_input_ids.cpu().numpy()), + "encoder_attention_mask": numpy.ascontiguousarray(inputs.encoder_attention_mask.cpu().numpy()), + "encoder_hidden_states": numpy.ascontiguousarray(inputs.encoder_hidden_states.cpu().numpy()), } if inputs.past_key_values: @@ -303,26 +347,33 @@ def onnxruntime_inference(ort_session, inputs: T5DecoderInputs): return ort_outputs @staticmethod - def verify_onnx(model: Union[T5Decoder, T5DecoderInit], - ort_session: InferenceSession, - device: torch.device, - max_cases=4): - """ Compare the result from PyTorch and OnnxRuntime to verify the ONNX model is good. - """ - float16: bool = (TypeHelper.get_input_type(ort_session, "encoder_hidden_states") == "tensor(float16)") + def verify_onnx( + model: Union[T5Decoder, T5DecoderInit], + ort_session: InferenceSession, + device: torch.device, + max_cases=4, + ): + """Compare the result from PyTorch and OnnxRuntime to verify the ONNX model is good.""" + float16: bool = TypeHelper.get_input_type(ort_session, "encoder_hidden_states") == "tensor(float16)" test_cases = [(4, 11, 3), (1, 2, 5), (3, 1, 1), (8, 5, 2)] test_cases_max_diff = [] - for (batch_size, encode_sequence_length, past_decode_sequence_length) in test_cases[:max_cases]: + for ( + batch_size, + encode_sequence_length, + past_decode_sequence_length, + ) in test_cases[:max_cases]: if isinstance(model, T5DecoderInit): past_decode_sequence_length = 0 - inputs = T5DecoderInputs.create_dummy(model.config, - batch_size, - encode_sequence_length, - past_decode_sequence_length, - device=device, - float16=float16) + inputs = T5DecoderInputs.create_dummy( + model.config, + batch_size, + encode_sequence_length, + past_decode_sequence_length, + device=device, + float16=float16, + ) # We use fp32 PyTroch model as baseline even when ONNX model is fp16 input_list = inputs.to_fp32().to_list() @@ -345,7 +396,8 @@ def verify_onnx(model: Union[T5Decoder, T5DecoderInit], if isinstance(model, T5DecoderInit): for i in range(2 * model.config.num_layers): max_diff = numpy.amax( - numpy.abs(torch_outputs[2][i].cpu().numpy() - ort_outputs[1 + 2 * model.config.num_layers + i])) + numpy.abs(torch_outputs[2][i].cpu().numpy() - ort_outputs[1 + 2 * model.config.num_layers + i]) + ) logger.debug(f"cross attention past state {i} max_diff={max_diff}") max_diff_all = max(max_diff_all, max_diff) diff --git a/onnxruntime/python/tools/transformers/models/t5/t5_encoder.py b/onnxruntime/python/tools/transformers/models/t5/t5_encoder.py index d7150c4046b09..72a6b8585afb3 100644 --- a/onnxruntime/python/tools/transformers/models/t5/t5_encoder.py +++ b/onnxruntime/python/tools/transformers/models/t5/t5_encoder.py @@ -4,25 +4,28 @@ # license information. # -------------------------------------------------------------------------- +import logging +import os import random import sys -import os from pathlib import Path from typing import List -import logging + import numpy import torch from transformers import T5Config + from onnxruntime import InferenceSession -sys.path.append(os.path.join(os.path.dirname(__file__), '..', '..')) +sys.path.append(os.path.join(os.path.dirname(__file__), "..", "..")) from torch_onnx_export_helper import torch_onnx_export logger = logging.getLogger(__name__) class T5Encoder(torch.nn.Module): - """ T5 encoder outputs only the last hidden state""" + """T5 encoder outputs only the last hidden state""" + def __init__(self, encoder, config: T5Config): super().__init__() self.encoder = encoder @@ -38,9 +41,10 @@ def __init__(self, input_ids, attention_mask): self.attention_mask: torch.LongTensor = attention_mask @staticmethod - def create_dummy(batch_size: int, sequence_length: int, vocab_size: int, - device: torch.device): # -> T5EncoderInputs - """ Create dummy inputs for T5 encoder. + def create_dummy( + batch_size: int, sequence_length: int, vocab_size: int, device: torch.device + ): # -> T5EncoderInputs + """Create dummy inputs for T5 encoder. Args: batch_size (int): batch size @@ -51,11 +55,13 @@ def create_dummy(batch_size: int, sequence_length: int, vocab_size: int, Returns: T5EncoderInputs: dummy inputs for encoder """ - input_ids = torch.randint(low=0, - high=vocab_size - 1, - size=(batch_size, sequence_length), - dtype=torch.int64, - device=device) + input_ids = torch.randint( + low=0, + high=vocab_size - 1, + size=(batch_size, sequence_length), + dtype=torch.int64, + device=device, + ) attention_mask = torch.ones([batch_size, sequence_length], dtype=torch.int64, device=device) if sequence_length >= 2: @@ -71,11 +77,13 @@ def to_list(self) -> List: class T5EncoderHelper: @staticmethod - def export_onnx(encoder: T5Encoder, - device: torch.device, - onnx_model_path: str, - verbose: bool = True, - use_external_data_format: bool = False): + def export_onnx( + encoder: T5Encoder, + device: torch.device, + onnx_model_path: str, + verbose: bool = True, + use_external_data_format: bool = False, + ): """Export encoder to ONNX Args: @@ -86,59 +94,51 @@ def export_onnx(encoder: T5Encoder, use_external_data_format (bool, optional): use external data format or not. Defaults to False. """ config = encoder.config - encoder_inputs = T5EncoderInputs.create_dummy(batch_size=2, - sequence_length=4, - vocab_size=config.vocab_size, - device=device) + encoder_inputs = T5EncoderInputs.create_dummy( + batch_size=2, sequence_length=4, vocab_size=config.vocab_size, device=device + ) with torch.no_grad(): outputs = encoder(encoder_inputs.input_ids, encoder_inputs.attention_mask) Path(onnx_model_path).parent.mkdir(parents=True, exist_ok=True) - torch_onnx_export(encoder, - args=tuple(encoder_inputs.to_list()), - f=onnx_model_path, - export_params=True, - input_names=['input_ids', 'attention_mask'], - output_names=['hidden_states'], - dynamic_axes={ - 'input_ids': { - 0: 'batch_size', - 1: 'sequence_length' - }, - 'attention_mask': { - 0: 'batch_size', - 1: 'sequence_length' - }, - 'hidden_states': { - 0: 'batch_size', - 1: 'sequence_length' - }, - }, - opset_version=12, - do_constant_folding=True, - use_external_data_format=use_external_data_format, - verbose=verbose) + torch_onnx_export( + encoder, + args=tuple(encoder_inputs.to_list()), + f=onnx_model_path, + export_params=True, + input_names=["input_ids", "attention_mask"], + output_names=["hidden_states"], + dynamic_axes={ + "input_ids": {0: "batch_size", 1: "sequence_length"}, + "attention_mask": {0: "batch_size", 1: "sequence_length"}, + "hidden_states": {0: "batch_size", 1: "sequence_length"}, + }, + opset_version=12, + do_constant_folding=True, + use_external_data_format=use_external_data_format, + verbose=verbose, + ) @staticmethod def onnxruntime_inference(ort_session, inputs: T5EncoderInputs): - """ Run inference of ONNX model. - """ + """Run inference of ONNX model.""" ort_inputs = { - 'input_ids': numpy.ascontiguousarray(inputs.input_ids.cpu().numpy()), - 'attention_mask': numpy.ascontiguousarray(inputs.attention_mask.cpu().numpy()) + "input_ids": numpy.ascontiguousarray(inputs.input_ids.cpu().numpy()), + "attention_mask": numpy.ascontiguousarray(inputs.attention_mask.cpu().numpy()), } return ort_session.run(None, ort_inputs) @staticmethod def verify_onnx(model: T5Encoder, ort_session: InferenceSession, device: torch.device): - """ Compare the result from PyTorch and OnnxRuntime to verify the ONNX model is good. - """ - inputs = T5EncoderInputs.create_dummy(batch_size=4, - sequence_length=11, - vocab_size=model.config.vocab_size, - device=device) + """Compare the result from PyTorch and OnnxRuntime to verify the ONNX model is good.""" + inputs = T5EncoderInputs.create_dummy( + batch_size=4, + sequence_length=11, + vocab_size=model.config.vocab_size, + device=device, + ) input_list = inputs.to_list() torch_outputs = model(*input_list) @@ -146,6 +146,6 @@ def verify_onnx(model: T5Encoder, ort_session: InferenceSession, device: torch.d max_diff = numpy.amax(numpy.abs(torch_outputs.cpu().numpy() - ort_outputs[0])) - logger.info(f'max_diff={max_diff}') + logger.info(f"max_diff={max_diff}") return max_diff diff --git a/onnxruntime/python/tools/transformers/models/t5/t5_encoder_decoder_init.py b/onnxruntime/python/tools/transformers/models/t5/t5_encoder_decoder_init.py index f67c76fbd8463..fe7552809f4b9 100644 --- a/onnxruntime/python/tools/transformers/models/t5/t5_encoder_decoder_init.py +++ b/onnxruntime/python/tools/transformers/models/t5/t5_encoder_decoder_init.py @@ -4,46 +4,53 @@ # license information. # -------------------------------------------------------------------------- +import logging +import os +import sys from pathlib import Path from typing import List -import sys -import os -import logging + import numpy import torch +from past_helper import PastKeyValuesHelper +from t5_decoder import T5DecoderInit +from t5_encoder import T5Encoder, T5EncoderInputs from transformers import T5Config + from onnxruntime import InferenceSession -from t5_encoder import T5Encoder, T5EncoderInputs -from t5_decoder import T5DecoderInit -from past_helper import PastKeyValuesHelper -sys.path.append(os.path.join(os.path.dirname(__file__), '..', '..')) +sys.path.append(os.path.join(os.path.dirname(__file__), "..", "..")) from torch_onnx_export_helper import torch_onnx_export logger = logging.getLogger(__name__) class T5EncoderDecoderInit(torch.nn.Module): - """ A combination of T5Encoder and T5DecoderInit. - """ - def __init__(self, - encoder: torch.nn.Module, - decoder: torch.nn.Module, - lm_head: torch.nn.Module, - config: T5Config, - decoder_start_token_id: int = None): + """A combination of T5Encoder and T5DecoderInit.""" + + def __init__( + self, + encoder: torch.nn.Module, + decoder: torch.nn.Module, + lm_head: torch.nn.Module, + config: T5Config, + decoder_start_token_id: int = None, + ): super().__init__() self.config = config self.t5_encoder = T5Encoder(encoder, config) self.t5_decoder_init = T5DecoderInit(decoder, lm_head, config, decoder_start_token_id) - def forward(self, - encoder_input_ids: torch.Tensor, - encoder_attention_mask: torch.Tensor, - decoder_input_ids: torch.Tensor = None): + def forward( + self, + encoder_input_ids: torch.Tensor, + encoder_attention_mask: torch.Tensor, + decoder_input_ids: torch.Tensor = None, + ): encoder_hidden_states: torch.FloatTensor = self.t5_encoder(encoder_input_ids, encoder_attention_mask) - lm_logits, past_self, past_cross = self.t5_decoder_init(decoder_input_ids, encoder_attention_mask, - encoder_hidden_states) + lm_logits, past_self, past_cross = self.t5_decoder_init( + decoder_input_ids, encoder_attention_mask, encoder_hidden_states + ) return lm_logits, encoder_hidden_states, past_self, past_cross @@ -54,14 +61,21 @@ def __init__(self, encoder_input_ids, encoder_attention_mask, decoder_input_ids= self.decoder_input_ids: torch.LongTensor = decoder_input_ids @staticmethod - def create_dummy(config: T5Config, batch_size: int, encode_sequence_length: int, use_decoder_input_ids: int, - device: torch.device): # -> T5EncoderDecoderInitInputs: - encoder_inputs: T5EncoderInputs = T5EncoderInputs.create_dummy(batch_size, encode_sequence_length, - config.vocab_size, device) + def create_dummy( + config: T5Config, + batch_size: int, + encode_sequence_length: int, + use_decoder_input_ids: int, + device: torch.device, + ): # -> T5EncoderDecoderInitInputs: + encoder_inputs: T5EncoderInputs = T5EncoderInputs.create_dummy( + batch_size, encode_sequence_length, config.vocab_size, device + ) decoder_input_ids = None if use_decoder_input_ids: - decoder_input_ids = torch.ones( - (batch_size, 1), dtype=torch.long, device=device) * config.decoder_start_token_id + decoder_input_ids = ( + torch.ones((batch_size, 1), dtype=torch.long, device=device) * config.decoder_start_token_id + ) return T5EncoderDecoderInitInputs(encoder_inputs.input_ids, encoder_inputs.attention_mask, decoder_input_ids) @@ -74,12 +88,14 @@ def to_list(self) -> List: class T5EncoderDecoderInitHelper: @staticmethod - def export_onnx(model: T5EncoderDecoderInit, - device: torch.device, - onnx_model_path: str, - use_decoder_input_ids: bool = True, - verbose: bool = True, - use_external_data_format: bool = False): + def export_onnx( + model: T5EncoderDecoderInit, + device: torch.device, + onnx_model_path: str, + use_decoder_input_ids: bool = True, + verbose: bool = True, + use_external_data_format: bool = False, + ): """Export decoder to ONNX Args: @@ -91,11 +107,13 @@ def export_onnx(model: T5EncoderDecoderInit, """ assert isinstance(model, T5EncoderDecoderInit) - inputs = T5EncoderDecoderInitInputs.create_dummy(model.config, - batch_size=2, - encode_sequence_length=3, - use_decoder_input_ids=use_decoder_input_ids, - device=device) + inputs = T5EncoderDecoderInitInputs.create_dummy( + model.config, + batch_size=2, + encode_sequence_length=3, + use_decoder_input_ids=use_decoder_input_ids, + device=device, + ) input_list = inputs.to_list() outputs = model(*input_list) @@ -118,86 +136,94 @@ def export_onnx(model: T5EncoderDecoderInit, input_names = ["encoder_input_ids", "encoder_attention_mask"] # ONNX exporter might mark dimension like 'Transposepresent_value_self_1_dim_2'. Use more friendly string here. - sequence_length = '1' + sequence_length = "1" num_heads = str(model.config.num_heads) hidden_size = str(model.config.d_model) head_size = str(model.config.d_model // model.config.num_heads) dynamic_axes = { - 'encoder_input_ids': { - 0: 'batch_size', - 1: 'encode_sequence_length' + "encoder_input_ids": {0: "batch_size", 1: "encode_sequence_length"}, + "encoder_attention_mask": {0: "batch_size", 1: "encode_sequence_length"}, + "encoder_hidden_states": { + 0: "batch_size", + 1: "encode_sequence_length", + 2: hidden_size, }, - 'encoder_attention_mask': { - 0: 'batch_size', - 1: 'encode_sequence_length' - }, - 'encoder_hidden_states': { - 0: 'batch_size', - 1: 'encode_sequence_length', - 2: hidden_size - }, - "logits": { - 0: 'batch_size', - 1: sequence_length - } + "logits": {0: "batch_size", 1: sequence_length}, } if use_decoder_input_ids: input_names.append("decoder_input_ids") - dynamic_axes["decoder_input_ids"] = {0: 'batch_size', 1: sequence_length} + dynamic_axes["decoder_input_ids"] = {0: "batch_size", 1: sequence_length} for name in present_names: if "cross" in name: - dynamic_axes[name] = {0: 'batch_size', 1: num_heads, 2: 'encode_sequence_length', 3: head_size} + dynamic_axes[name] = { + 0: "batch_size", + 1: num_heads, + 2: "encode_sequence_length", + 3: head_size, + } else: # self attention past state - dynamic_axes[name] = {0: 'batch_size', 1: num_heads, 2: sequence_length, 3: head_size} + dynamic_axes[name] = { + 0: "batch_size", + 1: num_heads, + 2: sequence_length, + 3: head_size, + } Path(onnx_model_path).parent.mkdir(parents=True, exist_ok=True) - torch_onnx_export(model, - args=tuple(input_list), - f=onnx_model_path, - export_params=True, - input_names=input_names, - output_names=output_names, - dynamic_axes=dynamic_axes, - opset_version=12, - do_constant_folding=True, - use_external_data_format=use_external_data_format, - verbose=verbose) + torch_onnx_export( + model, + args=tuple(input_list), + f=onnx_model_path, + export_params=True, + input_names=input_names, + output_names=output_names, + dynamic_axes=dynamic_axes, + opset_version=12, + do_constant_folding=True, + use_external_data_format=use_external_data_format, + verbose=verbose, + ) @staticmethod def onnxruntime_inference(ort_session, inputs: T5EncoderDecoderInitInputs): - """ Run inference of ONNX model. - """ + """Run inference of ONNX model.""" logger.debug(f"start onnxruntime_inference") ort_inputs = { - 'encoder_input_ids': numpy.ascontiguousarray(inputs.encoder_input_ids.cpu().numpy()), - 'encoder_attention_mask': numpy.ascontiguousarray(inputs.encoder_attention_mask.cpu().numpy()), + "encoder_input_ids": numpy.ascontiguousarray(inputs.encoder_input_ids.cpu().numpy()), + "encoder_attention_mask": numpy.ascontiguousarray(inputs.encoder_attention_mask.cpu().numpy()), } if inputs.decoder_input_ids is not None: - ort_inputs['decoder_input_ids'] = numpy.ascontiguousarray(inputs.decoder_input_ids.cpu().numpy()) + ort_inputs["decoder_input_ids"] = numpy.ascontiguousarray(inputs.decoder_input_ids.cpu().numpy()) ort_outputs = ort_session.run(None, ort_inputs) return ort_outputs @staticmethod - def verify_onnx(model: T5EncoderDecoderInit, ort_session: InferenceSession, device: torch.device, max_cases=4): - """ Compare the result from PyTorch and OnnxRuntime to verify the ONNX model is good. - """ + def verify_onnx( + model: T5EncoderDecoderInit, + ort_session: InferenceSession, + device: torch.device, + max_cases=4, + ): + """Compare the result from PyTorch and OnnxRuntime to verify the ONNX model is good.""" ort_inputs = ort_session.get_inputs() use_decoder_input_ids = len(ort_inputs) == 3 test_cases = [(4, 11), (1, 2), (3, 1), (8, 5)] test_cases_max_diff = [] for (batch_size, encode_sequence_length) in test_cases[:max_cases]: - inputs = T5EncoderDecoderInitInputs.create_dummy(model.config, - batch_size, - encode_sequence_length, - use_decoder_input_ids=use_decoder_input_ids, - device=device) + inputs = T5EncoderDecoderInitInputs.create_dummy( + model.config, + batch_size, + encode_sequence_length, + use_decoder_input_ids=use_decoder_input_ids, + device=device, + ) ort_outputs = T5EncoderDecoderInitHelper.onnxruntime_inference(ort_session, inputs) @@ -205,12 +231,12 @@ def verify_onnx(model: T5EncoderDecoderInit, ort_session: InferenceSession, devi input_list = inputs.to_list() torch_outputs = model(*input_list) - assert (torch_outputs[0].cpu().numpy().shape == ort_outputs[0].shape) + assert torch_outputs[0].cpu().numpy().shape == ort_outputs[0].shape max_diff = numpy.amax(numpy.abs(torch_outputs[0].cpu().numpy() - ort_outputs[0])) logger.debug(f"logits max_diff={max_diff}") max_diff_all = max_diff - assert (torch_outputs[1].cpu().numpy().shape == ort_outputs[1].shape) + assert torch_outputs[1].cpu().numpy().shape == ort_outputs[1].shape max_diff = numpy.amax(numpy.abs(torch_outputs[1].cpu().numpy() - ort_outputs[1])) logger.debug(f"encoder_hidden_states max_diff={max_diff}") max_diff_all = max(max_diff_all, max_diff) @@ -221,12 +247,14 @@ def verify_onnx(model: T5EncoderDecoderInit, ort_session: InferenceSession, devi for i in range(2 * model.config.num_layers): max_diff = numpy.amax( - numpy.abs(torch_outputs[3][i].cpu().numpy() - ort_outputs[2 + 2 * model.config.num_layers + i])) + numpy.abs(torch_outputs[3][i].cpu().numpy() - ort_outputs[2 + 2 * model.config.num_layers + i]) + ) logger.debug(f"cross attention past state {i} max_diff={max_diff}") max_diff_all = max(max_diff_all, max_diff) test_cases_max_diff.append(max_diff_all) logger.info( - f"batch_size={batch_size} encode_sequence_length={encode_sequence_length}, max_diff={max_diff_all}") + f"batch_size={batch_size} encode_sequence_length={encode_sequence_length}, max_diff={max_diff_all}" + ) return max(test_cases_max_diff) diff --git a/onnxruntime/python/tools/transformers/models/t5/t5_helper.py b/onnxruntime/python/tools/transformers/models/t5/t5_helper.py index 3d4d4ab4483df..389e6b33a311b 100644 --- a/onnxruntime/python/tools/transformers/models/t5/t5_helper.py +++ b/onnxruntime/python/tools/transformers/models/t5/t5_helper.py @@ -4,22 +4,24 @@ # license information. # -------------------------------------------------------------------------- +import logging import os import sys from pathlib import Path -from typing import Union, Dict, List -import logging +from typing import Dict, List, Union + import torch -from transformers import T5ForConditionalGeneration -from onnxruntime import InferenceSession +from t5_decoder import T5Decoder, T5DecoderHelper, T5DecoderInit from t5_encoder import T5Encoder, T5EncoderHelper -from t5_decoder import T5DecoderInit, T5Decoder, T5DecoderHelper from t5_encoder_decoder_init import T5EncoderDecoderInit, T5EncoderDecoderInitHelper +from transformers import T5ForConditionalGeneration -sys.path.append(os.path.join(os.path.dirname(__file__), '..', '..')) -from onnx_model import OnnxModel +from onnxruntime import InferenceSession + +sys.path.append(os.path.join(os.path.dirname(__file__), "..", "..")) from float16 import float_to_float16_max_diff from fusion_utils import FusionUtils +from onnx_model import OnnxModel from optimizer import optimize_model logger = logging.getLogger(__name__) @@ -28,9 +30,13 @@ class T5Helper: - @staticmethod - def get_onnx_path(output_dir: str, model_name_or_path: str, suffix: str = "", new_folder: bool = False) -> str: + def get_onnx_path( + output_dir: str, + model_name_or_path: str, + suffix: str = "", + new_folder: bool = False, + ) -> str: """Build onnx path Args: @@ -46,7 +52,7 @@ def get_onnx_path(output_dir: str, model_name_or_path: str, suffix: str = "", ne if os.path.isdir(model_name_or_path): model_name = Path(model_name_or_path).parts[-1] else: - model_name.split('/')[-1] + model_name.split("/")[-1] model_name += suffix @@ -54,10 +60,12 @@ def get_onnx_path(output_dir: str, model_name_or_path: str, suffix: str = "", ne return os.path.join(dir, model_name + ".onnx") @staticmethod - def load_model(model_name_or_path: str, - cache_dir: str, - device: torch.device, - merge_encoder_and_decoder_init: bool = True) -> Dict[str, torch.nn.Module]: + def load_model( + model_name_or_path: str, + cache_dir: str, + device: torch.device, + merge_encoder_and_decoder_init: bool = True, + ) -> Dict[str, torch.nn.Module]: """Load model given a pretrained name or path, then build models for ONNX conversion. Args: @@ -75,38 +83,62 @@ def load_model(model_name_or_path: str, decoder.eval().to(device) if merge_encoder_and_decoder_init: - encoder_decoder_init = T5EncoderDecoderInit(model.encoder, - model.decoder, - model.lm_head, - model.config, - decoder_start_token_id=None) + encoder_decoder_init = T5EncoderDecoderInit( + model.encoder, + model.decoder, + model.lm_head, + model.config, + decoder_start_token_id=None, + ) return {"encoder_decoder_init": encoder_decoder_init, "decoder": decoder} else: encoder = T5Encoder(model.encoder, model.config) encoder.eval().to(device) decoder_init = T5DecoderInit(model.decoder, model.lm_head, model.config) decoder_init.eval().to(device) - return {"encoder": encoder, "decoder": decoder, "decoder_init": decoder_init} + return { + "encoder": encoder, + "decoder": decoder, + "decoder_init": decoder_init, + } @staticmethod - def export_onnx(model: Union[T5Encoder, T5Decoder, T5DecoderInit, T5EncoderDecoderInit], - device: torch.device, - onnx_model_path: str, - verbose: bool = True, - use_external_data_format: bool = False, - use_decoder_input_ids: bool = True): + def export_onnx( + model: Union[T5Encoder, T5Decoder, T5DecoderInit, T5EncoderDecoderInit], + device: torch.device, + onnx_model_path: str, + verbose: bool = True, + use_external_data_format: bool = False, + use_decoder_input_ids: bool = True, + ): if isinstance(model, T5Encoder): T5EncoderHelper.export_onnx(model, device, onnx_model_path, verbose, use_external_data_format) elif isinstance(model, T5EncoderDecoderInit): - T5EncoderDecoderInitHelper.export_onnx(model, device, onnx_model_path, use_decoder_input_ids, verbose, - use_external_data_format) + T5EncoderDecoderInitHelper.export_onnx( + model, + device, + onnx_model_path, + use_decoder_input_ids, + verbose, + use_external_data_format, + ) else: T5DecoderHelper.export_onnx(model, device, onnx_model_path, verbose, use_external_data_format) @staticmethod def auto_mixed_precision( - onnx_model: OnnxModel, - op_block_list: List[str] = ["Pow", "ReduceMean", "Add", "Sqrt", "Div", "Mul", "Softmax", "Relu"]): + onnx_model: OnnxModel, + op_block_list: List[str] = [ + "Pow", + "ReduceMean", + "Add", + "Sqrt", + "Div", + "Mul", + "Softmax", + "Relu", + ], + ): """Convert model to mixed precision. It detects whether original model has fp16 precision weights, and set parameters for float16 conversion automatically. Args: @@ -142,7 +174,7 @@ def auto_mixed_precision( # we can deduce that the weights are stored in float16 precision. max_diff = float_to_float16_max_diff(initializer) logger.debug(f"max diff of converting weights in last MatMul node {node.name}: {max_diff}") - is_weight_fp16_precision = (max_diff < 1E-6) + is_weight_fp16_precision = max_diff < 1e-6 else: logger.warning(f"Failed to find MatMul node for logits. Found {node.op_type} of node {node.name}") @@ -157,7 +189,7 @@ def auto_mixed_precision( "keep_io_types": keep_io_types, "op_block_list": op_block_list, "node_block_list": node_block_list, - "force_fp16_initializers": is_weight_fp16_precision + "force_fp16_initializers": is_weight_fp16_precision, } logger.info(f"auto_mixed_precision parameters: {parameters}") @@ -170,22 +202,25 @@ def auto_mixed_precision( return parameters @staticmethod - def optimize_onnx(onnx_model_path: str, - optimized_model_path: str, - is_float16: bool, - num_attention_heads: int, - hidden_size: int, - use_external_data_format: bool = False, - auto_mixed_precision: bool = True): - """ Optimize ONNX model with an option to convert it to use mixed precision. - """ - m = optimize_model(onnx_model_path, - model_type='bert', - num_heads=num_attention_heads, - hidden_size=hidden_size, - opt_level=0, - optimization_options=None, - use_gpu=False) + def optimize_onnx( + onnx_model_path: str, + optimized_model_path: str, + is_float16: bool, + num_attention_heads: int, + hidden_size: int, + use_external_data_format: bool = False, + auto_mixed_precision: bool = True, + ): + """Optimize ONNX model with an option to convert it to use mixed precision.""" + m = optimize_model( + onnx_model_path, + model_type="bert", + num_heads=num_attention_heads, + hidden_size=hidden_size, + opt_level=0, + optimization_options=None, + use_gpu=False, + ) if is_float16: if auto_mixed_precision: T5Helper.auto_mixed_precision(m) @@ -195,10 +230,12 @@ def optimize_onnx(onnx_model_path: str, m.save_model_to_file(optimized_model_path, use_external_data_format) @staticmethod - def verify_onnx(model: Union[T5Encoder, T5Decoder, T5DecoderInit, T5EncoderDecoderInit], - ort_session: InferenceSession, device: torch.device): - """ Compare the result from PyTorch and OnnxRuntime to verify the ONNX model is good. - """ + def verify_onnx( + model: Union[T5Encoder, T5Decoder, T5DecoderInit, T5EncoderDecoderInit], + ort_session: InferenceSession, + device: torch.device, + ): + """Compare the result from PyTorch and OnnxRuntime to verify the ONNX model is good.""" if isinstance(model, T5Encoder): return T5EncoderHelper.verify_onnx(model, ort_session, device) elif isinstance(model, T5EncoderDecoderInit): diff --git a/onnxruntime/python/tools/transformers/onnx_exporter.py b/onnxruntime/python/tools/transformers/onnx_exporter.py index 41aaea056da6c..50673716a7213 100644 --- a/onnxruntime/python/tools/transformers/onnx_exporter.py +++ b/onnxruntime/python/tools/transformers/onnx_exporter.py @@ -5,23 +5,23 @@ # -------------------------------------------------------------------------- import logging -import numpy import os -import torch +import sys from pathlib import Path -from transformers import AutoConfig, AutoTokenizer, LxmertConfig, TransfoXLConfig + +import numpy +import torch from affinity_helper import AffinitySetting -from benchmark_helper import create_onnxruntime_session, Precision, OptimizerInfo -from quantize_helper import QuantizeHelper +from benchmark_helper import OptimizerInfo, Precision, create_onnxruntime_session from huggingface_models import MODEL_CLASSES +from quantize_helper import QuantizeHelper from torch_onnx_export_helper import torch_onnx_export +from transformers import AutoConfig, AutoTokenizer, LxmertConfig, TransfoXLConfig -import sys - -sys.path.append(os.path.join(os.path.dirname(__file__), 'models', 'gpt2')) -from gpt2_helper import GPT2ModelNoPastState, PRETRAINED_GPT2_MODELS, TFGPT2ModelNoPastState +sys.path.append(os.path.join(os.path.dirname(__file__), "models", "gpt2")) +from gpt2_helper import PRETRAINED_GPT2_MODELS, GPT2ModelNoPastState, TFGPT2ModelNoPastState -os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' +os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" logger = logging.getLogger(__name__) @@ -36,7 +36,7 @@ def triu_onnx(x, diagonal=0, out=None): torch_triu = torch_func["triu"] template = torch_triu(torch.ones((1024, 1024), dtype=torch.uint8), diagonal) - mask = template[:x.size(0), :x.size(1)] + mask = template[: x.size(0), : x.size(1)] return torch.where(mask.bool(), x, torch.zeros_like(x)) @@ -51,25 +51,26 @@ def restore_torch_functions(): def create_onnxruntime_input(vocab_size, batch_size, sequence_length, input_names, config, data_type=numpy.int64): input_ids = numpy.random.randint(low=0, high=vocab_size - 1, size=(batch_size, sequence_length), dtype=data_type) - inputs = {'input_ids': input_ids} + inputs = {"input_ids": input_ids} if "attention_mask" in input_names: attention_mask = numpy.ones([batch_size, sequence_length], dtype=data_type) - inputs['attention_mask'] = attention_mask + inputs["attention_mask"] = attention_mask if "token_type_ids" in input_names: segment_ids = numpy.zeros([batch_size, sequence_length], dtype=data_type) - inputs['token_type_ids'] = segment_ids + inputs["token_type_ids"] = segment_ids if config.is_encoder_decoder: - inputs['decoder_input_ids'] = input_ids + inputs["decoder_input_ids"] = input_ids if isinstance(config, LxmertConfig): inputs["visual_feats"] = numpy.random.randn(1, 1, config.visual_feat_dim).astype(numpy.float32) inputs["visual_pos"] = numpy.random.randn(1, 1, config.visual_pos_dim).astype(numpy.float32) if isinstance(config, TransfoXLConfig): - inputs["tf_transfo_xl_model/transformer/pos_emb/einsum/Einsum/inputs_1:0"] = numpy.zeros([config.hidden_size], - dtype=numpy.float32) + inputs["tf_transfo_xl_model/transformer/pos_emb/einsum/Einsum/inputs_1:0"] = numpy.zeros( + [config.hidden_size], dtype=numpy.float32 + ) return inputs @@ -94,19 +95,26 @@ def update_flatten_list(inputs, res_list): def build_dynamic_axes(example_inputs, outputs_flatten): sequence_length = example_inputs["input_ids"].shape[-1] - dynamic_axes = {key: {0: 'batch_size', 1: 'seq_len'} for key in example_inputs.keys()} + dynamic_axes = {key: {0: "batch_size", 1: "seq_len"} for key in example_inputs.keys()} - output_names = ['output_' + str(i + 1) for i in range(len(outputs_flatten))] + output_names = ["output_" + str(i + 1) for i in range(len(outputs_flatten))] for i, output_name in enumerate(output_names): - dynamic_axes[output_name] = {0: 'batch_size'} + dynamic_axes[output_name] = {0: "batch_size"} dims = outputs_flatten[i].shape for j, dim in enumerate(dims): if dim == sequence_length: - dynamic_axes[output_name].update({j: 'seq_len'}) + dynamic_axes[output_name].update({j: "seq_len"}) return dynamic_axes, output_names -def validate_onnx_model(onnx_model_path, example_inputs, example_outputs_flatten, use_gpu, fp16, output_names=None): +def validate_onnx_model( + onnx_model_path, + example_inputs, + example_outputs_flatten, + use_gpu, + fp16, + output_names=None, +): test_session = create_onnxruntime_session(onnx_model_path, use_gpu, enable_all_optimization=False) if test_session is None: logger.error(f"{onnx_model_path} is an invalid ONNX model") @@ -119,7 +127,8 @@ def validate_onnx_model(onnx_model_path, example_inputs, example_outputs_flatten example_ort_outputs = test_session.run(output_names, example_ort_inputs) if len(example_outputs_flatten) != len(example_ort_outputs): logger.error( - f"Number of output tensors expected {len(example_outputs_flatten)}, got {len(example_ort_outputs)}") + f"Number of output tensors expected {len(example_outputs_flatten)}, got {len(example_ort_outputs)}" + ) return False for i in range(len(example_outputs_flatten)): @@ -129,7 +138,12 @@ def validate_onnx_model(onnx_model_path, example_inputs, example_outputs_flatten rtol = 5e-02 if fp16 else 1e-4 atol = 1e-01 if fp16 else 1e-4 - if not numpy.allclose(example_ort_outputs[i], example_outputs_flatten[i].cpu().numpy(), rtol=rtol, atol=atol): + if not numpy.allclose( + example_ort_outputs[i], + example_outputs_flatten[i].cpu().numpy(), + rtol=rtol, + atol=atol, + ): logger.error(f"Output tensor {i} is not close: rtol={rtol}, atol={atol}") return False @@ -137,10 +151,19 @@ def validate_onnx_model(onnx_model_path, example_inputs, example_outputs_flatten return True -def get_onnx_file_path(onnx_dir: str, model_name: str, input_count: int, optimized_by_script: bool, use_gpu: bool, - precision: Precision, optimized_by_onnxruntime: bool, use_external_data: bool): +def get_onnx_file_path( + onnx_dir: str, + model_name: str, + input_count: int, + optimized_by_script: bool, + use_gpu: bool, + precision: Precision, + optimized_by_onnxruntime: bool, + use_external_data: bool, +): from re import sub - normalized_model_name = sub(r'[^a-zA-Z0-9_]', '_', model_name) + + normalized_model_name = sub(r"[^a-zA-Z0-9_]", "_", model_name) if not optimized_by_script: filename = f"{normalized_model_name}_{input_count}" @@ -176,27 +199,43 @@ def add_filename_suffix(file_path: str, suffix: str) -> str: def optimize_onnx_model_by_ort(onnx_model_path, ort_model_path, use_gpu, overwrite, model_fusion_statistics): if overwrite or not os.path.exists(ort_model_path): Path(ort_model_path).parent.mkdir(parents=True, exist_ok=True) - from optimizer import optimize_by_onnxruntime, get_fusion_statistics + from optimizer import get_fusion_statistics, optimize_by_onnxruntime + # Use onnxruntime to optimize model, which will be saved to *_ort.onnx - opt_model = optimize_by_onnxruntime(onnx_model_path, - use_gpu=use_gpu, - optimized_model_path=ort_model_path, - opt_level=99) + opt_model = optimize_by_onnxruntime( + onnx_model_path, + use_gpu=use_gpu, + optimized_model_path=ort_model_path, + opt_level=99, + ) model_fusion_statistics[ort_model_path] = get_fusion_statistics(ort_model_path) else: logger.info(f"Skip optimization since model existed: {ort_model_path}") -def optimize_onnx_model(model_name, onnx_model_path, optimized_model_path, model_type, num_attention_heads, hidden_size, - use_gpu, precision, use_raw_attention_mask, overwrite, model_fusion_statistics, - use_external_data_format, optimization_options=None): +def optimize_onnx_model( + model_name, + onnx_model_path, + optimized_model_path, + model_type, + num_attention_heads, + hidden_size, + use_gpu, + precision, + use_raw_attention_mask, + overwrite, + model_fusion_statistics, + use_external_data_format, + optimization_options=None, +): if overwrite or not os.path.exists(optimized_model_path): Path(optimized_model_path).parent.mkdir(parents=True, exist_ok=True) - from optimizer import optimize_model from fusion_options import FusionOptions + from optimizer import optimize_model + if optimization_options == None: - optimization_options = FusionOptions(model_type) + optimization_options = FusionOptions(model_type) optimization_options.use_raw_attention_mask(use_raw_attention_mask) if Precision.FLOAT16 == precision: optimization_options.enable_gelu_approximation = True @@ -206,15 +245,17 @@ def optimize_onnx_model(model_name, onnx_model_path, optimized_model_path, model # Use script to optimize model. # Use opt_level <= 1 for models to be converted to fp16, because some fused op (like FusedGemm) has only fp32 and no fp16. # It is better to be conservative so we use opt_level=0 here, in case MemcpyFromHost is added to the graph by OnnxRuntime. - opt_model = optimize_model(onnx_model_path, - model_type, - num_heads=num_attention_heads, - hidden_size=hidden_size, - opt_level=0, - optimization_options=optimization_options, - use_gpu=use_gpu, - only_onnxruntime=False) - if model_type == 'bert_keras' or model_type == "bert_tf": + opt_model = optimize_model( + onnx_model_path, + model_type, + num_heads=num_attention_heads, + hidden_size=hidden_size, + opt_level=0, + optimization_options=optimization_options, + use_gpu=use_gpu, + only_onnxruntime=False, + ) + if model_type == "bert_keras" or model_type == "bert_tf": opt_model.use_dynamic_axes() model_fusion_statistics[optimized_model_path] = opt_model.get_fused_operator_statistics() @@ -228,21 +269,22 @@ def optimize_onnx_model(model_name, onnx_model_path, optimized_model_path, model def modelclass_dispatcher(model_name, custom_model_class): - if (custom_model_class != None): - if (custom_model_class in MODEL_CLASSES): + if custom_model_class != None: + if custom_model_class in MODEL_CLASSES: return custom_model_class else: - raise Exception("Valid model class: " + ' '.join(MODEL_CLASSES)) + raise Exception("Valid model class: " + " ".join(MODEL_CLASSES)) if model_name in PRETRAINED_GPT2_MODELS: return "GPT2ModelNoPastState" import re - if (re.search('-squad$', model_name) != None): + + if re.search("-squad$", model_name) != None: return "AutoModelForQuestionAnswering" - elif (re.search('-mprc$', model_name) != None): + elif re.search("-mprc$", model_name) != None: return "AutoModelForSequenceClassification" - elif (re.search('gpt2', model_name) != None): + elif re.search("gpt2", model_name) != None: return "AutoModelWithLMHead" return "AutoModel" @@ -258,7 +300,7 @@ def load_pretrained_model(model_name, config, cache_dir, custom_model_class, is_ return GPT2ModelNoPastState.from_pretrained(model_name, config=config, cache_dir=cache_dir) if is_tf_model: - model_class_name = 'TF' + model_class_name + model_class_name = "TF" + model_class_name transformers_module = __import__("transformers", fromlist=[model_class_name]) logger.info(f"Model class name: {model_class_name}") @@ -269,7 +311,7 @@ def load_pretrained_model(model_name, config, cache_dir, custom_model_class, is_ def load_pt_model(model_name, model_class, cache_dir, config_modifier): config = AutoConfig.from_pretrained(model_name, cache_dir=cache_dir) - if hasattr(config, 'return_dict'): + if hasattr(config, "return_dict"): config.return_dict = False config_modifier.modify(config) @@ -287,11 +329,13 @@ def load_tf_model(model_name, model_class, cache_dir, config_modifier): # Restore the affinity after model loading for expected ORT performance affi_helper = AffinitySetting() affi_helper.get_affinity() - model = load_pretrained_model(model_name, - config=config, - cache_dir=cache_dir, - custom_model_class=model_class, - is_tf_model=True) + model = load_pretrained_model( + model_name, + config=config, + cache_dir=cache_dir, + custom_model_class=model_class, + is_tf_model=True, + ) affi_helper.set_affinity() return config, model @@ -302,47 +346,84 @@ def load_pt_model_from_tf(model_name): # Note that we could get pt model from tf, but model source and its structure in this case is different from directly using # load_pt_model() and load_tf_model() even with the same name. Therefore it should not be used for comparing with them from convert_tf_models_to_pytorch import tf2pt_pipeline + config, model = tf2pt_pipeline(model_name) return config, model -def validate_and_optimize_onnx(model_name, - use_external_data_format, - model_type, - onnx_dir, - input_names, - use_gpu, - precision, - optimize_info, - validate_onnx, - use_raw_attention_mask, - overwrite, - config, - model_fusion_statistics, - onnx_model_path, - example_inputs, - example_outputs_flatten, - output_names, - fusion_options): +def validate_and_optimize_onnx( + model_name, + use_external_data_format, + model_type, + onnx_dir, + input_names, + use_gpu, + precision, + optimize_info, + validate_onnx, + use_raw_attention_mask, + overwrite, + config, + model_fusion_statistics, + onnx_model_path, + example_inputs, + example_outputs_flatten, + output_names, + fusion_options, +): is_valid_onnx_model = True if validate_onnx: - is_valid_onnx_model = validate_onnx_model(onnx_model_path, example_inputs, example_outputs_flatten, use_gpu, - False, output_names) + is_valid_onnx_model = validate_onnx_model( + onnx_model_path, + example_inputs, + example_outputs_flatten, + use_gpu, + False, + output_names, + ) if optimize_info == OptimizerInfo.NOOPT: return onnx_model_path, is_valid_onnx_model, config.vocab_size - if optimize_info == OptimizerInfo.BYSCRIPT or precision == Precision.FLOAT16 or precision == Precision.INT8: # Use script (optimizer.py) to optimize - optimized_model_path = get_onnx_file_path(onnx_dir, model_name, len(input_names), True, use_gpu, precision, - False, use_external_data_format) - optimize_onnx_model(model_name, onnx_model_path, optimized_model_path, model_type, config.num_attention_heads, - config.hidden_size, use_gpu, precision, use_raw_attention_mask, overwrite, - model_fusion_statistics, use_external_data_format, fusion_options) + if ( + optimize_info == OptimizerInfo.BYSCRIPT or precision == Precision.FLOAT16 or precision == Precision.INT8 + ): # Use script (optimizer.py) to optimize + optimized_model_path = get_onnx_file_path( + onnx_dir, + model_name, + len(input_names), + True, + use_gpu, + precision, + False, + use_external_data_format, + ) + optimize_onnx_model( + model_name, + onnx_model_path, + optimized_model_path, + model_type, + config.num_attention_heads, + config.hidden_size, + use_gpu, + precision, + use_raw_attention_mask, + overwrite, + model_fusion_statistics, + use_external_data_format, + fusion_options, + ) onnx_model_path = optimized_model_path if validate_onnx: - is_valid_onnx_model = validate_onnx_model(onnx_model_path, example_inputs, example_outputs_flatten, use_gpu, - precision == Precision.FLOAT16, output_names) + is_valid_onnx_model = validate_onnx_model( + onnx_model_path, + example_inputs, + example_outputs_flatten, + use_gpu, + precision == Precision.FLOAT16, + output_names, + ) if precision == Precision.INT8: logger.info(f"Quantizing model: {onnx_model_path}") @@ -351,23 +432,46 @@ def validate_and_optimize_onnx(model_name, if optimize_info == OptimizerInfo.BYORT: # Use OnnxRuntime to optimize if is_valid_onnx_model: - ort_model_path = add_filename_suffix(onnx_model_path, '_ort') - optimize_onnx_model_by_ort(onnx_model_path, ort_model_path, use_gpu, overwrite, model_fusion_statistics) + ort_model_path = add_filename_suffix(onnx_model_path, "_ort") + optimize_onnx_model_by_ort( + onnx_model_path, + ort_model_path, + use_gpu, + overwrite, + model_fusion_statistics, + ) return onnx_model_path, is_valid_onnx_model, config.vocab_size -def export_onnx_model_from_pt(model_name, opset_version, use_external_data_format, model_type, model_class, - config_modifier, cache_dir, onnx_dir, input_names, use_gpu, precision, optimizer_info, - validate_onnx, use_raw_attention_mask, overwrite, model_fusion_statistics, fusion_options): +def export_onnx_model_from_pt( + model_name, + opset_version, + use_external_data_format, + model_type, + model_class, + config_modifier, + cache_dir, + onnx_dir, + input_names, + use_gpu, + precision, + optimizer_info, + validate_onnx, + use_raw_attention_mask, + overwrite, + model_fusion_statistics, + fusion_options, +): config, model = load_pt_model(model_name, model_class, cache_dir, config_modifier) # config, model = load_pt_model_from_tf(model_name) model.cpu() tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=cache_dir) - max_input_size = tokenizer.max_model_input_sizes[ - model_name] if model_name in tokenizer.max_model_input_sizes else 1024 + max_input_size = ( + tokenizer.max_model_input_sizes[model_name] if model_name in tokenizer.max_model_input_sizes else 1024 + ) example_inputs = tokenizer.encode_plus("This is a sample input", return_tensors="pt") @@ -381,8 +485,16 @@ def export_onnx_model_from_pt(model_name, opset_version, use_external_data_forma example_outputs_flatten = flatten(example_outputs) example_outputs_flatten = update_flatten_list(example_outputs_flatten, []) - onnx_model_path = get_onnx_file_path(onnx_dir, model_name, len(input_names), False, use_gpu, precision, False, - use_external_data_format) + onnx_model_path = get_onnx_file_path( + onnx_dir, + model_name, + len(input_names), + False, + use_gpu, + precision, + False, + use_external_data_format, + ) if overwrite or not os.path.exists(onnx_model_path): logger.info("Exporting ONNX model to {}".format(onnx_model_path)) @@ -391,57 +503,97 @@ def export_onnx_model_from_pt(model_name, opset_version, use_external_data_forma dynamic_axes, output_names = build_dynamic_axes(example_inputs, example_outputs_flatten) replace_torch_functions() - torch_onnx_export(model=model, - args=tuple(example_inputs.values()), - f=onnx_model_path, - input_names=list(example_inputs.keys()), - output_names=output_names, - dynamic_axes=dynamic_axes, - do_constant_folding=True, - opset_version=opset_version, - use_external_data_format=use_external_data_format) + torch_onnx_export( + model=model, + args=tuple(example_inputs.values()), + f=onnx_model_path, + input_names=list(example_inputs.keys()), + output_names=output_names, + dynamic_axes=dynamic_axes, + do_constant_folding=True, + opset_version=opset_version, + use_external_data_format=use_external_data_format, + ) restore_torch_functions() else: logger.info(f"Skip export since model existed: {onnx_model_path}") onnx_model_file, is_valid_onnx_model, vocab_size = validate_and_optimize_onnx( - model_name, use_external_data_format, model_type, onnx_dir, input_names, use_gpu, precision, optimizer_info, - validate_onnx, use_raw_attention_mask, overwrite, config, model_fusion_statistics, onnx_model_path, - example_inputs, example_outputs_flatten, None, fusion_options) + model_name, + use_external_data_format, + model_type, + onnx_dir, + input_names, + use_gpu, + precision, + optimizer_info, + validate_onnx, + use_raw_attention_mask, + overwrite, + config, + model_fusion_statistics, + onnx_model_path, + example_inputs, + example_outputs_flatten, + None, + fusion_options, + ) return onnx_model_file, is_valid_onnx_model, vocab_size, max_input_size -def export_onnx_model_from_tf(model_name, opset_version, use_external_data_format, model_type, model_class, - config_modifier, cache_dir, onnx_dir, input_names, use_gpu, precision, optimizer_info, - validate_onnx, use_raw_attention_mask, overwrite, model_fusion_statistics, fusion_options): +def export_onnx_model_from_tf( + model_name, + opset_version, + use_external_data_format, + model_type, + model_class, + config_modifier, + cache_dir, + onnx_dir, + input_names, + use_gpu, + precision, + optimizer_info, + validate_onnx, + use_raw_attention_mask, + overwrite, + model_fusion_statistics, + fusion_options, +): # Use CPU to export import tensorflow as tf - tf.config.set_visible_devices([], 'GPU') + + tf.config.set_visible_devices([], "GPU") tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=cache_dir) # Fix "Using pad_token, but it is not set yet" error. if tokenizer.pad_token is None: - tokenizer.add_special_tokens({'pad_token': '[PAD]'}) - max_input_size = tokenizer.max_model_input_sizes[ - model_name] if model_name in tokenizer.max_model_input_sizes else 1024 + tokenizer.add_special_tokens({"pad_token": "[PAD]"}) + max_input_size = ( + tokenizer.max_model_input_sizes[model_name] if model_name in tokenizer.max_model_input_sizes else 1024 + ) config, model = load_tf_model(model_name, model_class, cache_dir, config_modifier) model.resize_token_embeddings(len(tokenizer)) - example_inputs = tokenizer.encode_plus("This is a sample input", - return_tensors="tf", - max_length=max_input_size, - padding="max_length", - truncation=True) + example_inputs = tokenizer.encode_plus( + "This is a sample input", + return_tensors="tf", + max_length=max_input_size, + padding="max_length", + truncation=True, + ) example_inputs = filter_inputs(example_inputs, input_names) if config.is_encoder_decoder: - example_inputs["decoder_input_ids"] = tokenizer.encode_plus("This is a sample input", - return_tensors="tf", - max_length=max_input_size, - padding="max_length", - truncation=True).input_ids + example_inputs["decoder_input_ids"] = tokenizer.encode_plus( + "This is a sample input", + return_tensors="tf", + max_length=max_input_size, + padding="max_length", + truncation=True, + ).input_ids if model_name == "unc-nlp/lxmert-base-uncased": example_inputs["visual_feats"] = tf.random.normal([1, 1, config.visual_feat_dim]) example_inputs["visual_pos"] = tf.random.normal([1, 1, config.visual_pos_dim]) @@ -463,10 +615,19 @@ def export_onnx_model_from_tf(model_name, opset_version, use_external_data_forma # Flatten is needed for gpt2 and distilgpt2. Output name sorting is needed for tf2onnx outputs to match onnx outputs. from tensorflow.python.util import nest + example_outputs_flatten = nest.flatten(example_outputs) - onnx_model_path = get_onnx_file_path(onnx_dir, model_name, len(input_names), False, use_gpu, precision, False, - use_external_data_format) + onnx_model_path = get_onnx_file_path( + onnx_dir, + model_name, + len(input_names), + False, + use_gpu, + precision, + False, + use_external_data_format, + ) tf_internal_model_path = onnx_model_path[:-5] if use_external_data_format else onnx_model_path if overwrite or not os.path.exists(tf_internal_model_path): @@ -474,20 +635,25 @@ def export_onnx_model_from_tf(model_name, opset_version, use_external_data_forma if not use_external_data_format: Path(tf_internal_model_path).parent.mkdir(parents=True, exist_ok=True) - import tf2onnx, zipfile + import zipfile + + import tf2onnx + tf2onnx.logging.set_level(tf2onnx.logging.ERROR) specs = [] for name, value in example_inputs.items(): dims = [None] * len(value.shape) specs.append(tf.TensorSpec(tuple(dims), value.dtype, name=name)) - _, _ = tf2onnx.convert.from_keras(model, - input_signature=tuple(specs), - opset=opset_version, - large_model=use_external_data_format, - output_path=tf_internal_model_path) + _, _ = tf2onnx.convert.from_keras( + model, + input_signature=tuple(specs), + opset=opset_version, + large_model=use_external_data_format, + output_path=tf_internal_model_path, + ) if use_external_data_format: # need to unpack the zip for run_onnxruntime() - with zipfile.ZipFile(tf_internal_model_path, 'r') as z: + with zipfile.ZipFile(tf_internal_model_path, "r") as z: z.extractall(os.path.dirname(tf_internal_model_path)) tf_internal_model_path = os.path.join(os.path.dirname(tf_internal_model_path), "__MODEL_PROTO.onnx") if os.path.exists(onnx_model_path): @@ -497,10 +663,32 @@ def export_onnx_model_from_tf(model_name, opset_version, use_external_data_forma else: logger.info(f"Skip export since model existed: {onnx_model_path}") - model_type = model_type + '_tf' - opt_onnx_model_file, onnx_model_file, is_valid_onnx_model, vocab_size = validate_and_optimize_onnx( - model_name, use_external_data_format, model_type, onnx_dir, input_names, use_gpu, precision, optimizer_info, - validate_onnx, use_raw_attention_mask, overwrite, config, model_fusion_statistics, onnx_model_path, - example_inputs, example_outputs_flatten, output_names, fusion_options) - - return opt_onnx_model_file, onnx_model_file, is_valid_onnx_model, vocab_size, max_input_size + model_type = model_type + "_tf" + (opt_onnx_model_file, onnx_model_file, is_valid_onnx_model, vocab_size,) = validate_and_optimize_onnx( + model_name, + use_external_data_format, + model_type, + onnx_dir, + input_names, + use_gpu, + precision, + optimizer_info, + validate_onnx, + use_raw_attention_mask, + overwrite, + config, + model_fusion_statistics, + onnx_model_path, + example_inputs, + example_outputs_flatten, + output_names, + fusion_options, + ) + + return ( + opt_onnx_model_file, + onnx_model_file, + is_valid_onnx_model, + vocab_size, + max_input_size, + ) diff --git a/onnxruntime/python/tools/transformers/onnx_model.py b/onnxruntime/python/tools/transformers/onnx_model.py index 70d5c58b93c26..882350169854d 100644 --- a/onnxruntime/python/tools/transformers/onnx_model.py +++ b/onnxruntime/python/tools/transformers/onnx_model.py @@ -1,16 +1,27 @@ -#------------------------------------------------------------------------- +# ------------------------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. -#-------------------------------------------------------------------------- +# -------------------------------------------------------------------------- -from typing import List, Tuple, Dict import logging import os import sys +from collections import deque from pathlib import Path +from typing import Dict, List, Tuple + import numpy as np -from collections import deque -from onnx import onnx_pb, AttributeProto, ModelProto, TensorProto, NodeProto, numpy_helper, helper, external_data_helper, save_model +from onnx import ( + AttributeProto, + ModelProto, + NodeProto, + TensorProto, + external_data_helper, + helper, + numpy_helper, + onnx_pb, + save_model, +) from shape_infer_helper import SymbolicShapeInferenceHelper logger = logging.getLogger(__name__) @@ -76,11 +87,11 @@ def graphs(self): for node in graph.node: for attr in node.attribute: if attr.type == AttributeProto.AttributeType.GRAPH: - assert (isinstance(attr.g, onnx_pb.GraphProto)) + assert isinstance(attr.g, onnx_pb.GraphProto) graph_queue.append(attr.g) if attr.type == AttributeProto.AttributeType.GRAPHS: for g in attr.graphs: - assert (isinstance(g, onnx_pb.GraphProto)) + assert isinstance(g, onnx_pb.GraphProto) graph_queue.append(g) return self.all_graphs @@ -195,7 +206,7 @@ def get_nodes_by_op_type(self, op_type): return nodes def get_children(self, node, input_name_to_nodes=None): - if (input_name_to_nodes is None): + if input_name_to_nodes is None: input_name_to_nodes = self.input_name_to_nodes() children = [] @@ -229,7 +240,7 @@ def get_parent(self, node, i, output_name_to_node=None): return output_name_to_node[input] def match_first_parent(self, node, parent_op_type, output_name_to_node, exclude=[]): - ''' + """ Find parent node based on constraints on op_type. Args: @@ -241,7 +252,7 @@ def match_first_parent(self, node, parent_op_type, output_name_to_node, exclude= Returns: parent: The matched parent node. None if not found. index: The input index of matched parent node. None if not found. - ''' + """ for i, input in enumerate(node.input): if input in output_name_to_node: parent = output_name_to_node[input] @@ -251,14 +262,16 @@ def match_first_parent(self, node, parent_op_type, output_name_to_node, exclude= logger.debug(f"To find first {parent_op_type}, current {parent.op_type}") return None, None - def match_parent(self, - node, - parent_op_type, - input_index=None, - output_name_to_node=None, - exclude=[], - return_indice=None): - ''' + def match_parent( + self, + node, + parent_op_type, + input_index=None, + output_name_to_node=None, + exclude=[], + return_indice=None, + ): + """ Find parent node based on constraints on op_type and index. When input_index is None, we will find the first parent node based on constraints, and return_indice will be appended the corresponding input index. @@ -272,7 +285,7 @@ def match_parent(self, Returns: parent: The matched parent node. - ''' + """ assert node is not None assert input_index is None or input_index >= 0 @@ -307,13 +320,15 @@ def match_parent_paths(self, node, paths, output_name_to_node): return i, matched, return_indice return -1, None, None - def match_parent_path(self, - node, - parent_op_types, - parent_input_index, - output_name_to_node=None, - return_indice=None): - ''' + def match_parent_path( + self, + node, + parent_op_types, + parent_input_index, + output_name_to_node=None, + return_indice=None, + ): + """ Find a sequence of input edges based on constraints on parent op_type and index. When input_index is None, we will find the first parent node based on constraints, and return_indice will be appended the corresponding input index. @@ -326,8 +341,8 @@ def match_parent_path(self, Returns: parents: a list of matched parent node. - ''' - assert (len(parent_input_index) == len(parent_op_types)) + """ + assert len(parent_input_index) == len(parent_op_types) if output_name_to_node is None: output_name_to_node = self.output_name_to_node() @@ -335,15 +350,19 @@ def match_parent_path(self, current_node = node matched_parents = [] for i, op_type in enumerate(parent_op_types): - matched_parent = self.match_parent(current_node, - op_type, - parent_input_index[i], - output_name_to_node, - exclude=[], - return_indice=return_indice) + matched_parent = self.match_parent( + current_node, + op_type, + parent_input_index[i], + output_name_to_node, + exclude=[], + return_indice=return_indice, + ) if matched_parent is None: - logger.debug(f"Failed to match index={i} parent_input_index={parent_input_index[i]} op_type={op_type}", - stack_info=True) + logger.debug( + f"Failed to match index={i} parent_input_index={parent_input_index[i]} op_type={op_type}", + stack_info=True, + ) return None matched_parents.append(matched_parent) @@ -385,10 +404,10 @@ def find_first_parent_by_type(self, node, parent_type, output_name_to_node=None, return None def get_constant_value(self, output_name): - for node in self.get_nodes_by_op_type('Constant'): + for node in self.get_nodes_by_op_type("Constant"): if node.output[0] == output_name: for att in node.attribute: - if att.name == 'value': + if att.name == "value": return numpy_helper.to_array(att.t) # Fall back to intializer since constant folding might have been @@ -455,13 +474,12 @@ def get_children_subgraph_nodes(self, root_node, stop_nodes, input_name_to_nodes return unique_nodes def tensor_shape_to_list(self, tensor_type): - """ Convert tensor shape to list - """ + """Convert tensor shape to list""" shape_list = [] for d in tensor_type.shape.dim: - if (d.HasField("dim_value")): + if d.HasField("dim_value"): shape_list.append(d.dim_value) # known dimension - elif (d.HasField("dim_param")): + elif d.HasField("dim_param"): shape_list.append(d.dim_param) # unknown dimension with symbolic name else: shape_list.append("?") # shall not happen @@ -494,7 +512,8 @@ def get_node_attribute(node: NodeProto, attribute_name: str): def convert_model_float32_to_float16(self, cast_input_output=True): logger.warning( - 'The function convert_model_float32_to_float16 is deprecated. Use convert_float_to_float16 instead!') + "The function convert_model_float32_to_float16 is deprecated. Use convert_float_to_float16 instead!" + ) self.convert_float_to_float16(use_symbolic_shape_infer=True, keep_io_types=cast_input_output) def convert_float_to_float16(self, use_symbolic_shape_infer=True, **kwargs): @@ -505,7 +524,7 @@ def convert_float_to_float16(self, use_symbolic_shape_infer=True, **kwargs): Args: use_symbolic_shape_infer (bool, optional): use symbolic shape inference instead of onnx shape inference. Defaults to True. - keep_io_types (Union[bool, List[str]], optional): It could be boolean or a list of float32 input/output names. + keep_io_types (Union[bool, List[str]], optional): It could be boolean or a list of float32 input/output names. If True, model inputs/outputs should be left as float32. Defaults to False. op_block_list (List[str], optional): List of operator types to leave as float32. Defaults to None, which will use `float16.DEFAULT_OP_BLOCK_LIST` as default. @@ -520,16 +539,17 @@ def convert_float_to_float16(self, use_symbolic_shape_infer=True, **kwargs): def float_to_float16_func(): # TODO: import from onnxconverter_common when it is stable - #try: + # try: # import onnxconverter_common as oc # from packaging.version import Version # if Version(oc.__version__) > Version("1.9.0"): # from onnxconverter_common.float16 import convert_float_to_float16 # return convert_float_to_float16 - #except ImportError: + # except ImportError: # pass from float16 import convert_float_to_float16 + return convert_float_to_float16 convert_float_to_float16 = float_to_float16_func() @@ -540,14 +560,21 @@ def float_to_float16_func(): shape_infer_helper = SymbolicShapeInferenceHelper(model) model = shape_infer_helper.infer_shapes(model, auto_merge=True, guess_output_rank=False) - parameters = {'disable_shape_infer': use_symbolic_shape_infer} - parameters.update({ - key: kwargs[key] - for key in [ - 'keep_io_types', 'min_positive_val', 'max_finite_val', 'op_block_list', 'node_block_list', - 'force_fp16_initializers' - ] if key in kwargs - }) + parameters = {"disable_shape_infer": use_symbolic_shape_infer} + parameters.update( + { + key: kwargs[key] + for key in [ + "keep_io_types", + "min_positive_val", + "max_finite_val", + "op_block_list", + "node_block_list", + "force_fp16_initializers", + ] + if key in kwargs + } + ) fp16_model = convert_float_to_float16(model, **parameters) self.initialize(fp16_model) @@ -566,8 +593,11 @@ def float_to_float16_func(): # Remove the second cast node. for node in self.nodes(): - if node.op_type == "Cast" and OnnxModel.get_node_attribute(node, "to") == int(TensorProto.FLOAT) and \ - self.get_dtype(node.input[0]) == int(TensorProto.FLOAT): + if ( + node.op_type == "Cast" + and OnnxModel.get_node_attribute(node, "to") == int(TensorProto.FLOAT) + and self.get_dtype(node.input[0]) == int(TensorProto.FLOAT) + ): if self.find_graph_output(node.output[0]): self.replace_output_of_all_nodes(node.input[0], node.output[0]) @@ -605,7 +635,7 @@ def create_node_name(self, op_type, name_prefix=None): for node in self.nodes(): if node.name and node.name.startswith(prefix): try: - index = int(node.name[len(prefix):]) + index = int(node.name[len(prefix) :]) suffix = max(index + 1, suffix) except ValueError: continue @@ -678,7 +708,7 @@ def input_index(node_output, child_node): def remove_unused_constant(self): input_name_to_nodes = self.input_name_to_nodes() - #remove unused constant + # remove unused constant unused_nodes = [] nodes = self.nodes() for node in nodes: @@ -741,8 +771,11 @@ def prune_graph(self, outputs=None): self.model.graph.input.remove(input) if input_to_remove or output_to_remove or nodes_to_remove: - logger.info("Graph pruned: {} inputs, {} outputs and {} nodes are removed".format( - len(input_to_remove), len(output_to_remove), len(nodes_to_remove))) + logger.info( + "Graph pruned: {} inputs, {} outputs and {} nodes are removed".format( + len(input_to_remove), len(output_to_remove), len(nodes_to_remove) + ) + ) self.update_graph() @@ -751,7 +784,7 @@ def update_graph(self, verbose=False): remaining_input_names = [] for node in graph.node: - if node.op_type in ['Loop', 'Scan', 'If']: + if node.op_type in ["Loop", "Scan", "If"]: # TODO: handle inner graph logger.debug(f"Skip update_graph since graph has operator: {node.op_type}") return @@ -854,13 +887,13 @@ def graph_topological_sort(graph): end = end + 1 start = start + 1 - assert (end == len(graph.node)), "Graph is not a DAG" - graph.ClearField('node') + assert end == len(graph.node), "Graph is not a DAG" + graph.ClearField("node") graph.node.extend(sorted_nodes) def topological_sort(self): - #TODO: support graph_topological_sort() in subgraphs - #for graph in self.graphs(): + # TODO: support graph_topological_sort() in subgraphs + # for graph in self.graphs(): # self.graph_topological_sort(graph) OnnxModel.graph_topological_sort(self.model.graph) @@ -885,16 +918,19 @@ def save_model_to_file(self, output_path, use_external_data_format=False, all_te if all_tensors_to_one_file: if os.path.exists(location): logger.warning( - f"External data file ({location}) existed. Please remove the file and try again.") + f"External data file ({location}) existed. Please remove the file and try again." + ) else: if os.listdir(output_dir): logger.warning( f"Output directory ({output_dir}) for external data is not empty. Please try again with a new directory." ) - external_data_helper.convert_model_to_external_data(self.model, - all_tensors_to_one_file=all_tensors_to_one_file, - location=location) + external_data_helper.convert_model_to_external_data( + self.model, + all_tensors_to_one_file=all_tensors_to_one_file, + location=location, + ) save_model(self.model, output_path) logger.info(f"Model saved to {output_path}") diff --git a/onnxruntime/python/tools/transformers/onnx_model_bart.py b/onnxruntime/python/tools/transformers/onnx_model_bart.py index 7ba3104c190d8..33db231c52332 100644 --- a/onnxruntime/python/tools/transformers/onnx_model_bart.py +++ b/onnxruntime/python/tools/transformers/onnx_model_bart.py @@ -1,9 +1,10 @@ -#------------------------------------------------------------------------- +# ------------------------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. -#-------------------------------------------------------------------------- +# -------------------------------------------------------------------------- import logging -from fusion_attention import FusionAttention, AttentionMask + +from fusion_attention import AttentionMask, FusionAttention from fusion_reshape import FusionReshape from onnx import numpy_helper from onnx_model import OnnxModel @@ -16,18 +17,33 @@ class FusionBartEncoderAttention(FusionAttention): """ Fuse Bart Attention subgraph into one Attention node. """ - def __init__(self, model: OnnxModel, hidden_size: int, num_heads: int, attention_mask: AttentionMask): + + def __init__( + self, + model: OnnxModel, + hidden_size: int, + num_heads: int, + attention_mask: AttentionMask, + ): super().__init__(model, hidden_size, num_heads, attention_mask) - def check_runtime_shape_path(self, reshape_qkv_2, reshape_qkv_1, reshape_q_2, reshape_k_2, reshape_v_2, root_input): - concat_qkv_2_path = self.model.match_parent_path(reshape_qkv_2, ['Concat'], [1]) + def check_runtime_shape_path( + self, + reshape_qkv_2, + reshape_qkv_1, + reshape_q_2, + reshape_k_2, + reshape_v_2, + root_input, + ): + concat_qkv_2_path = self.model.match_parent_path(reshape_qkv_2, ["Concat"], [1]) if concat_qkv_2_path is None: return False concat_qkv_2 = concat_qkv_2_path[0] - reshape_qkv_2_path_1 = self.model.match_parent_path(concat_qkv_2, ['Unsqueeze', 'Gather', 'Shape'], [0, 0, 0]) - reshape_qkv_2_path_2 = self.model.match_parent_path(concat_qkv_2, ['Unsqueeze', 'Gather', 'Shape'], [1, 0, 0]) - reshape_qkv_2_path_3 = self.model.match_parent_path(concat_qkv_2, ['Unsqueeze', 'Gather', 'Shape'], [2, 0, 0]) + reshape_qkv_2_path_1 = self.model.match_parent_path(concat_qkv_2, ["Unsqueeze", "Gather", "Shape"], [0, 0, 0]) + reshape_qkv_2_path_2 = self.model.match_parent_path(concat_qkv_2, ["Unsqueeze", "Gather", "Shape"], [1, 0, 0]) + reshape_qkv_2_path_3 = self.model.match_parent_path(concat_qkv_2, ["Unsqueeze", "Gather", "Shape"], [2, 0, 0]) if reshape_qkv_2_path_1 is None or reshape_qkv_2_path_2 is None or reshape_qkv_2_path_3 is None: return False @@ -38,16 +54,16 @@ def check_runtime_shape_path(self, reshape_qkv_2, reshape_qkv_1, reshape_q_2, re if shape_1.input[0] != root_input or shape_2.input[0] != root_input or shape_3.input[0] != root_input: return False - reshape_qkv_1_path_1 = self.model.match_parent_path(reshape_qkv_1, ['Concat', 'Unsqueeze', 'Gather'], [1, 0, 0]) - reshape_qkv_1_path_2 = self.model.match_parent_path(reshape_qkv_1, ['Concat', 'Unsqueeze', 'Gather'], [1, 2, 0]) + reshape_qkv_1_path_1 = self.model.match_parent_path(reshape_qkv_1, ["Concat", "Unsqueeze", "Gather"], [1, 0, 0]) + reshape_qkv_1_path_2 = self.model.match_parent_path(reshape_qkv_1, ["Concat", "Unsqueeze", "Gather"], [1, 2, 0]) if reshape_qkv_1_path_1 is None or reshape_qkv_1_path_2 is None: return False if reshape_qkv_1_path_1[-1].name != gather_1.name or reshape_qkv_1_path_2[-1].name != gather_2.name: return False - reshape_q_2_path = self.model.match_parent_path(reshape_q_2, ['Concat', 'Unsqueeze', 'Mul'], [1, 0, 0]) - reshape_k_2_path = self.model.match_parent_path(reshape_k_2, ['Concat', 'Unsqueeze', 'Mul'], [1, 0, 0]) - reshape_v_2_path = self.model.match_parent_path(reshape_v_2, ['Concat', 'Unsqueeze', 'Mul'], [1, 0, 0]) + reshape_q_2_path = self.model.match_parent_path(reshape_q_2, ["Concat", "Unsqueeze", "Mul"], [1, 0, 0]) + reshape_k_2_path = self.model.match_parent_path(reshape_k_2, ["Concat", "Unsqueeze", "Mul"], [1, 0, 0]) + reshape_v_2_path = self.model.match_parent_path(reshape_v_2, ["Concat", "Unsqueeze", "Mul"], [1, 0, 0]) if reshape_q_2_path is None or reshape_k_2_path is None or reshape_v_2_path is None: return False @@ -63,11 +79,20 @@ def check_runtime_shape_path(self, reshape_qkv_2, reshape_qkv_1, reshape_q_2, re def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): # SkipLayerNormalization has two inputs, and one of them is the root input for attention. - qkv_nodes = self.model.match_parent_path(normalize_node, - ['Add', 'MatMul', 'Reshape', 'Transpose', 'Reshape', 'MatMul'], - [None, 1, 0, 0, 0, 0]) + qkv_nodes = self.model.match_parent_path( + normalize_node, + ["Add", "MatMul", "Reshape", "Transpose", "Reshape", "MatMul"], + [None, 1, 0, 0, 0, 0], + ) if qkv_nodes is not None: - (add_out, matmul_out, reshape_qkv_2, transpose_qkv, reshape_qkv_1, matmul_qkv) = qkv_nodes + ( + add_out, + matmul_out, + reshape_qkv_2, + transpose_qkv, + reshape_qkv_1, + matmul_qkv, + ) = qkv_nodes else: return @@ -84,39 +109,53 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): root_input = other_inputs[0] children = input_name_to_nodes[root_input] children_types = [child.op_type for child in children] - if children_types.count('MatMul') != 3: + if children_types.count("MatMul") != 3: return - v_nodes = self.model.match_parent_path(matmul_qkv, ['Reshape', 'Transpose', 'Reshape', 'Add', 'MatMul'], - [1, 0, 0, 0, None]) + v_nodes = self.model.match_parent_path( + matmul_qkv, + ["Reshape", "Transpose", "Reshape", "Add", "MatMul"], + [1, 0, 0, 0, None], + ) if v_nodes is None: logger.debug("fuse_attention: failed to match v path") return (reshape_v_2, transpose_v, reshape_v_1, add_v, matmul_v) = v_nodes - qk_nodes = self.model.match_parent_path(matmul_qkv, ['Softmax', 'MatMul'], [0, 0]) + qk_nodes = self.model.match_parent_path(matmul_qkv, ["Softmax", "MatMul"], [0, 0]) if qk_nodes is not None: _, matmul_qk = qk_nodes else: return - q_nodes = self.model.match_parent_path(matmul_qk, ['Reshape', 'Transpose', 'Reshape', 'Mul', 'Add', 'MatMul'], - [0, 0, 0, 0, 0, 1]) + q_nodes = self.model.match_parent_path( + matmul_qk, + ["Reshape", "Transpose", "Reshape", "Mul", "Add", "MatMul"], + [0, 0, 0, 0, 0, 1], + ) if q_nodes is not None: reshape_q_2, _, reshape_q_1, _, add_q, matmul_q = q_nodes else: return - k_nodes = self.model.match_parent_path(matmul_qk, - ['Transpose', 'Reshape', 'Transpose', 'Reshape', 'Add', 'MatMul'], - [1, 0, 0, 0, 0, 1]) + k_nodes = self.model.match_parent_path( + matmul_qk, + ["Transpose", "Reshape", "Transpose", "Reshape", "Add", "MatMul"], + [1, 0, 0, 0, 0, 1], + ) if k_nodes is not None: _, reshape_k_2, _, reshape_k_1, add_k, matmul_k = k_nodes else: return - if not self.check_runtime_shape_path(reshape_qkv_2, reshape_qkv_1, reshape_q_2, reshape_k_2, reshape_v_2, - root_input): + if not self.check_runtime_shape_path( + reshape_qkv_2, + reshape_qkv_1, + reshape_q_2, + reshape_k_2, + reshape_v_2, + root_input, + ): return if matmul_v.input[0] == root_input and matmul_q.input[0] == root_input and matmul_v.input[0] == root_input: @@ -131,9 +170,20 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): logger.debug("fuse_attention: failed to detect num_heads or hidden_size") return - new_node = self.create_attention_node(mask_index, matmul_q, matmul_k, matmul_v, add_q, add_k, add_v, - num_heads, hidden_size, root_input, attention_last_node.output[0], - None) + new_node = self.create_attention_node( + mask_index, + matmul_q, + matmul_k, + matmul_v, + add_q, + add_k, + add_v, + num_heads, + hidden_size, + root_input, + attention_last_node.output[0], + None, + ) if new_node is None: return @@ -160,11 +210,15 @@ def fuse(self, reshape_node, input_name_to_nodes, output_name_to_node): return concat_node = output_name_to_node[reshape_node.input[1]] - if concat_node.op_type != 'Concat' or len(concat_node.input) != 4: + if concat_node.op_type != "Concat" or len(concat_node.input) != 4: return - path0 = self.model.match_parent_path(concat_node, ['Unsqueeze', 'Gather', 'Shape'], [0, 0, 0], - output_name_to_node) + path0 = self.model.match_parent_path( + concat_node, + ["Unsqueeze", "Gather", "Shape"], + [0, 0, 0], + output_name_to_node, + ) if path0 is None: return @@ -175,8 +229,12 @@ def fuse(self, reshape_node, input_name_to_nodes, output_name_to_node): if gather_value == 0: shape.append(0) - path1 = self.model.match_parent_path(concat_node, ['Unsqueeze', 'Gather', 'Shape'], [1, 0, 0], - output_name_to_node) + path1 = self.model.match_parent_path( + concat_node, + ["Unsqueeze", "Gather", "Shape"], + [1, 0, 0], + output_name_to_node, + ) if path1 is None: input_1_proto = self.model.get_initializer(concat_node.input[1]) input_2_proto = self.model.get_initializer(concat_node.input[2]) @@ -196,7 +254,7 @@ def fuse(self, reshape_node, input_name_to_nodes, output_name_to_node): shape.extend(input_1) shape.extend(input_2) shape.extend(input_3) - gemm_path = self.model.match_parent_path(reshape_node, ['Add', 'MatMul'], [0, 1], output_name_to_node) + gemm_path = self.model.match_parent_path(reshape_node, ["Add", "MatMul"], [0, 1], output_name_to_node) if gemm_path is None: return @@ -228,8 +286,9 @@ def fuse(self, reshape_node, input_name_to_nodes, output_name_to_node): shape.extend(input_2) shape.extend(input_3) - gemm_path = self.model.match_parent_path(reshape_node, ['Mul', 'Add', 'MatMul'], [0, 0, 1], - output_name_to_node) + gemm_path = self.model.match_parent_path( + reshape_node, ["Mul", "Add", "MatMul"], [0, 0, 1], output_name_to_node + ) if gemm_path is None: return diff --git a/onnxruntime/python/tools/transformers/onnx_model_bert.py b/onnxruntime/python/tools/transformers/onnx_model_bert.py index 2efbecfa7b55d..2b1192445b689 100644 --- a/onnxruntime/python/tools/transformers/onnx_model_bert.py +++ b/onnxruntime/python/tools/transformers/onnx_model_bert.py @@ -1,31 +1,32 @@ -#------------------------------------------------------------------------- +# ------------------------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. -#-------------------------------------------------------------------------- +# -------------------------------------------------------------------------- from logging import getLogger from typing import List -from onnx import GraphProto, ModelProto, TensorProto, ValueInfoProto, helper -from onnx_model import OnnxModel -from fusion_reshape import FusionReshape -from fusion_shape import FusionShape -from fusion_layernorm import FusionLayerNormalization, FusionLayerNormalizationTF -from fusion_skiplayernorm import FusionSkipLayerNormalization, FusionBiasSkipLayerNormalization + +from fusion_attention import AttentionMask, FusionAttention +from fusion_biasgelu import FusionBiasGelu from fusion_embedlayer import FusionEmbedLayerNormalization -from fusion_attention import FusionAttention, AttentionMask -from fusion_gelu import FusionGelu from fusion_fastgelu import FusionFastGelu -from fusion_biasgelu import FusionBiasGelu +from fusion_gelu import FusionGelu from fusion_gelu_approximation import FusionGeluApproximation -from fusion_utils import FusionUtils +from fusion_layernorm import FusionLayerNormalization, FusionLayerNormalizationTF from fusion_options import FusionOptions +from fusion_reshape import FusionReshape +from fusion_shape import FusionShape +from fusion_skiplayernorm import FusionBiasSkipLayerNormalization, FusionSkipLayerNormalization +from fusion_utils import FusionUtils +from onnx import GraphProto, ModelProto, TensorProto, ValueInfoProto, helper +from onnx_model import OnnxModel logger = getLogger(__name__) class BertOptimizationOptions(FusionOptions): - """ This class is deprecated - """ + """This class is deprecated""" + def __init__(self, model_type): logger.warning(f"BertOptimizationOptions is depreciated. Please use FusionOptions instead.") super().__init__(model_type) @@ -111,20 +112,22 @@ def get_graph_inputs_from_node_type(self, op_type: str, input_indices: List[int] graph_inputs.append(bert_input) elif bert_input in output_name_to_node: parent = output_name_to_node[bert_input] - if parent.op_type == 'Cast' and self.find_graph_input(parent.input[0]) is not None: + if parent.op_type == "Cast" and self.find_graph_input(parent.input[0]) is not None: if casted: graph_inputs.append(parent.input[0]) return graph_inputs def get_graph_inputs_from_fused_nodes(self, casted: bool): - inputs = self.get_graph_inputs_from_node_type('EmbedLayerNormalization', [0, 1, 7], casted) - inputs += self.get_graph_inputs_from_node_type('Attention', [3], casted) + inputs = self.get_graph_inputs_from_node_type("EmbedLayerNormalization", [0, 1, 7], casted) + inputs += self.get_graph_inputs_from_node_type("Attention", [3], casted) return inputs - def change_graph_input_type(self, - graph: GraphProto, - graph_input: ValueInfoProto, - new_type: int = TensorProto.INT32): + def change_graph_input_type( + self, + graph: GraphProto, + graph_input: ValueInfoProto, + new_type: int = TensorProto.INT32, + ): """Change graph input type, and add Cast node if needed. Args: @@ -151,16 +154,20 @@ def change_graph_input_type(self, nodes = input_name_to_nodes[graph_input.name] # For children that is not Cast node, insert a Cast node to convert int32 to original data type. - nodes_not_cast = [node for node in nodes if node.op_type != 'Cast'] + nodes_not_cast = [node for node in nodes if node.op_type != "Cast"] if nodes_not_cast: - node_name = self.create_node_name('Cast') - output_name = node_name + '_' + graph_input.name + node_name = self.create_node_name("Cast") + output_name = node_name + "_" + graph_input.name new_value_info = graph.value_info.add() new_value_info.CopyFrom(graph_input) new_value_info.name = output_name - new_cast_node = helper.make_node('Cast', [graph_input.name], [output_name], - to=int(graph_input.type.tensor_type.elem_type), - name=node_name) + new_cast_node = helper.make_node( + "Cast", + [graph_input.name], + [output_name], + to=int(graph_input.type.tensor_type.elem_type), + name=node_name, + ) graph.node.extend([new_cast_node]) for node in nodes_not_cast: @@ -168,7 +175,7 @@ def change_graph_input_type(self, # For children that is Cast node, no need to insert Cast. # When the children is Cast to int32, we can remove that Cast node since input type is int32 now. - nodes_cast = [node for node in nodes if node.op_type == 'Cast'] + nodes_cast = [node for node in nodes if node.op_type == "Cast"] for node in nodes_cast: if OnnxModel.get_node_attribute(node, "to") == int(new_type): self.replace_input_of_all_nodes(node.output[0], graph_input.name) @@ -181,8 +188,7 @@ def change_graph_input_type(self, return new_cast_node, nodes_to_remove def change_graph_inputs_to_int32(self): - """Change data type of all graph inputs to int32 type, and add Cast node if needed. - """ + """Change data type of all graph inputs to int32 type, and add Cast node if needed.""" graph = self.graph() add_cast_count = 0 remove_cast_count = 0 @@ -195,12 +201,13 @@ def change_graph_inputs_to_int32(self): f"Graph inputs are changed to int32. Added {add_cast_count} Cast nodes, and removed {remove_cast_count} Cast nodes." ) - def use_dynamic_axes(self, dynamic_batch_dim='batch_size', dynamic_seq_len='max_seq_len'): + def use_dynamic_axes(self, dynamic_batch_dim="batch_size", dynamic_seq_len="max_seq_len"): """ Update input and output shape to use dynamic axes. """ bert_graph_inputs = self.get_graph_inputs_from_fused_nodes( - casted=True) + self.get_graph_inputs_from_fused_nodes(casted=False) + casted=True + ) + self.get_graph_inputs_from_fused_nodes(casted=False) dynamic_batch_inputs = {} for input in self.model.graph.input: @@ -222,7 +229,7 @@ def preprocess(self): def adjust_reshape_and_expand(self): nodes_to_remove = [] for node in self.nodes(): - if node.op_type == 'Reshape': + if node.op_type == "Reshape": # Clean up unneccessary reshape nodes. # Find reshape nodes with no actually data in "shape" attribute and remove. reshape_shape = self.get_constant_value(node.input[1]) @@ -233,8 +240,12 @@ def adjust_reshape_and_expand(self): # Find path "Slice" -> "Reshape" -> "Expand" -> "Expand" -> current "Reshape", simplify the graph by # changing current reshape's input to output of slice. - reshape_path = self.match_parent_path(node, ['Expand', 'Expand', 'Reshape', 'Slice'], [0, 0, 0, 0], - self.output_name_to_node()) + reshape_path = self.match_parent_path( + node, + ["Expand", "Expand", "Reshape", "Slice"], + [0, 0, 0, 0], + self.output_name_to_node(), + ) if reshape_path is not None: expand_node = reshape_path[-3] expand_shape_value = self.get_constant_value(expand_node.input[1]) @@ -243,9 +254,13 @@ def adjust_reshape_and_expand(self): shape_value = self.get_constant_value(reshape_before_expand.input[1]) slice_node = reshape_path[-1] - if expand_shape_value is not None and shape_value is not None and len( - expand_shape_value) == 2 and len( - shape_value) == 1 and expand_shape_value[1] == shape_value[0]: + if ( + expand_shape_value is not None + and shape_value is not None + and len(expand_shape_value) == 2 + and len(shape_value) == 1 + and expand_shape_value[1] == shape_value[0] + ): node.input[0] = slice_node.output[0] if nodes_to_remove: @@ -268,27 +283,50 @@ def clean_graph(self): if node.op_type in op_input_id: i = op_input_id[node.op_type] parent_nodes = self.match_parent_path( - node, ['Cast', 'ConstantOfShape', 'Concat', 'Unsqueeze', 'Gather', 'Shape'], [i, 0, 0, 0, 0, 0], - output_name_to_node) + node, + [ + "Cast", + "ConstantOfShape", + "Concat", + "Unsqueeze", + "Gather", + "Shape", + ], + [i, 0, 0, 0, 0, 0], + output_name_to_node, + ) if parent_nodes is not None: - cast, constantOfShape, concat, unsqueeze, gather, shape = parent_nodes + ( + cast, + constantOfShape, + concat, + unsqueeze, + gather, + shape, + ) = parent_nodes if shape.input[0] == self.graph().input[0].name: constantOfShape.input[0] = shape.output[0] output_name_to_node = self.output_name_to_node() - if node.op_type == 'Attention': + if node.op_type == "Attention": # Before: # input_ids --> Shape -->ConstantOfShape -->Cast --> ReduceSum --> Attention # After: # remove this path, and remove the optional mask_index input of Attention node. - parent_nodes = self.match_parent_path(node, ['ReduceSum', 'Cast', 'ConstantOfShape', 'Shape'], - [3, 0, 0, 0], output_name_to_node) + parent_nodes = self.match_parent_path( + node, + ["ReduceSum", "Cast", "ConstantOfShape", "Shape"], + [3, 0, 0, 0], + output_name_to_node, + ) if parent_nodes is not None: if parent_nodes[-1].input[0] == self.graph().input[0].name: - attention_node = helper.make_node('Attention', - inputs=node.input[0:len(node.input) - 1], - outputs=node.output, - name=node.name + "_remove_mask") + attention_node = helper.make_node( + "Attention", + inputs=node.input[0 : len(node.input) - 1], + outputs=node.output, + name=node.name + "_remove_mask", + ) attention_node.domain = "com.microsoft" attention_node.attribute.extend([helper.make_attribute("num_heads", self.num_heads)]) self.add_node(attention_node, self.get_graph_by_node(attention_node).name) @@ -341,7 +379,7 @@ def optimize(self, options: FusionOptions = None, add_dynamic_axes=False): # Fuse SkipLayerNormalization and Add Bias before it. self.fuse_add_bias_skip_layer_norm() - if (options is not None and options.enable_gelu_approximation): + if options is not None and options.enable_gelu_approximation: self.gelu_approximation() self.remove_unused_constant() @@ -358,8 +396,13 @@ def get_fused_operator_statistics(self): """ op_count = {} ops = [ - 'EmbedLayerNormalization', 'Attention', 'Gelu', 'FastGelu', 'BiasGelu', 'LayerNormalization', - 'SkipLayerNormalization' + "EmbedLayerNormalization", + "Attention", + "Gelu", + "FastGelu", + "BiasGelu", + "LayerNormalization", + "SkipLayerNormalization", ] for op in ops: nodes = self.get_nodes_by_op_type(op) @@ -372,10 +415,10 @@ def is_fully_optimized(self): Returns True when the model is fully optimized. """ op_count = self.get_fused_operator_statistics() - embed = op_count['EmbedLayerNormalization'] - attention = op_count['Attention'] - gelu = op_count['Gelu'] + op_count['BiasGelu'] + op_count['FastGelu'] - layer_norm = op_count['LayerNormalization'] + op_count['SkipLayerNormalization'] + embed = op_count["EmbedLayerNormalization"] + attention = op_count["Attention"] + gelu = op_count["Gelu"] + op_count["BiasGelu"] + op_count["FastGelu"] + layer_norm = op_count["LayerNormalization"] + op_count["SkipLayerNormalization"] is_perfect = (embed > 0) and (attention > 0) and (attention == gelu) and (layer_norm >= 2 * attention) if layer_norm == 0: diff --git a/onnxruntime/python/tools/transformers/onnx_model_bert_keras.py b/onnxruntime/python/tools/transformers/onnx_model_bert_keras.py index 53f2ca1f9d9ba..33bb1d66a7528 100644 --- a/onnxruntime/python/tools/transformers/onnx_model_bert_keras.py +++ b/onnxruntime/python/tools/transformers/onnx_model_bert_keras.py @@ -1,14 +1,15 @@ -#------------------------------------------------------------------------- +# ------------------------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. -#-------------------------------------------------------------------------- +# -------------------------------------------------------------------------- +import argparse import logging -import onnx import sys -import argparse -import numpy as np from collections import deque + +import numpy as np +import onnx from onnx import ModelProto, TensorProto, numpy_helper from onnx_model_bert_tf import BertOnnxModelTF @@ -20,18 +21,27 @@ def __init__(self, model, num_heads, hidden_size): super().__init__(model, num_heads, hidden_size) def match_mask_path(self, add_or_sub_before_softmax): - mask_nodes = self.match_parent_path(add_or_sub_before_softmax, ['Mul', 'Sub', 'Reshape', 'Cast'], - [1, None, 1, 0]) + mask_nodes = self.match_parent_path( + add_or_sub_before_softmax, + ["Mul", "Sub", "Reshape", "Cast"], + [1, None, 1, 0], + ) if mask_nodes is not None: return mask_nodes - mask_nodes = self.match_parent_path(add_or_sub_before_softmax, ['Mul', 'Sub', 'Cast', 'Slice', 'Unsqueeze'], - [1, 1, 1, 0, 0]) + mask_nodes = self.match_parent_path( + add_or_sub_before_softmax, + ["Mul", "Sub", "Cast", "Slice", "Unsqueeze"], + [1, 1, 1, 0, 0], + ) if mask_nodes is not None: return mask_nodes - mask_nodes = self.match_parent_path(add_or_sub_before_softmax, ['Mul', 'Sub', 'Cast', 'Unsqueeze', 'Unsqueeze'], - [1, None, 1, 0, 0]) + mask_nodes = self.match_parent_path( + add_or_sub_before_softmax, + ["Mul", "Sub", "Cast", "Unsqueeze", "Unsqueeze"], + [1, None, 1, 0, 0], + ) return mask_nodes def check_attention_input(self, matmul_q, matmul_k, matmul_v, parent, output_name_to_node): @@ -42,7 +52,7 @@ def check_attention_input(self, matmul_q, matmul_k, matmul_v, parent, output_nam root_node = output_name_to_node[root_input] if root_node == parent: continue - if root_node.op_type == 'Reshape' and root_node.input[0] == parent.output[0]: + if root_node.op_type == "Reshape" and root_node.input[0] == parent.output[0]: reshape_nodes.append(root_node) continue logger.debug(f"Check attention input failed:{root_input}, {parent.output[0]}") @@ -61,56 +71,89 @@ def fuse_attention(self): for normalize_node in skip_layer_norm_nodes: # SkipLayerNormalization has two inputs, and one of them is the root input for attention. parent = self.get_parent(normalize_node, 0) - if parent is None or parent.op_type not in ["SkipLayerNormalization", "EmbedLayerNormalization"]: - if parent.op_type == 'Add': + if parent is None or parent.op_type not in [ + "SkipLayerNormalization", + "EmbedLayerNormalization", + ]: + if parent.op_type == "Add": parent = self.get_parent(normalize_node, 1) - if parent is None or parent.op_type not in ["SkipLayerNormalization", "EmbedLayerNormalization"]: + if parent is None or parent.op_type not in [ + "SkipLayerNormalization", + "EmbedLayerNormalization", + ]: logger.debug( - "First input for skiplayernorm: {}".format(parent.op_type if parent is not None else None)) + "First input for skiplayernorm: {}".format(parent.op_type if parent is not None else None) + ) continue else: logger.debug( - "First input for skiplayernorm: {}".format(parent.op_type if parent is not None else None)) + "First input for skiplayernorm: {}".format(parent.op_type if parent is not None else None) + ) continue else: # TODO: shall we add back the checking of children op types. pass - qkv_nodes = self.match_parent_path(normalize_node, - ['Add', 'Reshape', 'MatMul', 'Reshape', 'Transpose', 'MatMul'], - [None, 0, 0, 0, 0, 0]) + qkv_nodes = self.match_parent_path( + normalize_node, + ["Add", "Reshape", "MatMul", "Reshape", "Transpose", "MatMul"], + [None, 0, 0, 0, 0, 0], + ) if qkv_nodes is None: logger.debug("Failed to match qkv nodes") continue - (add, extra_reshape_0, matmul, reshape_qkv, transpose_qkv, matmul_qkv) = qkv_nodes + ( + add, + extra_reshape_0, + matmul, + reshape_qkv, + transpose_qkv, + matmul_qkv, + ) = qkv_nodes logger.debug("Matched qkv nodes") - v_nodes = self.match_parent_path(matmul_qkv, ['Transpose', 'Reshape', 'Add', 'Reshape', 'MatMul'], - [1, 0, 0, 0, 0]) + v_nodes = self.match_parent_path( + matmul_qkv, + ["Transpose", "Reshape", "Add", "Reshape", "MatMul"], + [1, 0, 0, 0, 0], + ) if v_nodes is None: logger.debug("Failed to match v path") continue (transpose_v, reshape_v, add_v, extra_reshape_1, matmul_v) = v_nodes - qk_nodes = self.match_parent_path(matmul_qkv, ['Softmax', 'Sub', 'MatMul'], [0, 0, 0]) + qk_nodes = self.match_parent_path(matmul_qkv, ["Softmax", "Sub", "MatMul"], [0, 0, 0]) if qk_nodes is not None: (softmax_qk, sub_qk, matmul_qk) = qk_nodes - q_nodes = self.match_parent_path(matmul_qk, ['Mul', 'Transpose', 'Reshape', 'Add', 'Reshape', 'MatMul'], - [0, None, 0, 0, 0, 0]) + q_nodes = self.match_parent_path( + matmul_qk, + ["Mul", "Transpose", "Reshape", "Add", "Reshape", "MatMul"], + [0, None, 0, 0, 0, 0], + ) if q_nodes is not None: - (mul_q, transpose_q, reshape_q, add_q, extra_reshape_2, matmul_q) = q_nodes + ( + mul_q, + transpose_q, + reshape_q, + add_q, + extra_reshape_2, + matmul_q, + ) = q_nodes else: - qk_nodes = self.match_parent_path(matmul_qkv, ['Softmax', 'Add', 'Mul', 'MatMul'], [0, 0, 0, None]) + qk_nodes = self.match_parent_path(matmul_qkv, ["Softmax", "Add", "Mul", "MatMul"], [0, 0, 0, None]) if qk_nodes is None: - qk_nodes = self.match_parent_path(matmul_qkv, ['Softmax', 'Add', 'Div', 'MatMul'], [0, 0, 0, None]) + qk_nodes = self.match_parent_path(matmul_qkv, ["Softmax", "Add", "Div", "MatMul"], [0, 0, 0, None]) if qk_nodes is None: logger.debug("Failed to match qk path") continue (softmax_qk, add_qk, mul_qk, matmul_qk) = qk_nodes - q_nodes = self.match_parent_path(matmul_qk, ['Transpose', 'Reshape', 'Add', 'Reshape', 'MatMul'], - [0, 0, 0, 0, 0]) + q_nodes = self.match_parent_path( + matmul_qk, + ["Transpose", "Reshape", "Add", "Reshape", "MatMul"], + [0, 0, 0, 0, 0], + ) if q_nodes is not None: (transpose_q, reshape_q, add_q, extra_reshape_2, matmul_q) = q_nodes @@ -118,8 +161,11 @@ def fuse_attention(self): logger.debug("Failed to match q path") continue - k_nodes = self.match_parent_path(matmul_qk, ['Transpose', 'Reshape', 'Add', 'Reshape', 'MatMul'], - [1, 0, 0, 0, 0]) + k_nodes = self.match_parent_path( + matmul_qk, + ["Transpose", "Reshape", "Add", "Reshape", "MatMul"], + [1, 0, 0, 0, 0], + ) if k_nodes is None: logger.debug("Failed to match k path") continue @@ -133,15 +179,26 @@ def fuse_attention(self): logger.debug("Sub node expected to have an input with constant value 1.0.") continue - is_same_root, reshape_nodes = self.check_attention_input(matmul_q, matmul_k, matmul_v, parent, - output_name_to_node) + is_same_root, reshape_nodes = self.check_attention_input( + matmul_q, matmul_k, matmul_v, parent, output_name_to_node + ) if is_same_root: mask_index = self.attention_mask.process_mask(mask_nodes[-1].input[0]) logger.debug("Create an Attention node.") - attention_node = self.attention_fusion.create_attention_node(mask_index, matmul_q, matmul_k, matmul_v, - add_q, add_k, add_v, self.num_heads, - self.hidden_size, parent.output[0], - reshape_qkv.output[0], None) + attention_node = self.attention_fusion.create_attention_node( + mask_index, + matmul_q, + matmul_k, + matmul_v, + add_q, + add_k, + add_v, + self.num_heads, + self.hidden_size, + parent.output[0], + reshape_qkv.output[0], + None, + ) if attention_node is None: continue @@ -188,9 +245,9 @@ def skip_reshape(self): logger.info(f"Skip consequent Reshape count: {count}") def fuse_embedding(self, node, output_name_to_node): - assert node.op_type == 'LayerNormalization' + assert node.op_type == "LayerNormalization" logger.debug(f"start fusing embedding from node with output={node.output[0]}...") - word_embed_path = self.match_parent_path(node, ['Add', 'Add', 'Gather'], [0, 0, 0], output_name_to_node) + word_embed_path = self.match_parent_path(node, ["Add", "Add", "Gather"], [0, 0, 0], output_name_to_node) if word_embed_path is None: logger.debug("failed to match word_embed_path") return False @@ -219,11 +276,12 @@ def fuse_embedding(self, node, output_name_to_node): logger.info("Found position embedding. name:{}, shape:{}".format(pos_initializer.name, temp.shape[1:])) position_embedding = "position_embedding" else: - logger.info("Failed to find position embedding. name:{}, shape:{}".format( - pos_initializer.name, temp.shape)) + logger.info( + "Failed to find position embedding. name:{}, shape:{}".format(pos_initializer.name, temp.shape) + ) return False else: - pos_embed_path = self.match_parent_path(add_node, ['Gather', 'Slice'], [1, 1], output_name_to_node) + pos_embed_path = self.match_parent_path(add_node, ["Gather", "Slice"], [1, 1], output_name_to_node) if pos_embed_path is None: logger.debug("failed to match pos_embed_path") return False @@ -239,8 +297,9 @@ def fuse_embedding(self, node, output_name_to_node): logger.info("Found word embedding. name:{}, shape:{}".format(pos_initializer.name, temp.shape)) position_embedding = pos_initializer.name else: - logger.info("Failed to find position embedding. name:{}, shape:{}".format( - pos_initializer.name, temp.shape)) + logger.info( + "Failed to find position embedding. name:{}, shape:{}".format(pos_initializer.name, temp.shape) + ) return False gather = self.get_parent(skip_node, 1, output_name_to_node) @@ -258,8 +317,9 @@ def fuse_embedding(self, node, output_name_to_node): logger.info("Found segment embedding. name:{}, shape:{}".format(segment_initializer.name, temp.shape)) segment_embedding = segment_initializer.name else: - logger.info("Failed to find segment embedding. name:{}, shape:{}".format( - segment_initializer.name, temp.shape)) + logger.info( + "Failed to find segment embedding. name:{}, shape:{}".format(segment_initializer.name, temp.shape) + ) return False logger.info("Create Embedding node") @@ -273,7 +333,7 @@ def process_embedding(self): logger.info("start processing embedding layer...") output_name_to_node = self.output_name_to_node() for node in self.nodes(): - if node.op_type == 'LayerNormalization': + if node.op_type == "LayerNormalization": if self.fuse_embedding(node, output_name_to_node): return break @@ -281,8 +341,8 @@ def process_embedding(self): def fuse_mask(self): nodes_to_remove = [] for node in self.nodes(): - if node.op_type == 'Mul' and self.has_constant_input(node, -10000): - mask_path = self.match_parent_path(node, ['Sub', 'Cast', 'Slice', 'Unsqueeze'], [0, 1, 0, 0]) + if node.op_type == "Mul" and self.has_constant_input(node, -10000): + mask_path = self.match_parent_path(node, ["Sub", "Cast", "Slice", "Unsqueeze"], [0, 1, 0, 0]) if mask_path is None: continue sub_node, cast_node, slice_node, unsqueeze_node = mask_path @@ -292,24 +352,30 @@ def fuse_mask(self): print("Cast input {} is not mask input {}".format(unsqueeze_node.input[0], mask_input_name)) continue - unsqueeze_added_1 = onnx.helper.make_node('Unsqueeze', - inputs=[mask_input_name], - outputs=['mask_fuse_unsqueeze1_output'], - name='Mask_UnSqueeze_1', - axes=[1]) - - unsqueeze_added_2 = onnx.helper.make_node('Unsqueeze', - inputs=['mask_fuse_unsqueeze1_output'], - outputs=['mask_fuse_unsqueeze2_output'], - name='Mask_UnSqueeze_2', - axes=[2]) - - #self.replace_node_input(cast_node, cast_node.input[0], 'mask_fuse_unsqueeze2_output') - cast_node_2 = onnx.helper.make_node('Cast', - inputs=['mask_fuse_unsqueeze2_output'], - outputs=['mask_fuse_cast_output']) + unsqueeze_added_1 = onnx.helper.make_node( + "Unsqueeze", + inputs=[mask_input_name], + outputs=["mask_fuse_unsqueeze1_output"], + name="Mask_UnSqueeze_1", + axes=[1], + ) + + unsqueeze_added_2 = onnx.helper.make_node( + "Unsqueeze", + inputs=["mask_fuse_unsqueeze1_output"], + outputs=["mask_fuse_unsqueeze2_output"], + name="Mask_UnSqueeze_2", + axes=[2], + ) + + # self.replace_node_input(cast_node, cast_node.input[0], 'mask_fuse_unsqueeze2_output') + cast_node_2 = onnx.helper.make_node( + "Cast", + inputs=["mask_fuse_unsqueeze2_output"], + outputs=["mask_fuse_cast_output"], + ) cast_node_2.attribute.extend([onnx.helper.make_attribute("to", 1)]) - self.replace_node_input(sub_node, sub_node.input[1], 'mask_fuse_cast_output') + self.replace_node_input(sub_node, sub_node.input[1], "mask_fuse_cast_output") nodes_to_remove.extend([slice_node, unsqueeze_node, cast_node]) self.add_node(unsqueeze_added_1) @@ -330,12 +396,33 @@ def remove_extra_reshape(self): for skiplayernorm_node in skiplayernorm_nodes: path = self.match_parent_path( skiplayernorm_node, - ['Add', 'Reshape', 'MatMul', 'Reshape', 'Gelu', 'Add', 'Reshape', 'MatMul', 'SkipLayerNormalization'], - [0, 0, 0, 0, 0, 0, 0, 0, 0]) + [ + "Add", + "Reshape", + "MatMul", + "Reshape", + "Gelu", + "Add", + "Reshape", + "MatMul", + "SkipLayerNormalization", + ], + [0, 0, 0, 0, 0, 0, 0, 0, 0], + ) if path is None: continue - add_1, reshape_1, matmul_1, reshape_2, gelu, add_2, reshape_3, matmul_2, skiplayernorm = path + ( + add_1, + reshape_1, + matmul_1, + reshape_2, + gelu, + add_2, + reshape_3, + matmul_2, + skiplayernorm, + ) = path add_2.input[0] = matmul_2.output[0] self.remove_node(reshape_3) matmul_1.input[0] = gelu.output[0] @@ -352,12 +439,35 @@ def remove_extra_reshape_2(self): for skiplayernorm_node in skiplayernorm_nodes: path = self.match_parent_path( skiplayernorm_node, - ['Add', 'Reshape', 'MatMul', 'Reshape', 'Gelu', 'Add', 'Reshape', 'MatMul', 'Reshape', 'SkipLayerNormalization'], - [None, 0, 0, 0, 0, 0, 0, 0, 0, 0]) # yapf: disable + [ + "Add", + "Reshape", + "MatMul", + "Reshape", + "Gelu", + "Add", + "Reshape", + "MatMul", + "Reshape", + "SkipLayerNormalization", + ], + [None, 0, 0, 0, 0, 0, 0, 0, 0, 0], + ) # yapf: disable if path is None: continue - add_1, reshape_1, matmul_1, reshape_2, gelu, add_2, reshape_3, matmul_2, reshape_4, skiplayernorm = path + ( + add_1, + reshape_1, + matmul_1, + reshape_2, + gelu, + add_2, + reshape_3, + matmul_2, + reshape_4, + skiplayernorm, + ) = path matmul_2.input[0] = skiplayernorm.output[0] self.remove_node(reshape_4) diff --git a/onnxruntime/python/tools/transformers/onnx_model_bert_tf.py b/onnxruntime/python/tools/transformers/onnx_model_bert_tf.py index 580b106223839..7455777273846 100644 --- a/onnxruntime/python/tools/transformers/onnx_model_bert_tf.py +++ b/onnxruntime/python/tools/transformers/onnx_model_bert_tf.py @@ -1,15 +1,16 @@ -#------------------------------------------------------------------------- +# ------------------------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. -#-------------------------------------------------------------------------- +# -------------------------------------------------------------------------- +import argparse import logging -import onnx import sys -import argparse -import numpy as np from collections import deque -from onnx import ModelProto, TensorProto, numpy_helper, helper + +import numpy as np +import onnx +from onnx import ModelProto, TensorProto, helper, numpy_helper from onnx_model_bert import BertOnnxModel logger = logging.getLogger(__name__) @@ -22,7 +23,7 @@ def __init__(self, model, num_heads, hidden_size): def remove_identity(self): nodes_to_remove = [] for node in self.nodes(): - if node.op_type == 'Identity': + if node.op_type == "Identity": if not self.find_graph_output(node.output[0]): self.replace_input_of_all_nodes(node.output[0], node.input[0]) nodes_to_remove.append(node) @@ -30,18 +31,27 @@ def remove_identity(self): logger.info(f"Removed Identity count: {len(nodes_to_remove)}") def match_mask_path(self, add_or_sub_before_softmax): - mask_nodes = self.match_parent_path(add_or_sub_before_softmax, ['Mul', 'Sub', 'Reshape', 'Cast'], - [1, None, 1, 0]) + mask_nodes = self.match_parent_path( + add_or_sub_before_softmax, + ["Mul", "Sub", "Reshape", "Cast"], + [1, None, 1, 0], + ) if mask_nodes is not None: return mask_nodes - mask_nodes = self.match_parent_path(add_or_sub_before_softmax, ['Mul', 'Sub', 'Cast', 'Slice', 'Unsqueeze'], - [1, 0, 1, 0, 0]) + mask_nodes = self.match_parent_path( + add_or_sub_before_softmax, + ["Mul", "Sub", "Cast", "Slice", "Unsqueeze"], + [1, 0, 1, 0, 0], + ) if mask_nodes is not None: return mask_nodes - mask_nodes = self.match_parent_path(add_or_sub_before_softmax, ['Mul', 'Sub', 'Cast', 'Unsqueeze', 'Unsqueeze'], - [1, None, 1, 0, 0]) + mask_nodes = self.match_parent_path( + add_or_sub_before_softmax, + ["Mul", "Sub", "Cast", "Unsqueeze", "Unsqueeze"], + [1, None, 1, 0, 0], + ) return mask_nodes @@ -81,21 +91,47 @@ def find_segment_ids(self, segment_embedding, input_ids): # If the segment id candidate is the same as the input_ids, try to assign alternative segment ids and simplify the graph if needed. segment_ids = nodes[0].input[1] _, segment_id_path, _ = self.match_parent_paths( - nodes[0], [(["ConstantOfShape", "Cast", "Concat", "Slice", "Cast", "Shape"], [1, 0, 0, 0, 0, 0]), - (["ConstantOfShape", "Cast", "Concat", "Unsqueeze", "Squeeze", "Slice", "Cast", "Shape" - ], [1, 0, 0, 0, 0, 0, 0, 0])], None) + nodes[0], + [ + ( + ["ConstantOfShape", "Cast", "Concat", "Slice", "Cast", "Shape"], + [1, 0, 0, 0, 0, 0], + ), + ( + [ + "ConstantOfShape", + "Cast", + "Concat", + "Unsqueeze", + "Squeeze", + "Slice", + "Cast", + "Shape", + ], + [1, 0, 0, 0, 0, 0, 0, 0], + ), + ], + None, + ) if segment_id_path and input_ids and input_ids == segment_id_path[-1].input[0]: logger.debug("Simplify semgent id path...") constantofshape_node = segment_id_path[0] graph_name = self.get_graph_by_node(constantofshape_node).name - self.add_node(helper.make_node('Shape', inputs=[input_ids], outputs=["input_shape"]), graph_name) + self.add_node( + helper.make_node("Shape", inputs=[input_ids], outputs=["input_shape"]), + graph_name, + ) constantofshape_value = helper.get_attribute_value(constantofshape_node.attribute[0]) self.add_node( - helper.make_node('ConstantOfShape', - inputs=["input_shape"], - outputs=["zeros_for_input_shape"], - value=constantofshape_value), graph_name) + helper.make_node( + "ConstantOfShape", + inputs=["input_shape"], + outputs=["zeros_for_input_shape"], + value=constantofshape_value, + ), + graph_name, + ) segment_ids = "zeros_for_input_shape" return segment_ids @@ -117,12 +153,22 @@ def find_input_ids(self, word_embedding): def find_mask_input(self, excluded_graph_inputs): for node in self.nodes(): - if node.op_type == 'Softmax': - mask_path = self.match_parent_path(node, ['Add', 'Mul', 'Sub', 'Cast', 'Slice', 'Unsqueeze'], - [0, 1, None, 1, 0, 0]) + if node.op_type == "Softmax": + mask_path = self.match_parent_path( + node, + ["Add", "Mul", "Sub", "Cast", "Slice", "Unsqueeze"], + [0, 1, None, 1, 0, 0], + ) if mask_path is None: continue - add_node, mul_node, sub_node, cast_node, slice_node, unsqueeze_node = mask_path + ( + add_node, + mul_node, + sub_node, + cast_node, + slice_node, + unsqueeze_node, + ) = mask_path if self.has_constant_input(mul_node, -10000) and self.has_constant_input(sub_node, 1): graph_inputs = self.get_graph_inputs(sub_node, recursive=True) inputs = [input for input in graph_inputs if input not in excluded_graph_inputs] @@ -134,24 +180,47 @@ def find_mask_input(self, excluded_graph_inputs): # Duplicated input found. Try to simplify the graph. path_to_be_simplified = self.match_parent_path( mask_path[-1], - ["ConstantOfShape", "Cast", "Concat", "Unsqueeze", "Squeeze", "Slice", "Cast", "Shape"], - [0, 0, 0, 0, 0, 0, 0, 0]) + [ + "ConstantOfShape", + "Cast", + "Concat", + "Unsqueeze", + "Squeeze", + "Slice", + "Cast", + "Shape", + ], + [0, 0, 0, 0, 0, 0, 0, 0], + ) duplicated_inputs = [input for input in graph_inputs if input in excluded_graph_inputs] # Simplify graph for dynamic axes. - if path_to_be_simplified and duplicated_inputs and len( - duplicated_inputs) == 1 and duplicated_inputs[0] == path_to_be_simplified[-1].input[0]: + if ( + path_to_be_simplified + and duplicated_inputs + and len(duplicated_inputs) == 1 + and duplicated_inputs[0] == path_to_be_simplified[-1].input[0] + ): logger.debug("Simplify semgent id path...") constantofshape_node = path_to_be_simplified[0] constantofshape_value = helper.get_attribute_value(constantofshape_node.attribute[0]) graph_name = self.get_graph_by_node(constantofshape_node).name self.add_node( - helper.make_node('Shape', inputs=[duplicated_inputs[0]], outputs=["input_shape_for_mask"]), - graph_name) + helper.make_node( + "Shape", + inputs=[duplicated_inputs[0]], + outputs=["input_shape_for_mask"], + ), + graph_name, + ) self.add_node( - helper.make_node('ConstantOfShape', - inputs=["input_shape_for_mask"], - outputs=[unsqueeze_node.input[0]], - value=constantofshape_value), graph_name) + helper.make_node( + "ConstantOfShape", + inputs=["input_shape_for_mask"], + outputs=[unsqueeze_node.input[0]], + value=constantofshape_value, + ), + graph_name, + ) return unsqueeze_node.input[0] return None @@ -173,7 +242,7 @@ def create_embedding_subgraph(self, normalize_node, word_embedding, segment_embe self.bert_inputs = [input_ids, segment_ids, mask_input] - mask_index = self.create_node_name('mask_index') + mask_index = self.create_node_name("mask_index") self.attention_mask.set_mask_indice(mask_input, mask_index) if self.find_graph_input(input_ids).type.tensor_type.elem_type != TensorProto.INT32: @@ -189,9 +258,9 @@ def create_embedding_subgraph(self, normalize_node, word_embedding, segment_embe else: mask_input, mask_input_cast_node = self.utils.cast_input_to_int32(mask_input) - embed_output = self.create_node_name('embed_output') + embed_output = self.create_node_name("embed_output") embed_node = onnx.helper.make_node( - 'EmbedLayerNormalization', + "EmbedLayerNormalization", inputs=[ input_ids, segment_ids, @@ -200,10 +269,11 @@ def create_embedding_subgraph(self, normalize_node, word_embedding, segment_embe segment_embedding, normalize_node.input[1], # gamma normalize_node.input[2], # beta - mask_input + mask_input, ], outputs=[embed_output, mask_index], - name="EmbedLayer") + name="EmbedLayer", + ) embed_node.domain = "com.microsoft" self.replace_input_of_all_nodes(normalize_node.output[0], embed_output) self.add_node(embed_node, self.get_graph_by_node(normalize_node).name) @@ -217,8 +287,12 @@ def process_embedding(self): layer_norm_nodes = self.get_nodes_by_op_type("LayerNormalization") for layer_norm_node in layer_norm_nodes: - pos_embed_path = self.match_parent_path(layer_norm_node, ['Add', 'Reshape', 'Slice'], [0, 1, 0], - output_name_to_node) + pos_embed_path = self.match_parent_path( + layer_norm_node, + ["Add", "Reshape", "Slice"], + [0, 1, 0], + output_name_to_node, + ) if pos_embed_path is None: continue @@ -240,7 +314,8 @@ def process_embedding(self): embeddings = self.get_2d_initializers_from_parent_subgraphs(first_parent) if len(embeddings) != 2: logger.warning( - "Failed to find two embeddings (word and segment) from Add node. Found {}".format(embeddings)) + "Failed to find two embeddings (word and segment) from Add node. Found {}".format(embeddings) + ) return word_embedding = None @@ -258,7 +333,12 @@ def process_embedding(self): return logger.info("Create Embedding node") - self.create_embedding_subgraph(layer_norm_node, word_embedding, segment_embedding, position_embedding) + self.create_embedding_subgraph( + layer_norm_node, + word_embedding, + segment_embedding, + position_embedding, + ) # Prune graph to remove those original embedding nodes. self.prune_graph() break @@ -291,51 +371,65 @@ def fuse_attention(self): for normalize_node in start_nodes: graph_name = self.get_graph_by_node(normalize_node).name # SkipLayerNormalization has two inputs, and one of them is the root input for attention. - if normalize_node.op_type == 'LayerNormalization': - add_before_layernorm = self.match_parent(normalize_node, 'Add', 0) + if normalize_node.op_type == "LayerNormalization": + add_before_layernorm = self.match_parent(normalize_node, "Add", 0) if add_before_layernorm is not None: normalize_node = add_before_layernorm else: continue parent = self.get_parent(normalize_node, 1) - if parent is None or parent.op_type not in ["SkipLayerNormalization", "LayerNormalization", "Reshape"]: + if parent is None or parent.op_type not in [ + "SkipLayerNormalization", + "LayerNormalization", + "Reshape", + ]: parent = self.get_parent(normalize_node, 0) - if parent is None or parent.op_type not in ["SkipLayerNormalization", "LayerNormalization", "Reshape"]: + if parent is None or parent.op_type not in [ + "SkipLayerNormalization", + "LayerNormalization", + "Reshape", + ]: logger.debug("Failed to match parent of normalize_node") continue - qkv_nodes = self.match_parent_path(normalize_node, ['Add', 'MatMul', 'Reshape', 'Transpose', 'MatMul'], - [0, 0, 0, 0, 0]) + qkv_nodes = self.match_parent_path( + normalize_node, + ["Add", "MatMul", "Reshape", "Transpose", "MatMul"], + [0, 0, 0, 0, 0], + ) if qkv_nodes is None: - qkv_nodes = self.match_parent_path(normalize_node, ['MatMul', 'Reshape', 'Transpose', 'MatMul'], - [1, 0, 0, 0]) + qkv_nodes = self.match_parent_path( + normalize_node, + ["MatMul", "Reshape", "Transpose", "MatMul"], + [1, 0, 0, 0], + ) if qkv_nodes is None: - qkv_nodes = self.match_parent_path(normalize_node, ['Add', 'Einsum', 'Einsum'], [0, 0, 0]) + qkv_nodes = self.match_parent_path(normalize_node, ["Add", "Einsum", "Einsum"], [0, 0, 0]) if qkv_nodes is None: logger.debug("Failed to match qkv nodes") continue matmul_qkv = qkv_nodes[-1] - v_nodes = self.match_parent_path(matmul_qkv, ['Transpose', 'Reshape', 'Add', 'MatMul'], [1, 0, 0, 0]) + v_nodes = self.match_parent_path(matmul_qkv, ["Transpose", "Reshape", "Add", "MatMul"], [1, 0, 0, 0]) if v_nodes is None: - v_nodes = self.match_parent_path(matmul_qkv, ['Add', 'Einsum'], [1, 0]) + v_nodes = self.match_parent_path(matmul_qkv, ["Add", "Einsum"], [1, 0]) if v_nodes is None: logger.debug("Failed to match v path") continue add_v = v_nodes[-2] matmul_v = v_nodes[-1] - qk_nodes = self.match_parent_path(matmul_qkv, ['Softmax', 'Add', "Mul", 'MatMul'], [0, 0, 0, 0]) + qk_nodes = self.match_parent_path(matmul_qkv, ["Softmax", "Add", "Mul", "MatMul"], [0, 0, 0, 0]) if qk_nodes is None: - qk_nodes = self.match_parent_path(matmul_qkv, ['Softmax', 'Add', 'Einsum'], [0, 0, 0]) + qk_nodes = self.match_parent_path(matmul_qkv, ["Softmax", "Add", "Einsum"], [0, 0, 0]) if qk_nodes is None: logger.debug("Failed to match qk_paths") continue matmul_qk = qk_nodes[-1] - q_nodes = self.match_parent_path(matmul_qk, ['Transpose', 'Reshape', 'Add', 'MatMul'], [0, 0, 0, 0]) + q_nodes = self.match_parent_path(matmul_qk, ["Transpose", "Reshape", "Add", "MatMul"], [0, 0, 0, 0]) if q_nodes is None: - q_nodes = self.match_parent_path(matmul_qk, ['Add', 'Einsum'], [0, 0]) + q_nodes = self.match_parent_path(matmul_qk, ["Add", "Einsum"], [0, 0]) if q_nodes is None: logger.debug("Failed to match q path") continue @@ -343,9 +437,9 @@ def fuse_attention(self): add_q = q_nodes[-2] matmul_q = q_nodes[-1] - k_nodes = self.match_parent_path(matmul_qk, ['Transpose', 'Reshape', 'Add', 'MatMul'], [1, 0, 0, 0]) + k_nodes = self.match_parent_path(matmul_qk, ["Transpose", "Reshape", "Add", "MatMul"], [1, 0, 0, 0]) if k_nodes is None: - k_nodes = self.match_parent_path(matmul_qk, ['Mul', 'Add', 'Einsum'], [1, 0, 0]) + k_nodes = self.match_parent_path(matmul_qk, ["Mul", "Add", "Einsum"], [1, 0, 0]) if k_nodes is None: logger.debug("Failed to match k path") continue @@ -363,15 +457,23 @@ def fuse_attention(self): continue # add a squeeze node to convert a 3-d mask to 2-d - squeeze_node = self.match_parent_path(mask_nodes[-1], ['Squeeze'], [0]) or self.match_parent_path( - mask_nodes[-1], ['Expand'], [0]) + squeeze_node = self.match_parent_path(mask_nodes[-1], ["Squeeze"], [0]) or self.match_parent_path( + mask_nodes[-1], ["Expand"], [0] + ) squeeze_node_name = "Squeeze_3d_to_2d_mask" squeeze_output_name = squeeze_node_name + "_output" if squeeze_node is None and len(mask_nodes) == 5 and self.find_graph_input(mask_nodes[-1].input[0]) is None: mask_input = mask_nodes[-1].input[1] self.add_node( - helper.make_node("Squeeze", [mask_input], [squeeze_output_name], squeeze_node_name, axes=[1]), - graph_name) + helper.make_node( + "Squeeze", + [mask_input], + [squeeze_output_name], + squeeze_node_name, + axes=[1], + ), + graph_name, + ) mask_nodes[-1].input[0] = squeeze_output_name is_same_root = self.check_attention_input(matmul_q, matmul_k, matmul_v, parent, output_name_to_node) @@ -380,37 +482,63 @@ def fuse_attention(self): logger.debug("Create an Attention node.") # For tf models, q and v are flipped. - attention_node = self.attention_fusion.create_attention_node(mask_index, matmul_k, matmul_q, matmul_v, - add_k, add_q, add_v, self.num_heads, - self.hidden_size, parent.output[0], - qkv_nodes[2].output[0], None) + attention_node = self.attention_fusion.create_attention_node( + mask_index, + matmul_k, + matmul_q, + matmul_v, + add_k, + add_q, + add_v, + self.num_heads, + self.hidden_size, + parent.output[0], + qkv_nodes[2].output[0], + None, + ) if attention_node is None: continue - if qkv_nodes[1].op_type == 'Einsum': + if qkv_nodes[1].op_type == "Einsum": # add reshape before einsum - tensor = helper.make_tensor(name=qkv_nodes[1].name + "_newshape", - data_type=TensorProto.INT64, - dims=[4], - vals=np.int64( - [[0, 0, self.num_heads, - int(self.hidden_size / self.num_heads)]]).tobytes(), - raw=True) + tensor = helper.make_tensor( + name=qkv_nodes[1].name + "_newshape", + data_type=TensorProto.INT64, + dims=[4], + vals=np.int64( + [ + [ + 0, + 0, + self.num_heads, + int(self.hidden_size / self.num_heads), + ] + ] + ).tobytes(), + raw=True, + ) self.add_initializer(tensor, graph_name) - reshape_ = helper.make_node("Reshape", - inputs=[attention_node.output[0], qkv_nodes[1].name + "_newshape"], - outputs=[qkv_nodes[1].name + "_reshape_output"], - name=qkv_nodes[1].name + "_reshape") + reshape_ = helper.make_node( + "Reshape", + inputs=[ + attention_node.output[0], + qkv_nodes[1].name + "_newshape", + ], + outputs=[qkv_nodes[1].name + "_reshape_output"], + name=qkv_nodes[1].name + "_reshape", + ) qkv_nodes[1].input[0] = qkv_nodes[1].name + "_reshape_output" self.add_node(reshape_, graph_name) - if parent.op_type == 'Reshape': + if parent.op_type == "Reshape": # Temporary work around: we require the skiplayernorm and attention op be fed with 3-d input hidden_size = numpy_helper.to_array(self.get_initializer(parent.input[1]))[1] - tensor = helper.make_tensor(name=parent.name + "_modified", - data_type=TensorProto.INT64, - dims=[3], - vals=np.int64([[1, -1, hidden_size]]).tobytes(), - raw=True) + tensor = helper.make_tensor( + name=parent.name + "_modified", + data_type=TensorProto.INT64, + dims=[3], + vals=np.int64([[1, -1, hidden_size]]).tobytes(), + raw=True, + ) self.add_initializer(tensor, graph_name) parent.input[1] = parent.name + "_modified" @@ -450,7 +578,7 @@ def skip_reshape(self): def remove_reshape_before_first_attention(self): attention_nodes = self.get_nodes_by_op_type("Attention") for attention_node in attention_nodes: - path = self.match_parent_path(attention_node, ['Reshape', 'EmbedLayerNormalization'], [0, 0]) + path = self.match_parent_path(attention_node, ["Reshape", "EmbedLayerNormalization"], [0, 0]) if path is None: continue logger.info("Remove Reshape before first Attention node.") diff --git a/onnxruntime/python/tools/transformers/onnx_model_gpt2.py b/onnxruntime/python/tools/transformers/onnx_model_gpt2.py index 3c71d74f6a70a..4f922820bbfa9 100644 --- a/onnxruntime/python/tools/transformers/onnx_model_gpt2.py +++ b/onnxruntime/python/tools/transformers/onnx_model_gpt2.py @@ -1,13 +1,14 @@ -#------------------------------------------------------------------------- +# ------------------------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. -#-------------------------------------------------------------------------- +# -------------------------------------------------------------------------- import logging + import onnx -from onnx_model_bert import BertOnnxModel -from fusion_gpt_attention_no_past import FusionGptAttentionNoPast from fusion_gpt_attention import FusionGptAttention from fusion_gpt_attention_megatron import FusionGptAttentionMegatron +from fusion_gpt_attention_no_past import FusionGptAttentionNoPast +from onnx_model_bert import BertOnnxModel logger = logging.getLogger(__name__) @@ -37,31 +38,38 @@ def postprocess(self): reshape_count = 0 for gemm_node in self.get_nodes_by_op_type("Gemm"): - reshape_after_gemm = self.find_first_child_by_type(gemm_node, - 'Reshape', - input_name_to_nodes, - recursive=False) + reshape_after_gemm = self.find_first_child_by_type( + gemm_node, "Reshape", input_name_to_nodes, recursive=False + ) return_indice = [] - nodes = self.match_parent_path(gemm_node, ['Reshape', 'FastGelu'], [0, 0], output_name_to_node) + nodes = self.match_parent_path(gemm_node, ["Reshape", "FastGelu"], [0, 0], output_name_to_node) if nodes is None: - nodes = self.match_parent_path(gemm_node, ['Reshape', 'LayerNormalization'], [0, 0], - output_name_to_node) + nodes = self.match_parent_path( + gemm_node, + ["Reshape", "LayerNormalization"], + [0, 0], + output_name_to_node, + ) if nodes is None: continue (reshape_before_gemm, root_node) = nodes - matmul_node_name = self.create_node_name('MatMul', 'FullyConnect_MatMul') - matmul_node = onnx.helper.make_node('MatMul', - inputs=[matmul_node_name + "_input", gemm_node.input[1]], - outputs=[matmul_node_name + "_output"], - name=matmul_node_name) - - add_node_name = self.create_node_name('Add', 'FullyConnect_Add') - add_node = onnx.helper.make_node('Add', - inputs=[matmul_node_name + "_output", gemm_node.input[2]], - outputs=[add_node_name + "_output"], - name=add_node_name) + matmul_node_name = self.create_node_name("MatMul", "FullyConnect_MatMul") + matmul_node = onnx.helper.make_node( + "MatMul", + inputs=[matmul_node_name + "_input", gemm_node.input[1]], + outputs=[matmul_node_name + "_output"], + name=matmul_node_name, + ) + + add_node_name = self.create_node_name("Add", "FullyConnect_Add") + add_node = onnx.helper.make_node( + "Add", + inputs=[matmul_node_name + "_output", gemm_node.input[2]], + outputs=[add_node_name + "_output"], + name=add_node_name, + ) self.replace_input_of_all_nodes(reshape_after_gemm.output[0], add_node_name + "_output") diff --git a/onnxruntime/python/tools/transformers/onnx_model_tnlr.py b/onnxruntime/python/tools/transformers/onnx_model_tnlr.py index c99817c410c3e..dc8f6810914a7 100644 --- a/onnxruntime/python/tools/transformers/onnx_model_tnlr.py +++ b/onnxruntime/python/tools/transformers/onnx_model_tnlr.py @@ -1,14 +1,15 @@ -#------------------------------------------------------------------------- +# ------------------------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. -#-------------------------------------------------------------------------- +# -------------------------------------------------------------------------- import logging -from fusion_attention import FusionAttention, AttentionMask +from typing import Union + +from fusion_attention import AttentionMask, FusionAttention from fusion_utils import NumpyHelper -from onnx import helper, numpy_helper, TensorProto, NodeProto +from onnx import NodeProto, TensorProto, helper, numpy_helper from onnx_model import OnnxModel from onnx_model_bert import BertOnnxModel -from typing import Union logger = logging.getLogger(__name__) @@ -18,11 +19,27 @@ class FusionTnlrAttention(FusionAttention): Fuse TNLR Attention subgraph into one Attention node. TNLR Attention has extra addtion after qk nodes and adopts [S, B, NH] as I/O shape. """ - def __init__(self, model: OnnxModel, hidden_size: int, num_heads: int, attention_mask: AttentionMask): + + def __init__( + self, + model: OnnxModel, + hidden_size: int, + num_heads: int, + attention_mask: AttentionMask, + ): super().__init__(model, hidden_size, num_heads, attention_mask) - def create_attention_node(self, mask_index: str, matmul: NodeProto, add: NodeProto, num_heads: int, - hidden_size: int, input: str, output: str, add_qk_str: str) -> Union[NodeProto, None]: + def create_attention_node( + self, + mask_index: str, + matmul: NodeProto, + add: NodeProto, + num_heads: int, + hidden_size: int, + input: str, + output: str, + add_qk_str: str, + ) -> Union[NodeProto, None]: assert num_heads > 0 if hidden_size > 0 and (hidden_size % num_heads) != 0: @@ -38,27 +55,35 @@ def create_attention_node(self, mask_index: str, matmul: NodeProto, add: NodePro qkv_weight = NumpyHelper.to_array(weight) qkv_bias = NumpyHelper.to_array(bias) - attention_node_name = self.model.create_node_name('Attention') + attention_node_name = self.model.create_node_name("Attention") - weight = helper.make_tensor(name=attention_node_name + '_qkv_weight', - data_type=TensorProto.FLOAT, - dims=[hidden_size, 3 * hidden_size], - vals=qkv_weight.flatten().tolist()) + weight = helper.make_tensor( + name=attention_node_name + "_qkv_weight", + data_type=TensorProto.FLOAT, + dims=[hidden_size, 3 * hidden_size], + vals=qkv_weight.flatten().tolist(), + ) # Sometimes weights and bias are stored in fp16 if weight.data_type == 10: weight.CopyFrom(numpy_helper.from_array(NumpyHelper.to_array(weight).astype(np.float16), weight.name)) self.model.add_initializer(weight, self.this_graph_name) - bias = helper.make_tensor(name=attention_node_name + '_qkv_bias', - data_type=TensorProto.FLOAT, - dims=[3 * hidden_size], - vals=qkv_bias.flatten().tolist()) + bias = helper.make_tensor( + name=attention_node_name + "_qkv_bias", + data_type=TensorProto.FLOAT, + dims=[3 * hidden_size], + vals=qkv_bias.flatten().tolist(), + ) if bias.data_type == 10: bias.CopyFrom(numpy_helper.from_array(NumpyHelper.to_array(bias).astype(np.float16), bias.name)) self.model.add_initializer(bias, self.this_graph_name) - attention_inputs = [input, attention_node_name + '_qkv_weight', attention_node_name + '_qkv_bias'] + attention_inputs = [ + input, + attention_node_name + "_qkv_weight", + attention_node_name + "_qkv_bias", + ] if mask_index is not None: attention_inputs.append(mask_index) else: @@ -68,10 +93,12 @@ def create_attention_node(self, mask_index: str, matmul: NodeProto, add: NodePro attention_inputs.append("") attention_inputs.append(add_qk_str) - attention_node = helper.make_node('Attention', - inputs=attention_inputs, - outputs=[output], - name=attention_node_name) + attention_node = helper.make_node( + "Attention", + inputs=attention_inputs, + outputs=[output], + name=attention_node_name, + ) attention_node.domain = "com.microsoft" attention_node.attribute.extend([helper.make_attribute("num_heads", num_heads)]) @@ -81,13 +108,15 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): # Sometimes we can not fuse skiplayernormalization since the add before layernorm has an output that used by nodes outside skiplayernorm # Conceptually we treat add before layernorm as skiplayernorm node since they share the same pattern start_node = normalize_node - if normalize_node.op_type != 'SkipLayerNormalization': + if normalize_node.op_type != "SkipLayerNormalization": return # SkipLayerNormalization has two inputs, and one of them is the root input for attention. - qkv_nodes = self.model.match_parent_path(start_node, - ['Where', 'Add', 'MatMul', 'Reshape', 'Transpose', 'MatMul'], - [1, 1, 1, 0, 0, 0]) + qkv_nodes = self.model.match_parent_path( + start_node, + ["Where", "Add", "MatMul", "Reshape", "Transpose", "MatMul"], + [1, 1, 1, 0, 0, 0], + ) if qkv_nodes is not None: (_, _, matmul_below, reshape_qkv, transpose_qkv, matmul_qkv) = qkv_nodes else: @@ -106,35 +135,44 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): root_input = other_inputs[0] - v_nodes = self.model.match_parent_path(matmul_qkv, ['Transpose', 'Reshape', 'Slice', 'Add', 'MatMul'], - [1, 0, 0, 0, 1]) + v_nodes = self.model.match_parent_path( + matmul_qkv, + ["Transpose", "Reshape", "Slice", "Add", "MatMul"], + [1, 0, 0, 0, 1], + ) if v_nodes is None: return (_, _, _, add, matmul) = v_nodes - upper_nodes = self.model.match_parent_path(matmul, ['Transpose'], [0]) + upper_nodes = self.model.match_parent_path(matmul, ["Transpose"], [0]) transpose = upper_nodes[0] - qk_nodes = self.model.match_parent_path(matmul_qkv, ['Softmax', 'Add', 'MatMul'], [0, 0, 0]) + qk_nodes = self.model.match_parent_path(matmul_qkv, ["Softmax", "Add", "MatMul"], [0, 0, 0]) if qk_nodes is None: return (_, add_qk, matmul_qk) = qk_nodes - q_nodes = self.model.match_parent_path(matmul_qk, ['Mul', 'Transpose', 'Reshape', 'Slice', 'Add', 'MatMul'], - [0, 0, 0, 0, 0, 1]) + q_nodes = self.model.match_parent_path( + matmul_qk, + ["Mul", "Transpose", "Reshape", "Slice", "Add", "MatMul"], + [0, 0, 0, 0, 0, 1], + ) if q_nodes is None: return add = q_nodes[-2] matmul = q_nodes[-1] - k_nodes = self.model.match_parent_path(matmul_qk, ['Transpose', 'Reshape', 'Slice', 'Add', 'MatMul'], - [1, 0, 0, 0, 1]) + k_nodes = self.model.match_parent_path( + matmul_qk, + ["Transpose", "Reshape", "Slice", "Add", "MatMul"], + [1, 0, 0, 0, 1], + ) if k_nodes is None: return add = k_nodes[-2] matmul = k_nodes[-1] - extra_add_qk_nodes = self.model.match_parent_path(add_qk, ['Reshape', 'Where'], [1, 0]) + extra_add_qk_nodes = self.model.match_parent_path(add_qk, ["Reshape", "Where"], [1, 0]) if extra_add_qk_nodes is None: return @@ -143,8 +181,16 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): attention_last_node = reshape_qkv # number of heads are same for all the paths, hence to create attention node, we pass the q_num_heads # the input_hidden_size represents the input hidden size, this is used as needed but hidden sizes for Q, K are extracted appropriately - new_node = self.create_attention_node(mask_index, matmul, add, self.num_heads, self.hidden_size, root_input, - attention_last_node.output[0], extra_add_qk_nodes[0].input[0]) + new_node = self.create_attention_node( + mask_index, + matmul, + add, + self.num_heads, + self.hidden_size, + root_input, + attention_last_node.output[0], + extra_add_qk_nodes[0].input[0], + ) if new_node is None: return @@ -152,9 +198,13 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): self.node_name_to_graph_name[new_node.name] = self.this_graph_name # Add a transpose node after the attention node - back_transpose = helper.make_node("Transpose", ["back_transpose_in_" + new_node.name], [new_node.output[0]], - "back_transpose_" + new_node.name, - perm=[1, 0, 2]) + back_transpose = helper.make_node( + "Transpose", + ["back_transpose_in_" + new_node.name], + [new_node.output[0]], + "back_transpose_" + new_node.name, + perm=[1, 0, 2], + ) self.model.add_node(back_transpose, self.this_graph_name) new_node.input[0] = transpose.input[0] new_node.output[0] = "back_transpose_in_" + new_node.name @@ -166,7 +216,7 @@ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): self.nodes_to_remove.extend(v_nodes) # Use prune graph to remove mask nodes since they are shared by all attention nodes. - #self.nodes_to_remove.extend(mask_nodes) + # self.nodes_to_remove.extend(mask_nodes) self.prune_graph = True diff --git a/onnxruntime/python/tools/transformers/optimizer.py b/onnxruntime/python/tools/transformers/optimizer.py index d865381e0055f..1db539cae8e0c 100644 --- a/onnxruntime/python/tools/transformers/optimizer.py +++ b/onnxruntime/python/tools/transformers/optimizer.py @@ -1,7 +1,7 @@ -#------------------------------------------------------------------------- +# ------------------------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. -#-------------------------------------------------------------------------- +# -------------------------------------------------------------------------- # Convert Bert ONNX model converted from TensorFlow or exported from PyTorch to use Attention, Gelu, # SkipLayerNormalization and EmbedLayerNormalization ops to optimize @@ -17,19 +17,20 @@ # (2) Change input data type from int64 to int32. # (3) Some model cannot be handled by OnnxRuntime, and you can modify this script to get optimized model. +import argparse import logging -import coloredlogs import os -import argparse from typing import Dict, Optional -from onnx import load_model, ModelProto + +import coloredlogs +from fusion_options import FusionOptions +from onnx import ModelProto, load_model from onnx_model_bart import BartOnnxModel from onnx_model_bert import BertOnnxModel -from onnx_model_bert_tf import BertOnnxModelTF from onnx_model_bert_keras import BertOnnxModelKeras +from onnx_model_bert_tf import BertOnnxModelTF from onnx_model_gpt2 import Gpt2OnnxModel from onnx_model_tnlr import TnlrOnnxModel -from fusion_options import FusionOptions logger = logging.getLogger(__name__) @@ -40,16 +41,22 @@ "bert_tf": (BertOnnxModelTF, "tf2onnx", 0), "bert_keras": (BertOnnxModelKeras, "keras2onnx", 0), "gpt2": (Gpt2OnnxModel, "pytorch", 1), - "gpt2_tf": (Gpt2OnnxModel, 'tf2onnx', 0), # might add a class for GPT2OnnxModel for TF later. + "gpt2_tf": ( + Gpt2OnnxModel, + "tf2onnx", + 0, + ), # might add a class for GPT2OnnxModel for TF later. "tnlr": (TnlrOnnxModel, "pytorch", 1), } -def optimize_by_onnxruntime(onnx_model_path: str, - use_gpu: bool = False, - optimized_model_path: Optional[str] = None, - opt_level: Optional[int] = 99, - disabled_optimizers=[]) -> str: +def optimize_by_onnxruntime( + onnx_model_path: str, + use_gpu: bool = False, + optimized_model_path: Optional[str] = None, + opt_level: Optional[int] = 99, + disabled_optimizers=[], +) -> str: """ Use onnxruntime to optimize model. @@ -65,7 +72,7 @@ def optimize_by_onnxruntime(onnx_model_path: str, assert opt_level in [1, 2, 99] import onnxruntime - if use_gpu and 'CUDAExecutionProvider' not in onnxruntime.get_available_providers(): + if use_gpu and "CUDAExecutionProvider" not in onnxruntime.get_available_providers(): logger.error("There is no gpu for onnxruntime to do optimization.") return onnx_model_path @@ -78,7 +85,7 @@ def optimize_by_onnxruntime(onnx_model_path: str, sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL if optimized_model_path is None: - path_prefix = onnx_model_path[:-5] #remove .onnx suffix + path_prefix = onnx_model_path[:-5] # remove .onnx suffix optimized_model_path = "{}_o{}_{}.onnx".format(path_prefix, opt_level, "gpu" if use_gpu else "cpu") sess_options.optimized_model_filepath = optimized_model_path @@ -88,27 +95,28 @@ def optimize_by_onnxruntime(onnx_model_path: str, kwargs["disabled_optimizers"] = disabled_optimizers if not use_gpu: - session = onnxruntime.InferenceSession(onnx_model_path, - sess_options, - providers=['CPUExecutionProvider'], - **kwargs) + session = onnxruntime.InferenceSession( + onnx_model_path, sess_options, providers=["CPUExecutionProvider"], **kwargs + ) else: - session = onnxruntime.InferenceSession(onnx_model_path, sess_options, - providers=['CUDAExecutionProvider'], - **kwargs) - assert 'CUDAExecutionProvider' in session.get_providers() # Make sure there is GPU + session = onnxruntime.InferenceSession( + onnx_model_path, sess_options, providers=["CUDAExecutionProvider"], **kwargs + ) + assert "CUDAExecutionProvider" in session.get_providers() # Make sure there is GPU assert os.path.exists(optimized_model_path) and os.path.isfile(optimized_model_path) logger.debug("Save optimized model by onnxruntime to {}".format(optimized_model_path)) return optimized_model_path -def optimize_by_fusion(model: ModelProto, - model_type: str = 'bert', - num_heads: int = 0, - hidden_size: int = 0, - optimization_options: Optional[FusionOptions] = None): - """ Optimize Model by graph fusion logic. +def optimize_by_fusion( + model: ModelProto, + model_type: str = "bert", + num_heads: int = 0, + hidden_size: int = 0, + optimization_options: Optional[FusionOptions] = None, +): + """Optimize Model by graph fusion logic. Note that ONNXRuntime graph optimizations (like constant folding) will not be applied. So it is better to enable constant folding during exporting ONNX model, or run optimize_by_onnxruntime on the model first like optimize_model. @@ -148,20 +156,23 @@ def optimize_by_fusion(model: ModelProto, optimizer.model.producer_name = "onnxruntime.transformers" from onnxruntime import __version__ as onnxruntime_version + optimizer.model.producer_version = onnxruntime_version return optimizer -def optimize_model(input: str, - model_type: str = 'bert', - num_heads: int = 0, - hidden_size: int = 0, - optimization_options: Optional[FusionOptions] = None, - opt_level: int = None, - use_gpu: bool = False, - only_onnxruntime: bool = False): - """ Optimize Model by OnnxRuntime and/or python fusion logic. +def optimize_model( + input: str, + model_type: str = "bert", + num_heads: int = 0, + hidden_size: int = 0, + optimization_options: Optional[FusionOptions] = None, + opt_level: int = None, + use_gpu: bool = False, + only_onnxruntime: bool = False, +): + """Optimize Model by OnnxRuntime and/or python fusion logic. ONNX Runtime has graph optimizations (https://onnxruntime.ai/docs/resources/graph-optimizations.html). However, the coverage is limited. We also have graph fusions that implemented in Python to improve the coverage. @@ -210,14 +221,22 @@ def optimize_model(input: str, temp_model_path = None if opt_level > 1: # Disable some optimizers that might cause failure in symbolic shape inference or attention fusion. - disabled_optimizers = [] if only_onnxruntime else [ - 'MatMulScaleFusion', 'MatMulAddFusion' - 'SimplifiedLayerNormFusion', 'GemmActivationFusion', 'BiasSoftmaxFusion' - ] - temp_model_path = optimize_by_onnxruntime(input, - use_gpu=use_gpu, - opt_level=opt_level, - disabled_optimizers=disabled_optimizers) + disabled_optimizers = ( + [] + if only_onnxruntime + else [ + "MatMulScaleFusion", + "MatMulAddFusion" "SimplifiedLayerNormFusion", + "GemmActivationFusion", + "BiasSoftmaxFusion", + ] + ) + temp_model_path = optimize_by_onnxruntime( + input, + use_gpu=use_gpu, + opt_level=opt_level, + disabled_optimizers=disabled_optimizers, + ) elif opt_level == 1: # basic optimizations (like constant folding and cast elimation) are not specified to exection provider. # CPU provider is used here so that there is no extra node for GPU memory copy. @@ -258,88 +277,89 @@ def get_fusion_statistics(optimized_model_path: str) -> Dict[str, int]: def _parse_arguments(): parser = argparse.ArgumentParser( - description= - 'Graph optimization tool for ONNX Runtime. It transforms ONNX graph to use optimized operators for Transformer models.' + description="Graph optimization tool for ONNX Runtime. It transforms ONNX graph to use optimized operators for Transformer models." ) - parser.add_argument('--input', required=True, type=str, help="input onnx model path") + parser.add_argument("--input", required=True, type=str, help="input onnx model path") - parser.add_argument('--output', required=True, type=str, help="optimized onnx model path") + parser.add_argument("--output", required=True, type=str, help="optimized onnx model path") - parser.add_argument('--model_type', - required=False, - type=str.lower, - default="bert", - choices=list(MODEL_TYPES.keys()), - help="Model type selected in the list: " + ", ".join(MODEL_TYPES.keys())) + parser.add_argument( + "--model_type", + required=False, + type=str.lower, + default="bert", + choices=list(MODEL_TYPES.keys()), + help="Model type selected in the list: " + ", ".join(MODEL_TYPES.keys()), + ) parser.add_argument( - '--num_heads', + "--num_heads", required=False, type=int, default=0, - help= - "number of attention heads like 12 for bert-base and 16 for bert-large. Default is 0 to detect automatically for BERT. For other model type, this parameter need specify correctly." + help="number of attention heads like 12 for bert-base and 16 for bert-large. Default is 0 to detect automatically for BERT. For other model type, this parameter need specify correctly.", ) parser.add_argument( - '--hidden_size', + "--hidden_size", required=False, type=int, default=0, - help= - "hidden size like 768 for bert-base and 1024 for bert-large. Default is 0 to detect automatically for BERT. For other model type, this parameter need specify correctly." + help="hidden size like 768 for bert-base and 1024 for bert-large. Default is 0 to detect automatically for BERT. For other model type, this parameter need specify correctly.", ) parser.add_argument( - '--input_int32', + "--input_int32", required=False, - action='store_true', - help= - "Use int32 (instead of int64) inputs. It could avoid unnecessary data cast when EmbedLayerNormalization is fused for BERT." + action="store_true", + help="Use int32 (instead of int64) inputs. It could avoid unnecessary data cast when EmbedLayerNormalization is fused for BERT.", ) parser.set_defaults(input_int32=False) parser.add_argument( - '--float16', + "--float16", required=False, - action='store_true', - help= - "Convert all weights and nodes in float32 to float16. It has potential loss in precision compared to mixed precision conversion (see convert_float_to_float16)." + action="store_true", + help="Convert all weights and nodes in float32 to float16. It has potential loss in precision compared to mixed precision conversion (see convert_float_to_float16).", ) parser.set_defaults(float16=False) FusionOptions.add_arguments(parser) - parser.add_argument('--verbose', required=False, action='store_true', help="show debug information.") + parser.add_argument("--verbose", required=False, action="store_true", help="show debug information.") parser.set_defaults(verbose=False) parser.add_argument( - '--use_gpu', + "--use_gpu", required=False, - action='store_true', - help="Use GPU for inference. Set this flag if your model is intended for GPU when opt_level > 1.") + action="store_true", + help="Use GPU for inference. Set this flag if your model is intended for GPU when opt_level > 1.", + ) parser.set_defaults(use_gpu=False) - parser.add_argument('--only_onnxruntime', - required=False, - action='store_true', - help="optimized by onnxruntime only, and no graph fusion in Python") + parser.add_argument( + "--only_onnxruntime", + required=False, + action="store_true", + help="optimized by onnxruntime only, and no graph fusion in Python", + ) parser.set_defaults(only_onnxruntime=False) parser.add_argument( - '--opt_level', + "--opt_level", required=False, type=int, choices=[0, 1, 2, 99], default=None, - help= - "onnxruntime optimization level. 0 will disable onnxruntime graph optimization. The recommended value is 1. When opt_level > 1 is used, optimized model for GPU might not run in CPU. Level 2 and 99 are intended for --only_onnxruntime." + help="onnxruntime optimization level. 0 will disable onnxruntime graph optimization. The recommended value is 1. When opt_level > 1 is used, optimized model for GPU might not run in CPU. Level 2 and 99 are intended for --only_onnxruntime.", ) - parser.add_argument('--use_external_data_format', - required=False, - action='store_true', - help="use external data format to store large model (>2GB)") + parser.add_argument( + "--use_external_data_format", + required=False, + action="store_true", + help="use external data format to store large model (>2GB)", + ) parser.set_defaults(use_external_data_format=False) args = parser.parse_args() @@ -349,9 +369,12 @@ def _parse_arguments(): def _setup_logger(verbose): if verbose: - coloredlogs.install(level='DEBUG', fmt='[%(filename)s:%(lineno)s - %(funcName)20s()] %(message)s') + coloredlogs.install( + level="DEBUG", + fmt="[%(filename)s:%(lineno)s - %(funcName)20s()] %(message)s", + ) else: - coloredlogs.install(fmt='%(funcName)20s: %(message)s') + coloredlogs.install(fmt="%(funcName)20s: %(message)s") def main(): @@ -366,14 +389,16 @@ def main(): optimization_options = FusionOptions.parse(args) - optimizer = optimize_model(args.input, - args.model_type, - args.num_heads, - args.hidden_size, - opt_level=args.opt_level, - optimization_options=optimization_options, - use_gpu=args.use_gpu, - only_onnxruntime=args.only_onnxruntime) + optimizer = optimize_model( + args.input, + args.model_type, + args.num_heads, + args.hidden_size, + opt_level=args.opt_level, + optimization_options=optimization_options, + use_gpu=args.use_gpu, + only_onnxruntime=args.only_onnxruntime, + ) if args.float16: optimizer.convert_float_to_float16(keep_io_types=True) diff --git a/onnxruntime/python/tools/transformers/profiler.py b/onnxruntime/python/tools/transformers/profiler.py index ddd00e94ebb7a..9f41654af3533 100644 --- a/onnxruntime/python/tools/transformers/profiler.py +++ b/onnxruntime/python/tools/transformers/profiler.py @@ -1,9 +1,11 @@ -import os import argparse import json -import psutil +import os + import numpy +import psutil from onnx import TensorProto + """ This profiler tool could run a transformer model and print out the kernel time spent on each Node of the model. Example of profiling of longformer model: @@ -12,100 +14,144 @@ python profiler.py --input profile_2021-10-25_12-02-41.json """ -NODES_TYPE_CONTAINING_SUBGRAPH = ['Scan', 'Loop', 'If'] +NODES_TYPE_CONTAINING_SUBGRAPH = ["Scan", "Loop", "If"] def parse_arguments(argv=None): parser = argparse.ArgumentParser() - parser.add_argument('-i', '--input', required=False, - type=str, - help="Set the input file for reading the profile results") + parser.add_argument( + "-i", + "--input", + required=False, + type=str, + help="Set the input file for reading the profile results", + ) - parser.add_argument('-m', '--model', required=False, type=str, help="onnx model path to run profiling. Required when --input is not specified.") + parser.add_argument( + "-m", + "--model", + required=False, + type=str, + help="onnx model path to run profiling. Required when --input is not specified.", + ) - parser.add_argument('-b', '--batch_size', required=False, type=int, default=1, help="batch size of input") + parser.add_argument( + "-b", + "--batch_size", + required=False, + type=int, + default=1, + help="batch size of input", + ) - parser.add_argument('-s', - '--sequence_length', - required=False, - type=int, - default=32, - help="sequence length of input") + parser.add_argument( + "-s", + "--sequence_length", + required=False, + type=int, + default=32, + help="sequence length of input", + ) - parser.add_argument('--past_sequence_length', - required=False, - type=int, - default=1, - help="past sequence length for gpt2") + parser.add_argument( + "--past_sequence_length", + required=False, + type=int, + default=1, + help="past sequence length for gpt2", + ) - parser.add_argument('--global_length', - required=False, - type=int, - default=1, - help="number of global tokens for longformer") + parser.add_argument( + "--global_length", + required=False, + type=int, + default=1, + help="number of global tokens for longformer", + ) parser.add_argument( - '--samples', + "--samples", required=False, type=int, default=1000, - help="number of samples to test. Set it large enough to reduce the variance of performance result.") + help="number of samples to test. Set it large enough to reduce the variance of performance result.", + ) parser.add_argument( - '--threshold', + "--threshold", required=False, type=float, default=0.01, - help="Threshold of run time ratio among all nodes. Nodes with larger ratio will show in top expensive nodes.") - - parser.add_argument("--thread_num", required=False, type=int, default=-1, help="number of threads to use") - - parser.add_argument('--input_ids_name', - required=False, - type=str, - default=None, - help="input name for input IDs, for bert") - parser.add_argument('--segment_ids_name', - required=False, - type=str, - default=None, - help="input name for segment IDs, for bert") - parser.add_argument('--input_mask_name', - required=False, - type=str, - default=None, - help="input name for attention mask, for bert") - - parser.add_argument('--dummy_inputs', - required=False, - default='default', - choices=['bert', 'gpt2', 'longformer', 'default'], - help="Type of model inputs. The default will create dummy inputs with ones.") - - parser.add_argument('-g', '--use_gpu', required=False, action='store_true', help="use GPU") + help="Threshold of run time ratio among all nodes. Nodes with larger ratio will show in top expensive nodes.", + ) + + parser.add_argument( + "--thread_num", + required=False, + type=int, + default=-1, + help="number of threads to use", + ) + + parser.add_argument( + "--input_ids_name", + required=False, + type=str, + default=None, + help="input name for input IDs, for bert", + ) + parser.add_argument( + "--segment_ids_name", + required=False, + type=str, + default=None, + help="input name for segment IDs, for bert", + ) + parser.add_argument( + "--input_mask_name", + required=False, + type=str, + default=None, + help="input name for attention mask, for bert", + ) + + parser.add_argument( + "--dummy_inputs", + required=False, + default="default", + choices=["bert", "gpt2", "longformer", "default"], + help="Type of model inputs. The default will create dummy inputs with ones.", + ) + + parser.add_argument("-g", "--use_gpu", required=False, action="store_true", help="use GPU") parser.set_defaults(use_gpu=False) - parser.add_argument('--provider', - required=False, - type=str, - default='cuda', - help="Execution provider to use") + parser.add_argument( + "--provider", + required=False, + type=str, + default="cuda", + help="Execution provider to use", + ) parser.add_argument( - '--basic_optimization', + "--basic_optimization", required=False, - action='store_true', - help="Enable only basic graph optimizations. By default, all optimizations are enabled in OnnxRuntime") + action="store_true", + help="Enable only basic graph optimizations. By default, all optimizations are enabled in OnnxRuntime", + ) parser.set_defaults(basic_optimization=False) - parser.add_argument('--kernel_time_only', - required=False, - action='store_true', - help="Only include the kernel time and no fence time") + parser.add_argument( + "--kernel_time_only", + required=False, + action="store_true", + help="Only include the kernel time and no fence time", + ) parser.set_defaults(kernel_time_only=False) - parser.add_argument('-v', '--verbose', required=False, action='store_true') + parser.add_argument("-v", "--verbose", required=False, action="store_true") parser.set_defaults(verbose=False) return parser.parse_args(argv) @@ -114,12 +160,14 @@ def parse_arguments(argv=None): def run_profile(onnx_model_path, use_gpu, provider, basic_optimization, thread_num, all_inputs): from benchmark_helper import create_onnxruntime_session - session = create_onnxruntime_session(onnx_model_path, - use_gpu, - provider, - enable_all_optimization=not basic_optimization, - num_threads=thread_num, - enable_profiling=True) + session = create_onnxruntime_session( + onnx_model_path, + use_gpu, + provider, + enable_all_optimization=not basic_optimization, + num_threads=thread_num, + enable_profiling=True, + ) for inputs in all_inputs: _ = session.run(None, inputs) @@ -199,7 +247,6 @@ def parse_kernel_results(sess_time, threshold=0): avg_time = duration / float(calls) lines.append(f"{duration:10d}\t{ratio * 100.0:5.2f}\t{calls:5d}\t{avg_time:8.1f}\t{kernel_name}") - # Group by operator op_time = {} for kernel_name, op_name in kernel_name_to_op_name.items(): @@ -237,8 +284,9 @@ def parse_node_results(sess_time, kernel_time_only=False, threshold=0): total = 0 for item in sess_time: if item["cat"] == "Node" and "dur" in item and "args" in item and "op_name" in item["args"]: - node_name = item["name"].replace("_kernel_time", "").replace("_fence_before", - "").replace("_fence_after", "") + node_name = ( + item["name"].replace("_kernel_time", "").replace("_fence_before", "").replace("_fence_after", "") + ) if "provider" in item["args"]: if item["args"]["provider"] == "CPUExecutionProvider": @@ -271,8 +319,9 @@ def parse_node_results(sess_time, kernel_time_only=False, threshold=0): # Output items in the original order. lines = [ - "\nNodes in the original order:", "-" * 64, - "Total(μs)\tTime%\tAcc %\tAvg(μs)\tCalls\tProvider\tNode" + "\nNodes in the original order:", + "-" * 64, + "Total(μs)\tTime%\tAcc %\tAvg(μs)\tCalls\tProvider\tNode", ] before_percentage = 0.0 for node_name in node_name_list: @@ -285,7 +334,6 @@ def parse_node_results(sess_time, kernel_time_only=False, threshold=0): lines.append( f"{duration:10d}\t{percentage:5.2f}\t{before_percentage:5.2f}\t{avg_time:8.1f}\t{calls:5d}\t{provider:8s}\t{node_name}" ) - # Output items with run time ratio > thresholds, and sorted by duration in the descending order. lines.append(f"\nTop expensive nodes with Time% >= {threshold*100:.2f}:") @@ -383,26 +431,30 @@ def group_node_results(sess_time, kernel_time_only, use_gpu): time_ratio = total_time / (total_kernel_time + total_fence_time) kernel_calls = op_kernel_records[op_name] avg_kernel_time = kernel_time / kernel_calls - lines.append(f"{total_time:10d}\t{time_ratio * 100.0:5.2f}\t{kernel_time:11d}\t{kernel_time_ratio * 100.0:5.2f}\t{kernel_calls:5d}\t{avg_kernel_time:14.1f}\t{fence_time:10d}\t{op_name}") + lines.append( + f"{total_time:10d}\t{time_ratio * 100.0:5.2f}\t{kernel_time:11d}\t{kernel_time_ratio * 100.0:5.2f}\t{kernel_calls:5d}\t{avg_kernel_time:14.1f}\t{fence_time:10d}\t{op_name}" + ) lines += ["", "Grouped by provider + operator"] lines.append("-" * 64) lines.append("Kernel(μs)\tProvider%\tCalls\tAvgKernel(μs)\tProvider\tOperator") for key, kernel_time in sorted(provider_op_kernel_time.items(), key=lambda x: x[1], reverse=True): - parts = key.split(':') + parts = key.split(":") provider = parts[0] - op_name = parts[1] - short_ep = provider.replace("ExecutionProvider", "") + op_name = parts[1] + short_ep = provider.replace("ExecutionProvider", "") calls = provider_op_kernel_records[key] avg_kernel_time = kernel_time / calls provider_time_ratio = kernel_time / provider_kernel_time[provider] - lines.append(f"{kernel_time:10d}\t{provider_time_ratio * 100.0:9.2f}\t{calls:5d}\t{avg_kernel_time:14.1f}\t{short_ep:8s}\t{op_name}") + lines.append( + f"{kernel_time:10d}\t{provider_time_ratio * 100.0:9.2f}\t{calls:5d}\t{avg_kernel_time:14.1f}\t{short_ep:8s}\t{op_name}" + ) return lines def get_dim_from_type_proto(dim): - return getattr(dim, dim.WhichOneof('value')) if type(dim.WhichOneof('value')) == str else None + return getattr(dim, dim.WhichOneof("value")) if type(dim.WhichOneof("value")) == str else None def get_shape_from_type_proto(type_proto): @@ -439,8 +491,11 @@ def create_dummy_inputs(onnx_model, batch_size, sequence_length, samples): elem_type = graph_input.type.tensor_type.elem_type assert elem_type in [TensorProto.FLOAT, TensorProto.INT32, TensorProto.INT64] - data_type = numpy.float32 if elem_type == TensorProto.FLOAT else ( - numpy.int64 if elem_type == TensorProto.INT64 else numpy.int32) + data_type = ( + numpy.float32 + if elem_type == TensorProto.FLOAT + else (numpy.int64 if elem_type == TensorProto.INT64 else numpy.int32) + ) data = numpy.ones(shape, dtype=data_type) dummy_inputs[graph_input.name] = data @@ -448,13 +503,15 @@ def create_dummy_inputs(onnx_model, batch_size, sequence_length, samples): return all_inputs -def create_bert_inputs(onnx_model, - batch_size, - sequence_length, - samples, - input_ids_name=None, - segment_ids_name=None, - input_mask_name=None): +def create_bert_inputs( + onnx_model, + batch_size, + sequence_length, + samples, + input_ids_name=None, + segment_ids_name=None, + input_mask_name=None, +): """Create dummy inputs for BERT model. Args: @@ -470,16 +527,19 @@ def create_bert_inputs(onnx_model, List[Dict]: list of inputs """ from bert_test_data import find_bert_inputs, generate_test_data + input_ids, segment_ids, input_mask = find_bert_inputs(onnx_model, input_ids_name, segment_ids_name, input_mask_name) - all_inputs = generate_test_data(batch_size, - sequence_length, - test_cases=samples, - seed=123, - verbose=False, - input_ids=input_ids, - segment_ids=segment_ids, - input_mask=input_mask, - random_mask_length=False) + all_inputs = generate_test_data( + batch_size, + sequence_length, + test_cases=samples, + seed=123, + verbose=False, + input_ids=input_ids, + segment_ids=segment_ids, + input_mask=input_mask, + random_mask_length=False, + ) return all_inputs @@ -502,10 +562,10 @@ def create_gpt2_inputs(onnx_model, batch_size, sequence_length, past_sequence_le """ # The symbolic names shall be same as those used in Gpt2Helper.export_onnx(...) function. symbols = { - 'batch_size': batch_size, - 'seq_len': sequence_length, - 'past_seq_len': past_sequence_length, - 'total_seq_len': sequence_length + past_sequence_length + "batch_size": batch_size, + "seq_len": sequence_length, + "past_seq_len": past_sequence_length, + "total_seq_len": sequence_length + past_sequence_length, } dummy_inputs = {} @@ -520,8 +580,11 @@ def create_gpt2_inputs(onnx_model, batch_size, sequence_length, past_sequence_le elem_type = graph_input.type.tensor_type.elem_type assert elem_type in [TensorProto.FLOAT, TensorProto.INT32, TensorProto.INT64] - data_type = numpy.float32 if elem_type == TensorProto.FLOAT else ( - numpy.int64 if elem_type == TensorProto.INT64 else numpy.int32) + data_type = ( + numpy.float32 + if elem_type == TensorProto.FLOAT + else (numpy.int64 if elem_type == TensorProto.INT64 else numpy.int32) + ) data = numpy.ones(shape, dtype=data_type) dummy_inputs[graph_input.name] = data @@ -545,7 +608,7 @@ def create_longformer_inputs(onnx_model, batch_size, sequence_length, global_len Returns: List[Dict]: list of inputs """ - symbols = {'batch_size': batch_size, 'sequence_length': sequence_length} + symbols = {"batch_size": batch_size, "sequence_length": sequence_length} dummy_inputs = {} for graph_input in onnx_model.get_graph_inputs_excluding_initializers(): @@ -559,8 +622,11 @@ def create_longformer_inputs(onnx_model, batch_size, sequence_length, global_len elem_type = graph_input.type.tensor_type.elem_type assert elem_type in [TensorProto.FLOAT, TensorProto.INT32, TensorProto.INT64] - data_type = numpy.float32 if elem_type == TensorProto.FLOAT else ( - numpy.int64 if elem_type == TensorProto.INT64 else numpy.int32) + data_type = ( + numpy.float32 + if elem_type == TensorProto.FLOAT + else (numpy.int64 if elem_type == TensorProto.INT64 else numpy.int32) + ) if "global" in graph_input.name: data = numpy.zeros(shape, dtype=data_type) @@ -572,6 +638,7 @@ def create_longformer_inputs(onnx_model, batch_size, sequence_length, global_len all_inputs = [dummy_inputs for _ in range(samples)] return all_inputs + def process_results(profile_file, args): profile_records = load_profile_json(profile_file) @@ -583,6 +650,7 @@ def process_results(profile_file, args): return lines + def run(args): num_threads = args.thread_num if args.thread_num > 0 else psutil.cpu_count(logical=False) @@ -592,31 +660,57 @@ def run(args): from onnx import load from onnx_model import OnnxModel + onnx_model = OnnxModel(load(args.model)) all_inputs = None - if args.dummy_inputs == 'bert': - all_inputs = create_bert_inputs(onnx_model, args.batch_size, args.sequence_length, args.samples, - args.input_ids_name, args.segment_ids_name, args.input_mask_name) - elif args.dummy_inputs == 'gpt2': - all_inputs = create_gpt2_inputs(onnx_model, args.batch_size, args.sequence_length, args.past_sequence_length, - args.samples) - elif args.dummy_inputs == 'longformer': - all_inputs = create_longformer_inputs(onnx_model, args.batch_size, args.sequence_length, args.global_length, - args.samples) + if args.dummy_inputs == "bert": + all_inputs = create_bert_inputs( + onnx_model, + args.batch_size, + args.sequence_length, + args.samples, + args.input_ids_name, + args.segment_ids_name, + args.input_mask_name, + ) + elif args.dummy_inputs == "gpt2": + all_inputs = create_gpt2_inputs( + onnx_model, + args.batch_size, + args.sequence_length, + args.past_sequence_length, + args.samples, + ) + elif args.dummy_inputs == "longformer": + all_inputs = create_longformer_inputs( + onnx_model, + args.batch_size, + args.sequence_length, + args.global_length, + args.samples, + ) else: # default all_inputs = create_dummy_inputs(onnx_model, args.batch_size, args.sequence_length, args.samples) - profile_file = run_profile(args.model, args.use_gpu, args.provider, args.basic_optimization, args.thread_num, all_inputs) + profile_file = run_profile( + args.model, + args.use_gpu, + args.provider, + args.basic_optimization, + args.thread_num, + all_inputs, + ) return profile_file -if __name__ == '__main__': +if __name__ == "__main__": arguments = parse_arguments() print("Arguments", arguments) from benchmark_helper import setup_logger + setup_logger(arguments.verbose) if not arguments.input: diff --git a/onnxruntime/python/tools/transformers/quantize_helper.py b/onnxruntime/python/tools/transformers/quantize_helper.py index f7227a990b208..d7e9eb9718a9e 100644 --- a/onnxruntime/python/tools/transformers/quantize_helper.py +++ b/onnxruntime/python/tools/transformers/quantize_helper.py @@ -5,9 +5,10 @@ # -------------------------------------------------------------------------- import logging -import torch -import onnx import os + +import onnx +import torch from transformers.modeling_utils import Conv1D logger = logging.getLogger(__name__) @@ -22,9 +23,9 @@ def _conv1d_to_linear(module): def conv1d_to_linear(model): - '''in-place + """in-place This is for Dynamic Quantization, as Conv1D is not recognized by PyTorch, convert it to nn.Linear - ''' + """ logger.debug("replace Conv1D with Linear") for name in list(model._modules): module = model._modules[name] @@ -38,33 +39,37 @@ def conv1d_to_linear(model): def _get_size_of_pytorch_model(model): torch.save(model.state_dict(), "temp.p") size = os.path.getsize("temp.p") / (1024 * 1024) - os.remove('temp.p') + os.remove("temp.p") return size class QuantizeHelper: @staticmethod def quantize_torch_model(model, dtype=torch.qint8): - ''' + """ Usage: model = quantize_model(model) TODO: mix of in-place and return, but results are different - ''' + """ conv1d_to_linear(model) quantized_model = torch.quantization.quantize_dynamic(model, {torch.nn.Linear}, dtype=dtype) - logger.info(f'Size of full precision Torch model(MB):{_get_size_of_pytorch_model(model)}') - logger.info(f'Size of quantized Torch model(MB):{_get_size_of_pytorch_model(quantized_model)}') + logger.info(f"Size of full precision Torch model(MB):{_get_size_of_pytorch_model(model)}") + logger.info(f"Size of quantized Torch model(MB):{_get_size_of_pytorch_model(quantized_model)}") return quantized_model @staticmethod def quantize_onnx_model(onnx_model_path, quantized_model_path, use_external_data_format=False): - from onnxruntime.quantization import quantize_dynamic from pathlib import Path + + from onnxruntime.quantization import quantize_dynamic + Path(quantized_model_path).parent.mkdir(parents=True, exist_ok=True) - logger.info(f'Size of full precision ONNX model(MB):{os.path.getsize(onnx_model_path)/(1024*1024)}') - quantize_dynamic(onnx_model_path, - quantized_model_path, - use_external_data_format = use_external_data_format) + logger.info(f"Size of full precision ONNX model(MB):{os.path.getsize(onnx_model_path)/(1024*1024)}") + quantize_dynamic( + onnx_model_path, + quantized_model_path, + use_external_data_format=use_external_data_format, + ) logger.info(f"quantized model saved to:{quantized_model_path}") - #TODO: inlcude external data in total model size. - logger.info(f'Size of quantized ONNX model(MB):{os.path.getsize(quantized_model_path)/(1024*1024)}') + # TODO: inlcude external data in total model size. + logger.info(f"Size of quantized ONNX model(MB):{os.path.getsize(quantized_model_path)/(1024*1024)}") diff --git a/onnxruntime/python/tools/transformers/shape_infer_helper.py b/onnxruntime/python/tools/transformers/shape_infer_helper.py index d3a3939563eca..bbb84a6bf1aa8 100644 --- a/onnxruntime/python/tools/transformers/shape_infer_helper.py +++ b/onnxruntime/python/tools/transformers/shape_infer_helper.py @@ -1,24 +1,32 @@ -#------------------------------------------------------------------------- +# ------------------------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. -#-------------------------------------------------------------------------- +# -------------------------------------------------------------------------- import os import sys + import onnx # In ORT Package the symbolic_shape_infer.py is in ../tools file_path = os.path.dirname(__file__) if os.path.exists(os.path.join(file_path, "../tools/symbolic_shape_infer.py")): - sys.path.append(os.path.join(file_path, '../tools')) + sys.path.append(os.path.join(file_path, "../tools")) else: - sys.path.append(os.path.join(file_path, '..')) + sys.path.append(os.path.join(file_path, "..")) from symbolic_shape_infer import SymbolicShapeInference, get_shape_from_type_proto, sympy class SymbolicShapeInferenceHelper(SymbolicShapeInference): - def __init__(self, model, verbose=0, int_max=2**31 - 1, auto_merge=True, guess_output_rank=False): + def __init__( + self, + model, + verbose=0, + int_max=2**31 - 1, + auto_merge=True, + guess_output_rank=False, + ): super().__init__(int_max, auto_merge, guess_output_rank, verbose) self.model_ = onnx.ModelProto() self.model_.CopyFrom(model) @@ -46,8 +54,16 @@ def _preprocess(self, in_mp): self.initializers_ = dict([(i.name, i) for i in self.out_mp_.graph.initializer]) self.known_vi_ = dict([(i.name, i) for i in list(self.out_mp_.graph.input)]) self.known_vi_.update( - dict([(i.name, onnx.helper.make_tensor_value_info(i.name, i.data_type, list(i.dims))) - for i in self.out_mp_.graph.initializer])) + dict( + [ + ( + i.name, + onnx.helper.make_tensor_value_info(i.name, i.data_type, list(i.dims)), + ) + for i in self.out_mp_.graph.initializer + ] + ) + ) # Override _get_sympy_shape() in symbolic_shape_infer.py to ensure shape inference by giving the actual value of dynamic axis def _get_sympy_shape(self, node, idx): @@ -66,7 +82,7 @@ def _get_sympy_shape(self, node, idx): return sympy_shape def get_edge_shape(self, edge): - assert (self.all_shapes_inferred_ == True) + assert self.all_shapes_inferred_ == True if edge not in self.known_vi_: print("Cannot retrive the shape of " + str(edge)) return None @@ -79,7 +95,7 @@ def get_edge_shape(self, edge): return shape def compare_shape(self, edge, edge_other): - assert (self.all_shapes_inferred_ == True) + assert self.all_shapes_inferred_ == True shape = self.get_edge_shape(edge) shape_other = self.get_edge_shape(edge_other) if shape is None or shape_other is None: diff --git a/onnxruntime/python/tools/transformers/shape_optimizer.py b/onnxruntime/python/tools/transformers/shape_optimizer.py index a58b8d49ba594..7174af0ac9ba0 100644 --- a/onnxruntime/python/tools/transformers/shape_optimizer.py +++ b/onnxruntime/python/tools/transformers/shape_optimizer.py @@ -1,32 +1,34 @@ -#------------------------------------------------------------------------- +# ------------------------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. -#-------------------------------------------------------------------------- +# -------------------------------------------------------------------------- # This tool is not used directly in bert optimization. It could assist developing the optimization script on the following senarios: # (1) It could simplify graph by removing many sub-graphs related to reshape. # (2) It could reduce extra inputs and outputs to fit other tools. The script compare_bert_results.py or bert_perf_test.py requires 3 inputs. -import sys import argparse -import numpy as np -from collections import deque -from typing import List -import onnx +import logging +import os import re +import sys import tempfile -import os -import logging +from collections import deque from datetime import datetime from pathlib import Path +from typing import List + +import numpy as np +import onnx from onnx import ModelProto, TensorProto, numpy_helper -import onnxruntime from onnx_model import OnnxModel +import onnxruntime + logger = logging.getLogger(__name__) -CONSTANT_SHAPE_NAME_PREFIX = 'constant_shape_opt__' -RESHAPE_INPUT_SHAPE_PREFIX = 'reshape_input_shape__' +CONSTANT_SHAPE_NAME_PREFIX = "constant_shape_opt__" +RESHAPE_INPUT_SHAPE_PREFIX = "reshape_input_shape__" class BertOnnxModelShapeOptimizer(OnnxModel): @@ -34,6 +36,7 @@ class BertOnnxModelShapeOptimizer(OnnxModel): This optimizer will replace Shape output or the shape input of Reshape node by initializer. Currently, it requires model inputs to have static shape. """ + def __init__(self, onnx_model): super().__init__(onnx_model.model) @@ -42,11 +45,13 @@ def add_shape_initializer(self, shape): Add an initializer for constant shape. """ shape_value = np.asarray(shape, dtype=np.int64) - constant_shape_name = self.create_node_name('Constant', CONSTANT_SHAPE_NAME_PREFIX) - tensor = onnx.helper.make_tensor(name=constant_shape_name, - data_type=TensorProto.INT64, - dims=shape_value.shape, - vals=shape_value) + constant_shape_name = self.create_node_name("Constant", CONSTANT_SHAPE_NAME_PREFIX) + tensor = onnx.helper.make_tensor( + name=constant_shape_name, + data_type=TensorProto.INT64, + dims=shape_value.shape, + vals=shape_value, + ) self.add_initializer(tensor) return tensor @@ -58,7 +63,7 @@ def get_shape_outputs(self): outputs = [] for node in self.model.graph.node: - if node.op_type == 'Shape': + if node.op_type == "Shape": if node.output[0] in input_name_to_nodes: outputs.append(node.output[0]) @@ -72,7 +77,7 @@ def get_reshape_shape_inputs(self): shape_inputs = [] for node in self.model.graph.node: - if node.op_type == 'Reshape': + if node.op_type == "Reshape": shape_inputs.append(node.input[1]) return shape_inputs @@ -85,10 +90,10 @@ def add_shape_for_reshape_input(self): output_names = [] nodes_to_add = [] for node in self.model.graph.node: - if node.op_type == 'Reshape': + if node.op_type == "Reshape": input = node.input[0] - output_name = self.create_node_name('Reshape_Input', RESHAPE_INPUT_SHAPE_PREFIX) - shape_node = onnx.helper.make_node('Shape', inputs=[input], outputs=[output_name]) + output_name = self.create_node_name("Reshape_Input", RESHAPE_INPUT_SHAPE_PREFIX) + shape_node = onnx.helper.make_node("Shape", inputs=[input], outputs=[output_name]) nodes_to_add.append(shape_node) output_names.append(output_name) @@ -125,21 +130,25 @@ def use_static_input(self, inputs, batch_size=1, max_seq_len=128): dim_proto = input.type.tensor_type.shape.dim[0] dim_proto.dim_value = batch_size dim_proto = input.type.tensor_type.shape.dim[1] - if dim_proto.HasField('dim_param'): + if dim_proto.HasField("dim_param"): dim_proto.dim_value = max_seq_len - elif dim_proto.HasField('dim_value') and dim_proto.dim_value != max_seq_len: + elif dim_proto.HasField("dim_value") and dim_proto.dim_value != max_seq_len: raise ValueError( - 'Unable to set dimension value to {} for axis {} of {}. Contradicts existing dimension value {}.' - .format(max_seq_len, 1, input.name, dim_proto.dim_value)) - - def create_dummy_inputs(self, - input_ids, - segment_ids, - input_mask, - batch_size, - sequence_length, - elem_type, - dictionary_size=8): + "Unable to set dimension value to {} for axis {} of {}. Contradicts existing dimension value {}.".format( + max_seq_len, 1, input.name, dim_proto.dim_value + ) + ) + + def create_dummy_inputs( + self, + input_ids, + segment_ids, + input_mask, + batch_size, + sequence_length, + elem_type, + dictionary_size=8, + ): """ Create dummy data for model inputs. If the model has more than 3 inputs, please update this function accordingly before running the tool. """ @@ -151,11 +160,11 @@ def create_dummy_inputs(self, input_3 = np.zeros((batch_size, sequence_length), dtype=np.int32) # Here we assume that 3 inputs have same data type - if elem_type == 1: #float32 + if elem_type == 1: # float32 input_1 = np.float32(input_1) input_2 = np.float32(input_2) input_3 = np.float32(input_3) - elif elem_type == 7: #int64 + elif elem_type == 7: # int64 input_1 = np.int64(input_1) input_2 = np.int64(input_2) input_3 = np.int64(input_3) @@ -163,8 +172,19 @@ def create_dummy_inputs(self, inputs = {input_ids: input_1, input_mask: input_2, segment_ids: input_3} return inputs - def shape_optimization(self, temp_model_path, input_ids, segment_ids, input_mask, output_names, batch_size, - sequence_length, enable_shape_opt, enable_reshape_opt, verbose): + def shape_optimization( + self, + temp_model_path, + input_ids, + segment_ids, + input_mask, + output_names, + batch_size, + sequence_length, + enable_shape_opt, + enable_reshape_opt, + verbose, + ): self.bert_inputs = [input_ids, segment_ids, input_mask] extra_outputs = [] @@ -189,9 +209,11 @@ def shape_optimization(self, temp_model_path, input_ids, segment_ids, input_mask out.write(self.model.SerializeToString()) sess_options = onnxruntime.SessionOptions() sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_DISABLE_ALL - session = onnxruntime.InferenceSession(temp_model_path, - sess_options, - providers=['CUDAExecutionProvider', 'CPUExecutionProvider']) + session = onnxruntime.InferenceSession( + temp_model_path, + sess_options, + providers=["CUDAExecutionProvider", "CPUExecutionProvider"], + ) elem_type = 7 for input in self.model.graph.input: @@ -258,21 +280,23 @@ def validate_outputs(self, output_names: List[str]): if name not in valid_names: raise Exception("Output {} does not exist in the graph outputs: {}".format(name, valid_names)) - def optimize(self, - output_path: str, - input_ids: str, - segment_ids: str, - input_mask: str, - enable_shape_opt: bool, - enable_reshape_opt: bool, - output_names: List[str] = None, - batch_size=1, - sequence_length=128, - verbose=False): + def optimize( + self, + output_path: str, + input_ids: str, + segment_ids: str, + input_mask: str, + enable_shape_opt: bool, + enable_reshape_opt: bool, + output_names: List[str] = None, + batch_size=1, + sequence_length=128, + verbose=False, + ): # Skip if shape optimization has been done before. for tensor in self.model.graph.initializer: if tensor.name.startswith(CONSTANT_SHAPE_NAME_PREFIX): - logger.info('Skip shape optimization since it has been done before') + logger.info("Skip shape optimization since it has been done before") return self.validate_input(input_ids) @@ -287,18 +311,28 @@ def optimize(self, if enable_shape_opt or enable_reshape_opt: if len(self.get_graph_inputs_excluding_initializers()) != 3: - logger.info('Skip shape optimization since graph input number is not 3') + logger.info("Skip shape optimization since graph input number is not 3") return with tempfile.TemporaryDirectory() as temp_dir: - temp_file_name = 'temp_{}.onnx'.format(datetime.now().strftime("%m_%d-%H_%M_%S")) + temp_file_name = "temp_{}.onnx".format(datetime.now().strftime("%m_%d-%H_%M_%S")) dir = "." if verbose else temp_dir temp_file = os.path.join(dir, temp_file_name) - self.shape_optimization(temp_file, input_ids, segment_ids, input_mask, remaining_outputs, batch_size, - sequence_length, enable_shape_opt, enable_reshape_opt, verbose) + self.shape_optimization( + temp_file, + input_ids, + segment_ids, + input_mask, + remaining_outputs, + batch_size, + sequence_length, + enable_shape_opt, + enable_reshape_opt, + verbose, + ) logger.debug(f"Temp model with additional outputs: {temp_file}") logger.warning( - f'Shape optimization is done. The optimized model might only work for input with batch_size={batch_size} sequence_length={sequence_length}' + f"Shape optimization is done. The optimized model might only work for input with batch_size={batch_size} sequence_length={sequence_length}" ) if output_path is not None: @@ -308,19 +342,19 @@ def optimize(self, def parse_arguments(): parser = argparse.ArgumentParser() - parser.add_argument('--input', required=True, type=str) - parser.add_argument('--output', required=True, type=str) - parser.add_argument('--input_ids', required=True, type=str) - parser.add_argument('--segment_ids', required=True, type=str) - parser.add_argument('--input_mask', required=True, type=str) - parser.add_argument('--output_names', required=False, type=str, default=None) - parser.add_argument('--batch_size', required=False, type=int, default=1) - parser.add_argument('--sequence_length', required=False, type=int, default=128) - parser.add_argument('--enable_shape_opt', required=False, action='store_true') + parser.add_argument("--input", required=True, type=str) + parser.add_argument("--output", required=True, type=str) + parser.add_argument("--input_ids", required=True, type=str) + parser.add_argument("--segment_ids", required=True, type=str) + parser.add_argument("--input_mask", required=True, type=str) + parser.add_argument("--output_names", required=False, type=str, default=None) + parser.add_argument("--batch_size", required=False, type=int, default=1) + parser.add_argument("--sequence_length", required=False, type=int, default=128) + parser.add_argument("--enable_shape_opt", required=False, action="store_true") parser.set_defaults(enable_shape_opt=False) - parser.add_argument('--enable_reshape_opt', required=False, action='store_true') + parser.add_argument("--enable_reshape_opt", required=False, action="store_true") parser.set_defaults(enable_reshape_opt=False) - parser.add_argument('--verbose', required=False, action='store_true') + parser.add_argument("--verbose", required=False, action="store_true") parser.set_defaults(verbose=False) args = parser.parse_args() return args @@ -329,10 +363,10 @@ def parse_arguments(): def setup_logging(verbose): log_handler = logging.StreamHandler(sys.stdout) if verbose: - log_handler.setFormatter(logging.Formatter('[%(filename)s:%(lineno)s - %(funcName)20s()] %(message)s')) + log_handler.setFormatter(logging.Formatter("[%(filename)s:%(lineno)s - %(funcName)20s()] %(message)s")) logging_level = logging.DEBUG else: - log_handler.setFormatter(logging.Formatter('%(filename)20s: %(message)s')) + log_handler.setFormatter(logging.Formatter("%(filename)20s: %(message)s")) logging_level = logging.INFO log_handler.setLevel(logging_level) logger.addHandler(log_handler) @@ -343,7 +377,7 @@ def main(): args = parse_arguments() setup_logging(args.verbose) - output_names = None if args.output_names is None else args.output_names.split(';') + output_names = None if args.output_names is None else args.output_names.split(";") model = ModelProto() with open(args.input, "rb") as input_file: @@ -352,8 +386,18 @@ def main(): optimizer = BertOnnxModelShapeOptimizer(onnx_model) - optimizer.optimize(args.output, args.input_ids, args.segment_ids, args.input_mask, args.enable_shape_opt, - args.enable_reshape_opt, output_names, args.batch_size, args.sequence_length, args.verbose) + optimizer.optimize( + args.output, + args.input_ids, + args.segment_ids, + args.input_mask, + args.enable_shape_opt, + args.enable_reshape_opt, + output_names, + args.batch_size, + args.sequence_length, + args.verbose, + ) if __name__ == "__main__": diff --git a/onnxruntime/python/tools/transformers/torch_onnx_export_helper.py b/onnxruntime/python/tools/transformers/torch_onnx_export_helper.py index 0912ee396f20e..119455684cea1 100644 --- a/onnxruntime/python/tools/transformers/torch_onnx_export_helper.py +++ b/onnxruntime/python/tools/transformers/torch_onnx_export_helper.py @@ -1,33 +1,36 @@ -#------------------------------------------------------------------------- +# ------------------------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. -#-------------------------------------------------------------------------- +# -------------------------------------------------------------------------- import torch + TrainingMode = torch.onnx.TrainingMode from packaging.version import Version + def torch_onnx_export( - model, - args, - f, - export_params=True, - verbose=False, - training=TrainingMode.EVAL, - input_names=None, - output_names=None, - operator_export_type=None, - opset_version=None, - _retain_param_name=None, - do_constant_folding=True, - example_outputs=None, - strip_doc_string=None, - dynamic_axes=None, - keep_initializers_as_inputs=None, - custom_opsets=None, - enable_onnx_checker=None, - use_external_data_format=None, - export_modules_as_functions=False): + model, + args, + f, + export_params=True, + verbose=False, + training=TrainingMode.EVAL, + input_names=None, + output_names=None, + operator_export_type=None, + opset_version=None, + _retain_param_name=None, + do_constant_folding=True, + example_outputs=None, + strip_doc_string=None, + dynamic_axes=None, + keep_initializers_as_inputs=None, + custom_opsets=None, + enable_onnx_checker=None, + use_external_data_format=None, + export_modules_as_functions=False, +): if Version(torch.__version__) >= Version("1.11.0"): torch.onnx.export( model=model, @@ -44,7 +47,8 @@ def torch_onnx_export( dynamic_axes=dynamic_axes, keep_initializers_as_inputs=keep_initializers_as_inputs, custom_opsets=custom_opsets, - export_modules_as_functions=export_modules_as_functions) + export_modules_as_functions=export_modules_as_functions, + ) else: torch.onnx.export( model=model, @@ -65,4 +69,5 @@ def torch_onnx_export( keep_initializers_as_inputs=keep_initializers_as_inputs, custom_opsets=custom_opsets, enable_onnx_checker=enable_onnx_checker, - use_external_data_format=use_external_data_format) + use_external_data_format=use_external_data_format, + ) diff --git a/onnxruntime/test/contrib_ops/attention_lstm_data_gen.py b/onnxruntime/test/contrib_ops/attention_lstm_data_gen.py index bff426146f48b..424fadd3ac9c6 100644 --- a/onnxruntime/test/contrib_ops/attention_lstm_data_gen.py +++ b/onnxruntime/test/contrib_ops/attention_lstm_data_gen.py @@ -14,94 +14,355 @@ aw_attn_size: int = 2 am_context_size: int = memDepth -root_variable_scope = 'LstmAttention' +root_variable_scope = "LstmAttention" with tf.variable_scope(root_variable_scope): - query = tf.get_variable("input", - initializer=tf.constant([ - 0.25, -1.5, 1.0, 0.25, -0.5, -1.5, 0.1, 1.5, 0.25, 0.0, 0.0, 0.0, 0.1, -0.125, 0.25, - -0.5, 0.25, 0.1, 1.0, 0.5, -1.5, 0.0, 0.0, 0.0 - ], - shape=[batchSize, queryMaxStep, queryDepth])) + query = tf.get_variable( + "input", + initializer=tf.constant( + [ + 0.25, + -1.5, + 1.0, + 0.25, + -0.5, + -1.5, + 0.1, + 1.5, + 0.25, + 0.0, + 0.0, + 0.0, + 0.1, + -0.125, + 0.25, + -0.5, + 0.25, + 0.1, + 1.0, + 0.5, + -1.5, + 0.0, + 0.0, + 0.0, + ], + shape=[batchSize, queryMaxStep, queryDepth], + ), + ) - querySeqLen = tf.Variable(tf.constant([queryMaxStep - 1, queryMaxStep - 2], tf.int32), name="query_seq_len") + querySeqLen = tf.Variable( + tf.constant([queryMaxStep - 1, queryMaxStep - 2], tf.int32), + name="query_seq_len", + ) memory = tf.get_variable( "memory", - initializer=tf.constant([ - 0.1, -0.25, 1.0, 1.0, -1.0, -1.5, 1.0, 0.25, -0.125, 0.1, -0.25, 0.5, -0.25, -1.25, 0.25, -1.0, 1.5, -1.250 - ], - shape=[batchSize, memMaxStep, memDepth])) + initializer=tf.constant( + [ + 0.1, + -0.25, + 1.0, + 1.0, + -1.0, + -1.5, + 1.0, + 0.25, + -0.125, + 0.1, + -0.25, + 0.5, + -0.25, + -1.25, + 0.25, + -1.0, + 1.5, + -1.250, + ], + shape=[batchSize, memMaxStep, memDepth], + ), + ) memSeqLen = tf.Variable(tf.constant([memMaxStep, memMaxStep - 1], dtype=tf.int32), name="mem_seq_len") with tf.variable_scope("fwBahdanau"): - fw_mem_layer_weights = tf.get_variable("memory_layer/kernel", - initializer=tf.constant([4.0, 2.0, 0.5, -8.0, -2.0, -2.0], - shape=[memDepth, am_attn_size])) + fw_mem_layer_weights = tf.get_variable( + "memory_layer/kernel", + initializer=tf.constant([4.0, 2.0, 0.5, -8.0, -2.0, -2.0], shape=[memDepth, am_attn_size]), + ) fw_query_layer_weights = tf.get_variable( "bidirectional_rnn/fw/attention_wrapper/bahdanau_attention/query_layer/kernel", - initializer=tf.constant([-0.125, -0.25, 0.1, -0.125, -0.5, 1.5], shape=[cell_hidden_size, am_attn_size])) + initializer=tf.constant( + [-0.125, -0.25, 0.1, -0.125, -0.5, 1.5], + shape=[cell_hidden_size, am_attn_size], + ), + ) - fw_aw_attn_weights = tf.get_variable("bidirectional_rnn/fw/attention_wrapper/attention_layer/kernel", - initializer=tf.constant( - [1.5, 1.0, 0.1, -0.25, 0.1, 1.0, -0.25, -0.125, -1.5, -1.5, -0.25, 1.5], - shape=[am_context_size + cell_hidden_size, aw_attn_size])) + fw_aw_attn_weights = tf.get_variable( + "bidirectional_rnn/fw/attention_wrapper/attention_layer/kernel", + initializer=tf.constant( + [1.5, 1.0, 0.1, -0.25, 0.1, 1.0, -0.25, -0.125, -1.5, -1.5, -0.25, 1.5], + shape=[am_context_size + cell_hidden_size, aw_attn_size], + ), + ) - fw_am_attention_v = tf.get_variable("bidirectional_rnn/fw/attention_wrapper/bahdanau_attention/attention_v", - initializer=tf.constant([-0.25, 0.1], shape=[am_attn_size])) + fw_am_attention_v = tf.get_variable( + "bidirectional_rnn/fw/attention_wrapper/bahdanau_attention/attention_v", + initializer=tf.constant([-0.25, 0.1], shape=[am_attn_size]), + ) fw_lstm_cell_kernel = tf.get_variable( "bidirectional_rnn/fw/attention_wrapper/lstm_cell/kernel", - initializer=tf.constant([ - -1.0, -1.5, -0.5, -1.5, 0.1, -0.5, 0.5, -1.5, -0.25, 1.0, -0.125, -0.25, -1.0, -0.5, 0.25, -0.125, -0.25, - -1.0, 1.5, 1.0, -1.5, 0.25, 0.5, 0.5, 1.5, -0.5, -1.0, -0.5, 0.1, 1.0, 0.1, -0.5, -0.125, -1.5, 0.1, 1.5, - 1.0, -0.5, -0.5, -1.5, -0.125, -0.125, 0.25, -0.25, -0.25, 0.1, -0.5, -0.25, 0.25, -0.5, 0.1, -0.5, -0.25, - 0.25, 0.1, 0.5, -1.5, -0.125, 1.5, 0.5, -1.5, 1.0, 0.1, -0.5, -1.5, 0.5, -1.0, 0.25, -0.25, 1.0, 0.25, 0.5, - -0.125, 0.1, -1.0, -1.0, 0.1, 1.5, -1.5, 0.1, 1.5, 0.5, 0.25, 1.0, 1.0, -1.5, -0.25, 0.5, -0.25, 1.0, -1.0, - 0.25, -0.5, 0.5, -1.5, 0.5 - ], - shape=[aw_attn_size + queryDepth + cell_hidden_size, 4 * cell_hidden_size])) - - fw_lstm_cell_bias = tf.get_variable("bidirectional_rnn/fw/attention_wrapper/lstm_cell/bias", - initializer=tf.constant( - [0.25, -0.25, 0.1, 1.0, 1.5, -1.5, 1.5, -1.0, -0.25, 1.0, -0.25, 1.0], - shape=[4 * cell_hidden_size])) + initializer=tf.constant( + [ + -1.0, + -1.5, + -0.5, + -1.5, + 0.1, + -0.5, + 0.5, + -1.5, + -0.25, + 1.0, + -0.125, + -0.25, + -1.0, + -0.5, + 0.25, + -0.125, + -0.25, + -1.0, + 1.5, + 1.0, + -1.5, + 0.25, + 0.5, + 0.5, + 1.5, + -0.5, + -1.0, + -0.5, + 0.1, + 1.0, + 0.1, + -0.5, + -0.125, + -1.5, + 0.1, + 1.5, + 1.0, + -0.5, + -0.5, + -1.5, + -0.125, + -0.125, + 0.25, + -0.25, + -0.25, + 0.1, + -0.5, + -0.25, + 0.25, + -0.5, + 0.1, + -0.5, + -0.25, + 0.25, + 0.1, + 0.5, + -1.5, + -0.125, + 1.5, + 0.5, + -1.5, + 1.0, + 0.1, + -0.5, + -1.5, + 0.5, + -1.0, + 0.25, + -0.25, + 1.0, + 0.25, + 0.5, + -0.125, + 0.1, + -1.0, + -1.0, + 0.1, + 1.5, + -1.5, + 0.1, + 1.5, + 0.5, + 0.25, + 1.0, + 1.0, + -1.5, + -0.25, + 0.5, + -0.25, + 1.0, + -1.0, + 0.25, + -0.5, + 0.5, + -1.5, + 0.5, + ], + shape=[aw_attn_size + queryDepth + cell_hidden_size, 4 * cell_hidden_size], + ), + ) + + fw_lstm_cell_bias = tf.get_variable( + "bidirectional_rnn/fw/attention_wrapper/lstm_cell/bias", + initializer=tf.constant( + [0.25, -0.25, 0.1, 1.0, 1.5, -1.5, 1.5, -1.0, -0.25, 1.0, -0.25, 1.0], + shape=[4 * cell_hidden_size], + ), + ) with tf.variable_scope("bwBahdanau"): - fw_mem_layer_weights = tf.get_variable("memory_layer/kernel", - initializer=tf.constant([4.0, 2.0, 0.5, -8.0, -2.0, -2.0], - shape=[memDepth, am_attn_size])) + fw_mem_layer_weights = tf.get_variable( + "memory_layer/kernel", + initializer=tf.constant([4.0, 2.0, 0.5, -8.0, -2.0, -2.0], shape=[memDepth, am_attn_size]), + ) bw_query_layer_weights = tf.get_variable( "bidirectional_rnn/bw/attention_wrapper/bahdanau_attention/query_layer/kernel", - initializer=tf.constant([-0.125, -0.25, 0.1, -0.125, -0.5, 1.5], shape=[cell_hidden_size, am_attn_size])) + initializer=tf.constant( + [-0.125, -0.25, 0.1, -0.125, -0.5, 1.5], + shape=[cell_hidden_size, am_attn_size], + ), + ) - bw_aw_attn_weights = tf.get_variable("bidirectional_rnn/bw/attention_wrapper/attention_layer/kernel", - initializer=tf.constant( - [1.5, 1.0, 0.1, -0.25, 0.1, 1.0, -0.25, -0.125, -1.5, -1.5, -0.25, 1.5], - shape=[am_context_size + cell_hidden_size, aw_attn_size])) + bw_aw_attn_weights = tf.get_variable( + "bidirectional_rnn/bw/attention_wrapper/attention_layer/kernel", + initializer=tf.constant( + [1.5, 1.0, 0.1, -0.25, 0.1, 1.0, -0.25, -0.125, -1.5, -1.5, -0.25, 1.5], + shape=[am_context_size + cell_hidden_size, aw_attn_size], + ), + ) - bw_am_attention_v = tf.get_variable("bidirectional_rnn/bw/attention_wrapper/bahdanau_attention/attention_v", - initializer=tf.constant([-0.25, 0.1], shape=[am_attn_size])) + bw_am_attention_v = tf.get_variable( + "bidirectional_rnn/bw/attention_wrapper/bahdanau_attention/attention_v", + initializer=tf.constant([-0.25, 0.1], shape=[am_attn_size]), + ) bw_lstm_cell_kernel = tf.get_variable( "bidirectional_rnn/bw/attention_wrapper/lstm_cell/kernel", - initializer=tf.constant([ - -1.0, -1.5, -0.5, -1.5, 0.1, -0.5, 0.5, -1.5, -0.25, 1.0, -0.125, -0.25, -1.0, -0.5, 0.25, -0.125, -0.25, - -1.0, 1.5, 1.0, -1.5, 0.25, 0.5, 0.5, 1.5, -0.5, -1.0, -0.5, 0.1, 1.0, 0.1, -0.5, -0.125, -1.5, 0.1, 1.5, - 1.0, -0.5, -0.5, -1.5, -0.125, -0.125, 0.25, -0.25, -0.25, 0.1, -0.5, -0.25, 0.25, -0.5, 0.1, -0.5, -0.25, - 0.25, 0.1, 0.5, -1.5, -0.125, 1.5, 0.5, -1.5, 1.0, 0.1, -0.5, -1.5, 0.5, -1.0, 0.25, -0.25, 1.0, 0.25, 0.5, - -0.125, 0.1, -1.0, -1.0, 0.1, 1.5, -1.5, 0.1, 1.5, 0.5, 0.25, 1.0, 1.0, -1.5, -0.25, 0.5, -0.25, 1.0, -1.0, - 0.25, -0.5, 0.5, -1.5, 0.5 - ], - shape=[aw_attn_size + queryDepth + cell_hidden_size, 4 * cell_hidden_size])) - - bw_lstm_cell_bias = tf.get_variable("bidirectional_rnn/bw/attention_wrapper/lstm_cell/bias", - initializer=tf.constant( - [0.25, -0.25, 0.1, 1.0, 1.5, -1.5, 1.5, -1.0, -0.25, 1.0, -0.25, 1.0], - shape=[4 * cell_hidden_size])) + initializer=tf.constant( + [ + -1.0, + -1.5, + -0.5, + -1.5, + 0.1, + -0.5, + 0.5, + -1.5, + -0.25, + 1.0, + -0.125, + -0.25, + -1.0, + -0.5, + 0.25, + -0.125, + -0.25, + -1.0, + 1.5, + 1.0, + -1.5, + 0.25, + 0.5, + 0.5, + 1.5, + -0.5, + -1.0, + -0.5, + 0.1, + 1.0, + 0.1, + -0.5, + -0.125, + -1.5, + 0.1, + 1.5, + 1.0, + -0.5, + -0.5, + -1.5, + -0.125, + -0.125, + 0.25, + -0.25, + -0.25, + 0.1, + -0.5, + -0.25, + 0.25, + -0.5, + 0.1, + -0.5, + -0.25, + 0.25, + 0.1, + 0.5, + -1.5, + -0.125, + 1.5, + 0.5, + -1.5, + 1.0, + 0.1, + -0.5, + -1.5, + 0.5, + -1.0, + 0.25, + -0.25, + 1.0, + 0.25, + 0.5, + -0.125, + 0.1, + -1.0, + -1.0, + 0.1, + 1.5, + -1.5, + 0.1, + 1.5, + 0.5, + 0.25, + 1.0, + 1.0, + -1.5, + -0.25, + 0.5, + -0.25, + 1.0, + -1.0, + 0.25, + -0.5, + 0.5, + -1.5, + 0.5, + ], + shape=[aw_attn_size + queryDepth + cell_hidden_size, 4 * cell_hidden_size], + ), + ) + + bw_lstm_cell_bias = tf.get_variable( + "bidirectional_rnn/bw/attention_wrapper/lstm_cell/bias", + initializer=tf.constant( + [0.25, -0.25, 0.1, 1.0, 1.5, -1.5, 1.5, -1.0, -0.25, 1.0, -0.25, 1.0], + shape=[4 * cell_hidden_size], + ), + ) reuse = tf.AUTO_REUSE # tf.AUTO_REUSE or TRUE with tf.variable_scope(root_variable_scope, reuse=reuse): @@ -113,20 +374,16 @@ fw_cell = tf.contrib.rnn.LSTMCell(num_units=cell_hidden_size, forget_bias=0.0) bw_cell = tf.contrib.rnn.LSTMCell(num_units=cell_hidden_size, forget_bias=0.0) - fw_attn_wrapper = tf.contrib.seq2seq.AttentionWrapper(fw_cell, - fw_am, - attention_layer_size=aw_attn_size, - output_attention=False) - bw_attn_wrapper = tf.contrib.seq2seq.AttentionWrapper(bw_cell, - bw_am, - attention_layer_size=aw_attn_size, - output_attention=False) - - outputs, states = tf.nn.bidirectional_dynamic_rnn(fw_attn_wrapper, - bw_attn_wrapper, - query, - querySeqLen, - dtype=tf.float32) + fw_attn_wrapper = tf.contrib.seq2seq.AttentionWrapper( + fw_cell, fw_am, attention_layer_size=aw_attn_size, output_attention=False + ) + bw_attn_wrapper = tf.contrib.seq2seq.AttentionWrapper( + bw_cell, bw_am, attention_layer_size=aw_attn_size, output_attention=False + ) + + outputs, states = tf.nn.bidirectional_dynamic_rnn( + fw_attn_wrapper, bw_attn_wrapper, query, querySeqLen, dtype=tf.float32 + ) tensors = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=root_variable_scope) @@ -143,18 +400,25 @@ cell = fw_state.cell_state.c attention = fw_state.attention alignments = fw_state.alignments - sess.run(tf.Print(cell, [cell], '====FinalState(fw)', summarize=10000)) - sess.run(tf.Print(alignments, [alignments], '====Final Alignment(fw)', summarize=10000)) - sess.run(tf.Print(attention, [attention], '====Final Attention Context(fw)', summarize=10000)) + sess.run(tf.Print(cell, [cell], "====FinalState(fw)", summarize=10000)) + sess.run(tf.Print(alignments, [alignments], "====Final Alignment(fw)", summarize=10000)) + sess.run(tf.Print(attention, [attention], "====Final Attention Context(fw)", summarize=10000)) bw_state = states[1] # output_state_bw cell = bw_state.cell_state.c attention = bw_state.attention alignments = bw_state.alignments - sess.run(tf.Print(cell, [cell], '====FinalState(bw)', summarize=10000)) - sess.run(tf.Print(alignments, [alignments], '====Final Alignment(bw)', summarize=10000)) - sess.run(tf.Print(attention, [attention], '====Final Attention Context(bw)', summarize=10000)) + sess.run(tf.Print(cell, [cell], "====FinalState(bw)", summarize=10000)) + sess.run(tf.Print(alignments, [alignments], "====Final Alignment(bw)", summarize=10000)) + sess.run(tf.Print(attention, [attention], "====Final Attention Context(bw)", summarize=10000)) for t in tensors: - shape_str = '[' + ','.join(list(map(lambda x: str(x.__int__()), t.get_shape()))) + ']' - sess.run(tf.Print(t, [tf.reshape(t, [-1])], '\t'.join([t.name, shape_str, '']), summarize=10000)) + shape_str = "[" + ",".join(list(map(lambda x: str(x.__int__()), t.get_shape()))) + "]" + sess.run( + tf.Print( + t, + [tf.reshape(t, [-1])], + "\t".join([t.name, shape_str, ""]), + summarize=10000, + ) + ) diff --git a/onnxruntime/test/onnx/gen_test_models.py b/onnxruntime/test/onnx/gen_test_models.py index 31effae815466..509c27ec4efea 100644 --- a/onnxruntime/test/onnx/gen_test_models.py +++ b/onnxruntime/test/onnx/gen_test_models.py @@ -1,15 +1,14 @@ #!/usr/bin/env python3 # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. -import onnx -import numpy as np -import os import argparse +import os from datetime import date -from onnx import numpy_helper -from onnx import helper -from onnx import utils -from onnx import AttributeProto, TensorProto, GraphProto + +import numpy as np +import onnx +from onnx import AttributeProto, GraphProto, TensorProto, helper, numpy_helper, utils + def parse_arguments(): parser = argparse.ArgumentParser() @@ -30,6 +29,7 @@ def write_tensor(f, c, input_name=None): body = tensor.SerializeToString() f.write(body) + def infer_shapes(model_def): onnx.checker.check_model(model_def) onnx.helper.strip_doc_string(model_def) @@ -37,6 +37,7 @@ def infer_shapes(model_def): onnx.checker.check_model(final_model) return final_model + def generate_abs_op_test(type, X, top_test_folder): for is_raw in [True, False]: if is_raw: @@ -46,24 +47,24 @@ def generate_abs_op_test(type, X, top_test_folder): data_dir = os.path.join(test_folder, "test_data_0") os.makedirs(data_dir, exist_ok=True) # Create one output (ValueInfoProto) - Y = helper.make_tensor_value_info('Y', type, X.shape) - X_INFO = helper.make_tensor_value_info('X', type, X.shape) + Y = helper.make_tensor_value_info("Y", type, X.shape) + X_INFO = helper.make_tensor_value_info("X", type, X.shape) if is_raw: - tensor_x = onnx.helper.make_tensor(name='X', data_type=type, dims=X.shape, vals=X.tobytes(), raw=True) + tensor_x = onnx.helper.make_tensor(name="X", data_type=type, dims=X.shape, vals=X.tobytes(), raw=True) else: - tensor_x = onnx.helper.make_tensor(name='X', data_type=type, dims=X.shape, vals=X.ravel(), raw=False) + tensor_x = onnx.helper.make_tensor(name="X", data_type=type, dims=X.shape, vals=X.ravel(), raw=False) # Create a node (NodeProto) - node_def = helper.make_node('Abs', inputs=['X'], outputs=['Y']) + node_def = helper.make_node("Abs", inputs=["X"], outputs=["Y"]) # Create the graph (GraphProto) - graph_def = helper.make_graph([node_def], 'test-model', [X_INFO], [Y], [tensor_x]) + graph_def = helper.make_graph([node_def], "test-model", [X_INFO], [Y], [tensor_x]) # Create the model (ModelProto) - model_def = helper.make_model(graph_def, producer_name='onnx-example') - #final_model = infer_shapes(model_def) + model_def = helper.make_model(graph_def, producer_name="onnx-example") + # final_model = infer_shapes(model_def) final_model = model_def if is_raw: onnx.external_data_helper.convert_model_to_external_data(final_model, True) - onnx.save(final_model, os.path.join(test_folder, 'model.onnx')) + onnx.save(final_model, os.path.join(test_folder, "model.onnx")) expected_output_array = np.abs(X) expected_output_tensor = numpy_helper.from_array(expected_output_array) with open(os.path.join(data_dir, "output_0.pb"), "wb") as f: @@ -74,18 +75,18 @@ def generate_size_op_test(type, X, test_folder): data_dir = os.path.join(test_folder, "test_data_0") os.makedirs(data_dir, exist_ok=True) # Create one output (ValueInfoProto) - Y = helper.make_tensor_value_info('Y', TensorProto.INT64, []) - X_INFO = helper.make_tensor_value_info('X', type, X.shape) - tensor_x = onnx.helper.make_tensor(name='X', data_type=type, dims=X.shape, vals=X.ravel(), raw=False) + Y = helper.make_tensor_value_info("Y", TensorProto.INT64, []) + X_INFO = helper.make_tensor_value_info("X", type, X.shape) + tensor_x = onnx.helper.make_tensor(name="X", data_type=type, dims=X.shape, vals=X.ravel(), raw=False) # Create a node (NodeProto) - node_def = helper.make_node('Size', inputs=['X'], outputs=['Y']) + node_def = helper.make_node("Size", inputs=["X"], outputs=["Y"]) # Create the graph (GraphProto) - graph_def = helper.make_graph([node_def], 'test-model', [X_INFO], [Y], [tensor_x]) + graph_def = helper.make_graph([node_def], "test-model", [X_INFO], [Y], [tensor_x]) # Create the model (ModelProto) - model_def = helper.make_model(graph_def, producer_name='onnx-example') + model_def = helper.make_model(graph_def, producer_name="onnx-example") final_model = infer_shapes(model_def) - onnx.save(final_model, os.path.join(test_folder, 'model.onnx')) + onnx.save(final_model, os.path.join(test_folder, "model.onnx")) expected_output_array = np.int64(X.size) expected_output_tensor = numpy_helper.from_array(expected_output_array) with open(os.path.join(data_dir, "output_0.pb"), "wb") as f: @@ -97,18 +98,18 @@ def generate_reducesum_op_test(X, test_folder): data_dir = os.path.join(test_folder, "test_data_0") os.makedirs(data_dir, exist_ok=True) # Create one output (ValueInfoProto) - Y = helper.make_tensor_value_info('Y', type, []) - X_INFO = helper.make_tensor_value_info('X', type, X.shape) - tensor_x = onnx.helper.make_tensor(name='X', data_type=type, dims=X.shape, vals=X.ravel(), raw=False) + Y = helper.make_tensor_value_info("Y", type, []) + X_INFO = helper.make_tensor_value_info("X", type, X.shape) + tensor_x = onnx.helper.make_tensor(name="X", data_type=type, dims=X.shape, vals=X.ravel(), raw=False) # Create a node (NodeProto) - node_def = helper.make_node('ReduceSum', inputs=['X'], outputs=['Y'], keepdims=0) + node_def = helper.make_node("ReduceSum", inputs=["X"], outputs=["Y"], keepdims=0) # Create the graph (GraphProto) - graph_def = helper.make_graph([node_def], 'test-model', [X_INFO], [Y], [tensor_x]) + graph_def = helper.make_graph([node_def], "test-model", [X_INFO], [Y], [tensor_x]) # Create the model (ModelProto) - model_def = helper.make_model(graph_def, producer_name='onnx-example') + model_def = helper.make_model(graph_def, producer_name="onnx-example") final_model = infer_shapes(model_def) - onnx.save(final_model, os.path.join(test_folder, 'model.onnx')) + onnx.save(final_model, os.path.join(test_folder, "model.onnx")) expected_output_array = np.sum(X) expected_output_tensor = numpy_helper.from_array(expected_output_array) with open(os.path.join(data_dir, "output_0.pb"), "wb") as f: @@ -116,39 +117,78 @@ def generate_reducesum_op_test(X, test_folder): def test_abs(output_dir): - generate_abs_op_test(TensorProto.FLOAT, - np.random.randn(3, 4, 5).astype(np.float32), os.path.join(output_dir, 'test_abs_float')) - generate_abs_op_test(TensorProto.DOUBLE, - np.random.randn(3, 4, 5).astype(np.float64), os.path.join(output_dir, 'test_abs_double')) - generate_abs_op_test(TensorProto.INT8, np.int8([-127, -4, 0, 3, 127]), os.path.join(output_dir, 'test_abs_int8')) - generate_abs_op_test(TensorProto.UINT8, np.uint8([0, 1, 20, 255]), os.path.join(output_dir, 'test_abs_uint8')) - generate_abs_op_test(TensorProto.INT16, np.int16([-32767, -4, 0, 3, 32767]), - os.path.join(output_dir, 'test_abs_int16')) - generate_abs_op_test(TensorProto.UINT16, np.uint16([-32767, -4, 0, 3, 32767]), - os.path.join(output_dir, 'test_abs_uint16')) - generate_abs_op_test(TensorProto.INT32, np.int32([-2147483647, -4, 0, 3, 2147483647]), - os.path.join(output_dir, 'test_abs_int32')) - generate_abs_op_test(TensorProto.UINT32, np.uint32([0, 1, 20, 4294967295]), - os.path.join(output_dir, 'test_abs_uint32')) + generate_abs_op_test( + TensorProto.FLOAT, + np.random.randn(3, 4, 5).astype(np.float32), + os.path.join(output_dir, "test_abs_float"), + ) + generate_abs_op_test( + TensorProto.DOUBLE, + np.random.randn(3, 4, 5).astype(np.float64), + os.path.join(output_dir, "test_abs_double"), + ) + generate_abs_op_test( + TensorProto.INT8, + np.int8([-127, -4, 0, 3, 127]), + os.path.join(output_dir, "test_abs_int8"), + ) + generate_abs_op_test( + TensorProto.UINT8, + np.uint8([0, 1, 20, 255]), + os.path.join(output_dir, "test_abs_uint8"), + ) + generate_abs_op_test( + TensorProto.INT16, + np.int16([-32767, -4, 0, 3, 32767]), + os.path.join(output_dir, "test_abs_int16"), + ) + generate_abs_op_test( + TensorProto.UINT16, + np.uint16([-32767, -4, 0, 3, 32767]), + os.path.join(output_dir, "test_abs_uint16"), + ) + generate_abs_op_test( + TensorProto.INT32, + np.int32([-2147483647, -4, 0, 3, 2147483647]), + os.path.join(output_dir, "test_abs_int32"), + ) + generate_abs_op_test( + TensorProto.UINT32, + np.uint32([0, 1, 20, 4294967295]), + os.path.join(output_dir, "test_abs_uint32"), + ) number_info = np.iinfo(np.int64) - generate_abs_op_test(TensorProto.INT64, np.int64([-number_info.max, -4, 0, 3, number_info.max]), - os.path.join(output_dir, 'test_abs_int64')) + generate_abs_op_test( + TensorProto.INT64, + np.int64([-number_info.max, -4, 0, 3, number_info.max]), + os.path.join(output_dir, "test_abs_int64"), + ) number_info = np.iinfo(np.uint64) - generate_abs_op_test(TensorProto.UINT64, np.uint64([0, 1, 20, number_info.max]), - os.path.join(output_dir, 'test_abs_uint64')) + generate_abs_op_test( + TensorProto.UINT64, + np.uint64([0, 1, 20, number_info.max]), + os.path.join(output_dir, "test_abs_uint64"), + ) def test_reducesum(output_dir): generate_reducesum_op_test( - np.random.randn(3, 4, 5).astype(np.float32), os.path.join(output_dir, 'test_reducesum_random')) + np.random.randn(3, 4, 5).astype(np.float32), + os.path.join(output_dir, "test_reducesum_random"), + ) def test_size(output_dir): - generate_size_op_test(TensorProto.FLOAT, - np.random.randn(100, 3000, 10).astype(np.float32), - os.path.join(output_dir, 'test_size_float')) - generate_size_op_test(TensorProto.STRING, np.array(['abc', 'xy'], dtype=np.bytes_), - os.path.join(output_dir, 'test_size_string')) + generate_size_op_test( + TensorProto.FLOAT, + np.random.randn(100, 3000, 10).astype(np.float32), + os.path.join(output_dir, "test_size_float"), + ) + generate_size_op_test( + TensorProto.STRING, + np.array(["abc", "xy"], dtype=np.bytes_), + os.path.join(output_dir, "test_size_string"), + ) args = parse_arguments() @@ -159,4 +199,3 @@ def test_size(output_dir): test_abs(args.output_dir) test_size(args.output_dir) test_reducesum(args.output_dir) - diff --git a/onnxruntime/test/providers/cpu/reduction/reduction_test_cases_generator.py b/onnxruntime/test/providers/cpu/reduction/reduction_test_cases_generator.py index 2568cdb557bfc..235b4111bbcb0 100644 --- a/onnxruntime/test/providers/cpu/reduction/reduction_test_cases_generator.py +++ b/onnxruntime/test/providers/cpu/reduction/reduction_test_cases_generator.py @@ -2,6 +2,7 @@ # Licensed under the MIT License. import os + import numpy as np @@ -41,25 +42,25 @@ def TestReduction(op, data, axes, keepdims): def PrintResult(op, axes, keepdims, res): - print(" {\"%s\"," % op) + print(' {"%s",' % op) print("OpAttributesResult(") print(" // ReductionAttribute") print(" {") print(" // axes_") - print("{", end='') - print(*axes, sep=", ", end='') if axes else print("") + print("{", end="") + print(*axes, sep=", ", end="") if axes else print("") print("},") print(" // keep_dims_") print(keepdims, ",") print("},") print(" // expected dims") - print("{", end='') - print(*res.shape, sep=", ", end='') + print("{", end="") + print(*res.shape, sep=", ", end="") print("},") print(" // expected values") - print("{", end='') + print("{", end="") for i in range(0, res.size): print("%5.6ff," % res.item(i)) @@ -69,20 +70,20 @@ def PrintResult(op, axes, keepdims, res): def PrintDisableOptimizations(): print("// Optimizations are disabled in this file to improve build throughput") print("#if defined(_MSC_VER) || defined(__INTEL_COMPILER)") - print("#pragma optimize (\"\", off)") + print('#pragma optimize ("", off)') print("#elif defined(__GNUC__)") print("#if defined(__clang__)") print("\t#pragma clang optimize off") print("#else") print("\t#pragma GCC push_options") - print("\t#pragma GCC optimize (\"O0\")") + print('\t#pragma GCC optimize ("O0")') print("#endif") print("#endif") def PrintReenableOptimizations(): print("#if defined(_MSC_VER) || defined(__INTEL_COMPILER)") - print("\t#pragma optimize (\"\", on)") + print('\t#pragma optimize ("", on)') print("#elif defined(__GNUC__)") print("#if defined(__clang__)") print("\t#pragma clang optimize on") @@ -94,14 +95,35 @@ def PrintReenableOptimizations(): if __name__ == "__main__": from itertools import product + input_shape = [2, 3, 2, 2, 3] np.random.seed(0) input_data = np.random.uniform(size=input_shape) - axes_options = [(-1, 3), (2, 3), (2, 1, 4), (0, -2, -3), (0, 2, 3), (0,), (2,), (4,), None] + axes_options = [ + (-1, 3), + (2, 3), + (2, 1, 4), + (0, -2, -3), + (0, 2, 3), + (0,), + (2,), + (4,), + None, + ] keepdims_options = [0, 1] ops = [ - "ReduceL1", "ReduceL2", "ReduceLogSum", "ReduceLogSumExp", "ReduceMax", "ReduceMean", "ReduceMin", "ReduceProd", - "ReduceSum", "ReduceSumSquare", "ArgMax", "ArgMin" + "ReduceL1", + "ReduceL2", + "ReduceLogSum", + "ReduceLogSumExp", + "ReduceMax", + "ReduceMean", + "ReduceMin", + "ReduceProd", + "ReduceSum", + "ReduceSumSquare", + "ArgMax", + "ArgMin", ] print("// Please don't manually edit this file. Generated from reduction_test_cases_generator.py") PrintDisableOptimizations() @@ -109,11 +131,13 @@ def PrintReenableOptimizations(): print("// input_data") print("{") for i in range(0, input_data.size): - print("%5.6ff," % input_data.item(i),) + print( + "%5.6ff," % input_data.item(i), + ) print("},") print("// input_dims") - print("{", end='') - print(*input_shape, sep=", ", end='') + print("{", end="") + print(*input_shape, sep=", ", end="") print("},") print(" // map_op_attribute_expected") @@ -122,7 +146,7 @@ def PrintReenableOptimizations(): for config in product(axes_options, keepdims_options, ops): axes, keepdims, op = config - #ArgMax and ArgMin only take single axis (default 0) + # ArgMax and ArgMin only take single axis (default 0) skip = False if op == "ArgMax" or op == "ArgMin": skip = axes is not None and len(axes) > 1 diff --git a/onnxruntime/test/providers/cpu/rnn/GRU.py b/onnxruntime/test/providers/cpu/rnn/GRU.py index 74ad0f20b59b9..3fee29e9928f0 100644 --- a/onnxruntime/test/providers/cpu/rnn/GRU.py +++ b/onnxruntime/test/providers/cpu/rnn/GRU.py @@ -4,11 +4,11 @@ import numpy as np DebugOutput = False -np.set_printoptions(suppress=True) #, precision=16, floatmode='maxprec') +np.set_printoptions(suppress=True) # , precision=16, floatmode='maxprec') def print_with_shape(name, a, force_output=False): - if (force_output or DebugOutput): + if force_output or DebugOutput: print(name + " [shape: ", a.shape, "]\n", a) @@ -18,35 +18,40 @@ def print_results(Y): print("*************************") -class GRU_Helper(): - +class GRU_Helper: def __init__(self, **params): # Match the ONNXRuntime/CNTK behavior # If False use the python from the ONNX spec self.match_onnxruntime = True - required_inputs = ['X', 'W', 'R'] + required_inputs = ["X", "W", "R"] for i in required_inputs: assert i in params, "Missing Required Input: {0}".format(i) - num_directions = params['W'].shape[0] - sequence_length = params['X'].shape[0] - - hidden_size = params['R'].shape[-1] - batch_size = params['X'].shape[1] - - X = params['X'] - W = params['W'] - R = params['R'] - B = params['B'] if 'B' in params else np.zeros(num_directions * 6 * - hidden_size).reshape(num_directions, 6 * hidden_size) - H_0 = params['initial_h'] if 'initial_h' in params else np.zeros( - (num_directions, batch_size, hidden_size)).reshape(num_directions, batch_size, hidden_size) - LBR = params['linear_before_reset'] if 'linear_before_reset' in params else 0 - self.direction = params['direction'] if 'direction' in params else 'forward' - - if (num_directions == 1): - if (self.direction == 'forward'): + num_directions = params["W"].shape[0] + sequence_length = params["X"].shape[0] + + hidden_size = params["R"].shape[-1] + batch_size = params["X"].shape[1] + + X = params["X"] + W = params["W"] + R = params["R"] + B = ( + params["B"] + if "B" in params + else np.zeros(num_directions * 6 * hidden_size).reshape(num_directions, 6 * hidden_size) + ) + H_0 = ( + params["initial_h"] + if "initial_h" in params + else np.zeros((num_directions, batch_size, hidden_size)).reshape(num_directions, batch_size, hidden_size) + ) + LBR = params["linear_before_reset"] if "linear_before_reset" in params else 0 + self.direction = params["direction"] if "direction" in params else "forward" + + if num_directions == 1: + if self.direction == "forward": self.one = OneDirectionGRU(X, W, R, B, H_0, LBR) else: # flip input so we process in reverse @@ -66,7 +71,7 @@ def __init__(self, **params): def run(self): - if (self.direction == 'bidirectional'): + if self.direction == "bidirectional": f_output = self.one.execute() r_output = self.two.execute() @@ -87,15 +92,14 @@ def run(self): output = output.reshape(seq_length, 2, batch_size, hidden_size) else: output = self.one.execute() - if (self.direction == 'reverse'): + if self.direction == "reverse": # flip so it's back in the original order of the inputs output = np.flip(output, 0) return output -class OneDirectionGRU(): - +class OneDirectionGRU: def __init__(self, X, W, R, B, initial_h, LBR): self.X = X @@ -119,20 +123,20 @@ def execute(self): [r_z, r_r, r_h] = np.split(self.R, 3) [w_bz, w_br, w_bh, r_bz, r_br, r_bh] = np.split(self.B, 6) - #print_with_shape("w_z", w_z) - #print_with_shape("w_r", w_r) - #print_with_shape("w_h", w_h) + # print_with_shape("w_z", w_z) + # print_with_shape("w_r", w_r) + # print_with_shape("w_h", w_h) - #print_with_shape("r_z", r_z) - #print_with_shape("r_r", r_r) - #print_with_shape("r_h", r_h) + # print_with_shape("r_z", r_z) + # print_with_shape("r_r", r_r) + # print_with_shape("r_h", r_h) - #print_with_shape("w_bz", w_bz) - #print_with_shape("w_br", w_br) - #print_with_shape("w_bh", w_bh) - #print_with_shape("r_bz", r_bz) - #print_with_shape("r_br", r_br) - #print_with_shape("r_bh", r_bh) + # print_with_shape("w_bz", w_bz) + # print_with_shape("w_br", w_br) + # print_with_shape("w_bh", w_bh) + # print_with_shape("r_bz", r_bz) + # print_with_shape("r_br", r_br) + # print_with_shape("r_bh", r_bh) seq_len = self.X.shape[0] num_directions = 1 @@ -163,39 +167,41 @@ def execute(self): return output -class ONNXRuntimeTestContext(): - +class ONNXRuntimeTestContext: @staticmethod def OneDirectionWeights(): hidden_size = 2 - W = np.array([[ - [-0.494659, 0.0453352], # Wz - [-0.487793, 0.417264], - [-0.0091708, -0.255364], # Wr - [-0.106952, -0.266717], - [-0.0888852, -0.428709], # Wh - [-0.283349, 0.208792] - ]]).astype(np.float32) - - R = np.array([[ - [0.146626, -0.0620289], # Rz - [-0.0815302, 0.100482], - [-0.228172, 0.405972], # Rr - [0.31576, 0.281487], - [-0.394864, 0.42111], # Rh - [-0.386624, -0.390225] - ]]).astype(np.float32) - - W_B = np.array([[ - 0.381619, - 0.0323954, # Wbz - -0.258721, - 0.45056, # Wbr - -0.250755, - 0.0967895 - ]]).astype(np.float32) # Wbh + W = np.array( + [ + [ + [-0.494659, 0.0453352], # Wz + [-0.487793, 0.417264], + [-0.0091708, -0.255364], # Wr + [-0.106952, -0.266717], + [-0.0888852, -0.428709], # Wh + [-0.283349, 0.208792], + ] + ] + ).astype(np.float32) + + R = np.array( + [ + [ + [0.146626, -0.0620289], # Rz + [-0.0815302, 0.100482], + [-0.228172, 0.405972], # Rr + [0.31576, 0.281487], + [-0.394864, 0.42111], # Rh + [-0.386624, -0.390225], + ] + ] + ).astype(np.float32) + + W_B = np.array([[0.381619, 0.0323954, -0.258721, 0.45056, -0.250755, 0.0967895,]]).astype( # Wbz # Wbr + np.float32 + ) # Wbh R_B = np.zeros((1, 3 * hidden_size)).astype(np.float32) B = np.concatenate((W_B, R_B), axis=1) @@ -217,8 +223,7 @@ def BidirectionalWeights(): # replicate ONNXRuntime unit tests inputs to validate output -class GRU_ONNXRuntimeUnitTests(): - +class GRU_ONNXRuntimeUnitTests: @staticmethod def ForwardDefaultActivationsSimpleWeightsNoBiasTwoRows(): @@ -228,14 +233,14 @@ def ForwardDefaultActivationsSimpleWeightsNoBiasTwoRows(): batch_size = 2 input_size = 1 hidden_size = 3 - input = np.array([1., 2., 10., 11.]).astype(np.float32).reshape(seq_length, batch_size, input_size) + input = np.array([1.0, 2.0, 10.0, 11.0]).astype(np.float32).reshape(seq_length, batch_size, input_size) W = np.array([0.1, 0.2, 0.3, 1, 2, 3, 10, 11, 12]).astype(np.float32).reshape(1, 3 * hidden_size, input_size) weight_scale = 0.1 R = weight_scale * np.ones((1, 3 * hidden_size, hidden_size)).astype(np.float32) - gru = GRU_Helper(X=input, W=W, R=R, direction='forward') + gru = GRU_Helper(X=input, W=W, R=R, direction="forward") fw_output = gru.run() print_results(fw_output) @@ -248,22 +253,25 @@ def ReverseDefaultActivationsSimpleWeightsNoBiasTwoRows(): batch_size = 2 input_size = 1 hidden_size = 3 - input = np.array([[[1.], [2.]], [[10.], [11.]]]).astype(np.float32) + input = np.array([[[1.0], [2.0]], [[10.0], [11.0]]]).astype(np.float32) W = np.array([0.1, 0.2, 0.3, 1, 2, 3, 10, 11, 12]).astype(np.float32).reshape(1, 3 * hidden_size, input_size) weight_scale = 0.1 R = weight_scale * np.ones((1, 3 * hidden_size, hidden_size)).astype(np.float32) - gru = GRU_Helper(X=input, W=W, R=R, direction='reverse') + gru = GRU_Helper(X=input, W=W, R=R, direction="reverse") fw_output = gru.run() print_results(fw_output) @staticmethod def BidirectionalDefaultActivationsSimpleWeightsNoBias(linear_before_reset=0): - print(GRU_ONNXRuntimeUnitTests.BidirectionalDefaultActivationsSimpleWeightsNoBias.__name__ + - '.linear_before_reset=' + str(linear_before_reset)) + print( + GRU_ONNXRuntimeUnitTests.BidirectionalDefaultActivationsSimpleWeightsNoBias.__name__ + + ".linear_before_reset=" + + str(linear_before_reset) + ) seq_length = 2 batch_size = 3 if linear_before_reset else 2 @@ -271,9 +279,9 @@ def BidirectionalDefaultActivationsSimpleWeightsNoBias(linear_before_reset=0): hidden_size = 3 if linear_before_reset: - input = np.array([[[1.], [2.], [3.]], [[10.], [11.], [12.]]]).astype(np.float32) + input = np.array([[[1.0], [2.0], [3.0]], [[10.0], [11.0], [12.0]]]).astype(np.float32) else: - input = np.array([[[1.], [2.]], [[10.], [11.]]]).astype(np.float32) + input = np.array([[[1.0], [2.0]], [[10.0], [11.0]]]).astype(np.float32) W = np.array([0.1, 0.2, 0.3, 1, 2, 3, 10, 11, 12]).astype(np.float32).reshape(1, 3 * hidden_size, input_size) @@ -281,11 +289,13 @@ def BidirectionalDefaultActivationsSimpleWeightsNoBias(linear_before_reset=0): R = weight_scale * np.ones((1, 3 * hidden_size, hidden_size)).astype(np.float32) # duplicate the W and R inputs so we use the same values for both forward and reverse - gru = GRU_Helper(X=input, - W=np.tile(W, (2, 1)).reshape(2, 3 * hidden_size, input_size), - R=np.tile(R, (2, 1)).reshape(2, 3 * hidden_size, hidden_size), - direction='bidirectional', - linear_before_reset=linear_before_reset) + gru = GRU_Helper( + X=input, + W=np.tile(W, (2, 1)).reshape(2, 3 * hidden_size, input_size), + R=np.tile(R, (2, 1)).reshape(2, 3 * hidden_size, hidden_size), + direction="bidirectional", + linear_before_reset=linear_before_reset, + ) fw_output = gru.run() print_results(fw_output) @@ -293,33 +303,73 @@ def BidirectionalDefaultActivationsSimpleWeightsNoBias(linear_before_reset=0): @staticmethod def DefaultActivationsSimpleWeightsWithBias(rows=2, direction="forward", linear_before_reset=0): - print(GRU_ONNXRuntimeUnitTests.DefaultActivationsSimpleWeightsWithBias.__name__ + " batch_parallel=" + - str(rows != 1) + " direction=" + direction + " linear_before_reset=" + str(linear_before_reset)) + print( + GRU_ONNXRuntimeUnitTests.DefaultActivationsSimpleWeightsWithBias.__name__ + + " batch_parallel=" + + str(rows != 1) + + " direction=" + + direction + + " linear_before_reset=" + + str(linear_before_reset) + ) seq_length = 2 batch_size = rows input_size = 1 hidden_size = 3 - if (batch_size == 1): + if batch_size == 1: input = [-0.1, -0.3] else: input = [-0.1, 0.2, -0.3, 0.4] input = np.array(input).astype(np.float32).reshape(seq_length, batch_size, input_size) - W = np.array([0.1, 0.2, 0.3, 0.2, 0.3, 0.1, 0.3, 0.1, - 0.2]).astype(np.float32).reshape(1, 3 * hidden_size, input_size) + W = ( + np.array([0.1, 0.2, 0.3, 0.2, 0.3, 0.1, 0.3, 0.1, 0.2]) + .astype(np.float32) + .reshape(1, 3 * hidden_size, input_size) + ) weight_scale = 0.1 R = weight_scale * np.ones((1, 3 * hidden_size, hidden_size)).astype(np.float32) # Wb[zrh] Rb[zrh] - B = np.array( - [-0.01, 0.1, 0.01, -0.2, -0.02, 0.02, 0.3, -0.3, -0.3, -0.03, 0.5, -0.7, 0.05, -0.7, 0.3, 0.07, -0.03, - 0.5]).astype(np.float32).reshape(1, 6 * hidden_size) - - gru = GRU_Helper(X=input, W=W, R=R, B=B, direction=direction, linear_before_reset=linear_before_reset) + B = ( + np.array( + [ + -0.01, + 0.1, + 0.01, + -0.2, + -0.02, + 0.02, + 0.3, + -0.3, + -0.3, + -0.03, + 0.5, + -0.7, + 0.05, + -0.7, + 0.3, + 0.07, + -0.03, + 0.5, + ] + ) + .astype(np.float32) + .reshape(1, 6 * hidden_size) + ) + + gru = GRU_Helper( + X=input, + W=W, + R=R, + B=B, + direction=direction, + linear_before_reset=linear_before_reset, + ) fw_output = gru.run() print_results(fw_output) @@ -345,9 +395,9 @@ def ForwardDefaultActivationsSimpleWeightsWithBiasLinearBeforeReset(): @staticmethod def ReverseDefaultActivationsSimpleWeightsWithBiasLinearBeforeReset(): - GRU_ONNXRuntimeUnitTests.DefaultActivationsSimpleWeightsWithBias(rows=1, - direction="reverse", - linear_before_reset=1) + GRU_ONNXRuntimeUnitTests.DefaultActivationsSimpleWeightsWithBias( + rows=1, direction="reverse", linear_before_reset=1 + ) @staticmethod def Legacy_TestGRUOpForwardBasic(): @@ -368,7 +418,7 @@ def Legacy_TestGRUOpBackwardBasic(): input = np.array([[[-0.185934, -0.269585]], [[-0.455351, -0.276391]]]).astype(np.float32) W, R, B = ONNXRuntimeTestContext.OneDirectionWeights() - gru = GRU_Helper(X=input, W=W, R=R, B=B, direction='reverse') + gru = GRU_Helper(X=input, W=W, R=R, B=B, direction="reverse") output = gru.run() print_results(output) @@ -380,7 +430,7 @@ def Legacy_TestGRUOpBidirectionalBasic(): input = np.array([[[-0.455351, -0.276391]], [[-0.185934, -0.269585]]]).astype(np.float32) W, R, B = ONNXRuntimeTestContext.BidirectionalWeights() - gru = GRU_Helper(X=input, W=W, R=R, B=B, direction='bidirectional') + gru = GRU_Helper(X=input, W=W, R=R, B=B, direction="bidirectional") output = gru.run() print_results(output) diff --git a/onnxruntime/test/providers/cpu/rnn/LSTM.py b/onnxruntime/test/providers/cpu/rnn/LSTM.py index c005ba0b18d49..039a419552586 100644 --- a/onnxruntime/test/providers/cpu/rnn/LSTM.py +++ b/onnxruntime/test/providers/cpu/rnn/LSTM.py @@ -1,24 +1,22 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function -from __future__ import unicode_literals +from __future__ import absolute_import, division, print_function, unicode_literals -import numpy as np # type: ignore from typing import Any, Tuple -#import onnx -#from ..base import Base -#from . import expect +import numpy as np # type: ignore + +# import onnx +# from ..base import Base +# from . import expect DebugOutput = True -np.set_printoptions(suppress=True) #, precision=16, floatmode='maxprec') +np.set_printoptions(suppress=True) # , precision=16, floatmode='maxprec') def print_with_shape(name, a, force_output=False): - if (force_output or DebugOutput): + if force_output or DebugOutput: print(name + " [shape: ", a.shape, "]\n", a) @@ -32,42 +30,53 @@ def print_results(Y, Y_h, Y_c): print("*************************") -class LSTM_Helper(): - +class LSTM_Helper: def __init__(self, **params): # type: (*Any) -> None - required_inputs = ['X', 'W', 'R'] + required_inputs = ["X", "W", "R"] for i in required_inputs: assert i in params, "Missing Required Input: {0}".format(i) - X = params['X'] - W = params['W'] - R = params['R'] + X = params["X"] + W = params["W"] + R = params["R"] num_directions = W.shape[0] sequence_length = X.shape[0] batch_size = X.shape[1] hidden_size = R.shape[-1] - B = params['B'] if 'B' in params else np.zeros(num_directions * 8 * - hidden_size).reshape(num_directions, 8 * hidden_size) - P = params['P'] if 'P' in params else np.zeros(num_directions * 3 * - hidden_size).reshape(num_directions, 3 * hidden_size) - h_0 = params['initial_h'] if 'initial_h' in params else np.zeros( - (num_directions, batch_size, hidden_size)).reshape(num_directions, batch_size, hidden_size) - c_0 = params['initial_c'] if 'initial_c' in params else np.zeros( - (num_directions, batch_size, hidden_size)).reshape(num_directions, batch_size, hidden_size) - - f = params['f'] if 'f' in params else ActivationFuncs.sigmoid - g = params['g'] if 'g' in params else ActivationFuncs.tanh - h = params['h'] if 'h' in params else ActivationFuncs.tanh - input_forget = params['input_forget'] if 'input_forget' in params else False - clip = params['clip'] if 'clip' in params else 9999.0 - - self.direction = params['direction'] if 'direction' in params else 'forward' - - if (num_directions == 1): - if (self.direction == 'forward'): + B = ( + params["B"] + if "B" in params + else np.zeros(num_directions * 8 * hidden_size).reshape(num_directions, 8 * hidden_size) + ) + P = ( + params["P"] + if "P" in params + else np.zeros(num_directions * 3 * hidden_size).reshape(num_directions, 3 * hidden_size) + ) + h_0 = ( + params["initial_h"] + if "initial_h" in params + else np.zeros((num_directions, batch_size, hidden_size)).reshape(num_directions, batch_size, hidden_size) + ) + c_0 = ( + params["initial_c"] + if "initial_c" in params + else np.zeros((num_directions, batch_size, hidden_size)).reshape(num_directions, batch_size, hidden_size) + ) + + f = params["f"] if "f" in params else ActivationFuncs.sigmoid + g = params["g"] if "g" in params else ActivationFuncs.tanh + h = params["h"] if "h" in params else ActivationFuncs.tanh + input_forget = params["input_forget"] if "input_forget" in params else False + clip = params["clip"] if "clip" in params else 9999.0 + + self.direction = params["direction"] if "direction" in params else "forward" + + if num_directions == 1: + if self.direction == "forward": self.one = OneDirectionLSTM(X, W, R, B, P, h_0, c_0, f, g, h, input_forget, clip) else: # flip input so we process in reverse @@ -85,11 +94,24 @@ def __init__(self, **params): # type: (*Any) -> None c_0fw, c_0bw = np.vsplit(c_0, 2) self.one = OneDirectionLSTM(X, Wfw, Rfw, Bfw, Pfw, h_0fw, c_0fw, f, g, h, input_forget, clip) - self.two = OneDirectionLSTM(np.flip(X, 0), Wbw, Rbw, Bbw, Pfw, h_0bw, c_0fw, f, g, h, input_forget, clip) + self.two = OneDirectionLSTM( + np.flip(X, 0), + Wbw, + Rbw, + Bbw, + Pfw, + h_0bw, + c_0fw, + f, + g, + h, + input_forget, + clip, + ) def run(self): - if (self.direction == 'bidirectional'): + if self.direction == "bidirectional": f_output, f_Y_h, f_Y_c = self.one.execute() r_output, r_Y_h, r_Y_c = self.two.execute() @@ -103,8 +125,8 @@ def run(self): hidden_size = f_output.shape[3] output = np.empty((0, 2, batch_size, hidden_size), np.float32) - #Y_h = np.empty((0, 2, batch_size, hidden_size), np.float32) - #Y_c = np.empty((0, 2, hidden_size, hidden_size), np.float32) + # Y_h = np.empty((0, 2, batch_size, hidden_size), np.float32) + # Y_c = np.empty((0, 2, hidden_size, hidden_size), np.float32) for x in range(0, seq_length): output = np.append(output, f_output[x]) output = np.append(output, r_output_orig_input_order[x]) @@ -116,7 +138,7 @@ def run(self): else: output, Y_h, Y_c = self.one.execute() - if (self.direction == 'reverse'): + if self.direction == "reverse": # flip so it's back in the original order of the inputs output = np.flip(output, 0) @@ -124,7 +146,6 @@ def run(self): class ActivationFuncs: - @staticmethod def sigmoid(x): return 1 / (1 + np.exp(-x)) @@ -135,20 +156,21 @@ def tanh(x): class OneDirectionLSTM: - - def __init__(self, - X, - W, - R, - B, - P, - initial_h, - initial_c, - f=ActivationFuncs.sigmoid, - g=ActivationFuncs.tanh, - h=ActivationFuncs.tanh, - input_forget=False, - clip=9999.0): + def __init__( + self, + X, + W, + R, + B, + P, + initial_h, + initial_c, + f=ActivationFuncs.sigmoid, + g=ActivationFuncs.tanh, + h=ActivationFuncs.tanh, + input_forget=False, + clip=9999.0, + ): self.X = X # remove num_directions axis for W, R, B, P, H_0, C_0 @@ -184,7 +206,7 @@ def execute(self): # type: () -> Tuple[np.ndarray, np.ndarray] for x in np.split(self.X, self.X.shape[0], axis=0): print_with_shape("Xt1", x) - #gates = np.dot(x, np.transpose(self.W)) + np.dot(H_t, np.transpose(self.R)) + np.add(*np.split(self.B, 2)) + # gates = np.dot(x, np.transpose(self.W)) + np.dot(H_t, np.transpose(self.R)) + np.add(*np.split(self.B, 2)) print_with_shape("W^T", np.transpose(self.W)) # t0 == t-1, t1 == current @@ -204,7 +226,7 @@ def execute(self): # type: () -> Tuple[np.ndarray, np.ndarray] print_with_shape("ct_in", ct_in) i = self.f(np.clip((it_in + p_i * C_t), -self.clip, self.clip)) - if (self.input_forget): + if self.input_forget: f = 1.0 - i # this is what ONNXRuntime does else: f = self.f(np.clip((ft_in + p_f * C_t), -self.clip, self.clip)) @@ -229,7 +251,6 @@ def execute(self): # type: () -> Tuple[np.ndarray, np.ndarray] class LSTM: # Base): - @staticmethod def SimpleWeightsNoBiasTwoRows(direction): # type: () -> None @@ -241,15 +262,18 @@ def SimpleWeightsNoBiasTwoRows(direction): # type: () -> None hidden_size = 3 number_of_gates = 4 - input = np.array([[[1.], [2.]], [[10.], [11.]]]).astype(np.float32) + input = np.array([[[1.0], [2.0]], [[10.0], [11.0]]]).astype(np.float32) - W = np.array([0.1, 0.2, 0.3, 0.4, 1, 2, 3, 4, 10, 11, 12, - 13]).astype(np.float32).reshape(1, number_of_gates * hidden_size, input_size) + W = ( + np.array([0.1, 0.2, 0.3, 0.4, 1, 2, 3, 4, 10, 11, 12, 13]) + .astype(np.float32) + .reshape(1, number_of_gates * hidden_size, input_size) + ) weight_scale = 0.1 R = weight_scale * np.ones((1, number_of_gates * hidden_size, hidden_size)).astype(np.float32) - if (direction == 'bidirectional'): + if direction == "bidirectional": W = W = np.tile(W, (2, 1)).reshape(2, number_of_gates * hidden_size, input_size) R = R = np.tile(R, (2, 1)).reshape(2, number_of_gates * hidden_size, hidden_size) @@ -271,11 +295,17 @@ def LargeBatchWithClip(clip): number_of_gates = 4 # sequentialvalues from 1 to 32 - input = np.array(range(1, seq_length * batch_size + 1, - 1)).astype(np.float32).reshape(seq_length, batch_size, input_size) - - W = np.array([0.1, 0.2, 0.3, 0.4, 1, 2, 3, 4, 10, 11, 12, - 13]).astype(np.float32).reshape(1, number_of_gates * hidden_size, input_size) + input = ( + np.array(range(1, seq_length * batch_size + 1, 1)) + .astype(np.float32) + .reshape(seq_length, batch_size, input_size) + ) + + W = ( + np.array([0.1, 0.2, 0.3, 0.4, 1, 2, 3, 4, 10, 11, 12, 13]) + .astype(np.float32) + .reshape(1, number_of_gates * hidden_size, input_size) + ) weight_scale = 0.1 R = weight_scale * np.ones((1, number_of_gates * hidden_size, hidden_size)).astype(np.float32) @@ -297,8 +327,11 @@ def BatchParallelFalseSeqLengthGreaterThanOne(): input = np.array([1, 2]).astype(np.float32).reshape(seq_length, batch_size, input_size) - W = np.array([0.1, 0.2, 0.3, 0.4, 1, 2, 3, 4]).astype(np.float32).reshape(1, number_of_gates * hidden_size, - input_size) + W = ( + np.array([0.1, 0.2, 0.3, 0.4, 1, 2, 3, 4]) + .astype(np.float32) + .reshape(1, number_of_gates * hidden_size, input_size) + ) weight_scale = 0.1 R = weight_scale * np.ones((1, number_of_gates * hidden_size, hidden_size)).astype(np.float32) @@ -313,7 +346,7 @@ def export_initial_bias(): # type: () -> None print(LSTM.export_initial_bias.__name__) - input = np.array([[[1., 2., 3.], [4., 5., 6.], [7., 8., 9.]]]).astype(np.float32) + input = np.array([[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]]]).astype(np.float32) input_size = 3 hidden_size = 4 @@ -343,7 +376,7 @@ def export_initial_bias(): # type: () -> None @staticmethod def export_peepholes(): # type: () -> None - input = np.array([[[1., 2., 3., 4.], [5., 6., 7., 8.]]]).astype(np.float32) + input = np.array([[[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0]]]).astype(np.float32) input_size = 4 hidden_size = 3 @@ -386,40 +419,88 @@ def OneDirectionWeights(): hidden_size = ONNXRuntimeTestContext.hidden_size input_size = ONNXRuntimeTestContext.input_size - W = np.array([ - -0.494659, 0.0453352, -0.487793, 0.417264, -0.0175329, 0.489074, -0.446013, 0.414029, -0.0091708, -0.255364, - -0.106952, -0.266717, -0.0888852, -0.428709, -0.283349, 0.208792 - ]).reshape(num_directions, 4 * hidden_size, input_size).astype(np.float32) - - R = np.array([ - 0.146626, -0.0620289, -0.0815302, 0.100482, -0.219535, -0.306635, -0.28515, -0.314112, -0.228172, 0.405972, - 0.31576, 0.281487, -0.394864, 0.42111, -0.386624, -0.390225 - ]).reshape(num_directions, 4 * hidden_size, hidden_size).astype(np.float32) - - P = np.array([0.2345, 0.5235, 0.4378, 0.3475, 0.8927, 0.3456]).reshape(num_directions, - 3 * hidden_size).astype(np.float32) + W = ( + np.array( + [ + -0.494659, + 0.0453352, + -0.487793, + 0.417264, + -0.0175329, + 0.489074, + -0.446013, + 0.414029, + -0.0091708, + -0.255364, + -0.106952, + -0.266717, + -0.0888852, + -0.428709, + -0.283349, + 0.208792, + ] + ) + .reshape(num_directions, 4 * hidden_size, input_size) + .astype(np.float32) + ) + + R = ( + np.array( + [ + 0.146626, + -0.0620289, + -0.0815302, + 0.100482, + -0.219535, + -0.306635, + -0.28515, + -0.314112, + -0.228172, + 0.405972, + 0.31576, + 0.281487, + -0.394864, + 0.42111, + -0.386624, + -0.390225, + ] + ) + .reshape(num_directions, 4 * hidden_size, hidden_size) + .astype(np.float32) + ) + + P = ( + np.array([0.2345, 0.5235, 0.4378, 0.3475, 0.8927, 0.3456]) + .reshape(num_directions, 3 * hidden_size) + .astype(np.float32) + ) # // [8*hidden] - B = np.array([ - 0.381619, - 0.0323954, - -0.14449, - 0.420804, - -0.258721, - 0.45056, - -0.250755, - 0.0967895, - - # peephole bias - 0.0, - 0.0, - 0.0, - 0.0, - 0.0, - 0.0, - 0.0, - 0.0 - ]).reshape(num_directions, 8 * hidden_size).astype(np.float32) + B = ( + np.array( + [ + 0.381619, + 0.0323954, + -0.14449, + 0.420804, + -0.258721, + 0.45056, + -0.250755, + 0.0967895, + # peephole bias + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + ] + ) + .reshape(num_directions, 8 * hidden_size) + .astype(np.float32) + ) return W, R, B, P @@ -444,22 +525,23 @@ def DefaultInput(): batch_size = 1 input_size = 2 - input = np.array([-0.455351, -0.276391, -0.185934, -0.269585])\ - .reshape(seq_length, batch_size, input_size)\ - .astype(np.float32) + input = ( + np.array([-0.455351, -0.276391, -0.185934, -0.269585]) + .reshape(seq_length, batch_size, input_size) + .astype(np.float32) + ) return input class ONNXRuntimeUnitTests: - @staticmethod def ONNXRuntime_TestLSTMBidirectionalBasic(): print(ONNXRuntimeUnitTests.ONNXRuntime_TestLSTMBidirectionalBasic.__name__) input = ONNXRuntimeTestContext.DefaultInput() W, R, B, P = ONNXRuntimeTestContext.BidirectionalWeights() - lstm = LSTM_Helper(X=input, W=W, R=R, B=B, P=P, direction='bidirectional') + lstm = LSTM_Helper(X=input, W=W, R=R, B=B, P=P, direction="bidirectional") Y, Y_h, Y_c = lstm.run() print_results(Y, Y_h, Y_c) @@ -498,7 +580,7 @@ def ONNXRuntime_TestLSTMBackward(): input = ONNXRuntimeTestContext.DefaultInput() W, R, B, P = ONNXRuntimeTestContext.OneDirectionWeights() - lstm = LSTM_Helper(X=input, W=W, R=R, B=B, P=P, direction='reverse') + lstm = LSTM_Helper(X=input, W=W, R=R, B=B, P=P, direction="reverse") Y, Y_h, Y_c = lstm.run() print_results(Y, Y_h, Y_c) @@ -532,13 +614,15 @@ def ONNXRuntime_TestLSTMActivation(): input = ONNXRuntimeTestContext.DefaultInput() W, R, B, P = ONNXRuntimeTestContext.OneDirectionWeights() - lstm = LSTM_Helper(X=input, - W=W, - R=R, - B=B, - f=ActivationFuncs.tanh, - g=ActivationFuncs.sigmoid, - h=ActivationFuncs.tanh) + lstm = LSTM_Helper( + X=input, + W=W, + R=R, + B=B, + f=ActivationFuncs.tanh, + g=ActivationFuncs.sigmoid, + h=ActivationFuncs.tanh, + ) Y, Y_h, Y_c = lstm.run() print_results(Y, Y_h, Y_c) @@ -552,31 +636,51 @@ def ONNXRuntime_TestLSTMBatchReallocation(): input = ONNXRuntimeTestContext.DefaultInput() W, R, B, P = ONNXRuntimeTestContext.OneDirectionWeights() - lstm = LSTM_Helper(X=input, - W=W, - R=R, - B=B, - f=ActivationFuncs.tanh, - g=ActivationFuncs.sigmoid, - h=ActivationFuncs.tanh) + lstm = LSTM_Helper( + X=input, + W=W, + R=R, + B=B, + f=ActivationFuncs.tanh, + g=ActivationFuncs.sigmoid, + h=ActivationFuncs.tanh, + ) Y, Y_h, Y_c = lstm.run() print_results(Y, Y_h, Y_c) print("===============") batch_size = 3 - input = np.array([ - -0.455351, -0.476391, -0.555351, -0.376391, -0.655351, -0.276391, -0.185934, -0.869585, -0.285934, - -0.769585, -0.385934, -0.669585 - ]).reshape(seq_length, batch_size, input_size).astype(np.float32) + input = ( + np.array( + [ + -0.455351, + -0.476391, + -0.555351, + -0.376391, + -0.655351, + -0.276391, + -0.185934, + -0.869585, + -0.285934, + -0.769585, + -0.385934, + -0.669585, + ] + ) + .reshape(seq_length, batch_size, input_size) + .astype(np.float32) + ) W, R, B, P = ONNXRuntimeTestContext.OneDirectionWeights() - lstm = LSTM_Helper(X=input, - W=W, - R=R, - B=B, - f=ActivationFuncs.tanh, - g=ActivationFuncs.sigmoid, - h=ActivationFuncs.tanh) + lstm = LSTM_Helper( + X=input, + W=W, + R=R, + B=B, + f=ActivationFuncs.tanh, + g=ActivationFuncs.sigmoid, + h=ActivationFuncs.tanh, + ) Y, Y_h, Y_c = lstm.run() print_results(Y, Y_h, Y_c) @@ -590,42 +694,62 @@ def ONNXRuntime_TestLSTMOutputWrite(): input = ONNXRuntimeTestContext.DefaultInput() W, R, B, P = ONNXRuntimeTestContext.BidirectionalWeights() - lstm = LSTM_Helper(X=input, - W=W, - R=R, - B=B, - direction='bidirectional', - f=ActivationFuncs.tanh, - g=ActivationFuncs.sigmoid, - h=ActivationFuncs.tanh) + lstm = LSTM_Helper( + X=input, + W=W, + R=R, + B=B, + direction="bidirectional", + f=ActivationFuncs.tanh, + g=ActivationFuncs.sigmoid, + h=ActivationFuncs.tanh, + ) Y, Y_h, Y_c = lstm.run() print_results(Y, Y_h, Y_c) print("===============") batch_size = 3 - input = np.array([ - -0.455351, -0.776391, -0.355351, -0.576391, -0.255351, -0.376391, -0.185934, -0.169585, -0.285934, - -0.469585, -0.385934, -0.669585 - ]).reshape(seq_length, batch_size, input_size).astype(np.float32) + input = ( + np.array( + [ + -0.455351, + -0.776391, + -0.355351, + -0.576391, + -0.255351, + -0.376391, + -0.185934, + -0.169585, + -0.285934, + -0.469585, + -0.385934, + -0.669585, + ] + ) + .reshape(seq_length, batch_size, input_size) + .astype(np.float32) + ) W, R, B, P = ONNXRuntimeTestContext.BidirectionalWeights() - lstm = LSTM_Helper(X=input, - W=W, - R=R, - B=B, - direction='bidirectional', - f=ActivationFuncs.tanh, - g=ActivationFuncs.sigmoid, - h=ActivationFuncs.tanh) + lstm = LSTM_Helper( + X=input, + W=W, + R=R, + B=B, + direction="bidirectional", + f=ActivationFuncs.tanh, + g=ActivationFuncs.sigmoid, + h=ActivationFuncs.tanh, + ) Y, Y_h, Y_c = lstm.run() print_results(Y, Y_h, Y_c) DebugOutput = False -LSTM.SimpleWeightsNoBiasTwoRows('forward') -LSTM.SimpleWeightsNoBiasTwoRows('reverse') -LSTM.SimpleWeightsNoBiasTwoRows('bidirectional') +LSTM.SimpleWeightsNoBiasTwoRows("forward") +LSTM.SimpleWeightsNoBiasTwoRows("reverse") +LSTM.SimpleWeightsNoBiasTwoRows("bidirectional") LSTM.LargeBatchWithClip(99999.0) # too large to affect output LSTM.LargeBatchWithClip(4.0) LSTM.BatchParallelFalseSeqLengthGreaterThanOne() diff --git a/onnxruntime/test/python/contrib_ops/onnx_contrib_ops_helper.py b/onnxruntime/test/python/contrib_ops/onnx_contrib_ops_helper.py index 25f53ce878ef5..0003975d9c6da 100644 --- a/onnxruntime/test/python/contrib_ops/onnx_contrib_ops_helper.py +++ b/onnxruntime/test/python/contrib_ops/onnx_contrib_ops_helper.py @@ -3,14 +3,15 @@ # # Helper functions for generating ONNX model and data to test ONNX Runtime contrib ops -import onnx import os -from onnx import numpy_helper -import subprocess import shutil +import subprocess + +import onnx +from onnx import numpy_helper TOP_DIR = os.path.realpath(os.path.dirname(__file__)) -DATA_DIR = os.path.join(TOP_DIR, '..', 'testdata/') +DATA_DIR = os.path.join(TOP_DIR, "..", "testdata/") def prepare_dir(path): @@ -23,59 +24,57 @@ def _extract_value_info(arr, name, ele_type=None): return onnx.helper.make_tensor_value_info( name=name, elem_type=ele_type if ele_type else onnx.mapping.NP_TYPE_TO_TENSOR_TYPE[arr.dtype], - shape=arr.shape) + shape=arr.shape, + ) def generate_data(graph, inputs, outputs, name): output_dir = os.path.join(DATA_DIR, name) prepare_dir(output_dir) model = onnx.helper.make_model(graph) - with open(os.path.join(output_dir, 'model.onnx'), 'wb') as f: + with open(os.path.join(output_dir, "model.onnx"), "wb") as f: f.write(model.SerializeToString()) - data_set = os.path.join(output_dir, 'test_data_set_0') + data_set = os.path.join(output_dir, "test_data_set_0") prepare_dir(data_set) for j, input_np in enumerate(inputs): - tensor = numpy_helper.from_array( - input_np, model.graph.input[j].name) - with open(os.path.join( - data_set, 'input_{}.pb'.format(j)), 'wb') as f: + tensor = numpy_helper.from_array(input_np, model.graph.input[j].name) + with open(os.path.join(data_set, "input_{}.pb".format(j)), "wb") as f: f.write(tensor.SerializeToString()) for j, output_np in enumerate(outputs): - tensor = numpy_helper.from_array( - output_np, model.graph.output[j].name) - with open(os.path.join( - data_set, 'output_{}.pb'.format(j)), 'wb') as f: + tensor = numpy_helper.from_array(output_np, model.graph.output[j].name) + with open(os.path.join(data_set, "output_{}.pb".format(j)), "wb") as f: f.write(tensor.SerializeToString()) -def expect(node, # type: onnx.NodeProto - inputs, - outputs, - name, - **kwargs - ): # type: (...) -> None - present_inputs = [x for x in node.input if (x != '')] - present_outputs = [x for x in node.output if (x != '')] +def expect( + node, # type: onnx.NodeProto + inputs, + outputs, + name, + **kwargs +): # type: (...) -> None + present_inputs = [x for x in node.input if (x != "")] + present_outputs = [x for x in node.output if (x != "")] input_types = [None] * len(inputs) - if 'input_types' in kwargs: - input_types = kwargs[str('input_types')] - del kwargs[str('input_types')] + if "input_types" in kwargs: + input_types = kwargs[str("input_types")] + del kwargs[str("input_types")] output_types = [None] * len(outputs) - if 'output_types' in kwargs: - output_types = kwargs[str('output_types')] - del kwargs[str('output_types')] - inputs_vi = [_extract_value_info(arr, arr_name, input_type) - for arr, arr_name, input_type in zip(inputs, present_inputs, input_types)] - outputs_vi = [_extract_value_info(arr, arr_name, output_type) - for arr, arr_name, output_type in zip(outputs, present_outputs, output_types)] - graph = onnx.helper.make_graph( - nodes=[node], - name=name, - inputs=inputs_vi, - outputs=outputs_vi) + if "output_types" in kwargs: + output_types = kwargs[str("output_types")] + del kwargs[str("output_types")] + inputs_vi = [ + _extract_value_info(arr, arr_name, input_type) + for arr, arr_name, input_type in zip(inputs, present_inputs, input_types) + ] + outputs_vi = [ + _extract_value_info(arr, arr_name, output_type) + for arr, arr_name, output_type in zip(outputs, present_outputs, output_types) + ] + graph = onnx.helper.make_graph(nodes=[node], name=name, inputs=inputs_vi, outputs=outputs_vi) generate_data(graph, inputs, outputs, name) cwd = os.getcwd() - onnx_test_runner = os.path.join(cwd, 'onnx_test_runner') - subprocess.run([onnx_test_runner, DATA_DIR+name], check=True, cwd=cwd) + onnx_test_runner = os.path.join(cwd, "onnx_test_runner") + subprocess.run([onnx_test_runner, DATA_DIR + name], check=True, cwd=cwd) diff --git a/onnxruntime/test/python/contrib_ops/onnx_test_torch_embedding.py b/onnxruntime/test/python/contrib_ops/onnx_test_torch_embedding.py index 20f5acd53df56..f152962710e55 100644 --- a/onnxruntime/test/python/contrib_ops/onnx_test_torch_embedding.py +++ b/onnxruntime/test/python/contrib_ops/onnx_test_torch_embedding.py @@ -3,9 +3,10 @@ # # Test reference implementation and model for ONNX Runtime conrtib op torch_embedding -import onnx import unittest + import numpy as np +import onnx from onnx_contrib_ops_helper import expect @@ -16,48 +17,48 @@ def torch_embedding_reference_implementation(weight, indices, padding_idx=None, class ONNXReferenceImplementationTest(unittest.TestCase): def test_torch_embedding(self): node = onnx.helper.make_node( - 'TorchEmbedding', - inputs=['w', 'x'], - outputs=['y'], + "TorchEmbedding", + inputs=["w", "x"], + outputs=["y"], domain="com.microsoft", ) x = np.random.randn(2, 4).astype(np.int64) w = np.random.randn(10, 3).astype(np.float32) y = torch_embedding_reference_implementation(w, x) - expect(node, inputs=[w, x], outputs=[y], name='test_torch_embedding') + expect(node, inputs=[w, x], outputs=[y], name="test_torch_embedding") def test_torch_embedding_long(self): node = onnx.helper.make_node( - 'TorchEmbedding', - inputs=['w', 'x'], - outputs=['y'], + "TorchEmbedding", + inputs=["w", "x"], + outputs=["y"], domain="com.microsoft", ) x = np.random.randn(2, 4).astype(np.int64) w = np.random.randn(10, 3).astype(np.int64) y = torch_embedding_reference_implementation(w, x) - expect(node, inputs=[w, x], outputs=[y], name='test_torch_embedding_long') + expect(node, inputs=[w, x], outputs=[y], name="test_torch_embedding_long") def test_torch_embedding_zero_dim(self): node = onnx.helper.make_node( - 'TorchEmbedding', - inputs=['w', 'x'], - outputs=['y'], + "TorchEmbedding", + inputs=["w", "x"], + outputs=["y"], domain="com.microsoft", ) x = np.random.randn(0, 4).astype(np.int64) w = np.random.randn(10, 3).astype(np.float32) y = torch_embedding_reference_implementation(w, x) - expect(node, inputs=[w, x], outputs=[y], name='test_torch_embedding_zero_dim') + expect(node, inputs=[w, x], outputs=[y], name="test_torch_embedding_zero_dim") def test_torch_embedding_padding_idx(self): node = onnx.helper.make_node( - 'TorchEmbedding', - inputs=['w', 'x', 'padding_idx'], - outputs=['y'], + "TorchEmbedding", + inputs=["w", "x", "padding_idx"], + outputs=["y"], domain="com.microsoft", ) @@ -65,13 +66,18 @@ def test_torch_embedding_padding_idx(self): w = np.random.randn(10, 3).astype(np.float32) padding_idx = np.random.randint(3, size=1).astype(np.int64) y = torch_embedding_reference_implementation(w, x, padding_idx) - expect(node, inputs=[w, x, padding_idx], outputs=[y], name='test_torch_embedding_padding_idx') + expect( + node, + inputs=[w, x, padding_idx], + outputs=[y], + name="test_torch_embedding_padding_idx", + ) def test_torch_embedding_scale_grad_by_freq(self): node = onnx.helper.make_node( - 'TorchEmbedding', - inputs=['w', 'x', 'padding_idx', 'scale'], - outputs=['y'], + "TorchEmbedding", + inputs=["w", "x", "padding_idx", "scale"], + outputs=["y"], domain="com.microsoft", ) @@ -80,8 +86,13 @@ def test_torch_embedding_scale_grad_by_freq(self): padding_idx = np.random.randint(3, size=1).astype(np.int64) scale = np.array([1]).astype(np.bool) y = torch_embedding_reference_implementation(w, x, padding_idx, scale) - expect(node, inputs=[w, x, padding_idx, scale], outputs=[y], name='test_torch_embedding_scale_grad_by_freq') + expect( + node, + inputs=[w, x, padding_idx, scale], + outputs=[y], + name="test_torch_embedding_scale_grad_by_freq", + ) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main(module=__name__, buffer=True) diff --git a/onnxruntime/test/python/contrib_ops/onnx_test_trilu.py b/onnxruntime/test/python/contrib_ops/onnx_test_trilu.py index 87e9a1a819936..d09f8fb96f127 100644 --- a/onnxruntime/test/python/contrib_ops/onnx_test_trilu.py +++ b/onnxruntime/test/python/contrib_ops/onnx_test_trilu.py @@ -3,9 +3,10 @@ # # Test reference implementation and model for ONNX Runtime conrtib op trilu -import onnx import unittest + import numpy as np +import onnx from onnx_contrib_ops_helper import expect @@ -20,150 +21,150 @@ def tril_reference_implementation(x, k=0): class ONNXReferenceImplementationTest(unittest.TestCase): def test_triu(self): node = onnx.helper.make_node( - 'Trilu', - inputs=['x'], - outputs=['y'], + "Trilu", + inputs=["x"], + outputs=["y"], domain="com.microsoft", ) x = np.random.randn(3, 4, 5).astype(np.float32) y = triu_reference_implementation(x) - expect(node, inputs=[x], outputs=[y], name='test_triu') + expect(node, inputs=[x], outputs=[y], name="test_triu") def test_triu_neg(self): node = onnx.helper.make_node( - 'Trilu', - inputs=['x', 'k'], - outputs=['y'], + "Trilu", + inputs=["x", "k"], + outputs=["y"], domain="com.microsoft", ) x = np.random.randn(3, 4, 5).astype(np.float32) k = np.array([-1]).astype(np.int64) y = triu_reference_implementation(x, k) - expect(node, inputs=[x, k], outputs=[y], name='test_triu_neg') + expect(node, inputs=[x, k], outputs=[y], name="test_triu_neg") def test_triu_out_neg(self): node = onnx.helper.make_node( - 'Trilu', - inputs=['x', 'k'], - outputs=['y'], + "Trilu", + inputs=["x", "k"], + outputs=["y"], domain="com.microsoft", ) x = np.random.randn(3, 4, 5).astype(np.float32) k = np.array([-7]).astype(np.int64) y = triu_reference_implementation(x, k) - expect(node, inputs=[x, k], outputs=[y], name='test_triu_out_neg') + expect(node, inputs=[x, k], outputs=[y], name="test_triu_out_neg") def test_triu_pos(self): node = onnx.helper.make_node( - 'Trilu', - inputs=['x', 'k'], - outputs=['y'], + "Trilu", + inputs=["x", "k"], + outputs=["y"], domain="com.microsoft", ) x = np.random.randn(3, 4, 5).astype(np.float32) k = np.array([2]).astype(np.int64) y = triu_reference_implementation(x, k) - expect(node, inputs=[x, k], outputs=[y], name='test_triu_pos') + expect(node, inputs=[x, k], outputs=[y], name="test_triu_pos") def test_triu_out_pos(self): node = onnx.helper.make_node( - 'Trilu', - inputs=['x', 'k'], - outputs=['y'], + "Trilu", + inputs=["x", "k"], + outputs=["y"], domain="com.microsoft", ) x = np.random.randn(3, 4, 5).astype(np.float32) k = np.array([6]).astype(np.int64) y = triu_reference_implementation(x, k) - expect(node, inputs=[x, k], outputs=[y], name='test_triu_out_pos') + expect(node, inputs=[x, k], outputs=[y], name="test_triu_out_pos") def test_triu_square(self): node = onnx.helper.make_node( - 'Trilu', - inputs=['x'], - outputs=['y'], + "Trilu", + inputs=["x"], + outputs=["y"], domain="com.microsoft", ) x = np.random.randn(3, 5, 5).astype(np.float32) y = triu_reference_implementation(x) - expect(node, inputs=[x], outputs=[y], name='test_triu_square') + expect(node, inputs=[x], outputs=[y], name="test_triu_square") def test_triu_square_neg(self): node = onnx.helper.make_node( - 'Trilu', - inputs=['x', 'k'], - outputs=['y'], + "Trilu", + inputs=["x", "k"], + outputs=["y"], domain="com.microsoft", ) x = np.random.randn(3, 5, 5).astype(np.float32) k = np.array([-1]).astype(np.int64) y = triu_reference_implementation(x, k) - expect(node, inputs=[x, k], outputs=[y], name='test_triu_square_neg') + expect(node, inputs=[x, k], outputs=[y], name="test_triu_square_neg") def test_triu_one_row_neg(self): node = onnx.helper.make_node( - 'Trilu', - inputs=['x', 'k'], - outputs=['y'], + "Trilu", + inputs=["x", "k"], + outputs=["y"], domain="com.microsoft", ) x = np.random.randn(3, 1, 5).astype(np.float32) k = np.array([-7]).astype(np.int64) y = triu_reference_implementation(x, k) - expect(node, inputs=[x, k], outputs=[y], name='test_triu_one_row_neg') + expect(node, inputs=[x, k], outputs=[y], name="test_triu_one_row_neg") def test_triu_square_pos(self): node = onnx.helper.make_node( - 'Trilu', - inputs=['x', 'k'], - outputs=['y'], + "Trilu", + inputs=["x", "k"], + outputs=["y"], domain="com.microsoft", ) x = np.random.randn(3, 5, 5).astype(np.float32) k = np.array([2]).astype(np.int64) y = triu_reference_implementation(x, k) - expect(node, inputs=[x, k], outputs=[y], name='test_triu_square_pos') + expect(node, inputs=[x, k], outputs=[y], name="test_triu_square_pos") def test_triu_zero(self): node = onnx.helper.make_node( - 'Trilu', - inputs=['x', 'k'], - outputs=['y'], + "Trilu", + inputs=["x", "k"], + outputs=["y"], domain="com.microsoft", ) x = np.random.randn(3, 0, 5).astype(np.float32) k = np.array([6]).astype(np.int64) y = triu_reference_implementation(x, k) - expect(node, inputs=[x, k], outputs=[y], name='test_triu_zero') + expect(node, inputs=[x, k], outputs=[y], name="test_triu_zero") def test_tril(self): node = onnx.helper.make_node( - 'Trilu', - inputs=['x'], - outputs=['y'], + "Trilu", + inputs=["x"], + outputs=["y"], upper=0, domain="com.microsoft", ) x = np.random.randn(3, 4, 5).astype(np.float32) y = tril_reference_implementation(x) - expect(node, inputs=[x], outputs=[y], name='test_tril') + expect(node, inputs=[x], outputs=[y], name="test_tril") def test_tril_neg(self): node = onnx.helper.make_node( - 'Trilu', - inputs=['x', 'k'], - outputs=['y'], + "Trilu", + inputs=["x", "k"], + outputs=["y"], upper=0, domain="com.microsoft", ) @@ -171,13 +172,13 @@ def test_tril_neg(self): x = np.random.randn(3, 4, 5).astype(np.float32) k = np.array([-1]).astype(np.int64) y = tril_reference_implementation(x, k) - expect(node, inputs=[x, k], outputs=[y], name='test_tril_neg') + expect(node, inputs=[x, k], outputs=[y], name="test_tril_neg") def test_tril_out_neg(self): node = onnx.helper.make_node( - 'Trilu', - inputs=['x', 'k'], - outputs=['y'], + "Trilu", + inputs=["x", "k"], + outputs=["y"], upper=0, domain="com.microsoft", ) @@ -185,13 +186,13 @@ def test_tril_out_neg(self): x = np.random.randn(3, 4, 5).astype(np.float32) k = np.array([-7]).astype(np.int64) y = tril_reference_implementation(x, k) - expect(node, inputs=[x, k], outputs=[y], name='test_tril_out_neg') + expect(node, inputs=[x, k], outputs=[y], name="test_tril_out_neg") def test_tril_pos(self): node = onnx.helper.make_node( - 'Trilu', - inputs=['x', 'k'], - outputs=['y'], + "Trilu", + inputs=["x", "k"], + outputs=["y"], upper=0, domain="com.microsoft", ) @@ -199,13 +200,13 @@ def test_tril_pos(self): x = np.random.randn(3, 4, 5).astype(np.float32) k = np.array([2]).astype(np.int64) y = tril_reference_implementation(x, k) - expect(node, inputs=[x, k], outputs=[y], name='test_tril_pos') + expect(node, inputs=[x, k], outputs=[y], name="test_tril_pos") def test_tril_out_pos(self): node = onnx.helper.make_node( - 'Trilu', - inputs=['x', 'k'], - outputs=['y'], + "Trilu", + inputs=["x", "k"], + outputs=["y"], upper=0, domain="com.microsoft", ) @@ -213,26 +214,26 @@ def test_tril_out_pos(self): x = np.random.randn(3, 4, 5).astype(np.float32) k = np.array([6]).astype(np.int64) y = tril_reference_implementation(x, k) - expect(node, inputs=[x, k], outputs=[y], name='test_tril_out_pos') + expect(node, inputs=[x, k], outputs=[y], name="test_tril_out_pos") def test_tril_square(self): node = onnx.helper.make_node( - 'Trilu', - inputs=['x'], - outputs=['y'], + "Trilu", + inputs=["x"], + outputs=["y"], upper=0, domain="com.microsoft", ) x = np.random.randn(3, 5, 5).astype(np.float32) y = tril_reference_implementation(x) - expect(node, inputs=[x], outputs=[y], name='test_tril_square') + expect(node, inputs=[x], outputs=[y], name="test_tril_square") def test_tril_square_neg(self): node = onnx.helper.make_node( - 'Trilu', - inputs=['x', 'k'], - outputs=['y'], + "Trilu", + inputs=["x", "k"], + outputs=["y"], upper=0, domain="com.microsoft", ) @@ -240,13 +241,13 @@ def test_tril_square_neg(self): x = np.random.randn(3, 5, 5).astype(np.float32) k = np.array([-1]).astype(np.int64) y = tril_reference_implementation(x, k) - expect(node, inputs=[x, k], outputs=[y], name='test_tril_square_neg') + expect(node, inputs=[x, k], outputs=[y], name="test_tril_square_neg") def test_tril_one_row_neg(self): node = onnx.helper.make_node( - 'Trilu', - inputs=['x', 'k'], - outputs=['y'], + "Trilu", + inputs=["x", "k"], + outputs=["y"], upper=0, domain="com.microsoft", ) @@ -254,13 +255,13 @@ def test_tril_one_row_neg(self): x = np.random.randn(3, 1, 5).astype(np.float32) k = np.array([-7]).astype(np.int64) y = tril_reference_implementation(x, k) - expect(node, inputs=[x, k], outputs=[y], name='test_tril_one_row_neg') + expect(node, inputs=[x, k], outputs=[y], name="test_tril_one_row_neg") def test_tril_square_pos(self): node = onnx.helper.make_node( - 'Trilu', - inputs=['x', 'k'], - outputs=['y'], + "Trilu", + inputs=["x", "k"], + outputs=["y"], upper=0, domain="com.microsoft", ) @@ -268,13 +269,13 @@ def test_tril_square_pos(self): x = np.random.randn(3, 5, 5).astype(np.float32) k = np.array([2]).astype(np.int64) y = tril_reference_implementation(x, k) - expect(node, inputs=[x, k], outputs=[y], name='test_tril_square_pos') + expect(node, inputs=[x, k], outputs=[y], name="test_tril_square_pos") def test_tril_zero(self): node = onnx.helper.make_node( - 'Trilu', - inputs=['x', 'k'], - outputs=['y'], + "Trilu", + inputs=["x", "k"], + outputs=["y"], upper=0, domain="com.microsoft", ) @@ -282,8 +283,8 @@ def test_tril_zero(self): x = np.random.randn(3, 0, 5).astype(np.float32) k = np.array([6]).astype(np.int64) y = tril_reference_implementation(x, k) - expect(node, inputs=[x, k], outputs=[y], name='test_tril_zero') + expect(node, inputs=[x, k], outputs=[y], name="test_tril_zero") -if __name__ == '__main__': +if __name__ == "__main__": unittest.main(module=__name__, buffer=True) diff --git a/onnxruntime/test/python/helper.py b/onnxruntime/test/python/helper.py index ba1d01637d23a..66a4f27319c1a 100644 --- a/onnxruntime/test/python/helper.py +++ b/onnxruntime/test/python/helper.py @@ -1,5 +1,6 @@ import os + def get_name(name): if os.path.exists(name): return name @@ -12,4 +13,3 @@ def get_name(name): if os.path.exists(res): return res raise FileNotFoundError("Unable to find '{0}' or '{1}' or '{2}'".format(name, rel, res)) - diff --git a/onnxruntime/test/python/onnx_backend_test_series.py b/onnxruntime/test/python/onnx_backend_test_series.py index 3165c6514c83b..90f152b605373 100644 --- a/onnxruntime/test/python/onnx_backend_test_series.py +++ b/onnxruntime/test/python/onnx_backend_test_series.py @@ -3,17 +3,18 @@ import argparse import json -import sys import os import platform +import sys import unittest + +import numpy as np import onnx import onnx.backend.test -import numpy as np import onnxruntime.backend as c2 -pytest_plugins = 'onnx.backend.test.report', +pytest_plugins = ("onnx.backend.test.report",) class OrtBackendTest(onnx.backend.test.BackendTest): @@ -27,7 +28,8 @@ def assert_similar_array(ref_output, output): if ref_output.dtype == np.object: np.testing.assert_array_equal(ref_output, output) else: - np.testing.assert_allclose(ref_output, output, rtol=1e-3, atol=1e-5) + np.testing.assert_allclose(ref_output, output, rtol=1e-3, atol=1e-5) + np.testing.assert_equal(len(ref_outputs), len(outputs)) for i in range(len(outputs)): if isinstance(outputs[i], list): @@ -41,98 +43,109 @@ def create_backend_test(testname=None): backend_test = OrtBackendTest(c2, __name__) # Type not supported - backend_test.exclude(r'(FLOAT16)') + backend_test.exclude(r"(FLOAT16)") if testname: - backend_test.include(testname + '.*') + backend_test.include(testname + ".*") else: # read filters data with open( - os.path.join(os.path.dirname(os.path.realpath(__file__)), 'testdata', - 'onnx_backend_test_series_filters.jsonc')) as f: + os.path.join( + os.path.dirname(os.path.realpath(__file__)), + "testdata", + "onnx_backend_test_series_filters.jsonc", + ) + ) as f: filters_lines = f.readlines() - filters_lines = [x.split('//')[0] for x in filters_lines] - filters = json.loads('\n'.join(filters_lines)) + filters_lines = [x.split("//")[0] for x in filters_lines] + filters = json.loads("\n".join(filters_lines)) - current_failing_tests = filters['current_failing_tests'] + current_failing_tests = filters["current_failing_tests"] - if platform.architecture()[0] == '32bit': - current_failing_tests += filters['current_failing_tests_x86'] + if platform.architecture()[0] == "32bit": + current_failing_tests += filters["current_failing_tests_x86"] - if c2.supports_device('DNNL'): - current_failing_tests += filters['current_failing_tests_DNNL'] + if c2.supports_device("DNNL"): + current_failing_tests += filters["current_failing_tests_DNNL"] - if c2.supports_device('NNAPI'): - current_failing_tests += filters['current_failing_tests_NNAPI'] + if c2.supports_device("NNAPI"): + current_failing_tests += filters["current_failing_tests_NNAPI"] - if c2.supports_device('OPENVINO_GPU_FP32') or c2.supports_device('OPENVINO_GPU_FP16'): - current_failing_tests += filters['current_failing_tests_OPENVINO_GPU'] + if c2.supports_device("OPENVINO_GPU_FP32") or c2.supports_device("OPENVINO_GPU_FP16"): + current_failing_tests += filters["current_failing_tests_OPENVINO_GPU"] - if c2.supports_device('OPENVINO_MYRIAD'): - current_failing_tests += filters['current_failing_tests_OPENVINO_GPU'] - current_failing_tests += filters['current_failing_tests_OPENVINO_MYRIAD'] + if c2.supports_device("OPENVINO_MYRIAD"): + current_failing_tests += filters["current_failing_tests_OPENVINO_GPU"] + current_failing_tests += filters["current_failing_tests_OPENVINO_MYRIAD"] - if c2.supports_device('OPENVINO_CPU_FP32'): - current_failing_tests += filters['current_failing_tests_OPENVINO_CPU_FP32'] + if c2.supports_device("OPENVINO_CPU_FP32"): + current_failing_tests += filters["current_failing_tests_OPENVINO_CPU_FP32"] - if c2.supports_device('MIGRAPHX'): + if c2.supports_device("MIGRAPHX"): current_failing_tests += [ - '^test_constant_pad_cpu', '^test_round_cpu', '^test_lrn_default_cpu', '^test_lrn_cpu', - '^test_dynamicquantizelinear_expanded_cpu', '^test_dynamicquantizelinear_max_adjusted_cpu', - '^test_dynamicquantizelinear_max_adjusted_expanded_cpu', '^test_dynamicquantizelinear_min_adjusted_cpu', - '^test_dynamicquantizelinear_min_adjusted_expanded_cpu', - '^test_range_float_type_positive_delta_expanded_cpu', - '^test_range_int32_type_negative_delta_expanded_cpu', - '^test_operator_symbolic_override_nested_cpu', - '^test_negative_log_likelihood_loss', - '^test_softmax_cross_entropy', - '^test_greater_equal', - '^test_if_seq_cpu', - '^test_loop11_cpu', - '^test_loop13_seq_cpu', - '^test_sequence_insert_at_back_cpu', - '^test_sequence_insert_at_front_cpu', - '^test_nonmaxsuppression_two_classes_cpu', - '^test_nonmaxsuppression_two_batches_cpu', - '^test_nonmaxsuppression_suppress_by_IOU_cpu', - '^test_nonmaxsuppression_suppress_by_IOU_and_scores_cpu', - '^test_nonmaxsuppression_limit_output_size_cpu', - '^test_nonmaxsuppression_identical_boxes_cpu', - '^test_nonmaxsuppression_flipped_coordinates_cpu', - '^test_nonmaxsuppression_center_point_box_format_cpu' + "^test_constant_pad_cpu", + "^test_round_cpu", + "^test_lrn_default_cpu", + "^test_lrn_cpu", + "^test_dynamicquantizelinear_expanded_cpu", + "^test_dynamicquantizelinear_max_adjusted_cpu", + "^test_dynamicquantizelinear_max_adjusted_expanded_cpu", + "^test_dynamicquantizelinear_min_adjusted_cpu", + "^test_dynamicquantizelinear_min_adjusted_expanded_cpu", + "^test_range_float_type_positive_delta_expanded_cpu", + "^test_range_int32_type_negative_delta_expanded_cpu", + "^test_operator_symbolic_override_nested_cpu", + "^test_negative_log_likelihood_loss", + "^test_softmax_cross_entropy", + "^test_greater_equal", + "^test_if_seq_cpu", + "^test_loop11_cpu", + "^test_loop13_seq_cpu", + "^test_sequence_insert_at_back_cpu", + "^test_sequence_insert_at_front_cpu", + "^test_nonmaxsuppression_two_classes_cpu", + "^test_nonmaxsuppression_two_batches_cpu", + "^test_nonmaxsuppression_suppress_by_IOU_cpu", + "^test_nonmaxsuppression_suppress_by_IOU_and_scores_cpu", + "^test_nonmaxsuppression_limit_output_size_cpu", + "^test_nonmaxsuppression_identical_boxes_cpu", + "^test_nonmaxsuppression_flipped_coordinates_cpu", + "^test_nonmaxsuppression_center_point_box_format_cpu", ] # Skip these tests for a "pure" DML onnxruntime python wheel. We keep these tests enabled for instances where both DML and CUDA # EPs are available (Windows GPU CI pipeline has this config) - these test will pass because CUDA has higher precendence than DML # and the nodes are assigned to only the CUDA EP (which supports these tests) - if c2.supports_device('DML') and not c2.supports_device('GPU'): + if c2.supports_device("DML") and not c2.supports_device("GPU"): current_failing_tests += [ - '^test_negative_log_likelihood_loss_input_shape_is_NCd1d2d3_none_no_weight_negative_ignore_index_cpu', - '^test_negative_log_likelihood_loss_input_shape_is_NCd1d2d3_none_no_weight_negative_ignore_index_expanded_cpu', - '^test_softmax_cross_entropy_input_shape_is_NCd1d2d3_none_no_weight_negative_ignore_index_cpu', - '^test_softmax_cross_entropy_input_shape_is_NCd1d2d3_none_no_weight_negative_ignore_index_expanded_cpu', - '^test_softmax_cross_entropy_input_shape_is_NCd1d2d3_none_no_weight_negative_ignore_index_log_prob_cpu', - '^test_softmax_cross_entropy_input_shape_is_NCd1d2d3_none_no_weight_negative_ignore_index_log_prob_expanded_cpu', - '^test_asin_example_cpu', - '^test_dynamicquantizelinear_cpu', - '^test_dynamicquantizelinear_expanded_cpu', - '^test_resize_downsample_scales_linear_cpu', - '^test_resize_downsample_sizes_linear_pytorch_half_pixel_cpu', - '^test_resize_downsample_sizes_nearest_cpu', - '^test_resize_upsample_sizes_nearest_cpu', - '^test_roialign_cpu' + "^test_negative_log_likelihood_loss_input_shape_is_NCd1d2d3_none_no_weight_negative_ignore_index_cpu", + "^test_negative_log_likelihood_loss_input_shape_is_NCd1d2d3_none_no_weight_negative_ignore_index_expanded_cpu", + "^test_softmax_cross_entropy_input_shape_is_NCd1d2d3_none_no_weight_negative_ignore_index_cpu", + "^test_softmax_cross_entropy_input_shape_is_NCd1d2d3_none_no_weight_negative_ignore_index_expanded_cpu", + "^test_softmax_cross_entropy_input_shape_is_NCd1d2d3_none_no_weight_negative_ignore_index_log_prob_cpu", + "^test_softmax_cross_entropy_input_shape_is_NCd1d2d3_none_no_weight_negative_ignore_index_log_prob_expanded_cpu", + "^test_asin_example_cpu", + "^test_dynamicquantizelinear_cpu", + "^test_dynamicquantizelinear_expanded_cpu", + "^test_resize_downsample_scales_linear_cpu", + "^test_resize_downsample_sizes_linear_pytorch_half_pixel_cpu", + "^test_resize_downsample_sizes_nearest_cpu", + "^test_resize_upsample_sizes_nearest_cpu", + "^test_roialign_cpu", ] - filters = current_failing_tests + \ - filters['tests_with_pre_opset7_dependencies'] + \ - filters['unsupported_usages'] + \ - filters['failing_permanently'] + \ - filters['test_with_types_disabled_due_to_binary_size_concerns'] + filters = ( + current_failing_tests + + filters["tests_with_pre_opset7_dependencies"] + + filters["unsupported_usages"] + + filters["failing_permanently"] + + filters["test_with_types_disabled_due_to_binary_size_concerns"] + ) - backend_test.exclude('(' + '|'.join(filters) + ')') - print('excluded tests:', filters) + backend_test.exclude("(" + "|".join(filters) + ")") + print("excluded tests:", filters) - # exclude TRT EP temporarily and only test CUDA EP to retain previous behavior + # exclude TRT EP temporarily and only test CUDA EP to retain previous behavior os.environ["ORT_ONNX_BACKEND_EXCLUDE_PROVIDERS"] = "TensorrtExecutionProvider" # import all test cases at global scope to make @@ -143,19 +156,22 @@ def create_backend_test(testname=None): def parse_args(): - parser = argparse.ArgumentParser(os.path.basename(__file__), - description='Run the ONNX backend tests using ONNXRuntime.') + parser = argparse.ArgumentParser( + os.path.basename(__file__), + description="Run the ONNX backend tests using ONNXRuntime.", + ) # Add an argument to match a single test name, by adding the name to the 'include' filter. # Using -k with python unittest (https://docs.python.org/3/library/unittest.html#command-line-options) # doesn't work as it filters on the test method name (Runner._add_model_test) rather than inidividual # test case names. parser.add_argument( - '-t', - '--test-name', - dest='testname', + "-t", + "--test-name", + dest="testname", type=str, - help="Only run tests that match this value. Matching is regex based, and '.*' is automatically appended") + help="Only run tests that match this value. Matching is regex based, and '.*' is automatically appended", + ) # parse just our args. python unittest has its own args and arg parsing, and that runs inside unittest.main() args, left = parser.parse_known_args() @@ -164,7 +180,7 @@ def parse_args(): return args -if __name__ == '__main__': +if __name__ == "__main__": args = parse_args() backend_test = create_backend_test(args.testname) diff --git a/onnxruntime/test/python/onnxruntime_test_ort_trainer.py b/onnxruntime/test/python/onnxruntime_test_ort_trainer.py index 9bbd57da0d985..f269c53f2e9f0 100644 --- a/onnxruntime/test/python/onnxruntime_test_ort_trainer.py +++ b/onnxruntime/test/python/onnxruntime_test_ort_trainer.py @@ -1,28 +1,43 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. +import copy import os -import unittest -import pytest import sys -import copy -import numpy as np -from numpy.testing import assert_allclose, assert_array_equal +import unittest +import numpy as np import onnx +import pytest import torch import torch.nn as nn import torch.nn.functional as F +from helper import get_name +from numpy.testing import assert_allclose, assert_array_equal from torchvision import datasets, transforms -from helper import get_name import onnxruntime -from onnxruntime.capi.ort_trainer import ORTTrainer, IODescription, ModelDescription, LossScaler, generate_sample, save_checkpoint, load_checkpoint +from onnxruntime.capi.ort_trainer import ( + IODescription, + LossScaler, + ModelDescription, + ORTTrainer, + generate_sample, + load_checkpoint, + save_checkpoint, +) SCRIPT_DIR = os.path.realpath(os.path.dirname(__file__)) + def ort_trainer_learning_rate_description(): - return IODescription('Learning_Rate', [1, ], torch.float32) + return IODescription( + "Learning_Rate", + [ + 1, + ], + torch.float32, + ) def remove_extra_info(model_desc): @@ -35,18 +50,44 @@ def remove_extra_info(model_desc): output_desc.num_classes_ = None return simple_model_desc + def bert_model_description(): vocab_size = 30528 - input_ids_desc = IODescription('input_ids', ['batch', 'max_seq_len_in_batch'], torch.int64, num_classes=vocab_size) - segment_ids_desc = IODescription('segment_ids', ['batch', 'max_seq_len_in_batch'], torch.int64, num_classes=2) - input_mask_desc = IODescription('input_mask', ['batch', 'max_seq_len_in_batch'], torch.int64, num_classes=2) - masked_lm_labels_desc = IODescription('masked_lm_labels', ['batch', 'max_seq_len_in_batch'], torch.int64, - num_classes=vocab_size) - next_sentence_labels_desc = IODescription('next_sentence_labels', ['batch', ], torch.int64, num_classes=2) - loss_desc = IODescription('loss', [], torch.float32) + input_ids_desc = IODescription( + "input_ids", + ["batch", "max_seq_len_in_batch"], + torch.int64, + num_classes=vocab_size, + ) + segment_ids_desc = IODescription("segment_ids", ["batch", "max_seq_len_in_batch"], torch.int64, num_classes=2) + input_mask_desc = IODescription("input_mask", ["batch", "max_seq_len_in_batch"], torch.int64, num_classes=2) + masked_lm_labels_desc = IODescription( + "masked_lm_labels", + ["batch", "max_seq_len_in_batch"], + torch.int64, + num_classes=vocab_size, + ) + next_sentence_labels_desc = IODescription( + "next_sentence_labels", + [ + "batch", + ], + torch.int64, + num_classes=2, + ) + loss_desc = IODescription("loss", [], torch.float32) + + return ModelDescription( + [ + input_ids_desc, + segment_ids_desc, + input_mask_desc, + masked_lm_labels_desc, + next_sentence_labels_desc, + ], + [loss_desc], + ) - return ModelDescription([input_ids_desc, segment_ids_desc, input_mask_desc, masked_lm_labels_desc, - next_sentence_labels_desc], [loss_desc]) def map_optimizer_attributes(name): no_decay_keys = ["bias", "gamma", "beta", "LayerNorm"] @@ -56,18 +97,22 @@ def map_optimizer_attributes(name): else: return {"alpha": 0.9, "beta": 0.999, "lambda": 0.01, "epsilon": 1e-6} + def generate_sample_batch(desc, batch_size, device): desc_ = copy.deepcopy(desc) desc_.shape_[0] = batch_size sample = generate_sample(desc_, device) return sample -def create_ort_trainer(gradient_accumulation_steps, - use_mixed_precision, - allreduce_post_accumulation, - use_simple_model_desc=True, - loss_scaler=None, - deepspeed_zero_stage=0): + +def create_ort_trainer( + gradient_accumulation_steps, + use_mixed_precision, + allreduce_post_accumulation, + use_simple_model_desc=True, + loss_scaler=None, + deepspeed_zero_stage=0, +): model_desc = bert_model_description() simple_model_desc = remove_extra_info(model_desc) if use_simple_model_desc else model_desc learning_rate_description = ort_trainer_learning_rate_description() @@ -75,34 +120,45 @@ def create_ort_trainer(gradient_accumulation_steps, onnx_model = onnx.load(get_name("bert_toy_postprocessed.onnx")) - model = ORTTrainer(onnx_model, None, simple_model_desc, "LambOptimizer", - map_optimizer_attributes, - learning_rate_description, - device, - gradient_accumulation_steps=gradient_accumulation_steps, - world_rank=0, world_size=1, - loss_scaler=loss_scaler, - use_mixed_precision=use_mixed_precision, - allreduce_post_accumulation=allreduce_post_accumulation, - deepspeed_zero_stage = deepspeed_zero_stage) + model = ORTTrainer( + onnx_model, + None, + simple_model_desc, + "LambOptimizer", + map_optimizer_attributes, + learning_rate_description, + device, + gradient_accumulation_steps=gradient_accumulation_steps, + world_rank=0, + world_size=1, + loss_scaler=loss_scaler, + use_mixed_precision=use_mixed_precision, + allreduce_post_accumulation=allreduce_post_accumulation, + deepspeed_zero_stage=deepspeed_zero_stage, + ) return model, model_desc, device -def runBertTrainingTest(gradient_accumulation_steps, - use_mixed_precision, - allreduce_post_accumulation, - use_simple_model_desc=True, - use_internel_loss_scale=False): + +def runBertTrainingTest( + gradient_accumulation_steps, + use_mixed_precision, + allreduce_post_accumulation, + use_simple_model_desc=True, + use_internel_loss_scale=False, +): torch.manual_seed(1) onnxruntime.set_seed(1) loss_scaler = LossScaler("ort_test_input_loss_scalar", True) if use_internel_loss_scale else None - model, model_desc, device = create_ort_trainer(gradient_accumulation_steps, - use_mixed_precision, - allreduce_post_accumulation, - use_simple_model_desc, - loss_scaler) + model, model_desc, device = create_ort_trainer( + gradient_accumulation_steps, + use_mixed_precision, + allreduce_post_accumulation, + use_simple_model_desc, + loss_scaler, + ) if loss_scaler is None: loss_scaler = LossScaler(model.loss_scale_input_name, True) @@ -115,14 +171,38 @@ def runBertTrainingTest(gradient_accumulation_steps, batch_size = 16 num_batches = 8 for batch in range(num_batches): - input_ids_batches = [*input_ids_batches, generate_sample_batch(model_desc.inputs_[0], batch_size, device)] - segment_ids_batches = [*segment_ids_batches, generate_sample_batch(model_desc.inputs_[1], batch_size, device)] - input_mask_batches = [*input_mask_batches, generate_sample_batch(model_desc.inputs_[2], batch_size, device)] - masked_lm_labels_batches = [*masked_lm_labels_batches, generate_sample_batch(model_desc.inputs_[3], batch_size, device)] - next_sentence_labels_batches = [*next_sentence_labels_batches, generate_sample_batch(model_desc.inputs_[4], batch_size, device)] - - lr_batch_list = [0.0000000e+00, 4.6012269e-07, 9.2024538e-07, 1.3803681e-06, 1.8404908e-06, - 2.3006135e-06, 2.7607362e-06, 3.2208588e-06, 3.6809815e-06] + input_ids_batches = [ + *input_ids_batches, + generate_sample_batch(model_desc.inputs_[0], batch_size, device), + ] + segment_ids_batches = [ + *segment_ids_batches, + generate_sample_batch(model_desc.inputs_[1], batch_size, device), + ] + input_mask_batches = [ + *input_mask_batches, + generate_sample_batch(model_desc.inputs_[2], batch_size, device), + ] + masked_lm_labels_batches = [ + *masked_lm_labels_batches, + generate_sample_batch(model_desc.inputs_[3], batch_size, device), + ] + next_sentence_labels_batches = [ + *next_sentence_labels_batches, + generate_sample_batch(model_desc.inputs_[4], batch_size, device), + ] + + lr_batch_list = [ + 0.0000000e00, + 4.6012269e-07, + 9.2024538e-07, + 1.3803681e-06, + 1.8404908e-06, + 2.3006135e-06, + 2.7607362e-06, + 3.2208588e-06, + 3.6809815e-06, + ] actual_losses = [] actual_all_finites = [] @@ -136,12 +216,14 @@ def runBertTrainingTest(gradient_accumulation_steps, lr = lr_batch_list[batch_count] learning_rate = torch.tensor([lr]).to(device) - training_args = [input_ids, - segment_ids, - input_mask, - masked_lm_labels, - next_sentence_labels, - learning_rate] + training_args = [ + input_ids, + segment_ids, + input_mask, + masked_lm_labels, + next_sentence_labels, + learning_rate, + ] if use_mixed_precision: if not use_internel_loss_scale: loss_scale = torch.tensor([loss_scaler.loss_scale_]).to(device) @@ -152,7 +234,10 @@ def runBertTrainingTest(gradient_accumulation_steps, actual_loss, actual_all_finite = actual_loss if not use_internel_loss_scale: loss_scaler.update_loss_scale(actual_all_finite.item()) - actual_all_finites = [*actual_all_finites, actual_all_finite.cpu().numpy().item(0)] + actual_all_finites = [ + *actual_all_finites, + actual_all_finite.cpu().numpy().item(0), + ] actual_losses = [*actual_losses, actual_loss.cpu().numpy().item(0)] else: @@ -162,7 +247,14 @@ def runBertTrainingTest(gradient_accumulation_steps, if batch_count == num_batches - 1: # test eval_step api with fetches at the end of the training. # if eval_step is called during the training, it will affect the actual training loss (training session is stateful). - eval_loss = model.eval_step(input_ids, segment_ids, input_mask, masked_lm_labels, next_sentence_labels, fetches=['loss']) + eval_loss = model.eval_step( + input_ids, + segment_ids, + input_mask, + masked_lm_labels, + next_sentence_labels, + fetches=["loss"], + ) eval_loss = eval_loss.cpu().numpy().item(0) # If using internal loss scale, all_finites are handled internally too. @@ -171,7 +263,8 @@ def runBertTrainingTest(gradient_accumulation_steps, else: return actual_losses, eval_loss -class MNISTWrapper(): + +class MNISTWrapper: class NeuralNet(nn.Module): def __init__(self, input_size, hidden_size, num_classes): super(MNISTWrapper.NeuralNet, self).__init__() @@ -213,9 +306,15 @@ def train_with_trainer(self, learningRate, trainer, device, train_loader, epoch) args_log_interval = 100 if batch_idx % args_log_interval == 0: - print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( - epoch, batch_idx * len(data), len(train_loader.dataset), - 100. * batch_idx / len(train_loader), loss.item())) + print( + "Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format( + epoch, + batch_idx * len(data), + len(train_loader.dataset), + 100.0 * batch_idx / len(train_loader), + loss.item(), + ) + ) actual_losses = [*actual_losses, loss.cpu().numpy().item()] return actual_losses @@ -228,42 +327,65 @@ def test_with_trainer(self, trainer, device, test_loader): for data, target in test_loader: data, target = data.to(device), target.to(device) data = data.reshape(data.shape[0], -1) - output = F.log_softmax(trainer.eval_step((data), fetches=['probability']), dim=1) - test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss - pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability + output = F.log_softmax(trainer.eval_step((data), fetches=["probability"]), dim=1) + test_loss += F.nll_loss(output, target, reduction="sum").item() # sum up batch loss + pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability correct += pred.eq(target.view_as(pred)).sum().item() test_loss /= len(test_loader.dataset) - print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format( - test_loss, correct, len(test_loader.dataset), - 100. * correct / len(test_loader.dataset))) + print( + "\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n".format( + test_loss, + correct, + len(test_loader.dataset), + 100.0 * correct / len(test_loader.dataset), + ) + ) return test_loss, correct / len(test_loader.dataset) def mnist_model_description(): - input_desc = IODescription('input1', ['batch', 784], torch.float32) - label_desc = IODescription('label', ['batch', ], torch.int64, num_classes=10) - loss_desc = IODescription('loss', [], torch.float32) - probability_desc = IODescription('probability', ['batch', 10], torch.float32) + input_desc = IODescription("input1", ["batch", 784], torch.float32) + label_desc = IODescription( + "label", + [ + "batch", + ], + torch.int64, + num_classes=10, + ) + loss_desc = IODescription("loss", [], torch.float32) + probability_desc = IODescription("probability", ["batch", 10], torch.float32) return ModelDescription([input_desc, label_desc], [loss_desc, probability_desc]) def get_loaders(self): args_batch_size = 64 args_test_batch_size = 1000 - kwargs = {'num_workers': 0, 'pin_memory': True} + kwargs = {"num_workers": 0, "pin_memory": True} # set shuffle to False to get deterministic data set among different torch version train_loader = torch.utils.data.DataLoader( - datasets.MNIST(os.path.join(SCRIPT_DIR, 'data'), train=True, download=True, - transform=transforms.Compose([transforms.ToTensor(), - transforms.Normalize((0.1307,), (0.3081,))])), - batch_size=args_batch_size, shuffle=False, **kwargs) + datasets.MNIST( + os.path.join(SCRIPT_DIR, "data"), + train=True, + download=True, + transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]), + ), + batch_size=args_batch_size, + shuffle=False, + **kwargs + ) test_loader = torch.utils.data.DataLoader( - datasets.MNIST(os.path.join(SCRIPT_DIR, 'data'), train=False, transform=transforms.Compose([ - transforms.ToTensor(), - transforms.Normalize((0.1307,), (0.3081,))])), - batch_size=args_test_batch_size, shuffle=False, **kwargs) + datasets.MNIST( + os.path.join(SCRIPT_DIR, "data"), + train=False, + transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]), + ), + batch_size=args_test_batch_size, + shuffle=False, + **kwargs + ) return train_loader, test_loader @@ -287,15 +409,39 @@ def get_model_with_internal_loss(self): model_desc = MNISTWrapper.mnist_model_description() return model, model_desc - def get_trainer(self, model, model_desc, device, onnx_opset_ver=12, frozen_weights=[], - internal_loss_fn=False, get_lr_this_step=None, optimizer="SGDOptimizer"): + def get_trainer( + self, + model, + model_desc, + device, + onnx_opset_ver=12, + frozen_weights=[], + internal_loss_fn=False, + get_lr_this_step=None, + optimizer="SGDOptimizer", + ): loss_fn = MNISTWrapper.my_loss if not internal_loss_fn else None - return ORTTrainer(model, loss_fn, model_desc, optimizer, None, IODescription('Learning_Rate', [1, ], - torch.float32), device, _opset_version=onnx_opset_ver, frozen_weights=frozen_weights, - get_lr_this_step=get_lr_this_step) + return ORTTrainer( + model, + loss_fn, + model_desc, + optimizer, + None, + IODescription( + "Learning_Rate", + [ + 1, + ], + torch.float32, + ), + device, + _opset_version=onnx_opset_ver, + frozen_weights=frozen_weights, + get_lr_this_step=get_lr_this_step, + ) -class TestOrtTrainer(unittest.TestCase): +class TestOrtTrainer(unittest.TestCase): def run_mnist_training_and_testing(onnx_opset_ver): torch.manual_seed(1) device = torch.device("cuda") @@ -307,11 +453,28 @@ def run_mnist_training_and_testing(onnx_opset_ver): learningRate = 0.01 args_epochs = 2 - expected_losses = [2.312044143676758, 0.8018650412559509, 0.5819257497787476, 0.47025489807128906, - 0.35800155997276306, 0.41124576330184937, 0.2731882333755493, 0.4201386570930481, - 0.39458805322647095, 0.38380366563796997, 0.2722422480583191, 0.24230478703975677, - 0.23505745828151703, 0.33442264795303345, 0.21140924096107483, 0.31545233726501465, - 0.18556523323059082, 0.3453553020954132, 0.29598352313041687, 0.3595045208930969] + expected_losses = [ + 2.312044143676758, + 0.8018650412559509, + 0.5819257497787476, + 0.47025489807128906, + 0.35800155997276306, + 0.41124576330184937, + 0.2731882333755493, + 0.4201386570930481, + 0.39458805322647095, + 0.38380366563796997, + 0.2722422480583191, + 0.24230478703975677, + 0.23505745828151703, + 0.33442264795303345, + 0.21140924096107483, + 0.31545233726501465, + 0.18556523323059082, + 0.3453553020954132, + 0.29598352313041687, + 0.3595045208930969, + ] expected_test_losses = [0.3145490005493164, 0.256188737487793] expected_test_accuracies = [0.9075, 0.9265] @@ -319,7 +482,10 @@ def run_mnist_training_and_testing(onnx_opset_ver): actual_losses = [] actual_test_losses, actual_accuracies = [], [] for epoch in range(1, args_epochs + 1): - actual_losses = [*actual_losses, *mnist.train_with_trainer(learningRate, trainer, device, train_loader, epoch)] + actual_losses = [ + *actual_losses, + *mnist.train_with_trainer(learningRate, trainer, device, train_loader, epoch), + ] test_loss, accuracy = mnist.test_with_trainer(trainer, device, test_loader) actual_test_losses = [*actual_test_losses, test_loss] @@ -328,9 +494,8 @@ def run_mnist_training_and_testing(onnx_opset_ver): # if you update outcomes, also do so for resume from checkpoint test # args_checkpoint_epoch = 1 # if epoch == args_checkpoint_epoch: - # state = {'rng_state': torch.get_rng_state(), 'model': trainer.state_dict()} - # torch.save(state, get_name("ckpt_mnist.pt")) - + # state = {'rng_state': torch.get_rng_state(), 'model': trainer.state_dict()} + # torch.save(state, get_name("ckpt_mnist.pt")) print("actual_losses=", actual_losses) print("actual_test_losses=", actual_test_losses) @@ -340,11 +505,21 @@ def run_mnist_training_and_testing(onnx_opset_ver): # import pdb; pdb.set_trace() rtol = 1e-03 assert_allclose(expected_losses, actual_losses, rtol=rtol, err_msg="loss mismatch") - assert_allclose(expected_test_losses, actual_test_losses, rtol=rtol, err_msg="test loss mismatch") - assert_allclose(expected_test_accuracies, actual_accuracies, rtol=rtol, err_msg="test accuracy mismatch") + assert_allclose( + expected_test_losses, + actual_test_losses, + rtol=rtol, + err_msg="test loss mismatch", + ) + assert_allclose( + expected_test_accuracies, + actual_accuracies, + rtol=rtol, + err_msg="test accuracy mismatch", + ) def testMNISTTrainingAndTestingOpset12(self): - TestOrtTrainer.run_mnist_training_and_testing(onnx_opset_ver = 12) + TestOrtTrainer.run_mnist_training_and_testing(onnx_opset_ver=12) def testMNISTResumeTrainingAndTesting(self): torch.manual_seed(1) @@ -358,9 +533,18 @@ def testMNISTResumeTrainingAndTesting(self): args_epochs = 2 args_checkpoint_epoch = 1 # should match those in test without checkpointing - expected_losses = [0.26509523391723633, 0.24135658144950867, 0.2397943139076233, 0.3351520597934723, - 0.20998981595039368, 0.31488314270973206, 0.18481917679309845, 0.34727591276168823, - 0.2971782684326172, 0.3609251379966736] + expected_losses = [ + 0.26509523391723633, + 0.24135658144950867, + 0.2397943139076233, + 0.3351520597934723, + 0.20998981595039368, + 0.31488314270973206, + 0.18481917679309845, + 0.34727591276168823, + 0.2971782684326172, + 0.3609251379966736, + ] expected_test_losses = [0.25632242965698243] expected_test_accuracies = [0.9264] @@ -371,12 +555,15 @@ def testMNISTResumeTrainingAndTesting(self): # restore from checkpoint resume_trainer = mnist.get_trainer(model, model_desc, device) checkpoint = torch.load(get_name("ckpt_mnist.pt"), map_location="cpu") - torch.set_rng_state(checkpoint['rng_state']) - resume_trainer.load_state_dict(checkpoint['model'], strict=True) + torch.set_rng_state(checkpoint["rng_state"]) + resume_trainer.load_state_dict(checkpoint["model"], strict=True) # continue .. for epoch in range(args_checkpoint_epoch + 1, args_epochs + 1): - actual_losses = [*actual_losses, *mnist.train_with_trainer(learningRate, resume_trainer, device, train_loader, epoch)] + actual_losses = [ + *actual_losses, + *mnist.train_with_trainer(learningRate, resume_trainer, device, train_loader, epoch), + ] test_loss, accuracy = mnist.test_with_trainer(resume_trainer, device, test_loader) actual_test_losses = [*actual_test_losses, test_loss] @@ -390,8 +577,18 @@ def testMNISTResumeTrainingAndTesting(self): # import pdb; pdb.set_trace() rtol = 1e-03 assert_allclose(expected_losses, actual_losses, rtol=rtol, err_msg="loss mismatch") - assert_allclose(expected_test_losses, actual_test_losses, rtol=rtol, err_msg="test loss mismatch") - assert_allclose(expected_test_accuracies, actual_accuracies, rtol=rtol, err_msg="test accuracy mismatch") + assert_allclose( + expected_test_losses, + actual_test_losses, + rtol=rtol, + err_msg="test loss mismatch", + ) + assert_allclose( + expected_test_accuracies, + actual_accuracies, + rtol=rtol, + err_msg="test accuracy mismatch", + ) def testMNISTStateDict(self): torch.manual_seed(1) @@ -415,12 +612,18 @@ def testMNISTStateDict(self): loss, _ = trainer.train_step(data, target, torch.tensor([learningRate])) state_dict = trainer.state_dict() - assert state_dict.keys() == {'fc1.bias', 'fc1.weight', 'fc2.bias', 'fc2.weight', 'bias_buffer'} + assert state_dict.keys() == { + "fc1.bias", + "fc1.weight", + "fc2.bias", + "fc2.weight", + "bias_buffer", + } def testMNISTSaveAsONNX(self): torch.manual_seed(1) device = torch.device("cuda") - onnx_file_name = 'mnist.onnx' + onnx_file_name = "mnist.onnx" if os.path.exists(onnx_file_name): os.remove(onnx_file_name) @@ -452,7 +655,7 @@ def testMNISTDevice(self): train_loader, test_loader = mnist.get_loaders() model, model_desc = mnist.get_model() - for model_device in [torch.device('cpu'), torch.device('cuda')]: + for model_device in [torch.device("cpu"), torch.device("cuda")]: model.to(model_device) trainer = mnist.get_trainer(model, model_desc, device) learningRate = 0.02 @@ -482,8 +685,9 @@ def testMNISTInitializerNames(self): loss, _ = trainer.train_step(data, target, torch.tensor([learningRate])) - assert (set([n.name for n in trainer.onnx_model_.graph.initializer])-set(['bias_buffer'])) \ - == set([n for n, t in model.named_parameters()]) + assert (set([n.name for n in trainer.onnx_model_.graph.initializer]) - set(["bias_buffer"])) == set( + [n for n, t in model.named_parameters()] + ) def testMNISTInitializerNamesWithInternalLoss(self): torch.manual_seed(1) @@ -493,13 +697,17 @@ def testMNISTInitializerNamesWithInternalLoss(self): train_loader, test_loader = mnist.get_loaders() model, model_desc = mnist.get_model_with_internal_loss() - def get_lr_this_step(global_step): learningRate = 0.02 return torch.tensor([learningRate]) - trainer = mnist.get_trainer(model, model_desc, device, internal_loss_fn=True, - get_lr_this_step=get_lr_this_step) + trainer = mnist.get_trainer( + model, + model_desc, + device, + internal_loss_fn=True, + get_lr_this_step=get_lr_this_step, + ) epoch = 0 data, target = next(iter(train_loader)) @@ -508,8 +716,9 @@ def get_lr_this_step(global_step): loss, _ = trainer.train_step(data, target) - assert set([n.name for n in trainer.onnx_model_.graph.initializer]) \ - == set([n for n, t in model.named_parameters()]) + assert set([n.name for n in trainer.onnx_model_.graph.initializer]) == set( + [n for n, t in model.named_parameters()] + ) def testMNISTFrozenWeight(self): torch.manual_seed(1) @@ -519,7 +728,7 @@ def testMNISTFrozenWeight(self): train_loader, test_loader = mnist.get_loaders() model, model_desc = mnist.get_model() - trainer = mnist.get_trainer(model, model_desc, device, frozen_weights=['fc1.weight']) + trainer = mnist.get_trainer(model, model_desc, device, frozen_weights=["fc1.weight"]) learningRate = 0.02 epoch = 0 @@ -530,15 +739,14 @@ def testMNISTFrozenWeight(self): loss, _ = trainer.train_step(data, target, torch.tensor([learningRate])) - fc1_trainstep_1 = trainer.state_dict()['fc1.weight'] - fc2_trainstep_1 = trainer.state_dict()['fc2.weight'] + fc1_trainstep_1 = trainer.state_dict()["fc1.weight"] + fc2_trainstep_1 = trainer.state_dict()["fc2.weight"] loss, _ = trainer.train_step(data, target, torch.tensor([learningRate])) - fc1_trainstep_2 = trainer.state_dict()['fc1.weight'] - fc2_trainstep_2 = trainer.state_dict()['fc2.weight'] - assert np.array_equal(fc1_trainstep_1, fc1_trainstep_2) and \ - not np.array_equal(fc2_trainstep_1, fc2_trainstep_2) + fc1_trainstep_2 = trainer.state_dict()["fc1.weight"] + fc2_trainstep_2 = trainer.state_dict()["fc2.weight"] + assert np.array_equal(fc1_trainstep_1, fc1_trainstep_2) and not np.array_equal(fc2_trainstep_1, fc2_trainstep_2) def testMNISTTorchBuffer(self): torch.manual_seed(1) @@ -559,15 +767,16 @@ def testMNISTTorchBuffer(self): loss, _ = trainer.train_step(data, target, torch.tensor([learningRate])) - fc1_trainstep_1 = trainer.state_dict()['fc1.weight'] - bias_buffer_trainstep_1 = trainer.state_dict()['bias_buffer'] + fc1_trainstep_1 = trainer.state_dict()["fc1.weight"] + bias_buffer_trainstep_1 = trainer.state_dict()["bias_buffer"] loss, _ = trainer.train_step(data, target, torch.tensor([learningRate])) - fc1_trainstep_2 = trainer.state_dict()['fc1.weight'] - bias_buffer_trainstep_2 = trainer.state_dict()['bias_buffer'] - assert not np.array_equal(fc1_trainstep_1, fc1_trainstep_2) and \ - np.array_equal(bias_buffer_trainstep_1, bias_buffer_trainstep_2) + fc1_trainstep_2 = trainer.state_dict()["fc1.weight"] + bias_buffer_trainstep_2 = trainer.state_dict()["bias_buffer"] + assert not np.array_equal(fc1_trainstep_1, fc1_trainstep_2) and np.array_equal( + bias_buffer_trainstep_1, bias_buffer_trainstep_2 + ) def testMNISTFrozenWeightCheckpoint(self): torch.manual_seed(1) @@ -577,7 +786,7 @@ def testMNISTFrozenWeightCheckpoint(self): train_loader, test_loader = mnist.get_loaders() model, model_desc = mnist.get_model() - trainer = mnist.get_trainer(model, model_desc, device, frozen_weights=['fc1.weight']) + trainer = mnist.get_trainer(model, model_desc, device, frozen_weights=["fc1.weight"]) learningRate = 0.02 epoch = 0 @@ -600,7 +809,7 @@ def testMNISTFrozenWeightCheckpoint(self): state_dict = trainer.state_dict() new_model, _ = mnist.get_model() - trainer = mnist.get_trainer(new_model, model_desc, device, frozen_weights=['fc1.weight']) + trainer = mnist.get_trainer(new_model, model_desc, device, frozen_weights=["fc1.weight"]) trainer.load_state_dict(state_dict) ckpt_loss, _ = trainer.eval_step(data, target) @@ -617,8 +826,13 @@ def testMNISTTrainingCheckpoint(self): train_loader, test_loader = mnist.get_loaders() model, model_desc = mnist.get_model() - trainer = mnist.get_trainer(model, model_desc, device, - optimizer='LambOptimizer', frozen_weights=['fc1.weight']) + trainer = mnist.get_trainer( + model, + model_desc, + device, + optimizer="LambOptimizer", + frozen_weights=["fc1.weight"], + ) learningRate = 0.02 epoch = 0 @@ -642,8 +856,13 @@ def testMNISTTrainingCheckpoint(self): state_dict = trainer.state_dict() new_model, _ = mnist.get_model() - trainer = mnist.get_trainer(new_model, model_desc, device, - optimizer='LambOptimizer', frozen_weights=['fc1.weight']) + trainer = mnist.get_trainer( + new_model, + model_desc, + device, + optimizer="LambOptimizer", + frozen_weights=["fc1.weight"], + ) trainer.load_state_dict(state_dict) ckpt_loss, _ = trainer.eval_step(data, target) @@ -655,10 +874,22 @@ def testMNISTTrainingCheckpoint(self): assert np.array_equal(state_dict[key], loaded_state_dict[key]) def testBertTrainingBasic(self): - expected_losses = [11.027887, 11.108191, 11.055356, 11.040912, 10.960277, 11.02691, 11.082471, 10.920979] + expected_losses = [ + 11.027887, + 11.108191, + 11.055356, + 11.040912, + 10.960277, + 11.02691, + 11.082471, + 10.920979, + ] expected_eval_loss = [10.958977] actual_losses, actual_eval_loss = runBertTrainingTest( - gradient_accumulation_steps=1, use_mixed_precision=False, allreduce_post_accumulation=False) + gradient_accumulation_steps=1, + use_mixed_precision=False, + allreduce_post_accumulation=False, + ) # to update expected outcomes, enable pdb and run the test with -s and copy paste outputs # print('losses expected: ', expected_losses) @@ -669,14 +900,31 @@ def testBertTrainingBasic(self): rtol = 1e-03 assert_allclose(expected_losses, actual_losses, rtol=rtol, err_msg="loss mismatch") - assert_allclose(expected_eval_loss, actual_eval_loss, rtol=rtol, err_msg="evaluation loss mismatch") + assert_allclose( + expected_eval_loss, + actual_eval_loss, + rtol=rtol, + err_msg="evaluation loss mismatch", + ) def testBertTrainingGradientAccumulation(self): - expected_losses = [11.027887, 11.108191, 11.055354, 11.040904, 10.960266, 11.026897, 11.082475, 10.920998] + expected_losses = [ + 11.027887, + 11.108191, + 11.055354, + 11.040904, + 10.960266, + 11.026897, + 11.082475, + 10.920998, + ] expected_eval_loss = [10.958998] actual_losses, actual_eval_loss = runBertTrainingTest( - gradient_accumulation_steps=4, use_mixed_precision=False, allreduce_post_accumulation=False) + gradient_accumulation_steps=4, + use_mixed_precision=False, + allreduce_post_accumulation=False, + ) # to update expected outcomes, enable pdb and run the test with -s and copy paste outputs # print('losses expected: ', expected_losses) @@ -687,45 +935,56 @@ def testBertTrainingGradientAccumulation(self): rtol = 1e-03 assert_allclose(expected_losses, actual_losses, rtol=rtol, err_msg="loss mismatch") - assert_allclose(expected_eval_loss, actual_eval_loss, rtol=rtol, err_msg="evaluation loss mismatch") + assert_allclose( + expected_eval_loss, + actual_eval_loss, + rtol=rtol, + err_msg="evaluation loss mismatch", + ) def testBertCheckpointingBasic(self): - model,_,_ = create_ort_trainer(gradient_accumulation_steps=1, - use_mixed_precision=False, - allreduce_post_accumulation=True, - use_simple_model_desc=True, - loss_scaler=None) + model, _, _ = create_ort_trainer( + gradient_accumulation_steps=1, + use_mixed_precision=False, + allreduce_post_accumulation=True, + use_simple_model_desc=True, + loss_scaler=None, + ) sd = model.state_dict() # modify one of the default values - sd['bert.encoder.layer.0.attention.output.LayerNorm.weight'] +=1 + sd["bert.encoder.layer.0.attention.output.LayerNorm.weight"] += 1 model.load_state_dict(sd) - ckpt_dir = 'testdata' - save_checkpoint(model, ckpt_dir, 'bert_toy_save_test') + ckpt_dir = "testdata" + save_checkpoint(model, ckpt_dir, "bert_toy_save_test") del model # create new model - model2,_,_ = create_ort_trainer(gradient_accumulation_steps=1, - use_mixed_precision=False, - allreduce_post_accumulation=True, - use_simple_model_desc=True, - loss_scaler=None) + model2, _, _ = create_ort_trainer( + gradient_accumulation_steps=1, + use_mixed_precision=False, + allreduce_post_accumulation=True, + use_simple_model_desc=True, + loss_scaler=None, + ) # load changed checkpoint - load_checkpoint(model2, ckpt_dir, 'bert_toy_save_test') + load_checkpoint(model2, ckpt_dir, "bert_toy_save_test") loaded_sd = model2.state_dict() - for k,v in loaded_sd.items(): + for k, v in loaded_sd.items(): assert torch.all(torch.eq(v, sd[k])) def testWrapModelLossFnStateDict(self): torch.manual_seed(1) device = torch.device("cuda") + class LinearModel(torch.nn.Module): def __init__(self): super().__init__() self.linear = torch.nn.Linear(2, 4) + def forward(self, y=None, x=None): if y is not None: return self.linear(x) + y @@ -735,11 +994,19 @@ def forward(self, y=None, x=None): pt_model = LinearModel() data = torch.randn(2, 2) label = torch.tensor([0, 1], dtype=torch.int64) - input_desc = IODescription('x', [2, 2], torch.float32) - label_desc = IODescription('label', [2, ], torch.int64, num_classes=4) - output_desc = IODescription('output', [2, 4], torch.float32) - loss_desc = IODescription('loss', [], torch.float32) + input_desc = IODescription("x", [2, 2], torch.float32) + label_desc = IODescription( + "label", + [ + 2, + ], + torch.int64, + num_classes=4, + ) + output_desc = IODescription("output", [2, 4], torch.float32) + loss_desc = IODescription("loss", [], torch.float32) model_desc = ModelDescription([input_desc, label_desc], [loss_desc, output_desc]) + def loss_fn(x, label): return F.nll_loss(F.log_softmax(x, dim=1), label) @@ -748,12 +1015,25 @@ def get_lr_this_step(global_step): return torch.tensor([learningRate]) ort_trainer = ORTTrainer( - pt_model, loss_fn, model_desc, "SGDOptimizer", None, - IODescription('Learning_Rate', [1, ], torch.float32), device, - get_lr_this_step=get_lr_this_step) + pt_model, + loss_fn, + model_desc, + "SGDOptimizer", + None, + IODescription( + "Learning_Rate", + [ + 1, + ], + torch.float32, + ), + device, + get_lr_this_step=get_lr_this_step, + ) ort_trainer.train_step(x=data, label=label) state_dict = ort_trainer.state_dict() - assert state_dict.keys() == {'linear.bias', 'linear.weight'} + assert state_dict.keys() == {"linear.bias", "linear.weight"} + -if __name__ == '__main__': +if __name__ == "__main__": unittest.main(module=__name__, buffer=True) diff --git a/onnxruntime/test/python/onnxruntime_test_ort_trainer_with_mixed_precision.py b/onnxruntime/test/python/onnxruntime_test_ort_trainer_with_mixed_precision.py index 1260d85037417..0c9e703c61fe5 100644 --- a/onnxruntime/test/python/onnxruntime_test_ort_trainer_with_mixed_precision.py +++ b/onnxruntime/test/python/onnxruntime_test_ort_trainer_with_mixed_precision.py @@ -2,54 +2,101 @@ # Licensed under the MIT License. import unittest -from numpy.testing import assert_allclose, assert_array_equal +from numpy.testing import assert_allclose, assert_array_equal from onnxruntime_test_ort_trainer import runBertTrainingTest + class TestOrtTrainer(unittest.TestCase): def testBertTrainingMixedPrecision(self): expected_losses = [ - 11.034248352050781, 11.125300407409668, 11.006105422973633, 11.047048568725586, - 11.027417182922363, 11.015759468078613, 11.060905456542969, 10.971782684326172] + 11.034248352050781, + 11.125300407409668, + 11.006105422973633, + 11.047048568725586, + 11.027417182922363, + 11.015759468078613, + 11.060905456542969, + 10.971782684326172, + ] expected_all_finites = [True, True, True, True, True, True, True, True] expected_eval_loss = [10.959012985229492] actual_losses, actual_all_finites, actual_eval_loss = runBertTrainingTest( - gradient_accumulation_steps=1, use_mixed_precision=True, allreduce_post_accumulation=False, use_simple_model_desc=False) + gradient_accumulation_steps=1, + use_mixed_precision=True, + allreduce_post_accumulation=False, + use_simple_model_desc=False, + ) rtol = 1e-02 assert_allclose(expected_losses, actual_losses, rtol=rtol, err_msg="loss mismatch") assert_array_equal(expected_all_finites, actual_all_finites, "all_finite mismatch") - assert_allclose(expected_eval_loss, actual_eval_loss, rtol=rtol, err_msg="evaluation loss mismatch") + assert_allclose( + expected_eval_loss, + actual_eval_loss, + rtol=rtol, + err_msg="evaluation loss mismatch", + ) def testBertTrainingMixedPrecisionInternalLossScale(self): expected_losses = [ - 11.034248352050781, 11.125300407409668, 11.006105422973633, 11.047048568725586, - 11.027417182922363, 11.015759468078613, 11.060905456542969, 10.971782684326172] + 11.034248352050781, + 11.125300407409668, + 11.006105422973633, + 11.047048568725586, + 11.027417182922363, + 11.015759468078613, + 11.060905456542969, + 10.971782684326172, + ] expected_eval_loss = [10.959012985229492] actual_losses, actual_eval_loss = runBertTrainingTest( gradient_accumulation_steps=1, use_mixed_precision=True, allreduce_post_accumulation=False, use_simple_model_desc=False, - use_internel_loss_scale=True) + use_internel_loss_scale=True, + ) rtol = 1e-02 assert_allclose(expected_losses, actual_losses, rtol=rtol, err_msg="loss mismatch") - assert_allclose(expected_eval_loss, actual_eval_loss, rtol=rtol, err_msg="evaluation loss mismatch") + assert_allclose( + expected_eval_loss, + actual_eval_loss, + rtol=rtol, + err_msg="evaluation loss mismatch", + ) def testBertTrainingGradientAccumulationMixedPrecision(self): expected_losses = [ - 11.034248352050781, 11.125300407409668, 11.006077766418457, 11.047025680541992, - 11.027434349060059, 11.0156831741333, 11.060973167419434, 10.971841812133789] + 11.034248352050781, + 11.125300407409668, + 11.006077766418457, + 11.047025680541992, + 11.027434349060059, + 11.0156831741333, + 11.060973167419434, + 10.971841812133789, + ] expected_all_finites = [True, True] expected_eval_loss = [10.95903205871582] actual_losses, actual_all_finites, actual_eval_loss = runBertTrainingTest( - gradient_accumulation_steps=4, use_mixed_precision=True, allreduce_post_accumulation=False, use_simple_model_desc=False) + gradient_accumulation_steps=4, + use_mixed_precision=True, + allreduce_post_accumulation=False, + use_simple_model_desc=False, + ) rtol = 1e-02 assert_allclose(expected_losses, actual_losses, rtol=rtol, err_msg="loss mismatch") assert_array_equal(expected_all_finites, actual_all_finites, "all_finite mismatch") - assert_allclose(expected_eval_loss, actual_eval_loss, rtol=rtol, err_msg="evaluation loss mismatch") + assert_allclose( + expected_eval_loss, + actual_eval_loss, + rtol=rtol, + err_msg="evaluation loss mismatch", + ) + -if __name__ == '__main__': +if __name__ == "__main__": unittest.main(module=__name__, buffer=True) diff --git a/onnxruntime/test/python/onnxruntime_test_python.py b/onnxruntime/test/python/onnxruntime_test_python.py index 8f29c974f974d..49a0a4f941054 100644 --- a/onnxruntime/test/python/onnxruntime_test_python.py +++ b/onnxruntime/test/python/onnxruntime_test_python.py @@ -3,19 +3,20 @@ # -*- coding: UTF-8 -*- import gc -import numpy as np -import onnxruntime as onnxrt import os import platform import sys import threading import unittest +import numpy as np from helper import get_name + +import onnxruntime as onnxrt from onnxruntime.capi.onnxruntime_pybind11_state import Fail # handle change from python 3.8 and on where loading a dll from the current directory needs to be explicitly allowed. -if platform.system() == 'Windows' and sys.version_info.major >= 3 and sys.version_info.minor >= 8: +if platform.system() == "Windows" and sys.version_info.major >= 3 and sys.version_info.minor >= 8: os.add_dll_directory(os.getcwd()) available_providers = [provider for provider in onnxrt.get_available_providers()] @@ -32,12 +33,11 @@ # * testSequenceInsert # * testSequenceLength available_providers_without_tvm = [ - provider for provider in onnxrt.get_available_providers() - if provider not in {'TvmExecutionProvider'}] + provider for provider in onnxrt.get_available_providers() if provider not in {"TvmExecutionProvider"} +] class TestInferenceSession(unittest.TestCase): - def run_model(self, session_object, run_options): x = np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], dtype=np.float32) input_name = session_object.get_inputs()[0].name @@ -49,6 +49,7 @@ def testTvmImported(self): if "TvmExecutionProvider" not in onnxrt.get_available_providers(): return import tvm + self.assertTrue(tvm is not None) def testModelSerialization(self): @@ -57,22 +58,28 @@ def testModelSerialization(self): so.log_severity_level = 1 so.logid = "TestModelSerialization" so.optimized_model_filepath = "./PythonApiTestOptimizedModel.onnx" - onnxrt.InferenceSession(get_name("mul_1.onnx"), sess_options=so, providers=['CPUExecutionProvider']) + onnxrt.InferenceSession( + get_name("mul_1.onnx"), + sess_options=so, + providers=["CPUExecutionProvider"], + ) self.assertTrue(os.path.isfile(so.optimized_model_filepath)) except Fail as onnxruntime_error: - if str(onnxruntime_error) == "[ONNXRuntimeError] : 1 : FAIL : Unable to serialize model as it contains" \ - " compiled nodes. Please disable any execution providers which generate compiled nodes.": + if ( + str(onnxruntime_error) == "[ONNXRuntimeError] : 1 : FAIL : Unable to serialize model as it contains" + " compiled nodes. Please disable any execution providers which generate compiled nodes." + ): pass else: raise onnxruntime_error def testGetProviders(self): - self.assertTrue('CPUExecutionProvider' in onnxrt.get_available_providers()) + self.assertTrue("CPUExecutionProvider" in onnxrt.get_available_providers()) # get_all_providers() returns the default EP order from highest to lowest. # CPUExecutionProvider should always be last. - self.assertTrue('CPUExecutionProvider' == onnxrt.get_all_providers()[-1]) + self.assertTrue("CPUExecutionProvider" == onnxrt.get_all_providers()[-1]) sess = onnxrt.InferenceSession(get_name("mul_1.onnx"), providers=onnxrt.get_available_providers()) - self.assertTrue('CPUExecutionProvider' in sess.get_providers()) + self.assertTrue("CPUExecutionProvider" in sess.get_providers()) def testEnablingAndDisablingTelemetry(self): onnxrt.disable_telemetry_events() @@ -82,65 +89,68 @@ def testEnablingAndDisablingTelemetry(self): onnxrt.enable_telemetry_events() def testSetProviders(self): - if 'CUDAExecutionProvider' in onnxrt.get_available_providers(): - sess = onnxrt.InferenceSession(get_name("mul_1.onnx"), providers=['CUDAExecutionProvider']) + if "CUDAExecutionProvider" in onnxrt.get_available_providers(): + sess = onnxrt.InferenceSession(get_name("mul_1.onnx"), providers=["CUDAExecutionProvider"]) # confirm that CUDA Provider is in list of registered providers. - self.assertTrue('CUDAExecutionProvider' in sess.get_providers()) + self.assertTrue("CUDAExecutionProvider" in sess.get_providers()) # reset the session and register only CPU Provider. - sess.set_providers(['CPUExecutionProvider']) + sess.set_providers(["CPUExecutionProvider"]) # confirm only CPU Provider is registered now. - self.assertEqual(['CPUExecutionProvider'], sess.get_providers()) + self.assertEqual(["CPUExecutionProvider"], sess.get_providers()) def testSetProvidersWithOptions(self): - if 'TensorrtExecutionProvider' in onnxrt.get_available_providers(): - sess = onnxrt.InferenceSession(get_name("mul_1.onnx"), providers=['TensorrtExecutionProvider']) - self.assertIn('TensorrtExecutionProvider', sess.get_providers()) + if "TensorrtExecutionProvider" in onnxrt.get_available_providers(): + sess = onnxrt.InferenceSession(get_name("mul_1.onnx"), providers=["TensorrtExecutionProvider"]) + self.assertIn("TensorrtExecutionProvider", sess.get_providers()) options = sess.get_provider_options() - option = options['TensorrtExecutionProvider'] - self.assertIn('device_id', option) - self.assertIn('trt_max_partition_iterations', option) - self.assertIn('trt_min_subgraph_size', option) - self.assertIn('trt_max_workspace_size', option) - self.assertIn('trt_dump_subgraphs', option) - self.assertIn('trt_engine_cache_enable', option) - self.assertIn('trt_engine_cache_path', option) - self.assertIn('trt_force_sequential_engine_build', option) - - max_partition_iterations = option['trt_max_partition_iterations'] + option = options["TensorrtExecutionProvider"] + self.assertIn("device_id", option) + self.assertIn("trt_max_partition_iterations", option) + self.assertIn("trt_min_subgraph_size", option) + self.assertIn("trt_max_workspace_size", option) + self.assertIn("trt_dump_subgraphs", option) + self.assertIn("trt_engine_cache_enable", option) + self.assertIn("trt_engine_cache_path", option) + self.assertIn("trt_force_sequential_engine_build", option) + + max_partition_iterations = option["trt_max_partition_iterations"] new_max_partition_iterations = int(max_partition_iterations) + 1 - min_subgraph_size = option['trt_min_subgraph_size'] + min_subgraph_size = option["trt_min_subgraph_size"] new_min_subgraph_size = int(min_subgraph_size) + 1 - ori_max_workspace_size = option['trt_max_workspace_size'] + ori_max_workspace_size = option["trt_max_workspace_size"] new_max_workspace_size = int(ori_max_workspace_size) // 2 option = {} - option['trt_max_partition_iterations'] = new_max_partition_iterations - option['trt_min_subgraph_size'] = new_min_subgraph_size - option['trt_max_workspace_size'] = new_max_workspace_size + option["trt_max_partition_iterations"] = new_max_partition_iterations + option["trt_min_subgraph_size"] = new_min_subgraph_size + option["trt_max_workspace_size"] = new_max_workspace_size dump_subgraphs = "true" - option['trt_dump_subgraphs'] = dump_subgraphs + option["trt_dump_subgraphs"] = dump_subgraphs engine_cache_enable = "true" - option['trt_engine_cache_enable'] = engine_cache_enable - engine_cache_path = './engine_cache' - option['trt_engine_cache_path'] = engine_cache_path + option["trt_engine_cache_enable"] = engine_cache_enable + engine_cache_path = "./engine_cache" + option["trt_engine_cache_path"] = engine_cache_path force_sequential_engine_build = "true" - option['trt_force_sequential_engine_build'] = force_sequential_engine_build - sess.set_providers(['TensorrtExecutionProvider'], [option]) + option["trt_force_sequential_engine_build"] = force_sequential_engine_build + sess.set_providers(["TensorrtExecutionProvider"], [option]) options = sess.get_provider_options() - option = options['TensorrtExecutionProvider'] - self.assertEqual(option['trt_max_partition_iterations'], str(new_max_partition_iterations)) - self.assertEqual(option['trt_min_subgraph_size'], str(new_min_subgraph_size)) - self.assertEqual(option['trt_max_workspace_size'], str(new_max_workspace_size)) - self.assertEqual(option['trt_dump_subgraphs'], '1') - self.assertEqual(option['trt_engine_cache_enable'], '1') - self.assertEqual(option['trt_engine_cache_path'], str(engine_cache_path)) - self.assertEqual(option['trt_force_sequential_engine_build'], '1') + option = options["TensorrtExecutionProvider"] + self.assertEqual( + option["trt_max_partition_iterations"], + str(new_max_partition_iterations), + ) + self.assertEqual(option["trt_min_subgraph_size"], str(new_min_subgraph_size)) + self.assertEqual(option["trt_max_workspace_size"], str(new_max_workspace_size)) + self.assertEqual(option["trt_dump_subgraphs"], "1") + self.assertEqual(option["trt_engine_cache_enable"], "1") + self.assertEqual(option["trt_engine_cache_path"], str(engine_cache_path)) + self.assertEqual(option["trt_force_sequential_engine_build"], "1") # We currently disable following test code since that not all test machines/GPUs have nvidia int8 capability - ''' + """ int8_use_native_calibration_table = "false" option['trt_int8_use_native_calibration_table'] = int8_use_native_calibration_table int8_enable = "true" @@ -149,76 +159,85 @@ def testSetProvidersWithOptions(self): option['trt_int8_calibration_table_name'] = calib_table_name with self.assertRaises(RuntimeError): sess.set_providers(['TensorrtExecutionProvider'], [option]) - ''' + """ - if 'CUDAExecutionProvider' in onnxrt.get_available_providers(): - import sys + if "CUDAExecutionProvider" in onnxrt.get_available_providers(): import ctypes + import sys + CUDA_SUCCESS = 0 - def runBaseTest1(): - sess = onnxrt.InferenceSession(get_name("mul_1.onnx"), providers=['CUDAExecutionProvider']) - self.assertTrue('CUDAExecutionProvider' in sess.get_providers()) - option1 = {'device_id': 0} - sess.set_providers(['CUDAExecutionProvider'], [option1]) - self.assertEqual(['CUDAExecutionProvider', 'CPUExecutionProvider'], sess.get_providers()) - option2 = {'device_id': -1} + def runBaseTest1(): + sess = onnxrt.InferenceSession(get_name("mul_1.onnx"), providers=["CUDAExecutionProvider"]) + self.assertTrue("CUDAExecutionProvider" in sess.get_providers()) + + option1 = {"device_id": 0} + sess.set_providers(["CUDAExecutionProvider"], [option1]) + self.assertEqual( + ["CUDAExecutionProvider", "CPUExecutionProvider"], + sess.get_providers(), + ) + option2 = {"device_id": -1} with self.assertRaises(RuntimeError): - sess.set_providers(['CUDAExecutionProvider'], [option2]) - sess.set_providers(['CUDAExecutionProvider', 'CPUExecutionProvider'], [option1, {}]) - self.assertEqual(['CUDAExecutionProvider', 'CPUExecutionProvider'], sess.get_providers()) + sess.set_providers(["CUDAExecutionProvider"], [option2]) + sess.set_providers(["CUDAExecutionProvider", "CPUExecutionProvider"], [option1, {}]) + self.assertEqual( + ["CUDAExecutionProvider", "CPUExecutionProvider"], + sess.get_providers(), + ) def runBaseTest2(): - sess = onnxrt.InferenceSession(get_name("mul_1.onnx"), providers=['CUDAExecutionProvider']) - self.assertIn('CUDAExecutionProvider', sess.get_providers()) + sess = onnxrt.InferenceSession(get_name("mul_1.onnx"), providers=["CUDAExecutionProvider"]) + self.assertIn("CUDAExecutionProvider", sess.get_providers()) # test get/set of "gpu_mem_limit" configuration. options = sess.get_provider_options() - self.assertIn('CUDAExecutionProvider', options) - option = options['CUDAExecutionProvider'] - self.assertIn('gpu_mem_limit', option) - ori_mem_limit = option['gpu_mem_limit'] + self.assertIn("CUDAExecutionProvider", options) + option = options["CUDAExecutionProvider"] + self.assertIn("gpu_mem_limit", option) + ori_mem_limit = option["gpu_mem_limit"] new_mem_limit = int(ori_mem_limit) // 2 - option['gpu_mem_limit'] = new_mem_limit - sess.set_providers(['CUDAExecutionProvider'], [option]) + option["gpu_mem_limit"] = new_mem_limit + sess.set_providers(["CUDAExecutionProvider"], [option]) options = sess.get_provider_options() - self.assertEqual(options['CUDAExecutionProvider']['gpu_mem_limit'], str(new_mem_limit)) + self.assertEqual( + options["CUDAExecutionProvider"]["gpu_mem_limit"], + str(new_mem_limit), + ) - option['gpu_mem_limit'] = ori_mem_limit - sess.set_providers(['CUDAExecutionProvider'], [option]) + option["gpu_mem_limit"] = ori_mem_limit + sess.set_providers(["CUDAExecutionProvider"], [option]) options = sess.get_provider_options() - self.assertEqual(options['CUDAExecutionProvider']['gpu_mem_limit'], ori_mem_limit) + self.assertEqual(options["CUDAExecutionProvider"]["gpu_mem_limit"], ori_mem_limit) def test_get_and_set_option_with_values(option_name, option_values): provider_options = sess.get_provider_options() - self.assertIn('CUDAExecutionProvider', provider_options) - cuda_options = options['CUDAExecutionProvider'] + self.assertIn("CUDAExecutionProvider", provider_options) + cuda_options = options["CUDAExecutionProvider"] self.assertIn(option_name, cuda_options) for option_value in option_values: cuda_options[option_name] = option_value - sess.set_providers(['CUDAExecutionProvider'], [cuda_options]) + sess.set_providers(["CUDAExecutionProvider"], [cuda_options]) new_provider_options = sess.get_provider_options() self.assertEqual( - new_provider_options.get('CUDAExecutionProvider', {}).get(option_name), - str(option_value)) + new_provider_options.get("CUDAExecutionProvider", {}).get(option_name), + str(option_value), + ) - test_get_and_set_option_with_values( - 'arena_extend_strategy', ['kNextPowerOfTwo', 'kSameAsRequested']) + test_get_and_set_option_with_values("arena_extend_strategy", ["kNextPowerOfTwo", "kSameAsRequested"]) - test_get_and_set_option_with_values( - 'cudnn_conv_algo_search', ["DEFAULT", "EXHAUSTIVE", "HEURISTIC"]) + test_get_and_set_option_with_values("cudnn_conv_algo_search", ["DEFAULT", "EXHAUSTIVE", "HEURISTIC"]) - test_get_and_set_option_with_values( - 'do_copy_in_default_stream', [0, 1]) + test_get_and_set_option_with_values("do_copy_in_default_stream", [0, 1]) - option['gpu_external_alloc'] = '0' - option['gpu_external_free'] = '0' - option['gpu_external_empty_cache'] = '0' - sess.set_providers(['CUDAExecutionProvider'], [option]) + option["gpu_external_alloc"] = "0" + option["gpu_external_free"] = "0" + option["gpu_external_empty_cache"] = "0" + sess.set_providers(["CUDAExecutionProvider"], [option]) options = sess.get_provider_options() - self.assertEqual(options['CUDAExecutionProvider']['gpu_external_alloc'], '0') - self.assertEqual(options['CUDAExecutionProvider']['gpu_external_free'], '0') - self.assertEqual(options['CUDAExecutionProvider']['gpu_external_empty_cache'], '0') + self.assertEqual(options["CUDAExecutionProvider"]["gpu_external_alloc"], "0") + self.assertEqual(options["CUDAExecutionProvider"]["gpu_external_free"], "0") + self.assertEqual(options["CUDAExecutionProvider"]["gpu_external_empty_cache"], "0") # # Note: Tests that throw an exception leave an empty session due to how set_providers currently works, # so run them last. Each set_providers call will attempt to re-create a session, so it's @@ -226,21 +245,21 @@ def test_get_and_set_option_with_values(option_name, option_values): # Alternatively a valid call to set_providers could be used to recreate the underlying session # after a failed call. # - option['arena_extend_strategy'] = 'wrong_value' + option["arena_extend_strategy"] = "wrong_value" with self.assertRaises(RuntimeError): - sess.set_providers(['CUDAExecutionProvider'], [option]) + sess.set_providers(["CUDAExecutionProvider"], [option]) - option['gpu_mem_limit'] = -1024 + option["gpu_mem_limit"] = -1024 with self.assertRaises(RuntimeError): - sess.set_providers(['CUDAExecutionProvider'], [option]) + sess.set_providers(["CUDAExecutionProvider"], [option]) - option['gpu_mem_limit'] = 1024.1024 + option["gpu_mem_limit"] = 1024.1024 with self.assertRaises(RuntimeError): - sess.set_providers(['CUDAExecutionProvider'], [option]) + sess.set_providers(["CUDAExecutionProvider"], [option]) - option['gpu_mem_limit'] = 'wrong_value' + option["gpu_mem_limit"] = "wrong_value" with self.assertRaises(RuntimeError): - sess.set_providers(['CUDAExecutionProvider'], [option]) + sess.set_providers(["CUDAExecutionProvider"], [option]) def getCudaDeviceCount(): import ctypes @@ -260,16 +279,20 @@ def getCudaDeviceCount(): def setDeviceIdTest(i): import ctypes + import onnxruntime as onnxrt device = ctypes.c_int() result = ctypes.c_int() error_str = ctypes.c_char_p() - sess = onnxrt.InferenceSession(get_name("mul_1.onnx"), providers=['CPUExecutionProvider']) - option = {'device_id': i} - sess.set_providers(['CUDAExecutionProvider'], [option]) - self.assertEqual(['CUDAExecutionProvider', 'CPUExecutionProvider'], sess.get_providers()) + sess = onnxrt.InferenceSession(get_name("mul_1.onnx"), providers=["CPUExecutionProvider"]) + option = {"device_id": i} + sess.set_providers(["CUDAExecutionProvider"], [option]) + self.assertEqual( + ["CUDAExecutionProvider", "CPUExecutionProvider"], + sess.get_providers(), + ) result = cuda.cuCtxGetDevice(ctypes.byref(device)) if result != CUDA_SUCCESS: cuda.cuGetErrorString(result, ctypes.byref(error_str)) @@ -287,21 +310,21 @@ def runAdvancedTest(): for i in range(num_device): setDeviceIdTest(i) - sess = onnxrt.InferenceSession(get_name("mul_1.onnx"), providers=['CPUExecutionProvider']) + sess = onnxrt.InferenceSession(get_name("mul_1.onnx"), providers=["CPUExecutionProvider"]) # configure session with invalid option values and that should fail with self.assertRaises(RuntimeError): - option = {'device_id': num_device} - sess.set_providers(['CUDAExecutionProvider'], [option]) - option = {'device_id': 'invalid_value'} - sess.set_providers(['CUDAExecutionProvider'], [option]) + option = {"device_id": num_device} + sess.set_providers(["CUDAExecutionProvider"], [option]) + option = {"device_id": "invalid_value"} + sess.set_providers(["CUDAExecutionProvider"], [option]) # configure session with invalid option should fail with self.assertRaises(RuntimeError): - option = {'invalid_option': 123} - sess.set_providers(['CUDAExecutionProvider'], [option]) + option = {"invalid_option": 123} + sess.set_providers(["CUDAExecutionProvider"], [option]) - libnames = ('libcuda.so', 'libcuda.dylib', 'cuda.dll') + libnames = ("libcuda.so", "libcuda.dylib", "cuda.dll") for libname in libnames: try: cuda = ctypes.CDLL(libname) @@ -320,15 +343,15 @@ def runAdvancedTest(): def testInvalidSetProviders(self): with self.assertRaises(RuntimeError) as context: - sess = onnxrt.InferenceSession(get_name("mul_1.onnx"), providers=['CPUExecutionProvider']) - sess.set_providers(['InvalidProvider']) - self.assertTrue('Unknown Provider Type: InvalidProvider' in str(context.exception)) + sess = onnxrt.InferenceSession(get_name("mul_1.onnx"), providers=["CPUExecutionProvider"]) + sess.set_providers(["InvalidProvider"]) + self.assertTrue("Unknown Provider Type: InvalidProvider" in str(context.exception)) def testSessionProviders(self): - if 'CUDAExecutionProvider' in onnxrt.get_available_providers(): + if "CUDAExecutionProvider" in onnxrt.get_available_providers(): # create session from scratch, but constrain it to only use the CPU. - sess = onnxrt.InferenceSession(get_name("mul_1.onnx"), providers=['CPUExecutionProvider']) - self.assertEqual(['CPUExecutionProvider'], sess.get_providers()) + sess = onnxrt.InferenceSession(get_name("mul_1.onnx"), providers=["CPUExecutionProvider"]) + self.assertEqual(["CPUExecutionProvider"], sess.get_providers()) def testRunModel(self): sess = onnxrt.InferenceSession(get_name("mul_1.onnx"), providers=available_providers) @@ -400,15 +423,20 @@ def testRunModelMultipleThreads(self): # We keep this test enabled for instances where both DML and CUDA EPs are available # (Windows GPU CI pipeline has this config) - this test will pass because CUDA has higher precedence # than DML and the nodes are assigned to only the CUDA EP (which supports this test). - if 'DmlExecutionProvider' in available_providers and 'CUDAExecutionProvider' not in available_providers: - print("Skipping testRunModelMultipleThreads as the DML EP does not support calling Run()" - " on different threads using the same session object.") + if "DmlExecutionProvider" in available_providers and "CUDAExecutionProvider" not in available_providers: + print( + "Skipping testRunModelMultipleThreads as the DML EP does not support calling Run()" + " on different threads using the same session object." + ) else: so = onnxrt.SessionOptions() so.log_verbosity_level = 1 so.logid = "MultiThreadsTest" - sess = onnxrt.InferenceSession(get_name("mul_1.onnx"), sess_options=so, - providers=available_providers_without_tvm) + sess = onnxrt.InferenceSession( + get_name("mul_1.onnx"), + sess_options=so, + providers=available_providers_without_tvm, + ) ro1 = onnxrt.RunOptions() ro1.logid = "thread1" t1 = threading.Thread(target=self.run_model, args=(sess, ro1)) @@ -430,14 +458,14 @@ def testListAsInput(self): def testStringListAsInput(self): sess = onnxrt.InferenceSession(get_name("identity_string.onnx"), providers=available_providers_without_tvm) - x = np.array(['this', 'is', 'identity', 'test'], dtype=str).reshape((2, 2)) + x = np.array(["this", "is", "identity", "test"], dtype=str).reshape((2, 2)) x_name = sess.get_inputs()[0].name res = sess.run([], {x_name: x.tolist()}) np.testing.assert_equal(x, res[0]) def testRunDevice(self): device = onnxrt.get_device() - self.assertTrue('CPU' in device or 'GPU' in device) + self.assertTrue("CPU" in device or "GPU" in device) def testRunModelSymbolicInput(self): sess = onnxrt.InferenceSession(get_name("matmul_2.onnx"), providers=available_providers_without_tvm) @@ -446,12 +474,12 @@ def testRunModelSymbolicInput(self): self.assertEqual(input_name, "X") input_shape = sess.get_inputs()[0].shape # Input X has an unknown dimension. - self.assertEqual(input_shape, ['None', 2]) + self.assertEqual(input_shape, ["None", 2]) output_name = sess.get_outputs()[0].name self.assertEqual(output_name, "Y") output_shape = sess.get_outputs()[0].shape # Output X has an unknown dimension. - self.assertEqual(output_shape, ['None', 1]) + self.assertEqual(output_shape, ["None", 1]) res = sess.run([output_name], {input_name: x}) output_expected = np.array([[5.0], [11.0], [17.0]], dtype=np.float32) np.testing.assert_allclose(output_expected, res[0], rtol=1e-05, atol=1e-08) @@ -468,21 +496,21 @@ def testBooleanInputs(self): a_shape = sess.get_inputs()[0].shape self.assertEqual(a_shape, [2, 2]) a_type = sess.get_inputs()[0].type - self.assertEqual(a_type, 'tensor(bool)') + self.assertEqual(a_type, "tensor(bool)") b_name = sess.get_inputs()[1].name self.assertEqual(b_name, "input:0") b_shape = sess.get_inputs()[1].shape self.assertEqual(b_shape, [2, 2]) b_type = sess.get_inputs()[0].type - self.assertEqual(b_type, 'tensor(bool)') + self.assertEqual(b_type, "tensor(bool)") output_name = sess.get_outputs()[0].name self.assertEqual(output_name, "output:0") output_shape = sess.get_outputs()[0].shape self.assertEqual(output_shape, [2, 2]) output_type = sess.get_outputs()[0].type - self.assertEqual(output_type, 'tensor(bool)') + self.assertEqual(output_type, "tensor(bool)") output_expected = np.array([[True, False], [False, False]], dtype=bool) res = sess.run([output_name], {a_name: a, b_name: b}) @@ -490,84 +518,84 @@ def testBooleanInputs(self): def testStringInput1(self): sess = onnxrt.InferenceSession(get_name("identity_string.onnx"), providers=available_providers_without_tvm) - x = np.array(['this', 'is', 'identity', 'test'], dtype=str).reshape((2, 2)) + x = np.array(["this", "is", "identity", "test"], dtype=str).reshape((2, 2)) x_name = sess.get_inputs()[0].name self.assertEqual(x_name, "input:0") x_shape = sess.get_inputs()[0].shape self.assertEqual(x_shape, [2, 2]) x_type = sess.get_inputs()[0].type - self.assertEqual(x_type, 'tensor(string)') + self.assertEqual(x_type, "tensor(string)") output_name = sess.get_outputs()[0].name self.assertEqual(output_name, "output:0") output_shape = sess.get_outputs()[0].shape self.assertEqual(output_shape, [2, 2]) output_type = sess.get_outputs()[0].type - self.assertEqual(output_type, 'tensor(string)') + self.assertEqual(output_type, "tensor(string)") res = sess.run([output_name], {x_name: x}) np.testing.assert_equal(x, res[0]) def testStringInput2(self): sess = onnxrt.InferenceSession(get_name("identity_string.onnx"), providers=available_providers_without_tvm) - x = np.array(['Olá', '你好', '여보세요', 'hello'], dtype=str).reshape((2, 2)) + x = np.array(["Olá", "你好", "여보세요", "hello"], dtype=str).reshape((2, 2)) x_name = sess.get_inputs()[0].name self.assertEqual(x_name, "input:0") x_shape = sess.get_inputs()[0].shape self.assertEqual(x_shape, [2, 2]) x_type = sess.get_inputs()[0].type - self.assertEqual(x_type, 'tensor(string)') + self.assertEqual(x_type, "tensor(string)") output_name = sess.get_outputs()[0].name self.assertEqual(output_name, "output:0") output_shape = sess.get_outputs()[0].shape self.assertEqual(output_shape, [2, 2]) output_type = sess.get_outputs()[0].type - self.assertEqual(output_type, 'tensor(string)') + self.assertEqual(output_type, "tensor(string)") res = sess.run([output_name], {x_name: x}) np.testing.assert_equal(x, res[0]) def testInputBytes(self): sess = onnxrt.InferenceSession(get_name("identity_string.onnx"), providers=available_providers_without_tvm) - x = np.array([b'this', b'is', b'identity', b'test']).reshape((2, 2)) + x = np.array([b"this", b"is", b"identity", b"test"]).reshape((2, 2)) x_name = sess.get_inputs()[0].name self.assertEqual(x_name, "input:0") x_shape = sess.get_inputs()[0].shape self.assertEqual(x_shape, [2, 2]) x_type = sess.get_inputs()[0].type - self.assertEqual(x_type, 'tensor(string)') + self.assertEqual(x_type, "tensor(string)") output_name = sess.get_outputs()[0].name self.assertEqual(output_name, "output:0") output_shape = sess.get_outputs()[0].shape self.assertEqual(output_shape, [2, 2]) output_type = sess.get_outputs()[0].type - self.assertEqual(output_type, 'tensor(string)') + self.assertEqual(output_type, "tensor(string)") res = sess.run([output_name], {x_name: x}) - np.testing.assert_equal(x, res[0].astype('|S8')) + np.testing.assert_equal(x, res[0].astype("|S8")) def testInputObject(self): sess = onnxrt.InferenceSession(get_name("identity_string.onnx"), providers=available_providers_without_tvm) - x = np.array(['this', 'is', 'identity', 'test'], object).reshape((2, 2)) + x = np.array(["this", "is", "identity", "test"], object).reshape((2, 2)) x_name = sess.get_inputs()[0].name self.assertEqual(x_name, "input:0") x_shape = sess.get_inputs()[0].shape self.assertEqual(x_shape, [2, 2]) x_type = sess.get_inputs()[0].type - self.assertEqual(x_type, 'tensor(string)') + self.assertEqual(x_type, "tensor(string)") output_name = sess.get_outputs()[0].name self.assertEqual(output_name, "output:0") output_shape = sess.get_outputs()[0].shape self.assertEqual(output_shape, [2, 2]) output_type = sess.get_outputs()[0].type - self.assertEqual(output_type, 'tensor(string)') + self.assertEqual(output_type, "tensor(string)") res = sess.run([output_name], {x_name: x}) np.testing.assert_equal(x, res[0]) @@ -576,34 +604,34 @@ def testInputVoid(self): sess = onnxrt.InferenceSession(get_name("identity_string.onnx"), providers=available_providers_without_tvm) # numpy 1.20+ doesn't automatically pad the bytes based entries in the array when dtype is np.void, # so we use inputs where that is the case - x = np.array([b'must', b'have', b'same', b'size'], dtype=np.void).reshape((2, 2)) + x = np.array([b"must", b"have", b"same", b"size"], dtype=np.void).reshape((2, 2)) x_name = sess.get_inputs()[0].name self.assertEqual(x_name, "input:0") x_shape = sess.get_inputs()[0].shape self.assertEqual(x_shape, [2, 2]) x_type = sess.get_inputs()[0].type - self.assertEqual(x_type, 'tensor(string)') + self.assertEqual(x_type, "tensor(string)") output_name = sess.get_outputs()[0].name self.assertEqual(output_name, "output:0") output_shape = sess.get_outputs()[0].shape self.assertEqual(output_shape, [2, 2]) output_type = sess.get_outputs()[0].type - self.assertEqual(output_type, 'tensor(string)') + self.assertEqual(output_type, "tensor(string)") res = sess.run([output_name], {x_name: x}) - expr = np.array([['must', 'have'], ['same', 'size']], dtype=object) + expr = np.array([["must", "have"], ["same", "size"]], dtype=object) np.testing.assert_equal(expr, res[0]) def testRaiseWrongNumInputs(self): with self.assertRaises(ValueError) as context: sess = onnxrt.InferenceSession(get_name("logicaland.onnx"), providers=onnxrt.get_available_providers()) a = np.array([[True, True], [False, False]], dtype=bool) - res = sess.run([], {'input:0': a}) + res = sess.run([], {"input:0": a}) - self.assertTrue('Model requires 2 inputs' in str(context.exception)) + self.assertTrue("Model requires 2 inputs" in str(context.exception)) def testModelMeta(self): model_path = "../models/opset8/test_squeezenet/model.onnx" @@ -611,36 +639,42 @@ def testModelMeta(self): return sess = onnxrt.InferenceSession(model_path, providers=onnxrt.get_available_providers()) modelmeta = sess.get_modelmeta() - self.assertEqual('onnx-caffe2', modelmeta.producer_name) - self.assertEqual('squeezenet_old', modelmeta.graph_name) - self.assertEqual('', modelmeta.domain) - self.assertEqual('', modelmeta.description) - self.assertEqual('', modelmeta.graph_description) + self.assertEqual("onnx-caffe2", modelmeta.producer_name) + self.assertEqual("squeezenet_old", modelmeta.graph_name) + self.assertEqual("", modelmeta.domain) + self.assertEqual("", modelmeta.description) + self.assertEqual("", modelmeta.graph_description) def testProfilerWithSessionOptions(self): so = onnxrt.SessionOptions() so.enable_profiling = True - sess = onnxrt.InferenceSession(get_name("mul_1.onnx"), sess_options=so, - providers=onnxrt.get_available_providers()) + sess = onnxrt.InferenceSession( + get_name("mul_1.onnx"), + sess_options=so, + providers=onnxrt.get_available_providers(), + ) x = np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], dtype=np.float32) - sess.run([], {'X': x}) + sess.run([], {"X": x}) profile_file = sess.end_profiling() - tags = ['pid', 'dur', 'ts', 'ph', 'X', 'name', 'args'] + tags = ["pid", "dur", "ts", "ph", "X", "name", "args"] with open(profile_file) as f: lines = f.readlines() - self.assertTrue('[' in lines[0]) - for i in range(1, len(lines)-1): + self.assertTrue("[" in lines[0]) + for i in range(1, len(lines) - 1): for tag in tags: self.assertTrue(tag in lines[i]) - self.assertTrue(']' in lines[-1]) + self.assertTrue("]" in lines[-1]) def testProfilerGetStartTimeNs(self): def getSingleSessionProfilingStartTime(): so = onnxrt.SessionOptions() so.enable_profiling = True - sess = onnxrt.InferenceSession(get_name("mul_1.onnx"), sess_options=so, - providers=onnxrt.get_available_providers()) + sess = onnxrt.InferenceSession( + get_name("mul_1.onnx"), + sess_options=so, + providers=onnxrt.get_available_providers(), + ) return sess.get_profiling_start_time_ns() # Get 1st profiling's start time @@ -658,42 +692,45 @@ def testGraphOptimizationLevel(self): # default should be all optimizations optimization self.assertEqual(opt.graph_optimization_level, onnxrt.GraphOptimizationLevel.ORT_ENABLE_ALL) opt.graph_optimization_level = onnxrt.GraphOptimizationLevel.ORT_ENABLE_EXTENDED - self.assertEqual(opt.graph_optimization_level, onnxrt.GraphOptimizationLevel.ORT_ENABLE_EXTENDED) - sess = onnxrt.InferenceSession(get_name("logicaland.onnx"), sess_options=opt, - providers=available_providers) + self.assertEqual( + opt.graph_optimization_level, + onnxrt.GraphOptimizationLevel.ORT_ENABLE_EXTENDED, + ) + sess = onnxrt.InferenceSession(get_name("logicaland.onnx"), sess_options=opt, providers=available_providers) a = np.array([[True, True], [False, False]], dtype=bool) b = np.array([[True, False], [True, False]], dtype=bool) - res = sess.run([], {'input1:0': a, 'input:0': b}) + res = sess.run([], {"input1:0": a, "input:0": b}) def testSequenceLength(self): - sess = onnxrt.InferenceSession(get_name("sequence_length.onnx"), - providers=available_providers_without_tvm) + sess = onnxrt.InferenceSession(get_name("sequence_length.onnx"), providers=available_providers_without_tvm) x = [ np.array([1.0, 0.0, 3.0, 44.0, 23.0, 11.0], dtype=np.float32).reshape((2, 3)), - np.array([1.0, 0.0, 3.0, 44.0, 23.0, 11.0], dtype=np.float32).reshape((2, 3)) + np.array([1.0, 0.0, 3.0, 44.0, 23.0, 11.0], dtype=np.float32).reshape((2, 3)), ] x_name = sess.get_inputs()[0].name self.assertEqual(x_name, "X") x_type = sess.get_inputs()[0].type - self.assertEqual(x_type, 'seq(tensor(float))') + self.assertEqual(x_type, "seq(tensor(float))") output_name = sess.get_outputs()[0].name self.assertEqual(output_name, "Y") output_type = sess.get_outputs()[0].type - self.assertEqual(output_type, 'tensor(int64)') + self.assertEqual(output_type, "tensor(int64)") output_expected = np.array(2, dtype=np.int64) res = sess.run([output_name], {x_name: x}) self.assertEqual(output_expected, res[0]) def testSequenceConstruct(self): - sess = onnxrt.InferenceSession(get_name("sequence_construct.onnx"), - providers=available_providers_without_tvm) + sess = onnxrt.InferenceSession( + get_name("sequence_construct.onnx"), + providers=available_providers_without_tvm, + ) - self.assertEqual(sess.get_inputs()[0].type, 'tensor(int64)') - self.assertEqual(sess.get_inputs()[1].type, 'tensor(int64)') + self.assertEqual(sess.get_inputs()[0].type, "tensor(int64)") + self.assertEqual(sess.get_inputs()[1].type, "tensor(int64)") self.assertEqual(sess.get_inputs()[0].name, "tensor1") self.assertEqual(sess.get_inputs()[1].name, "tensor2") @@ -701,29 +738,34 @@ def testSequenceConstruct(self): output_name = sess.get_outputs()[0].name self.assertEqual(output_name, "output_sequence") output_type = sess.get_outputs()[0].type - self.assertEqual(output_type, 'seq(tensor(int64))') + self.assertEqual(output_type, "seq(tensor(int64))") output_expected = [ np.array([1, 0, 3, 44, 23, 11], dtype=np.int64).reshape((2, 3)), - np.array([1, 2, 3, 4, 5, 6], dtype=np.int64).reshape((2, 3)) + np.array([1, 2, 3, 4, 5, 6], dtype=np.int64).reshape((2, 3)), ] res = sess.run( - [output_name], { + [output_name], + { "tensor1": np.array([1, 0, 3, 44, 23, 11], dtype=np.int64).reshape((2, 3)), - "tensor2": np.array([1, 2, 3, 4, 5, 6], dtype=np.int64).reshape((2, 3)) - }) + "tensor2": np.array([1, 2, 3, 4, 5, 6], dtype=np.int64).reshape((2, 3)), + }, + ) np.testing.assert_array_equal(output_expected, res[0]) def testSequenceInsert(self): opt = onnxrt.SessionOptions() opt.execution_mode = onnxrt.ExecutionMode.ORT_SEQUENTIAL - sess = onnxrt.InferenceSession(get_name("sequence_insert.onnx"), sess_options=opt, - providers=available_providers_without_tvm) + sess = onnxrt.InferenceSession( + get_name("sequence_insert.onnx"), + sess_options=opt, + providers=available_providers_without_tvm, + ) - self.assertEqual(sess.get_inputs()[0].type, 'seq(tensor(int64))') - self.assertEqual(sess.get_inputs()[1].type, 'tensor(int64)') + self.assertEqual(sess.get_inputs()[0].type, "seq(tensor(int64))") + self.assertEqual(sess.get_inputs()[1].type, "tensor(int64)") self.assertEqual(sess.get_inputs()[0].name, "input_seq") self.assertEqual(sess.get_inputs()[1].name, "tensor") @@ -731,13 +773,16 @@ def testSequenceInsert(self): output_name = sess.get_outputs()[0].name self.assertEqual(output_name, "output_sequence") output_type = sess.get_outputs()[0].type - self.assertEqual(output_type, 'seq(tensor(int64))') + self.assertEqual(output_type, "seq(tensor(int64))") output_expected = [np.array([1, 0, 3, 44, 23, 11], dtype=np.int64).reshape((2, 3))] - res = sess.run([output_name], { - "tensor": np.array([1, 0, 3, 44, 23, 11], dtype=np.int64).reshape((2, 3)), - "input_seq": [] - }) + res = sess.run( + [output_name], + { + "tensor": np.array([1, 0, 3, 44, 23, 11], dtype=np.int64).reshape((2, 3)), + "input_seq": [], + }, + ) np.testing.assert_array_equal(output_expected, res[0]) def testOrtExecutionMode(self): @@ -748,20 +793,25 @@ def testOrtExecutionMode(self): def testLoadingSessionOptionsFromModel(self): try: - os.environ['ORT_LOAD_CONFIG_FROM_MODEL'] = str(1) - sess = onnxrt.InferenceSession(get_name("model_with_valid_ort_config_json.onnx"), - providers=onnxrt.get_available_providers()) + os.environ["ORT_LOAD_CONFIG_FROM_MODEL"] = str(1) + sess = onnxrt.InferenceSession( + get_name("model_with_valid_ort_config_json.onnx"), + providers=onnxrt.get_available_providers(), + ) session_options = sess.get_session_options() self.assertEqual(session_options.inter_op_num_threads, 5) # from the ORT config self.assertEqual(session_options.intra_op_num_threads, 2) # from the ORT config - self.assertEqual(session_options.execution_mode, - onnxrt.ExecutionMode.ORT_SEQUENTIAL) # default option (not from the ORT config) + self.assertEqual( + session_options.execution_mode, onnxrt.ExecutionMode.ORT_SEQUENTIAL + ) # default option (not from the ORT config) - self.assertEqual(session_options.graph_optimization_level, - onnxrt.GraphOptimizationLevel.ORT_ENABLE_ALL) # from the ORT config + self.assertEqual( + session_options.graph_optimization_level, + onnxrt.GraphOptimizationLevel.ORT_ENABLE_ALL, + ) # from the ORT config self.assertEqual(session_options.enable_profiling, True) # from the ORT config @@ -770,14 +820,17 @@ def testLoadingSessionOptionsFromModel(self): finally: # Make sure the usage of the feature is disabled after this test - os.environ['ORT_LOAD_CONFIG_FROM_MODEL'] = str(0) + os.environ["ORT_LOAD_CONFIG_FROM_MODEL"] = str(0) def testSessionOptionsAddFreeDimensionOverrideByDenotation(self): so = onnxrt.SessionOptions() so.add_free_dimension_override_by_denotation("DATA_BATCH", 3) so.add_free_dimension_override_by_denotation("DATA_CHANNEL", 5) - sess = onnxrt.InferenceSession(get_name("abs_free_dimensions.onnx"), sess_options=so, - providers=onnxrt.get_available_providers()) + sess = onnxrt.InferenceSession( + get_name("abs_free_dimensions.onnx"), + sess_options=so, + providers=onnxrt.get_available_providers(), + ) input_name = sess.get_inputs()[0].name self.assertEqual(input_name, "x") input_shape = sess.get_inputs()[0].shape @@ -788,8 +841,11 @@ def testSessionOptionsAddFreeDimensionOverrideByName(self): so = onnxrt.SessionOptions() so.add_free_dimension_override_by_name("Dim1", 4) so.add_free_dimension_override_by_name("Dim2", 6) - sess = onnxrt.InferenceSession(get_name("abs_free_dimensions.onnx"), sess_options=so, - providers=onnxrt.get_available_providers()) + sess = onnxrt.InferenceSession( + get_name("abs_free_dimensions.onnx"), + sess_options=so, + providers=onnxrt.get_available_providers(), + ) input_name = sess.get_inputs()[0].name self.assertEqual(input_name, "x") input_shape = sess.get_inputs()[0].shape @@ -809,13 +865,16 @@ def testInvalidSessionOptionsConfigEntry(self): with self.assertRaises(RuntimeError) as context: so.get_session_config_entry(invalide_key) self.assertTrue( - 'SessionOptions does not have configuration with key: ' + invalide_key in str(context.exception)) + "SessionOptions does not have configuration with key: " + invalide_key in str(context.exception) + ) def testSessionOptionsAddInitializer(self): # Create an initializer and add it to a SessionOptions instance so = onnxrt.SessionOptions() # This initializer is different from the actual initializer in the model for "W" - ortvalue_initializer = onnxrt.OrtValue.ortvalue_from_numpy(np.array([[2.0, 1.0], [4.0, 3.0], [6.0, 5.0]], dtype=np.float32)) + ortvalue_initializer = onnxrt.OrtValue.ortvalue_from_numpy( + np.array([[2.0, 1.0], [4.0, 3.0], [6.0, 5.0]], dtype=np.float32) + ) # The user should manage the life cycle of this OrtValue and should keep it in scope # as long as any session that is going to be reliant on it is in scope so.add_initializer("W", ortvalue_initializer) @@ -823,9 +882,17 @@ def testSessionOptionsAddInitializer(self): # Create an InferenceSession that only uses the CPU EP and validate that it uses the # initializer provided via the SessionOptions instance (overriding the model initializer) # We only use the CPU EP because the initializer we created is on CPU and we want the model to use that - sess = onnxrt.InferenceSession(get_name("mul_1.onnx"), sess_options=so, providers=['CPUExecutionProvider']) - res = sess.run(["Y"], {"X": np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], dtype=np.float32)}) - self.assertTrue(np.array_equal(res[0], np.array([[2.0, 2.0], [12.0, 12.0], [30.0, 30.0]], dtype=np.float32))) + sess = onnxrt.InferenceSession(get_name("mul_1.onnx"), sess_options=so, providers=["CPUExecutionProvider"]) + res = sess.run( + ["Y"], + {"X": np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], dtype=np.float32)}, + ) + self.assertTrue( + np.array_equal( + res[0], + np.array([[2.0, 2.0], [12.0, 12.0], [30.0, 30.0]], dtype=np.float32), + ) + ) def testSessionOptionsAddExternalInitializers(self): # Create an external initializer data in OrtValue @@ -834,22 +901,25 @@ def testSessionOptionsAddExternalInitializers(self): so = onnxrt.SessionOptions() so.add_external_initializers(["Pads_not_on_disk"], [ortvalue_initializer]) # This should not throw - onnxrt.InferenceSession(get_name("model_with_external_initializer_come_from_user.onnx"), sess_options=so, providers=['CPUExecutionProvider']) - + onnxrt.InferenceSession( + get_name("model_with_external_initializer_come_from_user.onnx"), + sess_options=so, + providers=["CPUExecutionProvider"], + ) def testRegisterCustomOpsLibrary(self): if sys.platform.startswith("win"): - shared_library = 'custom_op_library.dll' + shared_library = "custom_op_library.dll" if not os.path.exists(shared_library): raise FileNotFoundError("Unable to find '{0}'".format(shared_library)) elif sys.platform.startswith("darwin"): - shared_library = 'libcustom_op_library.dylib' + shared_library = "libcustom_op_library.dylib" if not os.path.exists(shared_library): raise FileNotFoundError("Unable to find '{0}'".format(shared_library)) else: - shared_library = './libcustom_op_library.so' + shared_library = "./libcustom_op_library.so" if not os.path.exists(shared_library): raise FileNotFoundError("Unable to find '{0}'".format(shared_library)) @@ -863,14 +933,14 @@ def testRegisterCustomOpsLibrary(self): # Model loading successfully indicates that the custom op node could be resolved successfully sess1 = onnxrt.InferenceSession(custom_op_model, sess_options=so1, providers=available_providers_without_tvm) - #Run with input data + # Run with input data input_name_0 = sess1.get_inputs()[0].name input_name_1 = sess1.get_inputs()[1].name output_name = sess1.get_outputs()[0].name - input_0 = np.ones((3,5)).astype(np.float32) - input_1 = np.zeros((3,5)).astype(np.float32) + input_0 = np.ones((3, 5)).astype(np.float32) + input_1 = np.zeros((3, 5)).astype(np.float32) res = sess1.run([output_name], {input_name_0: input_0, input_name_1: input_1}) - output_expected = np.ones((3,5)).astype(np.float32) + output_expected = np.ones((3, 5)).astype(np.float32) np.testing.assert_allclose(output_expected, res[0], rtol=1e-05, atol=1e-08) # Create an alias of SessionOptions instance @@ -891,8 +961,7 @@ def testOrtValue(self): numpy_arr_output = np.array([[1.0, 4.0], [9.0, 16.0], [25.0, 36.0]], dtype=np.float32) def test_session_with_ortvalue_input(ortvalue): - sess = onnxrt.InferenceSession(get_name("mul_1.onnx"), - providers=onnxrt.get_available_providers()) + sess = onnxrt.InferenceSession(get_name("mul_1.onnx"), providers=onnxrt.get_available_providers()) res = sess.run(["Y"], {"X": ortvalue}) self.assertTrue(np.array_equal(res[0], numpy_arr_output)) @@ -909,8 +978,8 @@ def test_session_with_ortvalue_input(ortvalue): # The constructed OrtValue should still be valid after being used in a session self.assertTrue(np.array_equal(ortvalue1.numpy(), numpy_arr_input)) - if 'CUDAExecutionProvider' in onnxrt.get_available_providers(): - ortvalue2 = onnxrt.OrtValue.ortvalue_from_numpy(numpy_arr_input, 'cuda', 0) + if "CUDAExecutionProvider" in onnxrt.get_available_providers(): + ortvalue2 = onnxrt.OrtValue.ortvalue_from_numpy(numpy_arr_input, "cuda", 0) self.assertEqual(ortvalue2.device_name(), "cuda") self.assertEqual(ortvalue2.shape(), [3, 2]) self.assertEqual(ortvalue2.data_type(), "tensor(float)") @@ -924,20 +993,22 @@ def test_session_with_ortvalue_input(ortvalue): self.assertTrue(np.array_equal(ortvalue2.numpy(), numpy_arr_input)) def testOrtValue_ghIssue9799(self): - if 'CUDAExecutionProvider' in onnxrt.get_available_providers(): - session = onnxrt.InferenceSession(get_name("identity_9799.onnx"), - providers=onnxrt.get_available_providers()) + if "CUDAExecutionProvider" in onnxrt.get_available_providers(): + session = onnxrt.InferenceSession( + get_name("identity_9799.onnx"), + providers=onnxrt.get_available_providers(), + ) for seq_length in range(40, 200): inps = np.ones((seq_length, 16, 7, 5, 3, 3)).astype(np.float32) - ort_val = onnxrt.OrtValue.ortvalue_from_numpy(inps, 'cuda', 0) - upstreams_onnxrt = {'input': ort_val} - outs = session.run(output_names=['output'], input_feed=upstreams_onnxrt)[0] + ort_val = onnxrt.OrtValue.ortvalue_from_numpy(inps, "cuda", 0) + upstreams_onnxrt = {"input": ort_val} + outs = session.run(output_names=["output"], input_feed=upstreams_onnxrt)[0] self.assertTrue(np.allclose(inps, outs)) def testSparseTensorCooFormat(self): - cpu_device = onnxrt.OrtDevice.make('cpu', 0) - shape = [9,9] + cpu_device = onnxrt.OrtDevice.make("cpu", 0) + shape = [9, 9] values = np.array([1.0, 2.0, 3.0], dtype=np.float32) # Linear indices indices = np.array([3, 5, 15], dtype=np.int64) @@ -945,7 +1016,7 @@ def testSparseTensorCooFormat(self): self.assertEqual(sparse_tensor.format(), onnxrt.OrtSparseFormat.ORT_SPARSE_COO) self.assertEqual(sparse_tensor.dense_shape(), shape) self.assertEqual(sparse_tensor.data_type(), "sparse_tensor(float)") - self.assertEqual(sparse_tensor.device_name(), 'cpu') + self.assertEqual(sparse_tensor.device_name(), "cpu") # Get Data View on a numeric type. values_ret = sparse_tensor.values() @@ -967,12 +1038,12 @@ def testSparseTensorCooFormat(self): gc.collect() # Test string data on cpu only, need to subst values only - str_values = np.array(['xyz', 'yxz', 'zyx'], dtype=str) + str_values = np.array(["xyz", "yxz", "zyx"], dtype=str) str_sparse_tensor = onnxrt.SparseTensor.sparse_coo_from_numpy(shape, str_values, indices, cpu_device) self.assertEqual(str_sparse_tensor.format(), onnxrt.OrtSparseFormat.ORT_SPARSE_COO) self.assertEqual(str_sparse_tensor.dense_shape(), shape) self.assertEqual(str_sparse_tensor.data_type(), "sparse_tensor(string)") - self.assertEqual(str_sparse_tensor.device_name(), 'cpu') + self.assertEqual(str_sparse_tensor.device_name(), "cpu") # Get string values back str_values_ret = str_sparse_tensor.values() @@ -983,13 +1054,13 @@ def testSparseTensorCooFormat(self): self.assertFalse(str_indices_ret.flags.writeable) self.assertTrue(np.array_equal(indices, str_indices_ret)) - cuda_device = onnxrt.OrtDevice.make('cuda', 0) - if 'CUDAExecutionProvider' in onnxrt.get_available_providers(): + cuda_device = onnxrt.OrtDevice.make("cuda", 0) + if "CUDAExecutionProvider" in onnxrt.get_available_providers(): # Test to_cuda copy_on_cuda = sparse_tensor.to_cuda(cuda_device) self.assertEqual(copy_on_cuda.dense_shape(), shape) self.assertEqual(copy_on_cuda.data_type(), "sparse_tensor(float)") - self.assertEqual(copy_on_cuda.device_name(), 'cuda') + self.assertEqual(copy_on_cuda.device_name(), "cuda") # Test that gpu copy would fail to copy to cuda with self.assertRaises(RuntimeError): @@ -1003,16 +1074,18 @@ def testSparseTensorCooFormat(self): sparse_tensor.to_cuda(cuda_device) def testSparseTensorCsrFormat(self): - cpu_device = onnxrt.OrtDevice.make('cpu', 0) - shape = [9,9] + cpu_device = onnxrt.OrtDevice.make("cpu", 0) + shape = [9, 9] values = np.array([1.0, 2.0, 3.0], dtype=np.float32) inner_indices = np.array([1, 1, 1], dtype=np.int64) outer_indices = np.array([0, 1, 2, 3, 3, 3, 3, 3, 3, 3], dtype=np.int64) - sparse_tensor = onnxrt.SparseTensor.sparse_csr_from_numpy(shape, values, inner_indices, outer_indices, cpu_device) + sparse_tensor = onnxrt.SparseTensor.sparse_csr_from_numpy( + shape, values, inner_indices, outer_indices, cpu_device + ) self.assertEqual(sparse_tensor.format(), onnxrt.OrtSparseFormat.ORT_SPARSE_CSRC) self.assertEqual(sparse_tensor.dense_shape(), shape) self.assertEqual(sparse_tensor.data_type(), "sparse_tensor(float)") - self.assertEqual(sparse_tensor.device_name(), 'cpu') + self.assertEqual(sparse_tensor.device_name(), "cpu") # Test CSR(C) indices inner_indices_ret = sparse_tensor.as_csrc_view().inner() @@ -1024,26 +1097,27 @@ def testSparseTensorCsrFormat(self): self.assertTrue(np.array_equal(outer_indices, outer_indices_ret)) # Test with strings - str_values = np.array(['xyz', 'yxz', 'zyx'], dtype=str) - str_sparse_tensor = onnxrt.SparseTensor.sparse_csr_from_numpy(shape, str_values, inner_indices, outer_indices, cpu_device) + str_values = np.array(["xyz", "yxz", "zyx"], dtype=str) + str_sparse_tensor = onnxrt.SparseTensor.sparse_csr_from_numpy( + shape, str_values, inner_indices, outer_indices, cpu_device + ) self.assertEqual(str_sparse_tensor.format(), onnxrt.OrtSparseFormat.ORT_SPARSE_CSRC) self.assertEqual(str_sparse_tensor.dense_shape(), shape) self.assertEqual(str_sparse_tensor.data_type(), "sparse_tensor(string)") - self.assertEqual(str_sparse_tensor.device_name(), 'cpu') + self.assertEqual(str_sparse_tensor.device_name(), "cpu") - if 'CUDAExecutionProvider' in onnxrt.get_available_providers(): - cuda_device = onnxrt.OrtDevice.make('cuda', 0) + if "CUDAExecutionProvider" in onnxrt.get_available_providers(): + cuda_device = onnxrt.OrtDevice.make("cuda", 0) cuda_sparse_tensor = sparse_tensor.to_cuda(cuda_device) - self.assertEqual(cuda_sparse_tensor.device_name(), 'cuda') + self.assertEqual(cuda_sparse_tensor.device_name(), "cuda") self.assertEqual(cuda_sparse_tensor.format(), onnxrt.OrtSparseFormat.ORT_SPARSE_CSRC) self.assertEqual(cuda_sparse_tensor.dense_shape(), shape) self.assertEqual(cuda_sparse_tensor.data_type(), "sparse_tensor(float)") - def testRunModelWithCudaCopyStream(self): available_providers = onnxrt.get_available_providers() - if (not 'CUDAExecutionProvider' in available_providers): + if not "CUDAExecutionProvider" in available_providers: print("Skipping testRunModelWithCudaCopyStream when CUDA is not available") else: # adapted from issue #4829 for a race condition when copy is not on default stream @@ -1052,57 +1126,79 @@ def testRunModelWithCudaCopyStream(self): # 2. it's easier to repro on slower GPU (like M60, Geforce 1070) # to repro #4829, set the CUDA EP do_copy_in_default_stream option to False - providers = [("CUDAExecutionProvider", {"do_copy_in_default_stream": True}), "CPUExecutionProvider"] + providers = [ + ("CUDAExecutionProvider", {"do_copy_in_default_stream": True}), + "CPUExecutionProvider", + ] session = onnxrt.InferenceSession(get_name("issue4829.onnx"), providers=providers) - shape = np.array([2,2], dtype=np.int64) + shape = np.array([2, 2], dtype=np.int64) for iteration in range(100000): - result = session.run(output_names=['output'], input_feed={'shape': shape}) + result = session.run(output_names=["output"], input_feed={"shape": shape}) def testSharedAllocatorUsingCreateAndRegisterAllocator(self): # Create and register an arena based allocator # ort_arena_cfg = onnxrt.OrtArenaCfg(0, -1, -1, -1) (create an OrtArenaCfg like this template if you want to use non-default parameters) - ort_memory_info = onnxrt.OrtMemoryInfo("Cpu", onnxrt.OrtAllocatorType.ORT_ARENA_ALLOCATOR, 0, onnxrt.OrtMemType.DEFAULT) + ort_memory_info = onnxrt.OrtMemoryInfo( + "Cpu", + onnxrt.OrtAllocatorType.ORT_ARENA_ALLOCATOR, + 0, + onnxrt.OrtMemType.DEFAULT, + ) # Use this option if using non-default OrtArenaCfg : onnxrt.create_and_register_allocator(ort_memory_info, ort_arena_cfg) onnxrt.create_and_register_allocator(ort_memory_info, None) # Create a session that will use the registered arena based allocator so1 = onnxrt.SessionOptions() so1.log_severity_level = 1 - so1.add_session_config_entry("session.use_env_allocators", "1"); - onnxrt.InferenceSession(get_name("mul_1.onnx"), sess_options=so1, providers=onnxrt.get_available_providers()) + so1.add_session_config_entry("session.use_env_allocators", "1") + onnxrt.InferenceSession( + get_name("mul_1.onnx"), + sess_options=so1, + providers=onnxrt.get_available_providers(), + ) # Create a session that will NOT use the registered arena based allocator so2 = onnxrt.SessionOptions() so2.log_severity_level = 1 - onnxrt.InferenceSession(get_name("mul_1.onnx"), sess_options=so2, providers=onnxrt.get_available_providers()) + onnxrt.InferenceSession( + get_name("mul_1.onnx"), + sess_options=so2, + providers=onnxrt.get_available_providers(), + ) def testMemoryArenaShrinkage(self): - if platform.architecture()[0] == '32bit' or 'ppc' in platform.machine() or 'powerpc' in platform.machine(): + if platform.architecture()[0] == "32bit" or "ppc" in platform.machine() or "powerpc" in platform.machine(): # on x86 or ppc builds, the CPU allocator does not use an arena print("Skipping testMemoryArenaShrinkage in 32bit or powerpc platform.") else: x = np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], dtype=np.float32) - sess1 = onnxrt.InferenceSession(get_name("mul_1.onnx"), providers=['CPUExecutionProvider']) + sess1 = onnxrt.InferenceSession(get_name("mul_1.onnx"), providers=["CPUExecutionProvider"]) input_name = sess1.get_inputs()[0].name # Shrink CPU memory after execution ro1 = onnxrt.RunOptions() ro1.add_run_config_entry("memory.enable_memory_arena_shrinkage", "cpu:0") - self.assertEqual(ro1.get_run_config_entry("memory.enable_memory_arena_shrinkage"), "cpu:0") + self.assertEqual( + ro1.get_run_config_entry("memory.enable_memory_arena_shrinkage"), + "cpu:0", + ) sess1.run([], {input_name: x}, ro1) available_providers = onnxrt.get_available_providers() - if 'CUDAExecutionProvider' in available_providers: + if "CUDAExecutionProvider" in available_providers: sess2 = onnxrt.InferenceSession(get_name("mul_1.onnx"), providers=available_providers) input_name = sess2.get_inputs()[0].name # Shrink CPU and GPU memory after execution ro2 = onnxrt.RunOptions() ro2.add_run_config_entry("memory.enable_memory_arena_shrinkage", "cpu:0;gpu:0") - self.assertEqual(ro2.get_run_config_entry("memory.enable_memory_arena_shrinkage"), "cpu:0;gpu:0") + self.assertEqual( + ro2.get_run_config_entry("memory.enable_memory_arena_shrinkage"), + "cpu:0;gpu:0", + ) sess2.run([], {input_name: x}, ro2) def testCheckAndNormalizeProviderArgs(self): @@ -1111,8 +1207,10 @@ def testCheckAndNormalizeProviderArgs(self): valid_providers = ["a", "b", "c"] def check_success(providers, provider_options, expected_providers, expected_provider_options): - actual_providers, actual_provider_options = check_and_normalize_provider_args( - providers, provider_options, valid_providers) + ( + actual_providers, + actual_provider_options, + ) = check_and_normalize_provider_args(providers, provider_options, valid_providers) self.assertEqual(actual_providers, expected_providers) self.assertEqual(actual_provider_options, expected_provider_options) @@ -1135,7 +1233,7 @@ def check_failure(providers, provider_options): # disable this test # provider not valid - #check_failure(["d"], None) + # check_failure(["d"], None) # providers not sequence check_failure(3, None) @@ -1157,19 +1255,20 @@ def check_failure(providers, provider_options): def testRegisterCustomEPsLibrary(self): from onnxruntime.capi import _pybind_state as C + available_eps = C.get_available_providers() - #skip amd gpu build - if 'kRocmExecutionProvider' in available_eps: + # skip amd gpu build + if "kRocmExecutionProvider" in available_eps: return if sys.platform.startswith("win"): - shared_library = 'test_execution_provider.dll' + shared_library = "test_execution_provider.dll" elif sys.platform.startswith("darwin"): # exclude for macos return else: - shared_library = './libtest_execution_provider.so' + shared_library = "./libtest_execution_provider.so" if not os.path.exists(shared_library): raise FileNotFoundError("Unable to find '{0}'".format(shared_library)) @@ -1181,11 +1280,19 @@ def testRegisterCustomEPsLibrary(self): session_options = C.get_default_session_options() sess = C.InferenceSession(session_options, custom_op_model, True, True) - sess.initialize_session(['my_ep'], - [{'shared_lib_path': shared_library, - 'device_id':'1', 'some_config':'val'}], - set()) + sess.initialize_session( + ["my_ep"], + [ + { + "shared_lib_path": shared_library, + "device_id": "1", + "some_config": "val", + } + ], + set(), + ) print("Create session with customize execution provider successfully!") -if __name__ == '__main__': + +if __name__ == "__main__": unittest.main(verbosity=1) diff --git a/onnxruntime/test/python/onnxruntime_test_python_backend.py b/onnxruntime/test/python/onnxruntime_test_python_backend.py index a88d62e9d4a70..26752d687f97c 100644 --- a/onnxruntime/test/python/onnxruntime_test_python_backend.py +++ b/onnxruntime/test/python/onnxruntime_test_python_backend.py @@ -3,14 +3,16 @@ # -*- coding: UTF-8 -*- import unittest + import numpy as np +from helper import get_name from numpy.testing import assert_allclose + import onnxruntime as onnxrt import onnxruntime.backend as backend -from helper import get_name -class TestBackend(unittest.TestCase): +class TestBackend(unittest.TestCase): def testRunModel(self): name = get_name("mul_1.onnx") rep = backend.prepare(name) @@ -45,12 +47,12 @@ def testAllocationPlanWorksWithOnlyExecutePathToFetchesOption(self): run_options.only_execute_path_to_fetches = True inp0, inp1 = np.ones((10,), dtype=np.float32), np.ones((10,), dtype=np.float32) - session_run_results = sess.run(['outp0'], {'inp0': inp0, 'inp1': inp1}, run_options) + session_run_results = sess.run(["outp0"], {"inp0": inp0, "inp1": inp1}, run_options) assert_allclose(session_run_results[0], -(inp0 + inp1)) - session_run_results = sess.run(['outp1'], {'inp0': inp0, 'inp1': inp1}, run_options) + session_run_results = sess.run(["outp1"], {"inp0": inp0, "inp1": inp1}, run_options) assert_allclose(session_run_results[0], -(inp0 - inp1)) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main(module=__name__, buffer=True) diff --git a/onnxruntime/test/python/onnxruntime_test_python_backend_mlops.py b/onnxruntime/test/python/onnxruntime_test_python_backend_mlops.py index 91fd5685cb754..b93cf865d4aa0 100644 --- a/onnxruntime/test/python/onnxruntime_test_python_backend_mlops.py +++ b/onnxruntime/test/python/onnxruntime_test_python_backend_mlops.py @@ -3,12 +3,14 @@ # -*- coding: UTF-8 -*- import unittest + import numpy as np -from onnxruntime import datasets +from helper import get_name +from onnx import load + import onnxruntime.backend as backend +from onnxruntime import datasets from onnxruntime.backend.backend import OnnxRuntimeBackend as ort_backend -from onnx import load -from helper import get_name def check_list_of_map_to_float(testcase, expected_rows, actual_rows): @@ -21,14 +23,15 @@ def check_list_of_map_to_float(testcase, expected_rows, actual_rows): for i in range(num_rows): # use np.testing.assert_allclose so we can specify the tolerance - np.testing.assert_allclose([expected_rows[i][key] for key in sorted_keys], - [actual_rows[i][key] for key in sorted_keys], - rtol=1e-05, - atol=1e-07) + np.testing.assert_allclose( + [expected_rows[i][key] for key in sorted_keys], + [actual_rows[i][key] for key in sorted_keys], + rtol=1e-05, + atol=1e-07, + ) class TestBackend(unittest.TestCase): - def testRunModelNonTensor(self): name = get_name("pipeline_vectorize.onnx") rep = backend.prepare(name) @@ -46,19 +49,19 @@ def testRunModelProto(self): res = rep.run(x) output_expected = np.array([0, 0, 0], dtype=np.float32) np.testing.assert_allclose(output_expected, res[0], rtol=1e-05, atol=1e-08) - output_expected = [{ - 0: 0.950599730014801, - 1: 0.027834169566631317, - 2: 0.02156602405011654 - }, { - 0: 0.9974970817565918, - 1: 5.6299926654901356e-05, - 2: 0.0024466661270707846 - }, { - 0: 0.9997311234474182, - 1: 1.1918064757310276e-07, - 2: 0.00026869276189245284 - }] + output_expected = [ + {0: 0.950599730014801, 1: 0.027834169566631317, 2: 0.02156602405011654}, + { + 0: 0.9974970817565918, + 1: 5.6299926654901356e-05, + 2: 0.0024466661270707846, + }, + { + 0: 0.9997311234474182, + 1: 1.1918064757310276e-07, + 2: 0.00026869276189245284, + }, + ] check_list_of_map_to_float(self, output_expected, res[1]) @@ -71,21 +74,22 @@ def testRunModelProtoApi(self): output_expected = np.array([0, 0, 0], dtype=np.float32) np.testing.assert_allclose(output_expected, outputs[0], rtol=1e-05, atol=1e-08) - output_expected = [{ - 0: 0.950599730014801, - 1: 0.027834169566631317, - 2: 0.02156602405011654 - }, { - 0: 0.9974970817565918, - 1: 5.6299926654901356e-05, - 2: 0.0024466661270707846 - }, { - 0: 0.9997311234474182, - 1: 1.1918064757310276e-07, - 2: 0.00026869276189245284 - }] + output_expected = [ + {0: 0.950599730014801, 1: 0.027834169566631317, 2: 0.02156602405011654}, + { + 0: 0.9974970817565918, + 1: 5.6299926654901356e-05, + 2: 0.0024466661270707846, + }, + { + 0: 0.9997311234474182, + 1: 1.1918064757310276e-07, + 2: 0.00026869276189245284, + }, + ] check_list_of_map_to_float(self, output_expected, outputs[1]) -if __name__ == '__main__': + +if __name__ == "__main__": unittest.main(module=__name__, buffer=True) diff --git a/onnxruntime/test/python/onnxruntime_test_python_cudagraph.py b/onnxruntime/test/python/onnxruntime_test_python_cudagraph.py index c03a2931b1098..b71b3a07cd41f 100644 --- a/onnxruntime/test/python/onnxruntime_test_python_cudagraph.py +++ b/onnxruntime/test/python/onnxruntime_test_python_cudagraph.py @@ -1,65 +1,79 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. +import gc +import os +import sys +import threading +import time + # -*- coding: UTF-8 -*- import unittest -import os + import numpy as np -import gc +from helper import get_name import onnxruntime as onnxrt -import threading -import sys -from helper import get_name from onnxruntime.capi.onnxruntime_pybind11_state import Fail -import time + class TestInferenceSessionWithCudaGraph(unittest.TestCase): - def testOrtValueUpdateInPlace(self): - x0 = np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], dtype=np.float32) - ortvalue_cpu = onnxrt.OrtValue.ortvalue_from_numpy(x0) - np.testing.assert_allclose(x0, ortvalue_cpu.numpy()) - - x1 = np.array([[10.0, 20.0], [30.0, 40.0], [50.0, 60.0]], dtype=np.float32) - ortvalue_cpu.update_inplace(x1) - np.testing.assert_allclose(x1, ortvalue_cpu.numpy()) - - if 'CUDAExecutionProvider' in onnxrt.get_available_providers(): - ortvalue_gpu = onnxrt.OrtValue.ortvalue_from_numpy(x0, 'cuda', 0) - np.testing.assert_allclose(x0, ortvalue_gpu.numpy()) - - ortvalue_gpu.update_inplace(x1) - np.testing.assert_allclose(x1, ortvalue_gpu.numpy()) - - def testRunModelWithCudaGraph(self): - if 'CUDAExecutionProvider' in onnxrt.get_available_providers(): - providers = [('CUDAExecutionProvider', {'enable_cuda_graph': True})] - INPUT_SIZE = 1280 - x = np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]*INPUT_SIZE, dtype=np.float32) - y = np.array([[0.0], [0.0], [0.0]]*INPUT_SIZE, dtype=np.float32) - x_ortvalue = onnxrt.OrtValue.ortvalue_from_numpy(x, 'cuda', 0) - y_ortvalue = onnxrt.OrtValue.ortvalue_from_numpy(y, 'cuda', 0) - - session = onnxrt.InferenceSession(get_name("matmul_2.onnx"), providers=providers) - io_binding = session.io_binding() - - # Bind the input and output - io_binding.bind_ortvalue_input('X', x_ortvalue) - io_binding.bind_ortvalue_output('Y', y_ortvalue) - - # One regular run for the necessary memory allocation and cuda graph capturing - session.run_with_iobinding(io_binding) - expected_y = np.array([[5.0], [11.0], [17.0]]*INPUT_SIZE, dtype=np.float32) - np.testing.assert_allclose(expected_y, y_ortvalue.numpy(), rtol=1e-05, atol=1e-05) - - # After capturing, CUDA graph replay happens from this Run onwards - session.run_with_iobinding(io_binding) - np.testing.assert_allclose(expected_y, y_ortvalue.numpy(), rtol=1e-05, atol=1e-05) - - # Update input and then replay CUDA graph - x_ortvalue.update_inplace(np.array([[10.0, 20.0], [30.0, 40.0], [50.0, 60.0]]*INPUT_SIZE, dtype=np.float32)) - session.run_with_iobinding(io_binding) - np.testing.assert_allclose(np.array([[50.0], [110.0], [170.0]]*INPUT_SIZE, dtype=np.float32), y_ortvalue.numpy(), rtol=1e-05, atol=1e-05) - -if __name__ == '__main__': + def testOrtValueUpdateInPlace(self): + x0 = np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], dtype=np.float32) + ortvalue_cpu = onnxrt.OrtValue.ortvalue_from_numpy(x0) + np.testing.assert_allclose(x0, ortvalue_cpu.numpy()) + + x1 = np.array([[10.0, 20.0], [30.0, 40.0], [50.0, 60.0]], dtype=np.float32) + ortvalue_cpu.update_inplace(x1) + np.testing.assert_allclose(x1, ortvalue_cpu.numpy()) + + if "CUDAExecutionProvider" in onnxrt.get_available_providers(): + ortvalue_gpu = onnxrt.OrtValue.ortvalue_from_numpy(x0, "cuda", 0) + np.testing.assert_allclose(x0, ortvalue_gpu.numpy()) + + ortvalue_gpu.update_inplace(x1) + np.testing.assert_allclose(x1, ortvalue_gpu.numpy()) + + def testRunModelWithCudaGraph(self): + if "CUDAExecutionProvider" in onnxrt.get_available_providers(): + providers = [("CUDAExecutionProvider", {"enable_cuda_graph": True})] + INPUT_SIZE = 1280 + x = np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]] * INPUT_SIZE, dtype=np.float32) + y = np.array([[0.0], [0.0], [0.0]] * INPUT_SIZE, dtype=np.float32) + x_ortvalue = onnxrt.OrtValue.ortvalue_from_numpy(x, "cuda", 0) + y_ortvalue = onnxrt.OrtValue.ortvalue_from_numpy(y, "cuda", 0) + + session = onnxrt.InferenceSession(get_name("matmul_2.onnx"), providers=providers) + io_binding = session.io_binding() + + # Bind the input and output + io_binding.bind_ortvalue_input("X", x_ortvalue) + io_binding.bind_ortvalue_output("Y", y_ortvalue) + + # One regular run for the necessary memory allocation and cuda graph capturing + session.run_with_iobinding(io_binding) + expected_y = np.array([[5.0], [11.0], [17.0]] * INPUT_SIZE, dtype=np.float32) + np.testing.assert_allclose(expected_y, y_ortvalue.numpy(), rtol=1e-05, atol=1e-05) + + # After capturing, CUDA graph replay happens from this Run onwards + session.run_with_iobinding(io_binding) + np.testing.assert_allclose(expected_y, y_ortvalue.numpy(), rtol=1e-05, atol=1e-05) + + # Update input and then replay CUDA graph + x_ortvalue.update_inplace( + np.array( + [[10.0, 20.0], [30.0, 40.0], [50.0, 60.0]] * INPUT_SIZE, + dtype=np.float32, + ) + ) + session.run_with_iobinding(io_binding) + np.testing.assert_allclose( + np.array([[50.0], [110.0], [170.0]] * INPUT_SIZE, dtype=np.float32), + y_ortvalue.numpy(), + rtol=1e-05, + atol=1e-05, + ) + + +if __name__ == "__main__": unittest.main() diff --git a/onnxruntime/test/python/onnxruntime_test_python_iobinding.py b/onnxruntime/test/python/onnxruntime_test_python_iobinding.py index 03a9399ac1791..489e05626608c 100644 --- a/onnxruntime/test/python/onnxruntime_test_python_iobinding.py +++ b/onnxruntime/test/python/onnxruntime_test_python_iobinding.py @@ -1,25 +1,33 @@ +import unittest + import numpy as np +from helper import get_name from numpy.testing import assert_almost_equal -from onnx.mapping import NP_TYPE_TO_TENSOR_TYPE -from onnx.defs import onnx_opset_version from onnx import helper +from onnx.defs import onnx_opset_version +from onnx.mapping import NP_TYPE_TO_TENSOR_TYPE + import onnxruntime as onnxrt -from onnxruntime.capi._pybind_state import ( # pylint: disable=E0611 - OrtDevice as C_OrtDevice, OrtValue as C_OrtValue, SessionIOBinding) -import unittest +from onnxruntime.capi._pybind_state import OrtDevice as C_OrtDevice # pylint: disable=E0611 +from onnxruntime.capi._pybind_state import OrtValue as C_OrtValue +from onnxruntime.capi._pybind_state import SessionIOBinding -from helper import get_name class TestIOBinding(unittest.TestCase): - def create_ortvalue_input_on_gpu(self): - return onnxrt.OrtValue.ortvalue_from_numpy(np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], dtype=np.float32), 'cuda', 0) + return onnxrt.OrtValue.ortvalue_from_numpy( + np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], dtype=np.float32), "cuda", 0 + ) def create_ortvalue_alternate_input_on_gpu(self): - return onnxrt.OrtValue.ortvalue_from_numpy(np.array([[2.0, 4.0], [6.0, 8.0], [10.0, 12.0]], dtype=np.float32), 'cuda', 0) + return onnxrt.OrtValue.ortvalue_from_numpy( + np.array([[2.0, 4.0], [6.0, 8.0], [10.0, 12.0]], dtype=np.float32), + "cuda", + 0, + ) def create_uninitialized_ortvalue_input_on_gpu(self): - return onnxrt.OrtValue.ortvalue_from_shape_and_type([3, 2], np.float32, 'cuda', 0) + return onnxrt.OrtValue.ortvalue_from_shape_and_type([3, 2], np.float32, "cuda", 0) def create_numpy_input(self): return np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], dtype=np.float32) @@ -35,16 +43,16 @@ def test_bind_input_to_cpu_arr(self): session = onnxrt.InferenceSession(get_name("mul_1.onnx"), providers=onnxrt.get_available_providers()) io_binding = session.io_binding() - + # Bind Numpy object (input) that's on CPU to wherever the model needs it - io_binding.bind_cpu_input('X', self.create_numpy_input()) - + io_binding.bind_cpu_input("X", self.create_numpy_input()) + # Bind output to CPU - io_binding.bind_output('Y') - + io_binding.bind_output("Y") + # Invoke Run session.run_with_iobinding(io_binding) - + # Sync if different CUDA streams io_binding.synchronize_outputs() @@ -57,46 +65,70 @@ def test_bind_input_to_cpu_arr(self): def test_bind_input_types(self): opset = onnx_opset_version() - devices = [(C_OrtDevice(C_OrtDevice.cpu(), C_OrtDevice.default_memory(), 0), ['CPUExecutionProvider'])] + devices = [ + ( + C_OrtDevice(C_OrtDevice.cpu(), C_OrtDevice.default_memory(), 0), + ["CPUExecutionProvider"], + ) + ] if "CUDAExecutionProvider" in onnxrt.get_all_providers(): - devices.append((C_OrtDevice(C_OrtDevice.cuda(), C_OrtDevice.default_memory(), 0), ['CUDAExecutionProvider'])) - + devices.append( + ( + C_OrtDevice(C_OrtDevice.cuda(), C_OrtDevice.default_memory(), 0), + ["CUDAExecutionProvider"], + ) + ) + for device, provider in devices: - for dtype in [np.float32, np.float64, np.int32, np.uint32, - np.int64, np.uint64, np.int16, np.uint16, - np.int8, np.uint8, np.float16, np.bool_]: + for dtype in [ + np.float32, + np.float64, + np.int32, + np.uint32, + np.int64, + np.uint64, + np.int16, + np.uint16, + np.int8, + np.uint8, + np.float16, + np.bool_, + ]: with self.subTest(dtype=dtype, device=str(device)): x = np.arange(8).reshape((-1, 2)).astype(dtype) proto_dtype = NP_TYPE_TO_TENSOR_TYPE[x.dtype] - X = helper.make_tensor_value_info('X', proto_dtype, [None, x.shape[1]]) - Y = helper.make_tensor_value_info('Y', proto_dtype, [None, x.shape[1]]) + X = helper.make_tensor_value_info("X", proto_dtype, [None, x.shape[1]]) + Y = helper.make_tensor_value_info("Y", proto_dtype, [None, x.shape[1]]) # inference - node_add = helper.make_node('Identity', ['X'], ['Y']) + node_add = helper.make_node("Identity", ["X"], ["Y"]) # graph - graph_def = helper.make_graph([node_add], 'lr', [X], [Y], []) + graph_def = helper.make_graph([node_add], "lr", [X], [Y], []) model_def = helper.make_model( - graph_def, producer_name='dummy', ir_version=7, + graph_def, + producer_name="dummy", + ir_version=7, producer_version="0", - opset_imports=[helper.make_operatorsetid('', opset)]) + opset_imports=[helper.make_operatorsetid("", opset)], + ) sess = onnxrt.InferenceSession(model_def.SerializeToString(), providers=provider) bind = SessionIOBinding(sess._sess) ort_value = C_OrtValue.ortvalue_from_numpy(x, device) - bind.bind_ortvalue_input('X', ort_value) - bind.bind_output('Y', device) + bind.bind_ortvalue_input("X", ort_value) + bind.bind_output("Y", device) sess._sess.run_with_iobinding(bind, None) ortvalue = bind.get_outputs()[0] y = ortvalue.numpy() assert_almost_equal(x, y) bind = SessionIOBinding(sess._sess) - bind.bind_input('X', device, dtype, x.shape, ort_value.data_ptr()) - bind.bind_output('Y', device) + bind.bind_input("X", device, dtype, x.shape, ort_value.data_ptr()) + bind.bind_output("Y", device) sess._sess.run_with_iobinding(bind, None) ortvalue = bind.get_outputs()[0] y = ortvalue.numpy() @@ -107,22 +139,22 @@ def test_bind_input_only(self): session = onnxrt.InferenceSession(get_name("mul_1.onnx"), providers=onnxrt.get_available_providers()) io_binding = session.io_binding() - + # Bind input to CUDA - io_binding.bind_input('X', 'cuda', 0, np.float32, [3, 2], input.data_ptr()) + io_binding.bind_input("X", "cuda", 0, np.float32, [3, 2], input.data_ptr()) # Sync if different CUDA streams io_binding.synchronize_inputs() # Bind output to CPU - io_binding.bind_output('Y') - + io_binding.bind_output("Y") + # Invoke Run session.run_with_iobinding(io_binding) # Sync if different CUDA streams io_binding.synchronize_outputs() - + # Get outputs over to CPU (the outputs which were bound to CUDA will get copied over to the host here) ort_output = io_binding.copy_outputs_to_cpu()[0] @@ -134,13 +166,13 @@ def test_bind_input_and_preallocated_output(self): session = onnxrt.InferenceSession(get_name("mul_1.onnx"), providers=onnxrt.get_available_providers()) io_binding = session.io_binding() - + # Bind input to CUDA - io_binding.bind_input('X', 'cuda', 0, np.float32, [3, 2], input.data_ptr()) + io_binding.bind_input("X", "cuda", 0, np.float32, [3, 2], input.data_ptr()) # Bind output to CUDA output = self.create_uninitialized_ortvalue_input_on_gpu() - io_binding.bind_output('Y', 'cuda', 0, np.float32, [3, 2], output.data_ptr()) + io_binding.bind_output("Y", "cuda", 0, np.float32, [3, 2], output.data_ptr()) # Sync if different CUDA streams io_binding.synchronize_inputs() @@ -150,28 +182,34 @@ def test_bind_input_and_preallocated_output(self): # Sync if different CUDA streams io_binding.synchronize_outputs() - + # Get outputs over to CPU (the outputs which were bound to CUDA will get copied over to the host here) ort_output_vals = io_binding.copy_outputs_to_cpu()[0] # Validate results self.assertTrue(np.array_equal(self.create_expected_output(), ort_output_vals)) - + # Validate if ORT actually wrote to pre-allocated buffer by copying the Torch allocated buffer # to the host and validating its contents ort_output_vals_in_cpu = output.numpy() # Validate results self.assertTrue(np.array_equal(self.create_expected_output(), ort_output_vals_in_cpu)) - def test_bind_input_and_non_preallocated_output(self): session = onnxrt.InferenceSession(get_name("mul_1.onnx"), providers=onnxrt.get_available_providers()) io_binding = session.io_binding() - + # Bind input to CUDA - io_binding.bind_input('X', 'cuda', 0, np.float32, [3, 2], self.create_ortvalue_input_on_gpu().data_ptr()) + io_binding.bind_input( + "X", + "cuda", + 0, + np.float32, + [3, 2], + self.create_ortvalue_input_on_gpu().data_ptr(), + ) # Bind output to CUDA - io_binding.bind_output('Y', 'cuda') + io_binding.bind_output("Y", "cuda") # Sync if different CUDA streams io_binding.synchronize_inputs() @@ -188,7 +226,7 @@ def test_bind_input_and_non_preallocated_output(self): self.assertEqual(ort_outputs[0].device_name(), "cuda") # Validate results (by copying results to CPU by creating a Numpy object) self.assertTrue(np.array_equal(self.create_expected_output(), ort_outputs[0].numpy())) - + # We should be able to repeat the above process as many times as we want - try once more ort_outputs = io_binding.get_outputs() self.assertEqual(len(ort_outputs), 1) @@ -198,7 +236,14 @@ def test_bind_input_and_non_preallocated_output(self): # Change the bound input and validate the results in the same bound OrtValue # Bind alternate input to CUDA - io_binding.bind_input('X', 'cuda', 0, np.float32, [3, 2], self.create_ortvalue_alternate_input_on_gpu().data_ptr()) + io_binding.bind_input( + "X", + "cuda", + 0, + np.float32, + [3, 2], + self.create_ortvalue_alternate_input_on_gpu().data_ptr(), + ) # Sync if different CUDA streams io_binding.synchronize_inputs() @@ -219,14 +264,14 @@ def test_bind_input_and_non_preallocated_output(self): def test_bind_input_and_bind_output_with_ortvalues(self): session = onnxrt.InferenceSession(get_name("mul_1.onnx"), providers=onnxrt.get_available_providers()) io_binding = session.io_binding() - + # Bind ortvalue as input input_ortvalue = self.create_ortvalue_input_on_gpu() - io_binding.bind_ortvalue_input('X', input_ortvalue) + io_binding.bind_ortvalue_input("X", input_ortvalue) # Bind ortvalue as output output_ortvalue = self.create_uninitialized_ortvalue_input_on_gpu() - io_binding.bind_ortvalue_output('Y', output_ortvalue) + io_binding.bind_ortvalue_output("Y", output_ortvalue) # Sync if different CUDA streams io_binding.synchronize_inputs() @@ -242,7 +287,7 @@ def test_bind_input_and_bind_output_with_ortvalues(self): # Bind another ortvalue as input input_ortvalue_2 = self.create_ortvalue_alternate_input_on_gpu() - io_binding.bind_ortvalue_input('X', input_ortvalue_2) + io_binding.bind_ortvalue_input("X", input_ortvalue_2) # Sync if different CUDA streams io_binding.synchronize_inputs() @@ -257,5 +302,5 @@ def test_bind_input_and_bind_output_with_ortvalues(self): self.assertTrue(np.array_equal(self.create_expected_output_alternate(), output_ortvalue.numpy())) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/onnxruntime/test/python/onnxruntime_test_python_keras.py b/onnxruntime/test/python/onnxruntime_test_python_keras.py index 02e7cdb8e7d71..fb94f67757844 100644 --- a/onnxruntime/test/python/onnxruntime_test_python_keras.py +++ b/onnxruntime/test/python/onnxruntime_test_python_keras.py @@ -4,16 +4,17 @@ # -*- coding: UTF-8 -*- # Taken from https://github.com/onnx/onnxmltools/blob/master/tests/end2end/test_custom_op.py. import unittest + import numpy as np import onnxmltools -import onnxruntime as onnxrt -from keras import backend as K from keras import Sequential -from keras.layers import Layer, Conv2D, MaxPooling2D +from keras import backend as K +from keras.layers import Conv2D, Layer, MaxPooling2D +import onnxruntime as onnxrt -class ScaledTanh(Layer): +class ScaledTanh(Layer): def __init__(self, alpha=1.0, beta=1.0, **kwargs): super(ScaledTanh, self).__init__(**kwargs) self.alpha = alpha @@ -31,16 +32,17 @@ def compute_output_shape(self, input_shape): def custom_activation(scope, operator, container): # type:(ScopeBase, OperatorBase, ModelContainer) -> None - container.add_node('ScaledTanh', - operator.input_full_names, - operator.output_full_names, - op_version=1, - alpha=operator.original_operator.alpha, - beta=operator.original_operator.beta) + container.add_node( + "ScaledTanh", + operator.input_full_names, + operator.output_full_names, + op_version=1, + alpha=operator.original_operator.alpha, + beta=operator.original_operator.beta, + ) class TestInferenceSessionKeras(unittest.TestCase): - def testRunModelConv(self): # keras model @@ -49,16 +51,19 @@ def testRunModelConv(self): model = Sequential() model.add( - Conv2D(2, - kernel_size=(1, 2), - strides=(1, 1), - padding='valid', - input_shape=(H, W, C), - data_format='channels_last')) + Conv2D( + 2, + kernel_size=(1, 2), + strides=(1, 1), + padding="valid", + input_shape=(H, W, C), + data_format="channels_last", + ) + ) model.add(ScaledTanh(0.9, 2.0)) - model.add(MaxPooling2D((2, 2), strides=(2, 2), data_format='channels_last')) + model.add(MaxPooling2D((2, 2), strides=(2, 2), data_format="channels_last")) - model.compile(optimizer='sgd', loss='mse') + model.compile(optimizer="sgd", loss="mse") actual = model.predict(x) self.assertIsNotNone(actual) @@ -75,5 +80,5 @@ def testRunModelConv(self): np.testing.assert_allclose(actual, actual_rt[0], rtol=1e-05, atol=1e-08) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main(module=__name__, buffer=True) diff --git a/onnxruntime/test/python/onnxruntime_test_python_mlops.py b/onnxruntime/test/python/onnxruntime_test_python_mlops.py index 740f240ed01c3..b6604a6d51e8a 100644 --- a/onnxruntime/test/python/onnxruntime_test_python_mlops.py +++ b/onnxruntime/test/python/onnxruntime_test_python_mlops.py @@ -1,65 +1,68 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. +import os + # -*- coding: UTF-8 -*- import unittest -import os + import numpy as np -import onnxruntime as onnxrt from helper import get_name +import onnxruntime as onnxrt -class TestInferenceSession(unittest.TestCase): +class TestInferenceSession(unittest.TestCase): def testZipMapStringFloat(self): - sess = onnxrt.InferenceSession(get_name("zipmap_stringfloat.onnx"), - providers=onnxrt.get_available_providers()) + sess = onnxrt.InferenceSession( + get_name("zipmap_stringfloat.onnx"), + providers=onnxrt.get_available_providers(), + ) x = np.array([1.0, 0.0, 3.0, 44.0, 23.0, 11.0], dtype=np.float32).reshape((2, 3)) x_name = sess.get_inputs()[0].name self.assertEqual(x_name, "X") x_type = sess.get_inputs()[0].type - self.assertEqual(x_type, 'tensor(float)') + self.assertEqual(x_type, "tensor(float)") output_name = sess.get_outputs()[0].name self.assertEqual(output_name, "Z") output_type = sess.get_outputs()[0].type - self.assertEqual(output_type, 'seq(map(string,tensor(float)))') - - output_expected = [{ - 'class2': 0.0, - 'class1': 1.0, - 'class3': 3.0 - }, { - 'class2': 23.0, - 'class1': 44.0, - 'class3': 11.0 - }] + self.assertEqual(output_type, "seq(map(string,tensor(float)))") + + output_expected = [ + {"class2": 0.0, "class1": 1.0, "class3": 3.0}, + {"class2": 23.0, "class1": 44.0, "class3": 11.0}, + ] res = sess.run([output_name], {x_name: x}) self.assertEqual(output_expected, res[0]) def testZipMapInt64Float(self): - sess = onnxrt.InferenceSession(get_name("zipmap_int64float.onnx"), - providers=onnxrt.get_available_providers()) + sess = onnxrt.InferenceSession( + get_name("zipmap_int64float.onnx"), + providers=onnxrt.get_available_providers(), + ) x = np.array([1.0, 0.0, 3.0, 44.0, 23.0, 11.0], dtype=np.float32).reshape((2, 3)) x_name = sess.get_inputs()[0].name self.assertEqual(x_name, "X") x_type = sess.get_inputs()[0].type - self.assertEqual(x_type, 'tensor(float)') + self.assertEqual(x_type, "tensor(float)") output_name = sess.get_outputs()[0].name self.assertEqual(output_name, "Z") output_type = sess.get_outputs()[0].type - self.assertEqual(output_type, 'seq(map(int64,tensor(float)))') + self.assertEqual(output_type, "seq(map(int64,tensor(float)))") output_expected = [{10: 1.0, 20: 0.0, 30: 3.0}, {10: 44.0, 20: 23.0, 30: 11.0}] res = sess.run([output_name], {x_name: x}) self.assertEqual(output_expected, res[0]) def testDictVectorizer(self): - sess = onnxrt.InferenceSession(get_name("pipeline_vectorize.onnx"), - providers=onnxrt.get_available_providers()) + sess = onnxrt.InferenceSession( + get_name("pipeline_vectorize.onnx"), + providers=onnxrt.get_available_providers(), + ) input_name = sess.get_inputs()[0].name self.assertEqual(input_name, "float_input") input_type = str(sess.get_inputs()[0].type) @@ -84,7 +87,10 @@ def testDictVectorizer(self): try: res = sess.run([output_name], {input_name: xwrong}) except RuntimeError as e: - self.assertIn("Unexpected key type , it cannot be linked to C type int64_t", str(e)) + self.assertIn( + "Unexpected key type , it cannot be linked to C type int64_t", + str(e), + ) # numpy type x = {np.int64(k): np.float32(v) for k, v in x.items()} @@ -103,8 +109,7 @@ def testDictVectorizer(self): np.testing.assert_allclose(output_expected, res[0], rtol=1e-05, atol=1e-08) def testLabelEncoder(self): - sess = onnxrt.InferenceSession(get_name("LabelEncoder.onnx"), - providers=onnxrt.get_available_providers()) + sess = onnxrt.InferenceSession(get_name("LabelEncoder.onnx"), providers=onnxrt.get_available_providers()) input_name = sess.get_inputs()[0].name self.assertEqual(input_name, "input") input_type = str(sess.get_inputs()[0].type) @@ -119,18 +124,18 @@ def testLabelEncoder(self): self.assertEqual(output_shape, [1, 1]) # Array - x = np.array([['4']]) + x = np.array([["4"]]) res = sess.run([output_name], {input_name: x}) output_expected = np.array([[3]], dtype=np.int64) np.testing.assert_allclose(output_expected, res[0], rtol=1e-05, atol=1e-08) # Python type - x = np.array(['4'], ndmin=2) + x = np.array(["4"], ndmin=2) res = sess.run([output_name], {input_name: x}) output_expected = np.array([3], ndmin=2, dtype=np.int64) np.testing.assert_allclose(output_expected, res[0], rtol=1e-05, atol=1e-08) - x = np.array(['4'], ndmin=2, dtype=object) + x = np.array(["4"], ndmin=2, dtype=object) res = sess.run([output_name], {input_name: x}) output_expected = np.array([3], ndmin=2, dtype=np.int64) np.testing.assert_allclose(output_expected, res[0], rtol=1e-05, atol=1e-08) @@ -142,23 +147,28 @@ def test_run_model_mlnet(self): # where one node is assigned to CUDA and one node to DML, as it doesn't have the data transfer capabilities to # deal with potentially different device memory. Hence, use a session with only DML and CPU (excluding CUDA) # for this test as it breaks with both CUDA and DML registered. - if ('CUDAExecutionProvider' in available_providers and 'DmlExecutionProvider' in available_providers): - sess = onnxrt.InferenceSession(get_name("mlnet_encoder.onnx"), None, - ['DmlExecutionProvider', 'CPUExecutionProvider']) + if "CUDAExecutionProvider" in available_providers and "DmlExecutionProvider" in available_providers: + sess = onnxrt.InferenceSession( + get_name("mlnet_encoder.onnx"), + None, + ["DmlExecutionProvider", "CPUExecutionProvider"], + ) else: sess = onnxrt.InferenceSession(get_name("mlnet_encoder.onnx"), providers=available_providers) names = [_.name for _ in sess.get_outputs()] - self.assertEqual(['C00', 'C12'], names) - c0 = np.array([5.], dtype=np.float32).reshape(1, 1) + self.assertEqual(["C00", "C12"], names) + c0 = np.array([5.0], dtype=np.float32).reshape(1, 1) - c1 = np.array([b'A\0A\0', b"B\0B\0", b"C\0C\0"], np.void).reshape(1, 3) - res = sess.run(None, {'C0': c0, 'C1': c1}) + c1 = np.array([b"A\0A\0", b"B\0B\0", b"C\0C\0"], np.void).reshape(1, 3) + res = sess.run(None, {"C0": c0, "C1": c1}) mat = res[1] total = mat.sum() self.assertEqual(total, 2) - self.assertEqual(list(mat.ravel()), - list(np.array([[[0., 0., 0., 0.], [1., 0., 0., 0.], [0., 0., 1., 0.]]]).ravel())) + self.assertEqual( + list(mat.ravel()), + list(np.array([[[0.0, 0.0, 0.0, 0.0], [1.0, 0.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0]]]).ravel()), + ) # In memory, the size of each element is fixed and equal to the # longest element. We cannot use bytes because numpy is trimming @@ -166,8 +176,8 @@ def test_run_model_mlnet(self): # (to save space). It does not have this behaviour for void # but as a result, numpy does not know anymore the size # of each element, they all have the same size. - c1 = np.array([b'A\0A\0\0', b"B\0B\0\0", b"C\0C\0\0"], np.void).reshape(1, 3) - res = sess.run(None, {'C0': c0, 'C1': c1}) + c1 = np.array([b"A\0A\0\0", b"B\0B\0\0", b"C\0C\0\0"], np.void).reshape(1, 3) + res = sess.run(None, {"C0": c0, "C1": c1}) mat = res[1] total = mat.sum() self.assertEqual(total, 0) @@ -179,24 +189,34 @@ def test_run_model_tree_ensemble_aionnxml_3(self): # first threshold of the tree is 1.7999999523162842 # all number 1.79* are the same once converting to float32. # predictions must be the same with float32 and different with float64. - iris = np.array([[0, 1, 1.7999999523162842, 3], - [0, 1, 1.7999999523, 3], - [0, 1, 1.79999995232, 3]], dtype=np.float64) + iris = np.array( + [ + [0, 1, 1.7999999523162842, 3], + [0, 1, 1.7999999523, 3], + [0, 1, 1.79999995232, 3], + ], + dtype=np.float64, + ) sess = onnxrt.InferenceSession(model, providers=available_providers) - got = sess.run(None, {'X': iris}) + got = sess.run(None, {"X": iris}) self.assertEqual(got[0].dtype, np.float64) self.assertEqual(got[0].shape, (3, 1)) res64 = got[0].tolist() self.assertEqual(res64, [[0.7284910678863525], [0.7284910678863525], [0.9134243130683899]]) - iris = np.array([[0, 1, 1.7999999523162842, 3], - [0, 1, 1.7999999523, 3], - [0, 1, 1.79999995232, 3]], dtype=np.float32) - got = sess.run(None, {'X': iris.astype(np.float64)}) + iris = np.array( + [ + [0, 1, 1.7999999523162842, 3], + [0, 1, 1.7999999523, 3], + [0, 1, 1.79999995232, 3], + ], + dtype=np.float32, + ) + got = sess.run(None, {"X": iris.astype(np.float64)}) self.assertEqual(got[0].dtype, np.float64) self.assertEqual(got[0].shape, (3, 1)) res32 = got[0].tolist() self.assertEqual(res32, [[0.7284910678863525], [0.7284910678863525], [0.7284910678863525]]) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/onnxruntime/test/python/onnxruntime_test_python_nuphar.py b/onnxruntime/test/python/onnxruntime_test_python_nuphar.py index 463ddb2d3e4f2..0db5fe9a5b498 100644 --- a/onnxruntime/test/python/onnxruntime_test_python_nuphar.py +++ b/onnxruntime/test/python/onnxruntime_test_python_nuphar.py @@ -1,65 +1,70 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. -# -*- coding: UTF-8 -*- -import numpy as np -import onnx -from onnx import helper, numpy_helper -import onnxruntime as onnxrt -from helper import get_name import os -from onnxruntime.nuphar.rnn_benchmark import perf_test, generate_model -from onnxruntime.nuphar.model_tools import validate_with_ort, run_shape_inference import shutil -import sys import subprocess +import sys import tarfile import unittest import urllib.request +# -*- coding: UTF-8 -*- +import numpy as np +import onnx +from helper import get_name +from onnx import helper, numpy_helper + +import onnxruntime as onnxrt +from onnxruntime.nuphar.model_tools import run_shape_inference, validate_with_ort +from onnxruntime.nuphar.rnn_benchmark import generate_model, perf_test + + def reference_gemm(a, b, c, alpha, beta, transA, transB): a = a if transA == 0 else a.T b = b if transB == 0 else b.T y = alpha * np.dot(a, b) + beta * c return y + def set_gemm_node_attrs(attrs, config): - if config['alpha'] != 1.0: - attrs['alpha'] = config['alpha'] - if config['beta'] != 1.0: - attrs['beta'] = config['beta'] - if config['transA']: - attrs['transA'] = 1 - if config['transB']: - attrs['transB'] = 1 + if config["alpha"] != 1.0: + attrs["alpha"] = config["alpha"] + if config["beta"] != 1.0: + attrs["beta"] = config["beta"] + if config["transA"]: + attrs["transA"] = 1 + if config["transB"]: + attrs["transB"] = 1 + def generate_gemm_inputs_initializers(graph, config, added_inputs_initializers={}, extend=False): - M = config['M'] - K = config['K'] - N = config['N'] + M = config["M"] + K = config["K"] + N = config["N"] - shape_a = [K, M] if config['transA'] else [M, K] - shape_b = [N, K] if config['transB'] else [K, N] + shape_a = [K, M] if config["transA"] else [M, K] + shape_b = [N, K] if config["transB"] else [K, N] shape_c = [M, N] # when A/B/C are graph input of the main graph which contains # a Scan node, then they need an extra 'seq' dimension - input_shape_a = ['seq'] + shape_a if extend else shape_a - input_shape_b = ['seq'] + shape_b if extend else shape_b - input_shape_c = ['seq'] + shape_c if extend else shape_c + input_shape_a = ["seq"] + shape_a if extend else shape_a + input_shape_b = ["seq"] + shape_b if extend else shape_b + input_shape_c = ["seq"] + shape_c if extend else shape_c np.random.seed(12345) a = np.random.ranf(shape_a).astype(np.float32) b = np.random.ranf(shape_b).astype(np.float32) - c = np.random.ranf(shape_c).astype(np.float32) if config['withC'] else np.array(0) + c = np.random.ranf(shape_c).astype(np.float32) if config["withC"] else np.array(0) - init_a = a if config['initA'] else None - init_b = b if config['initB'] else None - init_c = c if config['initC'] else None + init_a = a if config["initA"] else None + init_b = b if config["initB"] else None + init_c = c if config["initC"] else None - A = config['A'] - B = config['B'] - C = config['C'] + A = config["A"] + B = config["B"] + C = config["C"] # A is an initializer if A in added_inputs_initializers: @@ -69,9 +74,7 @@ def generate_gemm_inputs_initializers(graph, config, added_inputs_initializers={ if init_a is not None: graph.initializer.add().CopyFrom(numpy_helper.from_array(init_a, A)) else: - graph.input.add().CopyFrom(helper.make_tensor_value_info(A, - onnx.TensorProto.FLOAT, - input_shape_a)) + graph.input.add().CopyFrom(helper.make_tensor_value_info(A, onnx.TensorProto.FLOAT, input_shape_a)) # B is an initializer if B in added_inputs_initializers: @@ -81,11 +84,9 @@ def generate_gemm_inputs_initializers(graph, config, added_inputs_initializers={ if init_b is not None: graph.initializer.add().CopyFrom(numpy_helper.from_array(init_b, B)) else: - graph.input.add().CopyFrom(helper.make_tensor_value_info(B, - onnx.TensorProto.FLOAT, - input_shape_b)) + graph.input.add().CopyFrom(helper.make_tensor_value_info(B, onnx.TensorProto.FLOAT, input_shape_b)) - if config['withC']: + if config["withC"]: if C in added_inputs_initializers: c = added_inputs_initializers[C] else: @@ -93,93 +94,84 @@ def generate_gemm_inputs_initializers(graph, config, added_inputs_initializers={ if init_c is not None: graph.initializer.add().CopyFrom(numpy_helper.from_array(init_c, C)) else: - graph.input.add().CopyFrom(helper.make_tensor_value_info(C, - onnx.TensorProto.FLOAT, - input_shape_c)) + graph.input.add().CopyFrom(helper.make_tensor_value_info(C, onnx.TensorProto.FLOAT, input_shape_c)) return (a, b, c) + def generate_gemm_model(model_name, config): model = onnx.ModelProto() - model.ir_version = 7 # use stable onnx ir version + model.ir_version = 7 # use stable onnx ir version opset = model.opset_import.add() opset.version = 11 added_inputs_initializers = {} (a, b, c) = generate_gemm_inputs_initializers(model.graph, config, added_inputs_initializers) - node_inputs = [config['A'], config['B']] - if config['withC']: - node_inputs.append(config['C']) + node_inputs = [config["A"], config["B"]] + if config["withC"]: + node_inputs.append(config["C"]) attrs = {} set_gemm_node_attrs(attrs, config) - node = helper.make_node('Gemm', node_inputs, [config['Y']], config['node_name'], **attrs) + node = helper.make_node("Gemm", node_inputs, [config["Y"]], config["node_name"], **attrs) model.graph.node.add().CopyFrom(node) - shape_output = [config['M'], config['N']] - model.graph.output.add().CopyFrom(helper.make_tensor_value_info(config['Y'], - onnx.TensorProto.FLOAT, - shape_output)) + shape_output = [config["M"], config["N"]] + model.graph.output.add().CopyFrom(helper.make_tensor_value_info(config["Y"], onnx.TensorProto.FLOAT, shape_output)) # compute reference output - y = reference_gemm(a, b, c, config['alpha'], config['beta'], config['transA'], config['transB']) + y = reference_gemm(a, b, c, config["alpha"], config["beta"], config["transA"], config["transB"]) onnx.save(model, model_name) return (a, b, c, y) + def generate_gemm_node_subgraph(scan_body, scan_node_inputs, postfix, config, added_inputs): - M = config['M'] - K = config['K'] - N = config['N'] + M = config["M"] + K = config["K"] + N = config["N"] - shape_a = [K, M] if config['transA'] else [M, K] - shape_b = [N, K] if config['transB'] else [K, N] + shape_a = [K, M] if config["transA"] else [M, K] + shape_b = [N, K] if config["transB"] else [K, N] - A = config['A'] - B = config['B'] + A = config["A"] + B = config["B"] gemm_node_inputs = [] # A comes from the outer graph if it's an initializer - if config['initA']: + if config["initA"]: gemm_node_inputs.append(A) else: gemm_node_inputs.append(A + postfix) if A not in added_inputs: added_inputs[A] = 1 scan_node_inputs.append(A) - scan_body.input.add().CopyFrom(helper.make_tensor_value_info(A + postfix, - onnx.TensorProto.FLOAT, - shape_a)) + scan_body.input.add().CopyFrom(helper.make_tensor_value_info(A + postfix, onnx.TensorProto.FLOAT, shape_a)) # B comes from the outer graph if it's an initializer - if config['initB']: + if config["initB"]: gemm_node_inputs.append(B) else: gemm_node_inputs.append(B + postfix) if B not in added_inputs: added_inputs[B] = 1 scan_node_inputs.append(B) - scan_body.input.add().CopyFrom(helper.make_tensor_value_info(B + postfix, - onnx.TensorProto.FLOAT, - shape_b)) + scan_body.input.add().CopyFrom(helper.make_tensor_value_info(B + postfix, onnx.TensorProto.FLOAT, shape_b)) # C comes from Scan state - if config['withC']: - gemm_node_inputs.append('in_' + config['C'] + postfix) + if config["withC"]: + gemm_node_inputs.append("in_" + config["C"] + postfix) attrs = {} set_gemm_node_attrs(attrs, config) - node = helper.make_node('Gemm', - gemm_node_inputs, - [config['Y'] + postfix], - config['node_name'], - **attrs) + node = helper.make_node("Gemm", gemm_node_inputs, [config["Y"] + postfix], config["node_name"], **attrs) scan_body.node.add().CopyFrom(node) + def generate_gemm_scan_model(model_name, config1, config2): model = onnx.ModelProto() - model.ir_version = 7 # use stable onnx ir version + model.ir_version = 7 # use stable onnx ir version opset = model.opset_import.add() opset.version = 11 @@ -206,179 +198,201 @@ def generate_gemm_scan_model(model_name, config1, config2): # config1 and config2 configure alpha/beta/transA/transB for Gemm_1 and Gemm_2, respectively. scan_body = onnx.GraphProto() - scan_body.name = 'gemm_subgraph' + scan_body.name = "gemm_subgraph" - shape_c1 = [config1['M'], config1['N']] - shape_c2 = [config2['M'], config2['N']] + shape_c1 = [config1["M"], config1["N"]] + shape_c2 = [config2["M"], config2["N"]] assert shape_c1 == shape_c2 - C1 = config1['C'] - C2 = config2['C'] + C1 = config1["C"] + C2 = config2["C"] scan_node_inputs = [] - postfix = '_subgraph' + postfix = "_subgraph" states_cnt = 0 # make sure we create state inputs first - if config1['withC']: - assert config1['initC'] + if config1["withC"]: + assert config1["initC"] states_cnt = states_cnt + 1 scan_node_inputs.append(C1) - scan_body.input.add().CopyFrom(helper.make_tensor_value_info('in_' + C1 + postfix, - onnx.TensorProto.FLOAT, - shape_c1)) - if config2['withC'] and C1 != C2: - assert config2['initC'] + scan_body.input.add().CopyFrom( + helper.make_tensor_value_info("in_" + C1 + postfix, onnx.TensorProto.FLOAT, shape_c1) + ) + if config2["withC"] and C1 != C2: + assert config2["initC"] states_cnt = states_cnt + 1 scan_node_inputs.append(C2) - scan_body.input.add().CopyFrom(helper.make_tensor_value_info('in_' + C2 + postfix, - onnx.TensorProto.FLOAT, - shape_c2)) + scan_body.input.add().CopyFrom( + helper.make_tensor_value_info("in_" + C2 + postfix, onnx.TensorProto.FLOAT, shape_c2) + ) added_inputs_subgraph = {} - generate_gemm_node_subgraph(scan_body, - scan_node_inputs, - postfix, - config1, - added_inputs_subgraph) - generate_gemm_node_subgraph(scan_body, - scan_node_inputs, - postfix, - config2, - added_inputs_subgraph) - - sub_output = 'sub_output' + postfix + generate_gemm_node_subgraph(scan_body, scan_node_inputs, postfix, config1, added_inputs_subgraph) + generate_gemm_node_subgraph(scan_body, scan_node_inputs, postfix, config2, added_inputs_subgraph) + + sub_output = "sub_output" + postfix # create a Sub op instead of Add to break the MatMul-to-Gemm rewriting rule # performed by the ort optimizer - sub_node = helper.make_node('Sub', - [config1['Y'] + postfix, config2['Y'] + postfix], - [sub_output], - 'sub_node') + sub_node = helper.make_node( + "Sub", + [config1["Y"] + postfix, config2["Y"] + postfix], + [sub_output], + "sub_node", + ) scan_body.node.add().CopyFrom(sub_node) scan_node_outputs = [] # create state outputs - if config1['withC']: - id_node1 = onnx.helper.make_node('Identity', - [sub_output], - ['out_' + C1 + postfix], - 'id_node1') + if config1["withC"]: + id_node1 = onnx.helper.make_node("Identity", [sub_output], ["out_" + C1 + postfix], "id_node1") scan_body.node.add().CopyFrom(id_node1) - scan_body.output.add().CopyFrom(helper.make_tensor_value_info('out_' + C1 + postfix, - onnx.TensorProto.FLOAT, - shape_c1)) - scan_node_outputs.append('out_' + C1) - - if config2['withC'] and C1 != C2: - id_node2 = onnx.helper.make_node('Identity', - [sub_output], - ['out_' + C2 + postfix], - 'id_node2') + scan_body.output.add().CopyFrom( + helper.make_tensor_value_info("out_" + C1 + postfix, onnx.TensorProto.FLOAT, shape_c1) + ) + scan_node_outputs.append("out_" + C1) + + if config2["withC"] and C1 != C2: + id_node2 = onnx.helper.make_node("Identity", [sub_output], ["out_" + C2 + postfix], "id_node2") scan_body.node.add().CopyFrom(id_node2) - scan_body.output.add().CopyFrom(helper.make_tensor_value_info('out_' + C2 + postfix, - onnx.TensorProto.FLOAT, - shape_c2)) - scan_node_outputs.append('out_' + C2) + scan_body.output.add().CopyFrom( + helper.make_tensor_value_info("out_" + C2 + postfix, onnx.TensorProto.FLOAT, shape_c2) + ) + scan_node_outputs.append("out_" + C2) # scan subgraph output - scan_body.output.add().CopyFrom(helper.make_tensor_value_info(sub_output, - onnx.TensorProto.FLOAT, - shape_c1)) - scan_node_outputs.append('scan_output') + scan_body.output.add().CopyFrom(helper.make_tensor_value_info(sub_output, onnx.TensorProto.FLOAT, shape_c1)) + scan_node_outputs.append("scan_output") # create scan node inputs_cnt = len(scan_node_inputs) - states_cnt assert inputs_cnt > 0 - scan_node = onnx.helper.make_node('Scan', - scan_node_inputs, - scan_node_outputs, - 'scan_node', - num_scan_inputs=inputs_cnt, - body=scan_body) + scan_node = onnx.helper.make_node( + "Scan", + scan_node_inputs, + scan_node_outputs, + "scan_node", + num_scan_inputs=inputs_cnt, + body=scan_body, + ) model.graph.node.add().CopyFrom(scan_node) added_inputs_initializers = {} # main graph inputs and initializers - (a1, b1, c1) = generate_gemm_inputs_initializers(model.graph, - config1, - added_inputs_initializers, - extend=True) - (a2, b2, c2) = generate_gemm_inputs_initializers(model.graph, - config2, - added_inputs_initializers, - extend=True) - - shape_output = ['seq', config1['M'], config1['N']] + (a1, b1, c1) = generate_gemm_inputs_initializers(model.graph, config1, added_inputs_initializers, extend=True) + (a2, b2, c2) = generate_gemm_inputs_initializers(model.graph, config2, added_inputs_initializers, extend=True) + + shape_output = ["seq", config1["M"], config1["N"]] # main graph outputs - model.graph.output.add().CopyFrom(helper.make_tensor_value_info('scan_output', - onnx.TensorProto.FLOAT, - shape_output)) + model.graph.output.add().CopyFrom( + helper.make_tensor_value_info("scan_output", onnx.TensorProto.FLOAT, shape_output) + ) onnx.save(model, model_name) return (a1, b1, c1, a2, b2, c2) + def set_gemm_model_inputs(config, test_inputs, a, b, c): - if not config['initA']: - test_inputs[config['A']] = a - if not config['initB']: - test_inputs[config['B']] = b - if config['withC'] and not config['initC']: - test_inputs[config['C']] = c + if not config["initA"]: + test_inputs[config["A"]] = a + if not config["initB"]: + test_inputs[config["B"]] = b + if config["withC"] and not config["initC"]: + test_inputs[config["C"]] = c def make_providers(nuphar_settings): return [ - ('NupharExecutionProvider', { - 'nuphar_settings': nuphar_settings - }), - 'CPUExecutionProvider', + ("NupharExecutionProvider", {"nuphar_settings": nuphar_settings}), + "CPUExecutionProvider", ] class TestNuphar(unittest.TestCase): - def test_bidaf(self): cwd = os.getcwd() - bidaf_dir_src = '../models/opset9/test_bidaf' + bidaf_dir_src = "../models/opset9/test_bidaf" - bidaf_dir = os.path.join(cwd, 'bidaf') + bidaf_dir = os.path.join(cwd, "bidaf") if not os.path.exists(bidaf_dir): shutil.copytree(bidaf_dir_src, bidaf_dir) - bidaf_dir = os.path.join(cwd, 'bidaf') - bidaf_model = os.path.join(bidaf_dir, 'model.onnx') + bidaf_dir = os.path.join(cwd, "bidaf") + bidaf_model = os.path.join(bidaf_dir, "model.onnx") run_shape_inference(bidaf_model, bidaf_model) - bidaf_scan_model = os.path.join(bidaf_dir, 'bidaf_scan.onnx') - bidaf_opt_scan_model = os.path.join(bidaf_dir, 'bidaf_opt_scan.onnx') - bidaf_int8_scan_only_model = os.path.join(bidaf_dir, 'bidaf_int8_scan_only.onnx') - subprocess.run([ - sys.executable, '-m', 'onnxruntime.nuphar.model_editor', '--input', bidaf_model, '--output', - bidaf_scan_model, '--mode', 'to_scan' - ], - check=True, - cwd=cwd) - subprocess.run([ - sys.executable, '-m', 'onnxruntime.nuphar.model_editor', '--input', bidaf_scan_model, '--output', - bidaf_opt_scan_model, '--mode', 'opt_inproj' - ], - check=True, - cwd=cwd) - subprocess.run([ - sys.executable, '-m', 'onnxruntime.nuphar.model_quantizer', '--input', bidaf_opt_scan_model, '--output', - bidaf_int8_scan_only_model, '--only_for_scan' - ], - check=True, - cwd=cwd) + bidaf_scan_model = os.path.join(bidaf_dir, "bidaf_scan.onnx") + bidaf_opt_scan_model = os.path.join(bidaf_dir, "bidaf_opt_scan.onnx") + bidaf_int8_scan_only_model = os.path.join(bidaf_dir, "bidaf_int8_scan_only.onnx") + subprocess.run( + [ + sys.executable, + "-m", + "onnxruntime.nuphar.model_editor", + "--input", + bidaf_model, + "--output", + bidaf_scan_model, + "--mode", + "to_scan", + ], + check=True, + cwd=cwd, + ) + subprocess.run( + [ + sys.executable, + "-m", + "onnxruntime.nuphar.model_editor", + "--input", + bidaf_scan_model, + "--output", + bidaf_opt_scan_model, + "--mode", + "opt_inproj", + ], + check=True, + cwd=cwd, + ) + subprocess.run( + [ + sys.executable, + "-m", + "onnxruntime.nuphar.model_quantizer", + "--input", + bidaf_opt_scan_model, + "--output", + bidaf_int8_scan_only_model, + "--only_for_scan", + ], + check=True, + cwd=cwd, + ) # run onnx_test_runner to verify results # use -M to disable memory pattern - onnx_test_runner = os.path.join(cwd, 'onnx_test_runner') - subprocess.run([onnx_test_runner, '-e', 'nuphar', '-M', '-c', '1', '-j', '1', '-n', 'bidaf', cwd], check=True, cwd=cwd) + onnx_test_runner = os.path.join(cwd, "onnx_test_runner") + subprocess.run( + [ + onnx_test_runner, + "-e", + "nuphar", + "-M", + "-c", + "1", + "-j", + "1", + "-n", + "bidaf", + cwd, + ], + check=True, + cwd=cwd, + ) # test AOT on the quantized model - if os.name not in ['nt', 'posix']: + if os.name not in ["nt", "posix"]: return # don't run the rest of test if AOT is not supported - cache_dir = os.path.join(cwd, 'nuphar_cache') + cache_dir = os.path.join(cwd, "nuphar_cache") if os.path.exists(cache_dir): shutil.rmtree(cache_dir) os.makedirs(cache_dir) @@ -386,34 +400,45 @@ def test_bidaf(self): # prepare feed feed = {} for i in range(4): - tp = onnx.load_tensor(os.path.join(bidaf_dir, 'test_data_set_0', 'input_{}.pb'.format(i))) + tp = onnx.load_tensor(os.path.join(bidaf_dir, "test_data_set_0", "input_{}.pb".format(i))) feed[tp.name] = numpy_helper.to_array(tp) for model in [bidaf_opt_scan_model, bidaf_int8_scan_only_model]: - nuphar_settings = 'nuphar_cache_path:{}'.format(cache_dir) - for isa in ['avx', 'avx2', 'avx512']: + nuphar_settings = "nuphar_cache_path:{}".format(cache_dir) + for isa in ["avx", "avx2", "avx512"]: # JIT cache happens when initializing session sess = onnxrt.InferenceSession( - model, providers=make_providers(nuphar_settings + ', nuphar_codegen_target:' + isa)) + model, + providers=make_providers(nuphar_settings + ", nuphar_codegen_target:" + isa), + ) cache_dir_content = os.listdir(cache_dir) assert len(cache_dir_content) == 1 cache_versioned_dir = os.path.join(cache_dir, cache_dir_content[0]) - so_name = os.path.basename(model) + '.so' - subprocess.run([ - sys.executable, '-m', 'onnxruntime.nuphar.create_shared', '--input_dir', cache_versioned_dir, - '--output_name', so_name - ], - check=True) - - nuphar_settings = 'nuphar_cache_path:{}, nuphar_cache_so_name:{}, nuphar_cache_force_no_jit:{}'.format( - cache_dir, so_name, 'on') + so_name = os.path.basename(model) + ".so" + subprocess.run( + [ + sys.executable, + "-m", + "onnxruntime.nuphar.create_shared", + "--input_dir", + cache_versioned_dir, + "--output_name", + so_name, + ], + check=True, + ) + + nuphar_settings = "nuphar_cache_path:{}, nuphar_cache_so_name:{}, nuphar_cache_force_no_jit:{}".format( + cache_dir, so_name, "on" + ) sess = onnxrt.InferenceSession(model, providers=make_providers(nuphar_settings)) sess.run([], feed) # test avx - nuphar_settings = 'nuphar_cache_path:{}, nuphar_cache_so_name:{}, nuphar_cache_force_no_jit:{}, nuphar_codegen_target:{}'.format( - cache_dir, so_name, 'on', 'avx') + nuphar_settings = "nuphar_cache_path:{}, nuphar_cache_so_name:{}, nuphar_cache_force_no_jit:{}, nuphar_codegen_target:{}".format( + cache_dir, so_name, "on", "avx" + ) sess = onnxrt.InferenceSession(model, providers=make_providers(nuphar_settings)) sess.run([], feed) @@ -422,59 +447,90 @@ def test_bert_squad(self): # run symbolic shape inference on this model # set int_max to 1,000,000 to simplify symbol computes for things like min(1000000, seq_len) -> seq_len - bert_squad_dir_src = '../models/opset10/BERT_Squad' - bert_squad_dir = os.path.join(cwd, 'BERT_Squad') + bert_squad_dir_src = "../models/opset10/BERT_Squad" + bert_squad_dir = os.path.join(cwd, "BERT_Squad") if not os.path.exists(bert_squad_dir): shutil.copytree(bert_squad_dir_src, bert_squad_dir) - bert_squad_model = os.path.join(bert_squad_dir, 'bertsquad10.onnx') - subprocess.run([ - sys.executable, '-m', 'onnxruntime.tools.symbolic_shape_infer', '--input', bert_squad_model, '--output', - bert_squad_model, '--auto_merge', '--int_max=1000000' - ], - check=True, - cwd=cwd) + bert_squad_model = os.path.join(bert_squad_dir, "bertsquad10.onnx") + subprocess.run( + [ + sys.executable, + "-m", + "onnxruntime.tools.symbolic_shape_infer", + "--input", + bert_squad_model, + "--output", + bert_squad_model, + "--auto_merge", + "--int_max=1000000", + ], + check=True, + cwd=cwd, + ) # run onnx_test_runner to verify results - onnx_test_runner = os.path.join(cwd, 'onnx_test_runner') - subprocess.run([onnx_test_runner, '-e', 'nuphar', '-n', 'BERT_Squad', cwd], check=True, cwd=cwd) + onnx_test_runner = os.path.join(cwd, "onnx_test_runner") + subprocess.run( + [onnx_test_runner, "-e", "nuphar", "-n", "BERT_Squad", cwd], + check=True, + cwd=cwd, + ) # run onnxruntime_perf_test, note that nuphar currently is not integrated with ORT thread pool, so set -x 1 to avoid thread confliction with OpenMP - onnxruntime_perf_test = os.path.join(cwd, 'onnxruntime_perf_test') - subprocess.run([onnxruntime_perf_test, '-e', 'nuphar', '-x', '1', '-t', '20', bert_squad_model, '1.txt'], - check=True, - cwd=cwd) + onnxruntime_perf_test = os.path.join(cwd, "onnxruntime_perf_test") + subprocess.run( + [ + onnxruntime_perf_test, + "-e", + "nuphar", + "-x", + "1", + "-t", + "20", + bert_squad_model, + "1.txt", + ], + check=True, + cwd=cwd, + ) def test_rnn_benchmark(self): # make sure benchmarking scripts works # note: quantized model requires AVX2, otherwise it might be slow - avg_rnn, avg_scan, avg_int8 = perf_test('lstm', - num_threads=1, - input_dim=128, - hidden_dim=1024, - bidirectional=True, - layers=1, - seq_len=16, - batch_size=1, - min_duration_seconds=1) - avg_rnn, avg_scan, avg_int8 = perf_test('gru', - num_threads=1, - input_dim=128, - hidden_dim=1024, - bidirectional=False, - layers=2, - seq_len=16, - batch_size=3, - min_duration_seconds=1) - avg_rnn, avg_scan, avg_int8 = perf_test('rnn', - num_threads=1, - input_dim=128, - hidden_dim=1024, - bidirectional=False, - layers=3, - seq_len=16, - batch_size=2, - min_duration_seconds=1) + avg_rnn, avg_scan, avg_int8 = perf_test( + "lstm", + num_threads=1, + input_dim=128, + hidden_dim=1024, + bidirectional=True, + layers=1, + seq_len=16, + batch_size=1, + min_duration_seconds=1, + ) + avg_rnn, avg_scan, avg_int8 = perf_test( + "gru", + num_threads=1, + input_dim=128, + hidden_dim=1024, + bidirectional=False, + layers=2, + seq_len=16, + batch_size=3, + min_duration_seconds=1, + ) + avg_rnn, avg_scan, avg_int8 = perf_test( + "rnn", + num_threads=1, + input_dim=128, + hidden_dim=1024, + bidirectional=False, + layers=3, + seq_len=16, + batch_size=2, + min_duration_seconds=1, + ) def test_batch_scan(self): input_dim = 3 @@ -482,18 +538,20 @@ def test_batch_scan(self): bidirectional = False layers = 3 - for onnx_opset_ver in [7,13]: - lstm_model_name = 'test_batch_rnn_lstm.onnx' + for onnx_opset_ver in [7, 13]: + lstm_model_name = "test_batch_rnn_lstm.onnx" # create an LSTM model for generating baseline data - generate_model('lstm', - input_dim, - hidden_dim, - bidirectional, - layers, - lstm_model_name, - batch_one=False, - has_seq_len=True, - onnx_opset_ver=onnx_opset_ver) + generate_model( + "lstm", + input_dim, + hidden_dim, + bidirectional, + layers, + lstm_model_name, + batch_one=False, + has_seq_len=True, + onnx_opset_ver=onnx_opset_ver, + ) seq_len = 8 batch_size = 2 @@ -503,81 +561,129 @@ def test_batch_scan(self): # run lstm as baseline sess = onnxrt.InferenceSession(lstm_model_name, providers=onnxrt.get_available_providers()) - first_lstm_data_output = sess.run([], {'input': data_input[:, 0:1, :], 'seq_len': data_seq_len[0:1]}) + first_lstm_data_output = sess.run([], {"input": data_input[:, 0:1, :], "seq_len": data_seq_len[0:1]}) lstm_data_output = [] lstm_data_output = first_lstm_data_output for b in range(1, batch_size): - lstm_data_output = lstm_data_output + sess.run([], { - 'input': data_input[:, b:(b + 1), :], - 'seq_len': data_seq_len[b:(b + 1)] - }) + lstm_data_output = lstm_data_output + sess.run( + [], + { + "input": data_input[:, b : (b + 1), :], + "seq_len": data_seq_len[b : (b + 1)], + }, + ) lstm_data_output = np.concatenate(lstm_data_output, axis=1) # generate a batch scan model - scan_model_name = 'test_batch_rnn_scan.onnx' - subprocess.run([ - sys.executable, '-m', 'onnxruntime.nuphar.model_editor', '--input', lstm_model_name, '--output', - scan_model_name, '--mode', 'to_scan' - ], - check=True) + scan_model_name = "test_batch_rnn_scan.onnx" + subprocess.run( + [ + sys.executable, + "-m", + "onnxruntime.nuphar.model_editor", + "--input", + lstm_model_name, + "--output", + scan_model_name, + "--mode", + "to_scan", + ], + check=True, + ) # run scan_batch with batch size 1 sess = onnxrt.InferenceSession(scan_model_name, providers=onnxrt.get_available_providers()) - scan_batch_data_output = sess.run([], {'input': data_input[:, 0:1, :], 'seq_len': data_seq_len[0:1]}) + scan_batch_data_output = sess.run([], {"input": data_input[:, 0:1, :], "seq_len": data_seq_len[0:1]}) assert np.allclose(first_lstm_data_output, scan_batch_data_output) # run scan_batch with batch size 2 - scan_batch_data_output = sess.run([], {'input': data_input, 'seq_len': data_seq_len}) + scan_batch_data_output = sess.run([], {"input": data_input, "seq_len": data_seq_len}) assert np.allclose(lstm_data_output, scan_batch_data_output) # run scan_batch with batch size 1 again - scan_batch_data_output = sess.run([], {'input': data_input[:, 0:1, :], 'seq_len': data_seq_len[0:1]}) + scan_batch_data_output = sess.run([], {"input": data_input[:, 0:1, :], "seq_len": data_seq_len[0:1]}) assert np.allclose(first_lstm_data_output, scan_batch_data_output) def test_gemm_to_matmul(self): gemm_model_name_prefix = "gemm_model" matmul_model_name_prefix = "matmul_model" common_config = { - 'node_name':'GemmNode', - 'A':'inputA', 'B':'inputB', 'C':'inputC', 'Y':'output', 'M':2, 'K':3, 'N':4, 'withC':False, - 'initA':False, 'initB':False, 'initC':False, 'alpha':1.0, 'beta':1.0, 'transA':0, 'transB':0 + "node_name": "GemmNode", + "A": "inputA", + "B": "inputB", + "C": "inputC", + "Y": "output", + "M": 2, + "K": 3, + "N": 4, + "withC": False, + "initA": False, + "initB": False, + "initC": False, + "alpha": 1.0, + "beta": 1.0, + "transA": 0, + "transB": 0, } test_configs = [ {}, - {'transA':1}, - {'transB':1}, - {'transA':1, 'transB':1}, - {'withC':True}, - {'withC':True, 'initC':True}, - {'initA':True}, - {'initB':True}, - {'initA':True, 'initB':True}, - {'initA':True, 'transA':1}, - {'initB':True, 'transB':1}, - {'initA':True, 'transA':1, 'initB':True, 'transB':1}, - {'alpha':2.2}, - {'transA':1, 'alpha':2.2}, - {'initA':True, 'transA':1, 'alpha':2.2}, - {'withC':True, 'beta':3.3}, - {'withC':True, 'initC':True, 'beta':3.3}, - {'initA':True, 'transA':1, 'alpha':2.2, 'withC':True, 'initC':True, 'beta':3.3}, - {'transA':1, 'transB':1, 'alpha':2.2, 'withC':True, 'beta':3.3}, - {'transA':1, 'transB':1, 'alpha':2.2, 'withC':True, 'initC':True, 'beta':3.3} + {"transA": 1}, + {"transB": 1}, + {"transA": 1, "transB": 1}, + {"withC": True}, + {"withC": True, "initC": True}, + {"initA": True}, + {"initB": True}, + {"initA": True, "initB": True}, + {"initA": True, "transA": 1}, + {"initB": True, "transB": 1}, + {"initA": True, "transA": 1, "initB": True, "transB": 1}, + {"alpha": 2.2}, + {"transA": 1, "alpha": 2.2}, + {"initA": True, "transA": 1, "alpha": 2.2}, + {"withC": True, "beta": 3.3}, + {"withC": True, "initC": True, "beta": 3.3}, + { + "initA": True, + "transA": 1, + "alpha": 2.2, + "withC": True, + "initC": True, + "beta": 3.3, + }, + {"transA": 1, "transB": 1, "alpha": 2.2, "withC": True, "beta": 3.3}, + { + "transA": 1, + "transB": 1, + "alpha": 2.2, + "withC": True, + "initC": True, + "beta": 3.3, + }, ] for i, config in enumerate(test_configs): running_config = common_config.copy() running_config.update(config) - gemm_model_name = gemm_model_name_prefix+str(i)+'.onnx' - matmul_model_name = matmul_model_name_prefix+str(i)+'.onnx' + gemm_model_name = gemm_model_name_prefix + str(i) + ".onnx" + matmul_model_name = matmul_model_name_prefix + str(i) + ".onnx" a, b, c, expected_y = generate_gemm_model(gemm_model_name, running_config) - subprocess.run([ - sys.executable, '-m', 'onnxruntime.nuphar.model_editor', - '--input', gemm_model_name, - '--output', matmul_model_name, '--mode', 'gemm_to_matmul' - ], check=True) + subprocess.run( + [ + sys.executable, + "-m", + "onnxruntime.nuphar.model_editor", + "--input", + gemm_model_name, + "--output", + matmul_model_name, + "--mode", + "gemm_to_matmul", + ], + check=True, + ) sess = onnxrt.InferenceSession(matmul_model_name, providers=onnxrt.get_available_providers()) test_inputs = {} @@ -590,48 +696,132 @@ def test_gemm_to_matmul_with_scan(self): matmul_model_name_prefix = "matmul_scan_model" common_config = { - 'M':2, 'K':3, 'N':4, 'withC':False, 'initA':False, 'initB':False, 'initC':False, - 'alpha':1.0, 'beta':1.0, 'transA':0, 'transB':0 + "M": 2, + "K": 3, + "N": 4, + "withC": False, + "initA": False, + "initB": False, + "initC": False, + "alpha": 1.0, + "beta": 1.0, + "transA": 0, + "transB": 0, } common_config1 = common_config.copy() - common_config1.update({ - 'node_name':'GemmNode1', 'A':'input1A', 'B':'input1B', 'C':'input1C', 'Y':'output1' - }) + common_config1.update( + { + "node_name": "GemmNode1", + "A": "input1A", + "B": "input1B", + "C": "input1C", + "Y": "output1", + } + ) common_config2 = common_config.copy() - common_config2.update({ - 'node_name':'GemmNode2', 'A':'input2A', 'B':'input2B', 'C':'input2C', 'Y':'output2' - }) + common_config2.update( + { + "node_name": "GemmNode2", + "A": "input2A", + "B": "input2B", + "C": "input2C", + "Y": "output2", + } + ) test_configs = [ ({}, {}), - ({'transA':1}, {'transB':1}), - ({'transA':1, 'transB':1}, {'transA':1, 'transB':1}), - ({'alpha':2.2}, {'alpha':3.3}), - ({'transA':1, 'transB':1, 'alpha':2.2}, {'transA':1, 'transB':1, 'alpha':3.3}), - ({'withC':True, 'initC':True}, {}), - ({'withC':True, 'initC':True}, {'withC':True, 'initC':True}), - ({'transA':1, 'transB':1, 'alpha':2.2, 'withC':True, 'initC':True, 'beta':1.2}, - {'transA':1, 'transB':1, 'alpha':3.3, 'withC':True, 'initC':True, 'beta':4.1}), - ({'initA':True}, {}), - ({'initA':True}, {'initB':True}), + ({"transA": 1}, {"transB": 1}), + ({"transA": 1, "transB": 1}, {"transA": 1, "transB": 1}), + ({"alpha": 2.2}, {"alpha": 3.3}), + ( + {"transA": 1, "transB": 1, "alpha": 2.2}, + {"transA": 1, "transB": 1, "alpha": 3.3}, + ), + ({"withC": True, "initC": True}, {}), + ({"withC": True, "initC": True}, {"withC": True, "initC": True}), + ( + { + "transA": 1, + "transB": 1, + "alpha": 2.2, + "withC": True, + "initC": True, + "beta": 1.2, + }, + { + "transA": 1, + "transB": 1, + "alpha": 3.3, + "withC": True, + "initC": True, + "beta": 4.1, + }, + ), + ({"initA": True}, {}), + ({"initA": True}, {"initB": True}), # FIXME: enable the test below after we fix some likely issue in graph partitioner - #({'initA':True, 'initB':True}, {}), - #({'initA':True, 'initB':True, 'transA':1}, {'initA':True, 'transB':1}), - #({'initA':True, 'transA':1, 'transB':1, 'alpha':2.2}, + # ({'initA':True, 'initB':True}, {}), + # ({'initA':True, 'initB':True, 'transA':1}, {'initA':True, 'transB':1}), + # ({'initA':True, 'transA':1, 'transB':1, 'alpha':2.2}, # {'initB':True, 'transA':1, 'transB':1, 'alpha':3.3}), - #({'initA':True, 'transA':1, 'transB':1, 'alpha':2.2, 'withC':True, 'initC':True}, + # ({'initA':True, 'transA':1, 'transB':1, 'alpha':2.2, 'withC':True, 'initC':True}, # {'initB':True, 'transA':1, 'transB':1, 'alpha':3.3}), - #({'initA':True, 'transA':1, 'transB':1, 'alpha':2.2, 'withC':True, 'initC':True, 'beta':1.2}, + # ({'initA':True, 'transA':1, 'transB':1, 'alpha':2.2, 'withC':True, 'initC':True, 'beta':1.2}, # {'initB':True, 'transA':1, 'transB':1, 'alpha':3.3, 'withC':True, 'initC':True, 'beta':4.2}), - ({'A':'inputA', 'initA':True}, {'A':'inputA', 'initA':True}), - ({'B':'inputB', 'initB':True}, {'B':'inputB', 'initB':True}), - ({'C':'inputC', 'withC':True, 'initC':True}, {'C':'inputC', 'withC':True, 'initC':True}), - ({'transA':1, 'alpha':1.2, 'B':'inputB', 'initB':True, 'C':'inputC', 'withC':True, 'initC':True}, - {'transA':1, 'alpha':2.2, 'B':'inputB', 'initB':True, 'C':'inputC', 'withC':True, 'initC':True}), - ({'transB':1, 'alpha':1.2, 'B':'inputB'}, {'transB':1, 'alpha':2.2, 'B':'inputB'}), - ({'transA':1, 'alpha':1.2, 'A':'inputA', 'B':'inputB'}, - {'transA':1, 'alpha':2.2, 'A':'inputA', 'B':'inputB'}), - ({'transA':1, 'alpha':1.2, 'A':'inputA', 'B':'inputB', 'C':'inputC1', 'withC':True, 'initC':True}, - {'transA':1, 'alpha':2.2, 'A':'inputA', 'B':'inputB', 'C':'inputC2', 'withC':True, 'initC':True}), + ({"A": "inputA", "initA": True}, {"A": "inputA", "initA": True}), + ({"B": "inputB", "initB": True}, {"B": "inputB", "initB": True}), + ( + {"C": "inputC", "withC": True, "initC": True}, + {"C": "inputC", "withC": True, "initC": True}, + ), + ( + { + "transA": 1, + "alpha": 1.2, + "B": "inputB", + "initB": True, + "C": "inputC", + "withC": True, + "initC": True, + }, + { + "transA": 1, + "alpha": 2.2, + "B": "inputB", + "initB": True, + "C": "inputC", + "withC": True, + "initC": True, + }, + ), + ( + {"transB": 1, "alpha": 1.2, "B": "inputB"}, + {"transB": 1, "alpha": 2.2, "B": "inputB"}, + ), + ( + {"transA": 1, "alpha": 1.2, "A": "inputA", "B": "inputB"}, + {"transA": 1, "alpha": 2.2, "A": "inputA", "B": "inputB"}, + ), + ( + { + "transA": 1, + "alpha": 1.2, + "A": "inputA", + "B": "inputB", + "C": "inputC1", + "withC": True, + "initC": True, + }, + { + "transA": 1, + "alpha": 2.2, + "A": "inputA", + "B": "inputB", + "C": "inputC2", + "withC": True, + "initC": True, + }, + ), ] for i, config in enumerate(test_configs): @@ -641,16 +831,14 @@ def test_gemm_to_matmul_with_scan(self): running_config2 = common_config2.copy() running_config2.update(config2) - gemm_model_name = gemm_model_name_prefix+str(i)+'.onnx' - matmul_model_name = matmul_model_name_prefix+str(i)+'.onnx' - a1, b1, c1, a2, b2, c2 = generate_gemm_scan_model(gemm_model_name, - running_config1, - running_config2) + gemm_model_name = gemm_model_name_prefix + str(i) + ".onnx" + matmul_model_name = matmul_model_name_prefix + str(i) + ".onnx" + a1, b1, c1, a2, b2, c2 = generate_gemm_scan_model(gemm_model_name, running_config1, running_config2) - a1 = a1.reshape((1, ) + a1.shape) - b1 = b1.reshape((1, ) + b1.shape) - a2 = a2.reshape((1, ) + a2.shape) - b2 = b2.reshape((1, ) + b2.shape) + a1 = a1.reshape((1,) + a1.shape) + b1 = b1.reshape((1,) + b1.shape) + a2 = a2.reshape((1,) + a2.shape) + b2 = b2.reshape((1,) + b2.shape) sess = onnxrt.InferenceSession(gemm_model_name, providers=onnxrt.get_available_providers()) test_inputs = {} @@ -660,11 +848,20 @@ def test_gemm_to_matmul_with_scan(self): # run before model editing expected_y = sess.run([], test_inputs) - subprocess.run([ - sys.executable, '-m', 'onnxruntime.nuphar.model_editor', - '--input', gemm_model_name, - '--output', matmul_model_name, '--mode', 'gemm_to_matmul' - ], check=True) + subprocess.run( + [ + sys.executable, + "-m", + "onnxruntime.nuphar.model_editor", + "--input", + gemm_model_name, + "--output", + matmul_model_name, + "--mode", + "gemm_to_matmul", + ], + check=True, + ) # run after model editing sess = onnxrt.InferenceSession(matmul_model_name, providers=onnxrt.get_available_providers()) @@ -676,11 +873,20 @@ def test_gemm_to_matmul_with_scan(self): def test_loop_to_scan(self): loop_model_filename = get_name("nuphar_tiny_model_with_loop_shape_infered.onnx") scan_model_filename = "nuphar_tiny_model_with_loop_shape_infered_converted_to_scan.onnx" - subprocess.run([ - sys.executable, '-m', 'onnxruntime.nuphar.model_editor', - '--input', loop_model_filename, - '--output', scan_model_filename, '--mode', 'loop_to_scan' - ], check=True) + subprocess.run( + [ + sys.executable, + "-m", + "onnxruntime.nuphar.model_editor", + "--input", + loop_model_filename, + "--output", + scan_model_filename, + "--mode", + "loop_to_scan", + ], + check=True, + ) validate_with_ort(loop_model_filename, scan_model_filename) @@ -690,12 +896,21 @@ def test_loop_to_scan_with_inconvertible_loop(self): # Set --keep_unconvertible_loop_ops option so conversion will not fail due to unconvertible loop ops. loop_model_filename = get_name("nuphar_onnx_test_loop11_inconvertible_loop.onnx") scan_model_filename = "nuphar_onnx_test_loop11_inconvertible_loop_unchanged.onnx" - subprocess.run([ - sys.executable, '-m', 'onnxruntime.nuphar.model_editor', - '--input', loop_model_filename, - '--output', scan_model_filename, '--mode', 'loop_to_scan', - '--keep_unconvertible_loop_ops' - ], check=True) + subprocess.run( + [ + sys.executable, + "-m", + "onnxruntime.nuphar.model_editor", + "--input", + loop_model_filename, + "--output", + scan_model_filename, + "--mode", + "loop_to_scan", + "--keep_unconvertible_loop_ops", + ], + check=True, + ) # onnxruntime is failing with: # onnxruntime.capi.onnxruntime_pybind11_state.Fail: [ONNXRuntimeError] : 1 : @@ -707,15 +922,25 @@ def test_loop_to_scan_with_inconvertible_loop(self): def test_loop_to_scan_tool(self): loop_model_filename = get_name("nuphar_tiny_model_with_loop_shape_infered.onnx") scan_model_filename = "nuphar_tiny_model_with_loop_shape_infered_converted_to_scan.onnx" - subprocess.run([ - sys.executable, '-m', 'onnxruntime.nuphar.model_tools', - '--input', loop_model_filename, - '--output', scan_model_filename, - '--tool', 'convert_loop_to_scan_and_validate', - '--symbolic_dims', 'sequence=30' - ], check=True) + subprocess.run( + [ + sys.executable, + "-m", + "onnxruntime.nuphar.model_tools", + "--input", + loop_model_filename, + "--output", + scan_model_filename, + "--tool", + "convert_loop_to_scan_and_validate", + "--symbolic_dims", + "sequence=30", + ], + check=True, + ) validate_with_ort(loop_model_filename, scan_model_filename) -if __name__ == '__main__': + +if __name__ == "__main__": unittest.main() diff --git a/onnxruntime/test/python/onnxruntime_test_python_sparse_matmul.py b/onnxruntime/test/python/onnxruntime_test_python_sparse_matmul.py index 95d288b11862c..858672d4f08df 100644 --- a/onnxruntime/test/python/onnxruntime_test_python_sparse_matmul.py +++ b/onnxruntime/test/python/onnxruntime_test_python_sparse_matmul.py @@ -1,31 +1,36 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. +import gc +import os +import sys +import threading + # -*- coding: UTF-8 -*- import unittest -import os + import numpy as np -import gc +from helper import get_name import onnxruntime as onnxrt -import threading -import sys -from helper import get_name from onnxruntime.capi.onnxruntime_pybind11_state import Fail + class TestSparseToDenseMatmul(unittest.TestCase): def testRunSparseOutputOnly(self): - ''' + """ Try running models using the new run_with_ort_values sparse_initializer_as_output.onnx - requires no inputs, but only one output that comes from the initializer - ''' + """ # The below values are a part of the model - dense_shape = [3,3] + dense_shape = [3, 3] values = np.array([1.764052391052246, 0.40015721321105957, 0.978738009929657], np.float) indices = np.array([2, 3, 5], np.int64) - sess = onnxrt.InferenceSession(get_name("sparse_initializer_as_output.onnx"), - providers=onnxrt.get_available_providers()) + sess = onnxrt.InferenceSession( + get_name("sparse_initializer_as_output.onnx"), + providers=onnxrt.get_available_providers(), + ) res = sess.run_with_ort_values(["values"], {}) self.assertEqual(len(res), 1) ort_value = res[0] @@ -37,56 +42,365 @@ def testRunSparseOutputOnly(self): self.assertTrue(np.array_equal(indices, sparse_output.as_coo_rep().indices())) def testRunContribSparseMatMul(self): - ''' + """ Mutliple sparse COO tensor to dense - ''' - common_shape = [9,9] # inputs and oputputs same shape - A_values = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, - 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, 16.0, 17.0, - 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0, - 26.0, 27.0, 28.0, 29.0, 30.0, 31.0, 32.0, 33.0, - 34.0, 35.0, 36.0, 37.0, 38.0, 39.0, 40.0, 41.0, - 42.0, 43.0, 44.0, 45.0, 46.0, 47.0, 48.0, 49.0, - 50.0, 51.0, 52.0, 53.0], np.float32) + """ + common_shape = [9, 9] # inputs and oputputs same shape + A_values = np.array( + [ + 1.0, + 2.0, + 3.0, + 4.0, + 5.0, + 6.0, + 7.0, + 8.0, + 9.0, + 10.0, + 11.0, + 12.0, + 13.0, + 14.0, + 15.0, + 16.0, + 17.0, + 18.0, + 19.0, + 20.0, + 21.0, + 22.0, + 23.0, + 24.0, + 25.0, + 26.0, + 27.0, + 28.0, + 29.0, + 30.0, + 31.0, + 32.0, + 33.0, + 34.0, + 35.0, + 36.0, + 37.0, + 38.0, + 39.0, + 40.0, + 41.0, + 42.0, + 43.0, + 44.0, + 45.0, + 46.0, + 47.0, + 48.0, + 49.0, + 50.0, + 51.0, + 52.0, + 53.0, + ], + np.float32, + ) # 2-D index - A_indices = np.array([0, 1, 0, 2, 0, 6, 0, 7, 0, 8, 1, 0, 1, - 1, 1, 2, 1, 6, 1, 7, 1, 8, 2, 0, 2, 1, - 2, 2, 2, 6, 2, 7, 2, 8, 3, 3, 3, 4, 3, - 5, 3, 6, 3, 7, 3, 8, 4, 3, 4, 4, 4, 5, - 4, 6, 4, 7, 4, 8, 5, 3, 5, 4, 5, 5, 5, - 6, 5, 7, 5, 8, 6, 0, 6, 1, 6, 2, 6, 3, - 6, 4, 6, 5, 7, 0, 7, 1, 7, 2, 7, 3, 7, - 4, 7, 5, 8, 0, 8, 1, 8, 2, 8, 3, 8, 4, - 8, 5], np.int64).reshape((len(A_values), 2)) + A_indices = np.array( + [ + 0, + 1, + 0, + 2, + 0, + 6, + 0, + 7, + 0, + 8, + 1, + 0, + 1, + 1, + 1, + 2, + 1, + 6, + 1, + 7, + 1, + 8, + 2, + 0, + 2, + 1, + 2, + 2, + 2, + 6, + 2, + 7, + 2, + 8, + 3, + 3, + 3, + 4, + 3, + 5, + 3, + 6, + 3, + 7, + 3, + 8, + 4, + 3, + 4, + 4, + 4, + 5, + 4, + 6, + 4, + 7, + 4, + 8, + 5, + 3, + 5, + 4, + 5, + 5, + 5, + 6, + 5, + 7, + 5, + 8, + 6, + 0, + 6, + 1, + 6, + 2, + 6, + 3, + 6, + 4, + 6, + 5, + 7, + 0, + 7, + 1, + 7, + 2, + 7, + 3, + 7, + 4, + 7, + 5, + 8, + 0, + 8, + 1, + 8, + 2, + 8, + 3, + 8, + 4, + 8, + 5, + ], + np.int64, + ).reshape((len(A_values), 2)) - cpu_device = onnxrt.OrtDevice.make('cpu', 0) + cpu_device = onnxrt.OrtDevice.make("cpu", 0) sparse_tensor = onnxrt.SparseTensor.sparse_coo_from_numpy(common_shape, A_values, A_indices, cpu_device) A_ort_value = onnxrt.OrtValue.ort_value_from_sparse_tensor(sparse_tensor) - B_data = np.array([0, 1, 2, 0, 0, 0, 3, 4, 5, - 6, 7, 8, 0, 0, 0, 9, 10, 11, - 12, 13, 14, 0, 0, 0, 15, 16, 17, - 0, 0, 0, 18, 19, 20, 21, 22, 23, - 0, 0, 0, 24, 25, 26, 27, 28, 29, - 0, 0, 0, 30, 31, 32, 33, 34, 35, - 36, 37, 38, 39, 40, 41, 0, 0, 0, - 42, 43, 44, 45, 46, 47, 0, 0, 0, - 48, 49, 50, 51, 52, 53, 0, 0, 0], np.float32).reshape(common_shape) + B_data = np.array( + [ + 0, + 1, + 2, + 0, + 0, + 0, + 3, + 4, + 5, + 6, + 7, + 8, + 0, + 0, + 0, + 9, + 10, + 11, + 12, + 13, + 14, + 0, + 0, + 0, + 15, + 16, + 17, + 0, + 0, + 0, + 18, + 19, + 20, + 21, + 22, + 23, + 0, + 0, + 0, + 24, + 25, + 26, + 27, + 28, + 29, + 0, + 0, + 0, + 30, + 31, + 32, + 33, + 34, + 35, + 36, + 37, + 38, + 39, + 40, + 41, + 0, + 0, + 0, + 42, + 43, + 44, + 45, + 46, + 47, + 0, + 0, + 0, + 48, + 49, + 50, + 51, + 52, + 53, + 0, + 0, + 0, + ], + np.float32, + ).reshape(common_shape) B_ort_value = onnxrt.OrtValue.ortvalue_from_numpy(B_data) - Y_result = np.array([546, 561, 576, 552, 564, 576, 39, 42, 45, - 1410, 1461, 1512, 1362, 1392, 1422, 201, 222, 243, - 2274, 2361, 2448, 2172, 2220, 2268, 363, 402, 441, - 2784, 2850, 2916, 4362, 4485, 4608, 1551, 1608, 1665, - 3540, 3624, 3708, 5604, 5763, 5922, 2037, 2112, 2187, - 4296, 4398, 4500, 6846, 7041, 7236, 2523, 2616, 2709, - 678, 789, 900, 2892, 3012, 3132, 4263, 4494, 4725, - 786, 915, 1044, 3324, 3462, 3600, 4911, 5178, 5445, - 894, 1041, 1188, 3756, 3912, 4068, 5559, 5862, 6165], np.float).reshape(common_shape) + Y_result = np.array( + [ + 546, + 561, + 576, + 552, + 564, + 576, + 39, + 42, + 45, + 1410, + 1461, + 1512, + 1362, + 1392, + 1422, + 201, + 222, + 243, + 2274, + 2361, + 2448, + 2172, + 2220, + 2268, + 363, + 402, + 441, + 2784, + 2850, + 2916, + 4362, + 4485, + 4608, + 1551, + 1608, + 1665, + 3540, + 3624, + 3708, + 5604, + 5763, + 5922, + 2037, + 2112, + 2187, + 4296, + 4398, + 4500, + 6846, + 7041, + 7236, + 2523, + 2616, + 2709, + 678, + 789, + 900, + 2892, + 3012, + 3132, + 4263, + 4494, + 4725, + 786, + 915, + 1044, + 3324, + 3462, + 3600, + 4911, + 5178, + 5445, + 894, + 1041, + 1188, + 3756, + 3912, + 4068, + 5559, + 5862, + 6165, + ], + np.float, + ).reshape(common_shape) - sess = onnxrt.InferenceSession(get_name("sparse_to_dense_matmul.onnx"), - providers=onnxrt.get_available_providers()) - res = sess.run_with_ort_values(["dense_Y"], { "sparse_A" : A_ort_value, "dense_B" : B_ort_value }) + sess = onnxrt.InferenceSession( + get_name("sparse_to_dense_matmul.onnx"), + providers=onnxrt.get_available_providers(), + ) + res = sess.run_with_ort_values(["dense_Y"], {"sparse_A": A_ort_value, "dense_B": B_ort_value}) self.assertEqual(len(res), 1) ort_value = res[0] self.assertTrue(isinstance(ort_value, onnxrt.OrtValue)) diff --git a/onnxruntime/test/python/onnxruntime_test_python_symbolic_shape_infer.py b/onnxruntime/test/python/onnxruntime_test_python_symbolic_shape_infer.py index 5abd6fdcffbde..97ea2df3d5b9a 100644 --- a/onnxruntime/test/python/onnxruntime_test_python_symbolic_shape_infer.py +++ b/onnxruntime/test/python/onnxruntime_test_python_symbolic_shape_infer.py @@ -1,21 +1,32 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. -# -*- coding: UTF-8 -*- -import onnx -from onnx import helper, AttributeProto, TensorProto, GraphProto import os -if os.path.exists(os.path.join(os.path.dirname(__file__), '..', '..', 'python', 'tools', 'symbolic_shape_infer.py')): +# -*- coding: UTF-8 -*- +import onnx +from onnx import AttributeProto, GraphProto, TensorProto, helper + +if os.path.exists( + os.path.join( + os.path.dirname(__file__), + "..", + "..", + "python", + "tools", + "symbolic_shape_infer.py", + ) +): # Allow running this test script without installing onnxruntime package. import sys - sys.path.append(os.path.join(os.path.dirname(__file__), '..', '..', 'python', 'tools')) + + sys.path.append(os.path.join(os.path.dirname(__file__), "..", "..", "python", "tools")) from symbolic_shape_infer import SymbolicShapeInference else: from onnxruntime.tools.symbolic_shape_infer import SymbolicShapeInference -from pathlib import Path import unittest +from pathlib import Path def unique_element(lst): @@ -25,56 +36,57 @@ def unique_element(lst): class TestSymbolicShapeInference(unittest.TestCase): def test_symbolic_shape_infer(self): - + cwd = os.getcwd() - test_model_dir = os.path.join(cwd, '..', 'models') - for filename in Path(test_model_dir).rglob('*.onnx'): - if filename.name.startswith('.'): + test_model_dir = os.path.join(cwd, "..", "models") + for filename in Path(test_model_dir).rglob("*.onnx"): + if filename.name.startswith("."): continue # skip some bad model files print("Running symbolic shape inference on : " + str(filename)) - SymbolicShapeInference.infer_shapes(in_mp=onnx.load(str(filename)), - auto_merge=True, - int_max=100000, - guess_output_rank=True) + SymbolicShapeInference.infer_shapes( + in_mp=onnx.load(str(filename)), + auto_merge=True, + int_max=100000, + guess_output_rank=True, + ) def test_mismatched_types(self): graph = helper.make_graph( - [helper.make_node( - "If", - ["x"], - ["out"], - name="if_node", - then_branch=helper.make_graph( - [helper.make_node( - "Constant", + [ + helper.make_node( + "If", + ["x"], + ["out"], + name="if_node", + then_branch=helper.make_graph( + [ + helper.make_node( + "Constant", + [], + ["one_float"], + value=helper.make_tensor("one_float_value", TensorProto.FLOAT, [], [1]), + ) + ], + "then", [], - ["one_float"], - value=helper.make_tensor( - "one_float_value", - TensorProto.FLOAT, - [], - [1]), - )], - "then", - [], - [helper.make_tensor_value_info("one_float", TensorProto.FLOAT, [])], - ), - else_branch=helper.make_graph( - [helper.make_node( - "Constant", + [helper.make_tensor_value_info("one_float", TensorProto.FLOAT, [])], + ), + else_branch=helper.make_graph( + [ + helper.make_node( + "Constant", + [], + ["one_double"], + value=helper.make_tensor("one_double", TensorProto.DOUBLE, [], [1]), + ) + ], + "else", [], - ["one_double"], - value=helper.make_tensor( - "one_double", - TensorProto.DOUBLE, - [], - [1]), - )], - "else", - [], - [helper.make_tensor_value_info("one_double", TensorProto.DOUBLE, [])], - ))], + [helper.make_tensor_value_info("one_double", TensorProto.DOUBLE, [])], + ), + ) + ], "graph", [helper.make_tensor_value_info("x", TensorProto.BOOL, [])], [helper.make_tensor_value_info("out", TensorProto.FLOAT, [])], @@ -99,98 +111,140 @@ def _check_shapes(self, graph, inferred_graph, vis): # type: (GraphProto, Graph inferred_vis_names = set(x.name for x in inferred_vis) assert vis_names == inferred_vis_names, (vis_names, inferred_vis_names) for vi, inferred_vi in zip(vis, inferred_vis): - assert vi == inferred_vi, '\n%s\n%s\n' % (vi, inferred_vi) + assert vi == inferred_vi, "\n%s\n%s\n" % (vi, inferred_vi) assert False def test_unsqueeze_opset_11(self): - graph = helper.make_graph([ - helper.make_node("Unsqueeze", ["input"], ["temp"], axes=[0]), - helper.make_node("Identity", ["temp"], ["output"]), - ], "Unsqueeze_Test", [ - helper.make_tensor_value_info('input', TensorProto.FLOAT, ['b', 's']), - ], [ - helper.make_tensor_value_info('output', TensorProto.FLOAT, [1, 'b', 's']), - ]) - model = helper.make_model(graph, producer_name='Unsqueeze_Test_Model') + graph = helper.make_graph( + [ + helper.make_node("Unsqueeze", ["input"], ["temp"], axes=[0]), + helper.make_node("Identity", ["temp"], ["output"]), + ], + "Unsqueeze_Test", + [ + helper.make_tensor_value_info("input", TensorProto.FLOAT, ["b", "s"]), + ], + [ + helper.make_tensor_value_info("output", TensorProto.FLOAT, [1, "b", "s"]), + ], + ) + model = helper.make_model(graph, producer_name="Unsqueeze_Test_Model") model.opset_import[0].version = 11 inferred = SymbolicShapeInference.infer_shapes(model, auto_merge=True) expected_shapes = [ - helper.make_tensor_value_info('temp', TensorProto.FLOAT, [1, 'b', 's']), - helper.make_tensor_value_info('output', TensorProto.FLOAT, [1, 'b', 's']) + helper.make_tensor_value_info("temp", TensorProto.FLOAT, [1, "b", "s"]), + helper.make_tensor_value_info("output", TensorProto.FLOAT, [1, "b", "s"]), ] self._check_shapes(graph, inferred.graph, expected_shapes) def test_unsqueeze_opset_13(self): - graph = helper.make_graph([ - helper.make_node("Unsqueeze", ["input", "axes"], ["temp"]), - helper.make_node("Identity", ["temp"], ["output"]), - ], "Unsqueeze_Test", [ - helper.make_tensor_value_info('input', TensorProto.FLOAT, ['b', 's']), - ], [ - helper.make_tensor_value_info('output', TensorProto.FLOAT, ['b', 's', 1]), - ], [ - helper.make_tensor('axes', TensorProto.INT64, [1], [-1]), - ]) - model = helper.make_model(graph, producer_name='Unsqueeze_Test_Model') + graph = helper.make_graph( + [ + helper.make_node("Unsqueeze", ["input", "axes"], ["temp"]), + helper.make_node("Identity", ["temp"], ["output"]), + ], + "Unsqueeze_Test", + [ + helper.make_tensor_value_info("input", TensorProto.FLOAT, ["b", "s"]), + ], + [ + helper.make_tensor_value_info("output", TensorProto.FLOAT, ["b", "s", 1]), + ], + [ + helper.make_tensor("axes", TensorProto.INT64, [1], [-1]), + ], + ) + model = helper.make_model(graph, producer_name="Unsqueeze_Test_Model") model.opset_import[0].version = 13 inferred = SymbolicShapeInference.infer_shapes(model, auto_merge=True) expected_shapes = [ - helper.make_tensor_value_info('temp', TensorProto.FLOAT, ['b', 's', 1]), - helper.make_tensor_value_info('output', TensorProto.FLOAT, ['b', 's', 1]) + helper.make_tensor_value_info("temp", TensorProto.FLOAT, ["b", "s", 1]), + helper.make_tensor_value_info("output", TensorProto.FLOAT, ["b", "s", 1]), ] self._check_shapes(graph, inferred.graph, expected_shapes) def test_gather_indices(self): - graph = helper.make_graph([ - helper.make_node("Constant", [], ["data"], "constant", - value=helper.make_tensor('input', TensorProto.FLOAT, - [5], [0.0, 1.0, 2.0, 3.0, 4.0])), - helper.make_node("Gather", ["data", "indices"], ["output"], axis=0), - ], "Gather_Test", [ - helper.make_tensor_value_info('indices', TensorProto.INT64, ['b']), - ], [ - helper.make_tensor_value_info('output', TensorProto.FLOAT, ['b']), - ]) - model = helper.make_model(graph, producer_name='Gather_Test_Model') + graph = helper.make_graph( + [ + helper.make_node( + "Constant", + [], + ["data"], + "constant", + value=helper.make_tensor("input", TensorProto.FLOAT, [5], [0.0, 1.0, 2.0, 3.0, 4.0]), + ), + helper.make_node("Gather", ["data", "indices"], ["output"], axis=0), + ], + "Gather_Test", + [ + helper.make_tensor_value_info("indices", TensorProto.INT64, ["b"]), + ], + [ + helper.make_tensor_value_info("output", TensorProto.FLOAT, ["b"]), + ], + ) + model = helper.make_model(graph, producer_name="Gather_Test_Model") model.opset_import[0].version = 13 inferred = SymbolicShapeInference.infer_shapes(model, auto_merge=True) expected_shapes = [ - helper.make_tensor_value_info('data', TensorProto.FLOAT, [5]), - helper.make_tensor_value_info('output', TensorProto.FLOAT, ['b']) + helper.make_tensor_value_info("data", TensorProto.FLOAT, [5]), + helper.make_tensor_value_info("output", TensorProto.FLOAT, ["b"]), ] self._check_shapes(graph, inferred.graph, expected_shapes) def test_embed_layer_norm(self): hidden_size = 32 initializers = [ - helper.make_tensor('word_embedding', TensorProto.FLOAT, [100, hidden_size], [1.0] * (100 * hidden_size)), - helper.make_tensor('position_embedding', TensorProto.FLOAT, [20, hidden_size], [1.0] * (20 * hidden_size)), - helper.make_tensor('segment_embedding', TensorProto.FLOAT, [2, hidden_size], [1.0] * (2 * hidden_size)), - helper.make_tensor('gamma', TensorProto.FLOAT, [hidden_size], [1.0] * hidden_size), - helper.make_tensor('beta', TensorProto.FLOAT, [hidden_size], [1.0] * hidden_size) + helper.make_tensor( + "word_embedding", + TensorProto.FLOAT, + [100, hidden_size], + [1.0] * (100 * hidden_size), + ), + helper.make_tensor( + "position_embedding", + TensorProto.FLOAT, + [20, hidden_size], + [1.0] * (20 * hidden_size), + ), + helper.make_tensor( + "segment_embedding", + TensorProto.FLOAT, + [2, hidden_size], + [1.0] * (2 * hidden_size), + ), + helper.make_tensor("gamma", TensorProto.FLOAT, [hidden_size], [1.0] * hidden_size), + helper.make_tensor("beta", TensorProto.FLOAT, [hidden_size], [1.0] * hidden_size), ] nodes = [ - helper.make_node("EmbedLayerNormalization", - inputs=[ - "input_ids", "segment_ids", "word_embedding", "position_embedding", - "segment_embedding", "gamma", "beta" - ], - outputs=["output", "mask_index"], - domain="com.microsoft"), + helper.make_node( + "EmbedLayerNormalization", + inputs=[ + "input_ids", + "segment_ids", + "word_embedding", + "position_embedding", + "segment_embedding", + "gamma", + "beta", + ], + outputs=["output", "mask_index"], + domain="com.microsoft", + ), ] inputs = [ - helper.make_tensor_value_info('input_ids', TensorProto.FLOAT, ['b', 's']), - helper.make_tensor_value_info('segment_ids', TensorProto.FLOAT, ['b', 's']), + helper.make_tensor_value_info("input_ids", TensorProto.FLOAT, ["b", "s"]), + helper.make_tensor_value_info("segment_ids", TensorProto.FLOAT, ["b", "s"]), ] outputs = [ - helper.make_tensor_value_info('output', TensorProto.FLOAT, None), - helper.make_tensor_value_info('mask_index', TensorProto.INT32, None), + helper.make_tensor_value_info("output", TensorProto.FLOAT, None), + helper.make_tensor_value_info("mask_index", TensorProto.INT32, None), ] graph = helper.make_graph(nodes, "Unsqueeze_Test", inputs, outputs, initializers) @@ -198,8 +252,8 @@ def test_embed_layer_norm(self): inferred = SymbolicShapeInference.infer_shapes(model, auto_merge=True) expected_shapes = [ - helper.make_tensor_value_info('output', TensorProto.FLOAT, ['b', 's', hidden_size]), - helper.make_tensor_value_info('mask_index', TensorProto.INT32, ['b']) + helper.make_tensor_value_info("output", TensorProto.FLOAT, ["b", "s", hidden_size]), + helper.make_tensor_value_info("mask_index", TensorProto.INT32, ["b"]), ] self._check_shapes(graph, inferred.graph, expected_shapes) @@ -207,27 +261,23 @@ def test_softmax_cross_entropy_loss(self): hidden_size = 1024 nodes = [ - helper.make_node("SoftmaxCrossEntropyLoss", - inputs=["logits", "labels"], - outputs=["loss"]), + helper.make_node("SoftmaxCrossEntropyLoss", inputs=["logits", "labels"], outputs=["loss"]), ] inputs = [ - helper.make_tensor_value_info('logits', TensorProto.FLOAT, ['b', 's', hidden_size]), - helper.make_tensor_value_info('labels', TensorProto.INT32, ['b', 's']), + helper.make_tensor_value_info("logits", TensorProto.FLOAT, ["b", "s", hidden_size]), + helper.make_tensor_value_info("labels", TensorProto.INT32, ["b", "s"]), ] outputs = [ - helper.make_tensor_value_info('loss', TensorProto.FLOAT, None), + helper.make_tensor_value_info("loss", TensorProto.FLOAT, None), ] graph = helper.make_graph(nodes, "SoftmaxCrossEntropyLoss_Test", inputs, outputs, []) model = helper.make_model(graph) inferred = SymbolicShapeInference.infer_shapes(model, auto_merge=True) - expected_shapes = [ - helper.make_tensor_value_info('loss', TensorProto.FLOAT, []) - ] + expected_shapes = [helper.make_tensor_value_info("loss", TensorProto.FLOAT, [])] self._check_shapes(graph, inferred.graph, expected_shapes) def _test_einsum_one_input_impl(self, input_0_shape, output_0_shape, eqn): @@ -235,18 +285,16 @@ def _test_einsum_one_input_impl(self, input_0_shape, output_0_shape, eqn): helper.make_node("Einsum", ["input_0"], ["output_0"], "einsum_0", equation=eqn), ] inputs = [ - helper.make_tensor_value_info('input_0', TensorProto.FLOAT, input_0_shape), + helper.make_tensor_value_info("input_0", TensorProto.FLOAT, input_0_shape), ] outputs = [ - helper.make_tensor_value_info('output_0', TensorProto.FLOAT, None), + helper.make_tensor_value_info("output_0", TensorProto.FLOAT, None), ] graph = helper.make_graph(nodes, "Einsum_Test", inputs, outputs, []) model = helper.make_model(graph) inferred = SymbolicShapeInference.infer_shapes(model, auto_merge=True) - expected_shapes = [ - helper.make_tensor_value_info('output_0', TensorProto.FLOAT, output_0_shape) - ] + expected_shapes = [helper.make_tensor_value_info("output_0", TensorProto.FLOAT, output_0_shape)] self._check_shapes(graph, inferred.graph, expected_shapes) def _test_einsum_two_inputs_impl(self, input_0_shape, input_1_shape, output_0_shape, eqn): @@ -254,23 +302,21 @@ def _test_einsum_two_inputs_impl(self, input_0_shape, input_1_shape, output_0_sh helper.make_node("Einsum", ["input_0", "input_1"], ["output_0"], "einsum_0", equation=eqn), ] inputs = [ - helper.make_tensor_value_info('input_0', TensorProto.FLOAT, input_0_shape), - helper.make_tensor_value_info('input_1', TensorProto.FLOAT, input_1_shape), + helper.make_tensor_value_info("input_0", TensorProto.FLOAT, input_0_shape), + helper.make_tensor_value_info("input_1", TensorProto.FLOAT, input_1_shape), ] outputs = [ - helper.make_tensor_value_info('output_0', TensorProto.FLOAT, None), + helper.make_tensor_value_info("output_0", TensorProto.FLOAT, None), ] graph = helper.make_graph(nodes, "Einsum_Test", inputs, outputs, []) model = helper.make_model(graph) inferred = SymbolicShapeInference.infer_shapes(model, auto_merge=True) - expected_shapes = [ - helper.make_tensor_value_info('output_0', TensorProto.FLOAT, output_0_shape) - ] + expected_shapes = [helper.make_tensor_value_info("output_0", TensorProto.FLOAT, output_0_shape)] self._check_shapes(graph, inferred.graph, expected_shapes) def test_einsum_matmul(self): - self._test_einsum_two_inputs_impl([1, 'b', 8], [2, 12, 'n'], [1, 'b', 12, 'n'], "abc, cde -> abde") + self._test_einsum_two_inputs_impl([1, "b", 8], [2, 12, "n"], [1, "b", 12, "n"], "abc, cde -> abde") def test_einsum_batch_matmul(self): self._test_einsum_two_inputs_impl([5, 2, 3], [5, 3, 4], [5, 2, 4], "bij, bjk -> bik") @@ -279,13 +325,13 @@ def test_einsum_inner_prod(self): self._test_einsum_two_inputs_impl([5], [5], [], "i, i") def test_einsum_batch_diagonal(self): - self._test_einsum_one_input_impl([3, 5, 5], [3, 5], "...ii ->...i") + self._test_einsum_one_input_impl([3, 5, 5], [3, 5], "...ii ->...i") def test_einsum_sum(self): - self._test_einsum_one_input_impl(['a', 'b'], ['a'], "ij -> i") + self._test_einsum_one_input_impl(["a", "b"], ["a"], "ij -> i") def test_einsum_transpose(self): - self._test_einsum_one_input_impl(['a', 'b'], ['b', 'a'], "ij -> ji") + self._test_einsum_one_input_impl(["a", "b"], ["b", "a"], "ij -> ji") class TestSymbolicShapeInferenceForSlice(unittest.TestCase): @@ -302,30 +348,43 @@ def get_initializer(name): initializers = [ get_initializer(name) - for name in ["zero", "one", "two", "ten", "intmax", "neg_intmax", "neg_one", "neg_ten"] + for name in [ + "zero", + "one", + "two", + "ten", + "intmax", + "neg_intmax", + "neg_one", + "neg_ten", + ] ] inputs = [] nodes = [] for i, dim in enumerate(input_dims): inputs.append(onnx.helper.make_tensor_value_info(f"t{i}", TensorProto.FLOAT, ["B", dim])) - nodes.extend([ - onnx.helper.make_node("Shape", [f"t{i}"], [f"shape{i}"]), - onnx.helper.make_node("Slice", [f"shape{i}", "one", "two", "zero", "one"], [f"dim{i}"]), - onnx.helper.make_node("Neg", [f"dim{i}"], [f"neg_dim{i}"]) - ]) + nodes.extend( + [ + onnx.helper.make_node("Shape", [f"t{i}"], [f"shape{i}"]), + onnx.helper.make_node("Slice", [f"shape{i}", "one", "two", "zero", "one"], [f"dim{i}"]), + onnx.helper.make_node("Neg", [f"dim{i}"], [f"neg_dim{i}"]), + ] + ) def make_concat_dims(concat_name, dims): dims = [f"neg_{dimstrmap(dim[1:])}" if dim.startswith("-") else dimstrmap(dim) for dim in dims] return onnx.helper.make_node("Concat", dims, [concat_name], axis=0) - nodes.extend([ - onnx.helper.make_node("Concat", [inp.name for inp in inputs], ["concat"], axis=1), - make_concat_dims("starts", ["zero", start]), - make_concat_dims("ends", ["intmax", end]), - make_concat_dims("axes", ["zero", "one"]), - make_concat_dims("steps", ["one", step]), - onnx.helper.make_node("Slice", ["concat", "starts", "ends", "axes", "steps"], ["output"]) - ]) + nodes.extend( + [ + onnx.helper.make_node("Concat", [inp.name for inp in inputs], ["concat"], axis=1), + make_concat_dims("starts", ["zero", start]), + make_concat_dims("ends", ["intmax", end]), + make_concat_dims("axes", ["zero", "one"]), + make_concat_dims("steps", ["one", step]), + onnx.helper.make_node("Slice", ["concat", "starts", "ends", "axes", "steps"], ["output"]), + ] + ) output = onnx.helper.make_tensor_value_info("output", TensorProto.FLOAT, ["d1", "d2"]) graph_def = onnx.helper.make_graph(nodes, "graph", inputs, [output], initializer=initializers) model = SymbolicShapeInference.infer_shapes(onnx.helper.make_model(graph_def)) @@ -361,5 +420,5 @@ def test_flip_of_concat(self): self.check_slice_of_concat(["N", "N", "N"], "-one", "-intmax", "-one", "3*N") -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/onnxruntime/test/python/onnxruntime_test_python_tvm.py b/onnxruntime/test/python/onnxruntime_test_python_tvm.py index 945bf2604f924..b081a97b123c9 100644 --- a/onnxruntime/test/python/onnxruntime_test_python_tvm.py +++ b/onnxruntime/test/python/onnxruntime_test_python_tvm.py @@ -1,31 +1,28 @@ import numpy from numpy.testing import assert_almost_equal -from onnx import numpy_helper, TensorProto -from onnx.helper import ( - make_model, make_node, set_model_props, make_tensor, - make_graph, make_tensor_value_info) +from onnx import TensorProto, numpy_helper +from onnx.helper import make_graph, make_model, make_node, make_tensor, make_tensor_value_info, set_model_props + import onnxruntime if "TvmExecutionProvider" not in onnxruntime.get_available_providers(): - raise AssertionError( - "Unable to find 'TvmExecutionProvider' in %r." % onnxruntime.get_available_providers()) - -X = make_tensor_value_info('X', TensorProto.FLOAT, [None, None]) -A = make_tensor_value_info('A', TensorProto.FLOAT, [None, None]) -B = make_tensor_value_info('B', TensorProto.FLOAT, [None, None]) -Y = make_tensor_value_info('Y', TensorProto.FLOAT, [None, None]) -node1 = make_node('MatMul', ['X', 'A'], ['XA']) -node2 = make_node('Add', ['XA', 'B'], ['Y']) -graph = make_graph([node1, node2], 'lr', [X, A, B], [Y]) + raise AssertionError("Unable to find 'TvmExecutionProvider' in %r." % onnxruntime.get_available_providers()) + +X = make_tensor_value_info("X", TensorProto.FLOAT, [None, None]) +A = make_tensor_value_info("A", TensorProto.FLOAT, [None, None]) +B = make_tensor_value_info("B", TensorProto.FLOAT, [None, None]) +Y = make_tensor_value_info("Y", TensorProto.FLOAT, [None, None]) +node1 = make_node("MatMul", ["X", "A"], ["XA"]) +node2 = make_node("Add", ["XA", "B"], ["Y"]) +graph = make_graph([node1, node2], "lr", [X, A, B], [Y]) onnx_model = make_model(graph) a = numpy.random.randn(2, 2).astype(numpy.float32) b = numpy.random.randn(1, 2).astype(numpy.float32) x = numpy.random.randn(1, 2).astype(numpy.float32) -data = {'A': a, 'B': b, 'X': x} +data = {"A": a, "B": b, "X": x} -sess = onnxruntime.InferenceSession( - onnx_model.SerializeToString(), providers=['CPUExecutionProvider']) +sess = onnxruntime.InferenceSession(onnx_model.SerializeToString(), providers=["CPUExecutionProvider"]) y = sess.run(None, data)[0] @@ -37,8 +34,8 @@ tuning_file_path="", tuning_type="Ansor", input_names=" ".join(i.name for i in sess.get_inputs()), - input_shapes=" ".join(str(numpy.array(data[i.name].shape)) - for i in sess.get_inputs())) + input_shapes=" ".join(str(numpy.array(data[i.name].shape)) for i in sess.get_inputs()), +) so = onnxruntime.SessionOptions() so.log_severity_level = 0 @@ -46,9 +43,11 @@ so.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_DISABLE_ALL sess = onnxruntime.InferenceSession( - onnx_model.SerializeToString(), so, + onnx_model.SerializeToString(), + so, providers=["TvmExecutionProvider"], - provider_options=[provider_options]) + provider_options=[provider_options], +) y_tvm = sess.run(None, data)[0] diff --git a/onnxruntime/test/python/onnxruntime_test_training_unit_tests.py b/onnxruntime/test/python/onnxruntime_test_training_unit_tests.py index 0f6e210ad02e5..f89459cfd9750 100644 --- a/onnxruntime/test/python/onnxruntime_test_training_unit_tests.py +++ b/onnxruntime/test/python/onnxruntime_test_training_unit_tests.py @@ -2,19 +2,20 @@ # Licensed under the MIT License. import unittest -from numpy.testing import assert_allclose import torch import torch.nn as nn - +from numpy.testing import assert_allclose from onnxruntime_test_ort_trainer import map_optimizer_attributes, ort_trainer_learning_rate_description -import onnxruntime from onnxruntime_test_training_unittest_utils import process_dropout -from onnxruntime.capi.ort_trainer import ORTTrainer, IODescription, ModelDescription + +import onnxruntime +from onnxruntime.capi.ort_trainer import IODescription, ModelDescription, ORTTrainer torch.manual_seed(1) onnxruntime.set_seed(1) + class TestTrainingDropout(unittest.TestCase): def testTrainingAndEvalDropout(self): # Temporarily disable this test. @@ -22,49 +23,73 @@ def testTrainingAndEvalDropout(self): # to sort backward graph before forward graph which gives incorrect result. # TODO Re-enable when that is fixed. return + class TwoDropoutNet(nn.Module): def __init__(self, drop_prb_1, drop_prb_2, dim_size): super(TwoDropoutNet, self).__init__() self.drop_1 = nn.Dropout(drop_prb_1) self.drop_2 = nn.Dropout(drop_prb_2) self.weight_1 = torch.nn.Parameter(torch.zeros(dim_size, dtype=torch.float32)) + def forward(self, x): x = x + self.weight_1 x = self.drop_1(x) x = self.drop_2(x) output = x return output[0] + dim_size = 3 device = torch.device("cuda", 0) # This will drop all values, therefore expecting all 0 in output tensor model = TwoDropoutNet(0.999, 0.999, dim_size) - input_desc = IODescription('input', [dim_size], torch.float32) - output_desc = IODescription('output', [], torch.float32) + input_desc = IODescription("input", [dim_size], torch.float32) + output_desc = IODescription("output", [], torch.float32) model_desc = ModelDescription([input_desc], [output_desc]) lr_desc = ort_trainer_learning_rate_description() - model = ORTTrainer(model, None, model_desc, "LambOptimizer", - map_optimizer_attributes, - lr_desc, - device, - postprocess_model=process_dropout, - world_rank=0, world_size=1) + model = ORTTrainer( + model, + None, + model_desc, + "LambOptimizer", + map_optimizer_attributes, + lr_desc, + device, + postprocess_model=process_dropout, + world_rank=0, + world_size=1, + ) input = torch.ones(dim_size, dtype=torch.float32).to(device) expected_training_output = [0.0] expected_eval_output = [1.0] - learning_rate = torch.tensor([1.0000000e+00]).to(device) - input_args=[input, learning_rate] + learning_rate = torch.tensor([1.0000000e00]).to(device) + input_args = [input, learning_rate] train_output = model.train_step(*input_args) rtol = 1e-04 - assert_allclose(expected_training_output, train_output.item(), rtol=rtol, err_msg="dropout training loss mismatch") + assert_allclose( + expected_training_output, + train_output.item(), + rtol=rtol, + err_msg="dropout training loss mismatch", + ) eval_output = model.eval_step(input) - assert_allclose(expected_eval_output, eval_output.item(), rtol=rtol, err_msg="dropout eval loss mismatch") + assert_allclose( + expected_eval_output, + eval_output.item(), + rtol=rtol, + err_msg="dropout eval loss mismatch", + ) # Do another train step to make sure it's using original ratios train_output_2 = model.train_step(*input_args) - assert_allclose(expected_training_output, train_output_2.item(), rtol=rtol, err_msg="dropout training loss 2 mismatch") + assert_allclose( + expected_training_output, + train_output_2.item(), + rtol=rtol, + err_msg="dropout training loss 2 mismatch", + ) -if __name__ == '__main__': - unittest.main(module=__name__, buffer=True) +if __name__ == "__main__": + unittest.main(module=__name__, buffer=True) diff --git a/onnxruntime/test/python/onnxruntime_test_training_unittest_utils.py b/onnxruntime/test/python/onnxruntime_test_training_unittest_utils.py index 19cf2d98a0a27..3d3feca06a99b 100644 --- a/onnxruntime/test/python/onnxruntime_test_training_unittest_utils.py +++ b/onnxruntime/test/python/onnxruntime_test_training_unittest_utils.py @@ -1,6 +1,7 @@ import numpy as np from onnx import numpy_helper + def get_node_index(model, node): i = 0 while i < len(model.graph.node): @@ -9,13 +10,14 @@ def get_node_index(model, node): i += 1 return i if i < len(model.graph.node) else None -def add_const(model, name, output, t_value = None, f_value = None): + +def add_const(model, name, output, t_value=None, f_value=None): const_node = model.graph.node.add() - const_node.op_type = 'Constant' + const_node.op_type = "Constant" const_node.name = name const_node.output.extend([output]) attr = const_node.attribute.add() - attr.name = 'value' + attr.name = "value" if t_value is not None: attr.type = 4 attr.t.CopyFrom(t_value) @@ -24,20 +26,26 @@ def add_const(model, name, output, t_value = None, f_value = None): attr.f = f_value return const_node + def process_dropout(model): dropouts = [] index = 0 for node in model.graph.node: - if node.op_type == 'Dropout': + if node.op_type == "Dropout": new_dropout = model.graph.node.add() - new_dropout.op_type = 'TrainableDropout' - new_dropout.name = 'TrainableDropout_%d' % index - #make ratio node + new_dropout.op_type = "TrainableDropout" + new_dropout.name = "TrainableDropout_%d" % index + # make ratio node ratio = np.asarray([node.attribute[0].f], dtype=np.float32) print(ratio.shape) ratio_value = numpy_helper.from_array(ratio) - ratio_node = add_const(model, 'dropout_node_ratio_%d' % index, 'dropout_node_ratio_%d' % index, t_value=ratio_value) - print (ratio_node) + ratio_node = add_const( + model, + "dropout_node_ratio_%d" % index, + "dropout_node_ratio_%d" % index, + t_value=ratio_value, + ) + print(ratio_node) new_dropout.input.extend([node.input[0], ratio_node.output[0]]) new_dropout.output.extend(node.output) dropouts.append(get_node_index(model, node)) diff --git a/onnxruntime/test/python/quantization/op_test_utils.py b/onnxruntime/test/python/quantization/op_test_utils.py index 1764f675bb2c4..efb4c8fbc3453 100644 --- a/onnxruntime/test/python/quantization/op_test_utils.py +++ b/onnxruntime/test/python/quantization/op_test_utils.py @@ -1,15 +1,17 @@ +from pathlib import Path -import onnx import numpy as np +import onnx + import onnxruntime -from pathlib import Path from onnxruntime.quantization import CalibrationDataReader + class TestDataFeeds(CalibrationDataReader): def __init__(self, data_feeds): - ''' + """ parameter data_feeds: list of input feed, each input feed is diction of {input_name: np_array} - ''' + """ self.data_feeds = data_feeds self.iter_next = iter(self.data_feeds) @@ -21,9 +23,9 @@ def rewind(self): def InputFeedsNegOneZeroOne(n, name2shape): - ''' + """ randomize n feed according to shape, its values are from -1, 0, and 1 - ''' + """ input_data_list = [] for i in range(n): inputs = {} @@ -33,18 +35,20 @@ def InputFeedsNegOneZeroOne(n, name2shape): dr = TestDataFeeds(input_data_list) return dr + def check_op_type_order(testcase, model_to_check, ops): if isinstance(model_to_check, str): model = onnx.load(model_to_check) elif isinstance(model_to_check, onnx.ModelProto): model = model_to_check - testcase.assertEqual(len(ops), len(model.graph.node), 'op count is not same') + testcase.assertEqual(len(ops), len(model.graph.node), "op count is not same") for node_idx, node in enumerate(model.graph.node): testcase.assertEqual( ops[node_idx], node.op_type, - 'op {} is not in order. Expected: {}, Actual: {}'.format(node_idx, ops[node_idx], node.op_type)) + "op {} is not in order. Expected: {}, Actual: {}".format(node_idx, ops[node_idx], node.op_type), + ) def check_op_type_count(testcase, model_path, **kwargs): @@ -56,17 +60,27 @@ def check_op_type_count(testcase, model_path, **kwargs): if node.op_type in optype2count: optype2count[node.op_type] += 1 for op_type in kwargs: - testcase.assertEqual(kwargs[op_type], optype2count[op_type], 'op_type {} count not same'.format(op_type)) + testcase.assertEqual( + kwargs[op_type], + optype2count[op_type], + "op_type {} count not same".format(op_type), + ) def check_model_correctness(testcase, model_path_origin, model_path_to_check, inputs, rtol=1e-2, atol=0.05): sess_options = onnxruntime.SessionOptions() sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_DISABLE_ALL - origin_sess = onnxruntime.InferenceSession(model_path_origin, sess_options=sess_options, providers=["CPUExecutionProvider"]) + origin_sess = onnxruntime.InferenceSession( + model_path_origin, sess_options=sess_options, providers=["CPUExecutionProvider"] + ) origin_results = origin_sess.run([], inputs) - target_sess = onnxruntime.InferenceSession(model_path_to_check, sess_options=sess_options, providers=["CPUExecutionProvider"]) + target_sess = onnxruntime.InferenceSession( + model_path_to_check, + sess_options=sess_options, + providers=["CPUExecutionProvider"], + ) target_results = target_sess.run([], inputs) - testcase.assertEqual(len(origin_results), len(target_results), 'result count are different') + testcase.assertEqual(len(origin_results), len(target_results), "result count are different") for idx, ref_output in enumerate(origin_results): output = target_results[idx] np.testing.assert_allclose(ref_output, output, rtol=rtol, atol=atol) @@ -77,6 +91,7 @@ def check_op_nodes(testcase, model_path, node_checker): for node in model.graph.node: testcase.assertTrue(node_checker(node)) + def check_qtype_by_node_type(testcase, model_to_check, check_list): if isinstance(model_to_check, str): model = onnx.load(model_to_check) @@ -86,18 +101,18 @@ def check_qtype_by_node_type(testcase, model_to_check, check_list): value_infos = {vi.name: vi for vi in model.graph.value_info} value_infos.update({ot.name: ot for ot in model.graph.output}) value_infos.update({it.name: it for it in model.graph.input}) - initializers = {init.name : init for init in model.graph.initializer} + initializers = {init.name: init for init in model.graph.initializer} for node in model.graph.node: if node.op_type in check_list: input_output_check_list = check_list[node.op_type] for check_item in input_output_check_list: - tensor_name = node.input[check_item[1]] if check_item[0] == 'i' else node.output[check_item[1]] + tensor_name = node.input[check_item[1]] if check_item[0] == "i" else node.output[check_item[1]] testcase.assertTrue((tensor_name in value_infos) or (tensor_name in initializers)) if tensor_name in value_infos: vi = value_infos[tensor_name] - testcase.assertTrue(vi.type.HasField('tensor_type')) + testcase.assertTrue(vi.type.HasField("tensor_type")) testcase.assertTrue(vi.type.tensor_type.elem_type == check_item[2]) - else: #if (tensor_name in initializers): + else: # if (tensor_name in initializers): init = initializers[tensor_name] testcase.assertTrue(init.data_type == check_item[2]) diff --git a/onnxruntime/test/python/quantization/test_calibration.py b/onnxruntime/test/python/quantization/test_calibration.py index 398922519c5ca..9a7bd657cc067 100644 --- a/onnxruntime/test/python/quantization/test_calibration.py +++ b/onnxruntime/test/python/quantization/test_calibration.py @@ -7,24 +7,27 @@ # -------------------------------------------------------------------------- import unittest + +import numpy as np import onnx +from onnx import TensorProto, helper, numpy_helper + import onnxruntime -import numpy as np -from onnx import helper, TensorProto, numpy_helper from onnxruntime.quantization.calibrate import CalibrationDataReader, MinMaxCalibrater def generate_input_initializer(tensor_shape, tensor_dtype, input_name): - ''' - Helper function to generate initializers for test inputs - ''' + """ + Helper function to generate initializers for test inputs + """ tensor = np.random.normal(0, 0.3, tensor_shape).astype(tensor_dtype) init = numpy_helper.from_array(tensor, input_name) return init class TestDataReader(CalibrationDataReader): - '''for test purpose''' + """for test purpose""" + def __init__(self): self.preprocess_flag = True self.enum_data_dicts = [] @@ -36,16 +39,17 @@ def __init__(self): def get_next(self): if self.preprocess_flag: self.preprocess_flag = False - input_name = 'input' + input_name = "input" self.enum_data_dicts = iter([{input_name: input_data} for input_data in self.input_data_list]) return next(self.enum_data_dicts, None) def rewind(self): self.preprocess_flag = True + class TestCalibrate(unittest.TestCase): def test_augment_graph_config_1(self): - ''' TEST_CONFIG_1''' + """TEST_CONFIG_1""" # Conv # | @@ -53,32 +57,50 @@ def test_augment_graph_config_1(self): # | # MatMul - A = helper.make_tensor_value_info('A', TensorProto.FLOAT, [1, 1, 5, 5]) - B = helper.make_tensor_value_info('B', TensorProto.FLOAT, [1, 1, 3, 3]) - E = helper.make_tensor_value_info('E', TensorProto.FLOAT, [1, 1, 5, 1]) - F = helper.make_tensor_value_info('F', TensorProto.FLOAT, [1, 1, 5, 1]) - conv_node = onnx.helper.make_node('Conv', ['A', 'B'], ['C'], - name='Conv', - kernel_shape=[3, 3], - pads=[1, 1, 1, 1]) - clip_node = onnx.helper.make_node('Clip', ['C'], ['D'], name='Clip') - matmul_node = onnx.helper.make_node('MatMul', ['D', 'E'], ['F'], name='MatMul') - graph = helper.make_graph([conv_node, clip_node, matmul_node], 'test_graph_1', [A, B, E], [F]) + A = helper.make_tensor_value_info("A", TensorProto.FLOAT, [1, 1, 5, 5]) + B = helper.make_tensor_value_info("B", TensorProto.FLOAT, [1, 1, 3, 3]) + E = helper.make_tensor_value_info("E", TensorProto.FLOAT, [1, 1, 5, 1]) + F = helper.make_tensor_value_info("F", TensorProto.FLOAT, [1, 1, 5, 1]) + conv_node = onnx.helper.make_node( + "Conv", + ["A", "B"], + ["C"], + name="Conv", + kernel_shape=[3, 3], + pads=[1, 1, 1, 1], + ) + clip_node = onnx.helper.make_node("Clip", ["C"], ["D"], name="Clip") + matmul_node = onnx.helper.make_node("MatMul", ["D", "E"], ["F"], name="MatMul") + graph = helper.make_graph([conv_node, clip_node, matmul_node], "test_graph_1", [A, B, E], [F]) model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)]) - test_model_path = './test_model_1.onnx' + test_model_path = "./test_model_1.onnx" onnx.save(model, test_model_path) # Augmenting graph - augmented_model_path = './augmented_test_model_1.onnx' - calibrater = MinMaxCalibrater(test_model_path, ['Conv', 'MatMul'], augmented_model_path) + augmented_model_path = "./augmented_test_model_1.onnx" + calibrater = MinMaxCalibrater(test_model_path, ["Conv", "MatMul"], augmented_model_path) augmented_model = calibrater.get_augment_model() # Checking if each added ReduceMin and ReduceMax node and its output exists augmented_model_node_names = [node.name for node in augmented_model.graph.node] augmented_model_outputs = [output.name for output in augmented_model.graph.output] - added_node_names = ['C_ReduceMin', 'C_ReduceMax', 'D_ReduceMin', 'D_ReduceMax', 'F_ReduceMin', 'F_ReduceMax'] - added_outputs = ['C_ReduceMin', 'C_ReduceMax', 'D_ReduceMin', 'D_ReduceMax', 'F_ReduceMin', 'F_ReduceMax'] + added_node_names = [ + "C_ReduceMin", + "C_ReduceMax", + "D_ReduceMin", + "D_ReduceMax", + "F_ReduceMin", + "F_ReduceMax", + ] + added_outputs = [ + "C_ReduceMin", + "C_ReduceMax", + "D_ReduceMin", + "D_ReduceMax", + "F_ReduceMin", + "F_ReduceMax", + ] # Original 3 nodes + added ReduceMin/Max nodes self.assertEqual(len(augmented_model_node_names), 15) # Original 1 graph output + added outputs * 6 @@ -89,36 +111,44 @@ def test_augment_graph_config_1(self): self.assertTrue(output in augmented_model_outputs) def test_augment_graph_config_2(self): - '''TEST_CONFIG_2''' + """TEST_CONFIG_2""" # Conv # | # Conv - G = helper.make_tensor_value_info('G', TensorProto.FLOAT, [1, 1, 5, 5]) - H = helper.make_tensor_value_info('H', TensorProto.FLOAT, [1, 1, 3, 3]) - J = helper.make_tensor_value_info('J', TensorProto.FLOAT, [1, 1, 3, 3]) - K = helper.make_tensor_value_info('K', TensorProto.FLOAT, [1, 1, 5, 5]) - conv_node_1 = onnx.helper.make_node('Conv', ['G', 'H'], ['I'], - name='Conv1', - kernel_shape=[3, 3], - pads=[1, 1, 1, 1]) - conv_node_2 = onnx.helper.make_node('Conv', ['I', 'J'], ['K'], - name='Conv2', - kernel_shape=[3, 3], - pads=[1, 1, 1, 1]) - graph = helper.make_graph([conv_node_1, conv_node_2], 'test_graph_2', [G, H, J], [K]) + G = helper.make_tensor_value_info("G", TensorProto.FLOAT, [1, 1, 5, 5]) + H = helper.make_tensor_value_info("H", TensorProto.FLOAT, [1, 1, 3, 3]) + J = helper.make_tensor_value_info("J", TensorProto.FLOAT, [1, 1, 3, 3]) + K = helper.make_tensor_value_info("K", TensorProto.FLOAT, [1, 1, 5, 5]) + conv_node_1 = onnx.helper.make_node( + "Conv", + ["G", "H"], + ["I"], + name="Conv1", + kernel_shape=[3, 3], + pads=[1, 1, 1, 1], + ) + conv_node_2 = onnx.helper.make_node( + "Conv", + ["I", "J"], + ["K"], + name="Conv2", + kernel_shape=[3, 3], + pads=[1, 1, 1, 1], + ) + graph = helper.make_graph([conv_node_1, conv_node_2], "test_graph_2", [G, H, J], [K]) model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)]) - test_model_path = './test_model_2.onnx' + test_model_path = "./test_model_2.onnx" onnx.save(model, test_model_path) - augmented_model_path = './augmented_test_model_2.onnx' - calibrater = MinMaxCalibrater(test_model_path, ['Conv', 'MatMul'], augmented_model_path) + augmented_model_path = "./augmented_test_model_2.onnx" + calibrater = MinMaxCalibrater(test_model_path, ["Conv", "MatMul"], augmented_model_path) augmented_model = calibrater.get_augment_model() augmented_model_node_names = [node.name for node in augmented_model.graph.node] augmented_model_outputs = [output.name for output in augmented_model.graph.output] - added_node_names = ['I_ReduceMin', 'I_ReduceMax', 'K_ReduceMin', 'K_ReduceMax'] - added_outputs = ['I_ReduceMin', 'I_ReduceMax', 'K_ReduceMin', 'K_ReduceMax'] + added_node_names = ["I_ReduceMin", "I_ReduceMax", "K_ReduceMin", "K_ReduceMax"] + added_outputs = ["I_ReduceMin", "I_ReduceMax", "K_ReduceMin", "K_ReduceMax"] # Original 2 nodes + added ReduceMin/Max nodes * 4 self.assertEqual(len(augmented_model_node_names), 12) # Original 1 graph output + added outputs * 4 @@ -129,7 +159,7 @@ def test_augment_graph_config_2(self): self.assertTrue(output in augmented_model_outputs) def test_augment_graph_config_3(self): - '''TEST_CONFIG_3''' + """TEST_CONFIG_3""" # (input) # | @@ -143,35 +173,51 @@ def test_augment_graph_config_3(self): # | # (output) - L = helper.make_tensor_value_info('L', TensorProto.FLOAT, [1, 1, 5, 5]) - N = helper.make_tensor_value_info('N', TensorProto.FLOAT, [1, 1, 3, 3]) - Q = helper.make_tensor_value_info('Q', TensorProto.FLOAT, [1, 1, 5, 5]) - relu_node = onnx.helper.make_node('Relu', ['L'], ['M'], name='Relu') - conv_node = onnx.helper.make_node('Conv', ['M', 'N'], ['O'], - name='Conv', - kernel_shape=[3, 3], - pads=[1, 1, 1, 1]) - clip_node = onnx.helper.make_node('Clip', ['O'], ['P'], name='Clip') - matmul_node = onnx.helper.make_node('MatMul', ['P', 'M'], ['Q'], name='MatMul') - graph = helper.make_graph([relu_node, conv_node, clip_node, matmul_node], 'test_graph_3', [L, N], [Q]) + L = helper.make_tensor_value_info("L", TensorProto.FLOAT, [1, 1, 5, 5]) + N = helper.make_tensor_value_info("N", TensorProto.FLOAT, [1, 1, 3, 3]) + Q = helper.make_tensor_value_info("Q", TensorProto.FLOAT, [1, 1, 5, 5]) + relu_node = onnx.helper.make_node("Relu", ["L"], ["M"], name="Relu") + conv_node = onnx.helper.make_node( + "Conv", + ["M", "N"], + ["O"], + name="Conv", + kernel_shape=[3, 3], + pads=[1, 1, 1, 1], + ) + clip_node = onnx.helper.make_node("Clip", ["O"], ["P"], name="Clip") + matmul_node = onnx.helper.make_node("MatMul", ["P", "M"], ["Q"], name="MatMul") + graph = helper.make_graph([relu_node, conv_node, clip_node, matmul_node], "test_graph_3", [L, N], [Q]) model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)]) - test_model_path = './test_model_3.onnx' + test_model_path = "./test_model_3.onnx" onnx.save(model, test_model_path) # Augmenting graph - augmented_model_path = './augmented_test_model_3.onnx' - calibrater = MinMaxCalibrater(test_model_path, ['Conv', 'MatMul'], augmented_model_path) + augmented_model_path = "./augmented_test_model_3.onnx" + calibrater = MinMaxCalibrater(test_model_path, ["Conv", "MatMul"], augmented_model_path) augmented_model = calibrater.get_augment_model() augmented_model_node_names = [node.name for node in augmented_model.graph.node] augmented_model_outputs = [output.name for output in augmented_model.graph.output] added_node_names = [ - 'M_ReduceMin', 'M_ReduceMax', 'O_ReduceMin', 'O_ReduceMax', 'P_ReduceMin', 'P_ReduceMax', 'Q_ReduceMin', - 'Q_ReduceMax' + "M_ReduceMin", + "M_ReduceMax", + "O_ReduceMin", + "O_ReduceMax", + "P_ReduceMin", + "P_ReduceMax", + "Q_ReduceMin", + "Q_ReduceMax", ] added_outputs = [ - 'M_ReduceMin', 'M_ReduceMax', 'O_ReduceMin', 'O_ReduceMax', 'P_ReduceMin', 'P_ReduceMax', 'Q_ReduceMin', - 'Q_ReduceMax' + "M_ReduceMin", + "M_ReduceMax", + "O_ReduceMin", + "O_ReduceMax", + "P_ReduceMin", + "P_ReduceMax", + "Q_ReduceMin", + "Q_ReduceMax", ] # Original 4 nodes + added ReduceMin/Max nodes self.assertEqual(len(augmented_model_node_names), 14) @@ -196,27 +242,31 @@ def construct_test_compute_range_model(self, test_model_path): # Add # | # (X6) - input = helper.make_tensor_value_info('input', TensorProto.FLOAT, [1, 3, 1, 3]) - X1_output = helper.make_tensor_value_info('X1', TensorProto.FLOAT, [1, 3, 1, 3]) - X2_output = helper.make_tensor_value_info('X2', TensorProto.FLOAT, [1, 3, 1, 3]) - X3_output = helper.make_tensor_value_info('X3', TensorProto.FLOAT, [1, 3, 1, 3]) - X4_output = helper.make_tensor_value_info('X4', TensorProto.FLOAT, [1, 3, 1, 3]) - X5_output = helper.make_tensor_value_info('X5', TensorProto.FLOAT, [1, 3, 1, 3]) - X6_output = helper.make_tensor_value_info('X6', TensorProto.FLOAT, [1, 3, 1, 3]) - W1 = generate_input_initializer([3, 3, 1, 1], np.float32, 'W1') - B1 = generate_input_initializer([3], np.float32, 'B1') - W3 = generate_input_initializer([3, 3, 1, 1], np.float32, 'W3') - B3 = generate_input_initializer([3], np.float32, 'B3') - W5 = generate_input_initializer([3, 3, 1, 1], np.float32, 'W5') - B5 = generate_input_initializer([3], np.float32, 'B5') - relu_node_1 = onnx.helper.make_node('Relu', ['input'], ['X1'], name='Relu1') - conv_node_1 = onnx.helper.make_node('Conv', ['X1', 'W1', 'B1'], ['X2'], name='Conv1') - relu_node_2 = onnx.helper.make_node('Relu', ['X2'], ['X3'], name='Relu2') - conv_node_2 = onnx.helper.make_node('Conv', ['X3', 'W3', 'B3'], ['X4'], name='Conv2') - conv_node_3 = onnx.helper.make_node('Conv', ['X1', 'W5', 'B5'], ['X5'], name='Conv3') - add_node = onnx.helper.make_node('Add', ['X4', 'X5'], ['X6'], name='Add') - graph = helper.make_graph([relu_node_1, conv_node_1, relu_node_2, conv_node_2, conv_node_3, add_node], - 'test_graph_4', [input], [X1_output, X2_output, X3_output, X4_output, X5_output, X6_output]) + input = helper.make_tensor_value_info("input", TensorProto.FLOAT, [1, 3, 1, 3]) + X1_output = helper.make_tensor_value_info("X1", TensorProto.FLOAT, [1, 3, 1, 3]) + X2_output = helper.make_tensor_value_info("X2", TensorProto.FLOAT, [1, 3, 1, 3]) + X3_output = helper.make_tensor_value_info("X3", TensorProto.FLOAT, [1, 3, 1, 3]) + X4_output = helper.make_tensor_value_info("X4", TensorProto.FLOAT, [1, 3, 1, 3]) + X5_output = helper.make_tensor_value_info("X5", TensorProto.FLOAT, [1, 3, 1, 3]) + X6_output = helper.make_tensor_value_info("X6", TensorProto.FLOAT, [1, 3, 1, 3]) + W1 = generate_input_initializer([3, 3, 1, 1], np.float32, "W1") + B1 = generate_input_initializer([3], np.float32, "B1") + W3 = generate_input_initializer([3, 3, 1, 1], np.float32, "W3") + B3 = generate_input_initializer([3], np.float32, "B3") + W5 = generate_input_initializer([3, 3, 1, 1], np.float32, "W5") + B5 = generate_input_initializer([3], np.float32, "B5") + relu_node_1 = onnx.helper.make_node("Relu", ["input"], ["X1"], name="Relu1") + conv_node_1 = onnx.helper.make_node("Conv", ["X1", "W1", "B1"], ["X2"], name="Conv1") + relu_node_2 = onnx.helper.make_node("Relu", ["X2"], ["X3"], name="Relu2") + conv_node_2 = onnx.helper.make_node("Conv", ["X3", "W3", "B3"], ["X4"], name="Conv2") + conv_node_3 = onnx.helper.make_node("Conv", ["X1", "W5", "B5"], ["X5"], name="Conv3") + add_node = onnx.helper.make_node("Add", ["X4", "X5"], ["X6"], name="Add") + graph = helper.make_graph( + [relu_node_1, conv_node_1, relu_node_2, conv_node_2, conv_node_3, add_node], + "test_graph_4", + [input], + [X1_output, X2_output, X3_output, X4_output, X5_output, X6_output], + ) graph.initializer.add().CopyFrom(W1) graph.initializer.add().CopyFrom(B1) graph.initializer.add().CopyFrom(W3) @@ -227,10 +277,10 @@ def construct_test_compute_range_model(self, test_model_path): onnx.save(model, test_model_path) def test_compute_range(self): - test_model_path = './test_model_4.onnx' + test_model_path = "./test_model_4.onnx" self.construct_test_compute_range_model(test_model_path) - augmented_model_path = './augmented_test_model_4.onnx' + augmented_model_path = "./augmented_test_model_4.onnx" calibrater = MinMaxCalibrater(test_model_path, augmented_model_path=augmented_model_path) data_reader = TestDataReader() calibrater.collect_data(data_reader) @@ -238,9 +288,11 @@ def test_compute_range(self): sess_options = onnxruntime.SessionOptions() sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_DISABLE_ALL - infer_session = onnxruntime.InferenceSession(test_model_path, - sess_options=sess_options, - providers=['CPUExecutionProvider']) + infer_session = onnxruntime.InferenceSession( + test_model_path, + sess_options=sess_options, + providers=["CPUExecutionProvider"], + ) data_reader.rewind() rmin = np.array([np.inf, np.inf, np.inf, np.inf, np.inf, np.inf], dtype=np.float32) rmax = -1.0 * rmin @@ -249,8 +301,8 @@ def test_compute_range(self): if not input: break output = np.asarray(infer_session.run(None, input)).reshape(6, -1) - rmin=np.minimum(rmin, np.amin(output, axis=1)) - rmax=np.maximum(rmax, np.amax(output, axis=1)) + rmin = np.minimum(rmin, np.amin(output, axis=1)) + rmax = np.maximum(rmax, np.amax(output, axis=1)) min_max_pairs = list(zip(rmin, rmax)) output_names = [infer_session.get_outputs()[i].name for i in range(len(infer_session.get_outputs()))] @@ -259,46 +311,72 @@ def test_compute_range(self): self.assertEqual(output_min_max_dict[output_name], tensors_range[output_name]) def test_augment_graph_with_zero_value_dimension(self): - '''TEST_CONFIG_5''' + """TEST_CONFIG_5""" # Conv # | # Conv # | # Resize - G = helper.make_tensor_value_info('G', TensorProto.FLOAT, [1, 1, 5, 5]) - H = helper.make_tensor_value_info('H', TensorProto.FLOAT, [1, 1, 3, 3]) - J = helper.make_tensor_value_info('J', TensorProto.FLOAT, [1, 1, 3, 3]) - M = helper.make_tensor_value_info('M', TensorProto.FLOAT, [0]) - N = helper.make_tensor_value_info('N', TensorProto.FLOAT, [0]) - O = helper.make_tensor_value_info('O', TensorProto.FLOAT, [1,1,5,5]) + G = helper.make_tensor_value_info("G", TensorProto.FLOAT, [1, 1, 5, 5]) + H = helper.make_tensor_value_info("H", TensorProto.FLOAT, [1, 1, 3, 3]) + J = helper.make_tensor_value_info("J", TensorProto.FLOAT, [1, 1, 3, 3]) + M = helper.make_tensor_value_info("M", TensorProto.FLOAT, [0]) + N = helper.make_tensor_value_info("N", TensorProto.FLOAT, [0]) + O = helper.make_tensor_value_info("O", TensorProto.FLOAT, [1, 1, 5, 5]) # O = helper.make_tensor_value_info('O', TensorProto.FLOAT, None) - conv_node_1 = onnx.helper.make_node('Conv', ['G', 'H'], ['I'], - name='Conv1', - kernel_shape=[3, 3], - pads=[1, 1, 1, 1]) - conv_node_2 = onnx.helper.make_node('Conv', ['I', 'J'], ['K'], - name='Conv2', - kernel_shape=[3, 3], - pads=[1, 1, 1, 1]) - resize_node_1 = onnx.helper.make_node('Resize', ['K', 'M', 'N'], ['O'], - name='Reize1') - graph = helper.make_graph([conv_node_1, conv_node_2, resize_node_1], 'test_graph_5', [G, H, J, M, N], [O]) + conv_node_1 = onnx.helper.make_node( + "Conv", + ["G", "H"], + ["I"], + name="Conv1", + kernel_shape=[3, 3], + pads=[1, 1, 1, 1], + ) + conv_node_2 = onnx.helper.make_node( + "Conv", + ["I", "J"], + ["K"], + name="Conv2", + kernel_shape=[3, 3], + pads=[1, 1, 1, 1], + ) + resize_node_1 = onnx.helper.make_node("Resize", ["K", "M", "N"], ["O"], name="Reize1") + graph = helper.make_graph( + [conv_node_1, conv_node_2, resize_node_1], + "test_graph_5", + [G, H, J, M, N], + [O], + ) model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)]) - test_model_path = './test_model_5.onnx' + test_model_path = "./test_model_5.onnx" onnx.save(model, test_model_path) - augmented_model_path = './augmented_test_model_5.onnx' + augmented_model_path = "./augmented_test_model_5.onnx" calibrater = MinMaxCalibrater(test_model_path, [], augmented_model_path) augmented_model = calibrater.get_augment_model() augmented_model_node_names = [node.name for node in augmented_model.graph.node] augmented_model_outputs = [output.name for output in augmented_model.graph.output] - added_node_names = ['I_ReduceMin', 'I_ReduceMax', 'K_ReduceMin', 'K_ReduceMax', 'O_ReduceMin', 'O_ReduceMax'] - added_outputs = ['I_ReduceMin', 'I_ReduceMax', 'K_ReduceMin', 'K_ReduceMax', 'O_ReduceMin', 'O_ReduceMax'] - # Original 3 nodes + added ReduceMin/Max nodes * 8 + added_node_names = [ + "I_ReduceMin", + "I_ReduceMax", + "K_ReduceMin", + "K_ReduceMax", + "O_ReduceMin", + "O_ReduceMax", + ] + added_outputs = [ + "I_ReduceMin", + "I_ReduceMax", + "K_ReduceMin", + "K_ReduceMax", + "O_ReduceMin", + "O_ReduceMax", + ] + # Original 3 nodes + added ReduceMin/Max nodes * 8 self.assertEqual(len(augmented_model_node_names), 19) - # Original 1 graph output + added outputs * 8 + # Original 1 graph output + added outputs * 8 self.assertEqual(len(augmented_model_outputs), 17) for name in added_node_names: self.assertTrue(name in augmented_model_node_names) @@ -306,6 +384,5 @@ def test_augment_graph_with_zero_value_dimension(self): self.assertTrue(output in augmented_model_outputs) - -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/onnxruntime/test/python/quantization/test_conv_dynamic.py b/onnxruntime/test/python/quantization/test_conv_dynamic.py index 5958d93ea0004..3bb9275916d83 100644 --- a/onnxruntime/test/python/quantization/test_conv_dynamic.py +++ b/onnxruntime/test/python/quantization/test_conv_dynamic.py @@ -7,18 +7,26 @@ # -------------------------------------------------------------------------- import unittest + +import numpy as np import onnx +from onnx import TensorProto, helper, numpy_helper +from op_test_utils import ( + TestDataFeeds, + check_model_correctness, + check_op_type_count, + check_op_type_order, + check_qtype_by_node_type, +) + import onnxruntime -import numpy as np -from onnx import helper, TensorProto, numpy_helper -from onnxruntime.quantization import quantize_dynamic, QuantType -from op_test_utils import TestDataFeeds, check_model_correctness, check_op_type_count, check_op_type_order, check_qtype_by_node_type +from onnxruntime.quantization import QuantType, quantize_dynamic def generate_input_initializer(tensor_shape, tensor_dtype, input_name): - ''' + """ Helper function to generate initializers for test inputs - ''' + """ tensor = np.random.normal(0, 0.3, tensor_shape).astype(tensor_dtype) init = numpy_helper.from_array(tensor, input_name) return init @@ -38,38 +46,52 @@ def construct_model(self, model_path): # | # (output) initializers = [] - input = helper.make_tensor_value_info('input', TensorProto.FLOAT, [4, 2, 8, 8]) - output = helper.make_tensor_value_info('output', TensorProto.FLOAT, [4, 2, 8, 8]) + input = helper.make_tensor_value_info("input", TensorProto.FLOAT, [4, 2, 8, 8]) + output = helper.make_tensor_value_info("output", TensorProto.FLOAT, [4, 2, 8, 8]) - initializers.append(generate_input_initializer([2, 2, 1, 1], np.float32, 'W1')) - initializers.append(generate_input_initializer([2, 2, 1, 1], np.float32, 'W2')) - initializers.append(generate_input_initializer([2], np.float32, 'B')) - conv_node_1 = onnx.helper.make_node('Conv', ['input', 'W1', 'B'], ['Conv1_O'], name='Conv1') - conv_node_2 = onnx.helper.make_node('Conv', ['input', 'W2', 'B'], ['Conv2_O'], name='Conv2') - relu_node = onnx.helper.make_node('Relu', ['Conv1_O'], ['Relu_O'], name='Relu') - add_node = onnx.helper.make_node('Add', ['Relu_O', 'Conv2_O'], ['output'], name='Add') - graph = helper.make_graph([conv_node_1, relu_node, conv_node_2, add_node], - 'onnx_model_test', [input], [output], initializer=initializers) + initializers.append(generate_input_initializer([2, 2, 1, 1], np.float32, "W1")) + initializers.append(generate_input_initializer([2, 2, 1, 1], np.float32, "W2")) + initializers.append(generate_input_initializer([2], np.float32, "B")) + conv_node_1 = onnx.helper.make_node("Conv", ["input", "W1", "B"], ["Conv1_O"], name="Conv1") + conv_node_2 = onnx.helper.make_node("Conv", ["input", "W2", "B"], ["Conv2_O"], name="Conv2") + relu_node = onnx.helper.make_node("Relu", ["Conv1_O"], ["Relu_O"], name="Relu") + add_node = onnx.helper.make_node("Add", ["Relu_O", "Conv2_O"], ["output"], name="Add") + graph = helper.make_graph( + [conv_node_1, relu_node, conv_node_2, add_node], + "onnx_model_test", + [input], + [output], + initializer=initializers, + ) model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)]) onnx.save(model, model_path) def dynamic_quant_conv_test(self, weight_type, extra_options={}): np.random.seed(1) - model_fp32_path = 'conv_bias.fp32.onnx' + model_fp32_path = "conv_bias.fp32.onnx" self.construct_model(model_fp32_path) activation_proto_qtype = TensorProto.UINT8 - activation_type_str = 'u8' - weight_type_str = 'u8' if (weight_type == QuantType.QUInt8) else 's8' - model_int8_path = 'conv_bias.quant.{}{}.onnx'.format(activation_type_str, weight_type_str) + activation_type_str = "u8" + weight_type_str = "u8" if (weight_type == QuantType.QUInt8) else "s8" + model_int8_path = "conv_bias.quant.{}{}.onnx".format(activation_type_str, weight_type_str) - quantize_dynamic(model_fp32_path, model_int8_path, - weight_type=weight_type, extra_options=extra_options) - quant_nodes = {'ConvInteger': 2} + quantize_dynamic( + model_fp32_path, + model_int8_path, + weight_type=weight_type, + extra_options=extra_options, + ) + quant_nodes = {"ConvInteger": 2} check_op_type_count(self, model_int8_path, **quant_nodes) - qnode_io_qtypes = {'ConvInteger': [['i', 2, activation_proto_qtype]]} + qnode_io_qtypes = {"ConvInteger": [["i", 2, activation_proto_qtype]]} check_qtype_by_node_type(self, model_int8_path, qnode_io_qtypes) - check_model_correctness(self, model_fp32_path, model_int8_path, {'input': np.random.rand(4, 2, 8, 8).astype(np.float32)}) + check_model_correctness( + self, + model_fp32_path, + model_int8_path, + {"input": np.random.rand(4, 2, 8, 8).astype(np.float32)}, + ) def test_quant_conv(self): self.dynamic_quant_conv_test(QuantType.QUInt8, extra_options={}) @@ -79,5 +101,5 @@ def test_quant_conv(self): # self.dynamic_quant_conv_test(QuantType.QInt8, extra_options={'ActivationSymmetric': True}) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/onnxruntime/test/python/quantization/test_onnx_model.py b/onnxruntime/test/python/quantization/test_onnx_model.py index b1d1736639979..fc29810e9b97d 100644 --- a/onnxruntime/test/python/quantization/test_onnx_model.py +++ b/onnxruntime/test/python/quantization/test_onnx_model.py @@ -7,21 +7,24 @@ # -------------------------------------------------------------------------- import unittest + +import numpy as np import onnx +from onnx import TensorProto, helper, numpy_helper +from op_test_utils import TestDataFeeds, check_model_correctness, check_op_type_count, check_op_type_order + import onnxruntime -import numpy as np -from onnx import helper, TensorProto, numpy_helper from onnxruntime.quantization.onnx_model import ONNXModel -from op_test_utils import TestDataFeeds, check_model_correctness, check_op_type_count, check_op_type_order def generate_input_initializer(tensor_shape, tensor_dtype, input_name): - ''' - Helper function to generate initializers for test inputs - ''' - tensor = np.random.normal(0, 0.3, tensor_shape).astype(tensor_dtype) - init = numpy_helper.from_array(tensor, input_name) - return init + """ + Helper function to generate initializers for test inputs + """ + tensor = np.random.normal(0, 0.3, tensor_shape).astype(tensor_dtype) + init = numpy_helper.from_array(tensor, input_name) + return init + class TestONNXModel(unittest.TestCase): def construct_model(self, model_path): @@ -38,30 +41,36 @@ def construct_model(self, model_path): # | # (output) initializers = [] - input = helper.make_tensor_value_info('input', TensorProto.FLOAT, [4, 8, 12]) - output = helper.make_tensor_value_info('output', TensorProto.FLOAT, [4, 2, 8, 8]) + input = helper.make_tensor_value_info("input", TensorProto.FLOAT, [4, 8, 12]) + output = helper.make_tensor_value_info("output", TensorProto.FLOAT, [4, 2, 8, 8]) # make GRU - initializers.append(generate_input_initializer([2, 24, 12], np.float32, 'W_GRU')) - initializers.append(generate_input_initializer([2, 24, 8], np.float32, 'R_GRU')) - initializers.append(generate_input_initializer([2, 8, 8], np.float32, 'H_GRU')) + initializers.append(generate_input_initializer([2, 24, 12], np.float32, "W_GRU")) + initializers.append(generate_input_initializer([2, 24, 8], np.float32, "R_GRU")) + initializers.append(generate_input_initializer([2, 8, 8], np.float32, "H_GRU")) gru_node = onnx.helper.make_node( - 'GRU', - ['input', 'W_GRU', 'R_GRU', '', '', 'H_GRU'], - ['GRU_O'], - hidden_size = 8, - direction = 'bidirectional') - - initializers.append(generate_input_initializer([2, 2, 1, 1], np.float32, 'W1')) - initializers.append(generate_input_initializer([2, 2, 1, 1], np.float32, 'W2')) - initializers.append(generate_input_initializer([2], np.float32, 'B1')) - initializers.append(generate_input_initializer([2], np.float32, 'B2')) - conv_node_1 = onnx.helper.make_node('Conv', ['GRU_O', 'W1', 'B1'], ['Conv1_O'], name='Conv1') - conv_node_2 = onnx.helper.make_node('Conv', ['GRU_O', 'W2', 'B2'], ['Conv2_O'], name='Conv2') - relu_node = onnx.helper.make_node('Relu', ['Conv1_O'], ['Relu_O'], name='Relu') - add_node = onnx.helper.make_node('Add', ['Relu_O', 'Conv2_O'], ['output'], name='Add') - graph = helper.make_graph([conv_node_1, relu_node, conv_node_2, gru_node, add_node], - 'onnx_model_test', [input], [output], initializer=initializers) + "GRU", + ["input", "W_GRU", "R_GRU", "", "", "H_GRU"], + ["GRU_O"], + hidden_size=8, + direction="bidirectional", + ) + + initializers.append(generate_input_initializer([2, 2, 1, 1], np.float32, "W1")) + initializers.append(generate_input_initializer([2, 2, 1, 1], np.float32, "W2")) + initializers.append(generate_input_initializer([2], np.float32, "B1")) + initializers.append(generate_input_initializer([2], np.float32, "B2")) + conv_node_1 = onnx.helper.make_node("Conv", ["GRU_O", "W1", "B1"], ["Conv1_O"], name="Conv1") + conv_node_2 = onnx.helper.make_node("Conv", ["GRU_O", "W2", "B2"], ["Conv2_O"], name="Conv2") + relu_node = onnx.helper.make_node("Relu", ["Conv1_O"], ["Relu_O"], name="Relu") + add_node = onnx.helper.make_node("Add", ["Relu_O", "Conv2_O"], ["output"], name="Add") + graph = helper.make_graph( + [conv_node_1, relu_node, conv_node_2, gru_node, add_node], + "onnx_model_test", + [input], + [output], + initializer=initializers, + ) model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)]) onnx.save(model, model_path) @@ -76,32 +85,38 @@ def construct_model_Constant(self, model_path): # (output) initializers = [] - input = helper.make_tensor_value_info('input', TensorProto.FLOAT, [4, 8, 12]) - output = helper.make_tensor_value_info('output', TensorProto.FLOAT, [4, 8, 12]) + input = helper.make_tensor_value_info("input", TensorProto.FLOAT, [4, 8, 12]) + output = helper.make_tensor_value_info("output", TensorProto.FLOAT, [4, 8, 12]) # make nodes - constant_node = onnx.helper.make_node('Constant', [], ['const_output'], value_float=42.0) - add_node = onnx.helper.make_node('Add', ['input', 'const_output'], ['output'], name='Add') - graph = helper.make_graph([add_node, constant_node], - 'onnx_model_test', [input], [output], initializer=initializers) + constant_node = onnx.helper.make_node("Constant", [], ["const_output"], value_float=42.0) + add_node = onnx.helper.make_node("Add", ["input", "const_output"], ["output"], name="Add") + graph = helper.make_graph( + [add_node, constant_node], + "onnx_model_test", + [input], + [output], + initializer=initializers, + ) model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)]) onnx.save(model, model_path) def test_topo_sort(self): - test_model_path = 'onnx_model_topo_sort.onnx' + test_model_path = "onnx_model_topo_sort.onnx" self.construct_model(test_model_path) onnx_model = ONNXModel(onnx.load(test_model_path)) - check_op_type_order(self, onnx_model.model, ['Conv', 'Relu', 'Conv', 'GRU', 'Add']) + check_op_type_order(self, onnx_model.model, ["Conv", "Relu", "Conv", "GRU", "Add"]) onnx_model.topological_sort() - check_op_type_order(self, onnx_model.model, ['GRU', 'Conv', 'Conv', 'Relu', 'Add']) + check_op_type_order(self, onnx_model.model, ["GRU", "Conv", "Conv", "Relu", "Add"]) def test_topo_sort_constant(self): - test_model_path = 'onnx_model_topo_sort_constant.onnx' + test_model_path = "onnx_model_topo_sort_constant.onnx" self.construct_model_Constant(test_model_path) onnx_model = ONNXModel(onnx.load(test_model_path)) - check_op_type_order(self, onnx_model.model, ['Add', 'Constant']) + check_op_type_order(self, onnx_model.model, ["Add", "Constant"]) onnx_model.topological_sort() - check_op_type_order(self, onnx_model.model, ['Constant', 'Add']) + check_op_type_order(self, onnx_model.model, ["Constant", "Add"]) + -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/onnxruntime/test/python/quantization/test_op_argmax.py b/onnxruntime/test/python/quantization/test_op_argmax.py index d5f6e3de410fd..e73bb9093d99c 100644 --- a/onnxruntime/test/python/quantization/test_op_argmax.py +++ b/onnxruntime/test/python/quantization/test_op_argmax.py @@ -7,11 +7,19 @@ # -------------------------------------------------------------------------- import unittest -import onnx + import numpy as np -from onnx import helper, TensorProto -from onnxruntime.quantization import quantize_static, QuantFormat, QuantType -from op_test_utils import TestDataFeeds, check_model_correctness, check_op_type_count, check_op_nodes, check_qtype_by_node_type +import onnx +from onnx import TensorProto, helper +from op_test_utils import ( + TestDataFeeds, + check_model_correctness, + check_op_nodes, + check_op_type_count, + check_qtype_by_node_type, +) + +from onnxruntime.quantization import QuantFormat, QuantType, quantize_static class TestOpArgMax(unittest.TestCase): @@ -33,87 +41,145 @@ def construct_model_argmax(self, output_model_path, input_shape, output_shape): # ArgMax # | # (output) - input_name = 'input' - output_name = 'output' + input_name = "input" + output_name = "output" initializers = [] # make Conv node - conv_weight_name = 'conv_weight' + conv_weight_name = "conv_weight" conv_weight_arr = np.random.randint(-1, 2, [32, 256, 1, 1]).astype(np.float32) conv_weight_initializer = onnx.numpy_helper.from_array(conv_weight_arr, name=conv_weight_name) - conv_output_name = 'conv_output' + conv_output_name = "conv_output" conv_inputs = [input_name, conv_weight_name] conv_outputs = [conv_output_name] - conv_name = 'conv_node' - conv_node = onnx.helper.make_node('Conv', conv_inputs, conv_outputs, dilations=[1, 1], kernel_shape=[1, 1], - pads=[0, 0, 0, 0], strides=[1, 1], name=conv_name) + conv_name = "conv_node" + conv_node = onnx.helper.make_node( + "Conv", + conv_inputs, + conv_outputs, + dilations=[1, 1], + kernel_shape=[1, 1], + pads=[0, 0, 0, 0], + strides=[1, 1], + name=conv_name, + ) # make ArgMax node argmax_inputs = [conv_output_name] argmax_outputs = [output_name] - argmax_name = 'argmax_node' - argmax_node = onnx.helper.make_node('ArgMax', argmax_inputs, argmax_outputs, axis=3, keepdims=0, name=argmax_name) + argmax_name = "argmax_node" + argmax_node = onnx.helper.make_node( + "ArgMax", + argmax_inputs, + argmax_outputs, + axis=3, + keepdims=0, + name=argmax_name, + ) initializers = [conv_weight_initializer] # make graph input_tensor = helper.make_tensor_value_info(input_name, TensorProto.FLOAT, input_shape) output_tensor = helper.make_tensor_value_info(output_name, TensorProto.INT64, output_shape) - graph_name = 'ArgMax_Quant_Test' - graph = helper.make_graph([conv_node, argmax_node], graph_name, - [input_tensor], [output_tensor], initializer=initializers) + graph_name = "ArgMax_Quant_Test" + graph = helper.make_graph( + [conv_node, argmax_node], + graph_name, + [input_tensor], + [output_tensor], + initializer=initializers, + ) model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)]) - model.ir_version = 7 # use stable onnx ir version + model.ir_version = 7 # use stable onnx ir version onnx.save(model, output_model_path) - def quantize_argmax_test(self, activation_type, weight_type, extra_options = {}): + def quantize_argmax_test(self, activation_type, weight_type, extra_options={}): np.random.seed(1) - model_fp32_path = 'argmax_fp32.onnx' + model_fp32_path = "argmax_fp32.onnx" - self.construct_model_argmax(model_fp32_path, - [1, 256, 128, 128], - [1, 32, 128]) + self.construct_model_argmax(model_fp32_path, [1, 256, 128, 128], [1, 32, 128]) activation_proto_qtype = TensorProto.UINT8 if activation_type == QuantType.QUInt8 else TensorProto.INT8 - activation_type_str = 'u8' if (activation_type == QuantType.QUInt8) else 's8' - weight_type_str = 'u8' if (weight_type == QuantType.QUInt8) else 's8' - model_uint8_path = 'argmax_{}{}.onnx'.format(activation_type_str, weight_type_str) - model_uint8_qdq_path = 'argmax_{}{}_qdq.onnx'.format(activation_type_str, weight_type_str) - model_uint8_qdq_trt_path = 'argmax_{}{}_qdq_trt.onnx'.format(activation_type_str, weight_type_str) + activation_type_str = "u8" if (activation_type == QuantType.QUInt8) else "s8" + weight_type_str = "u8" if (weight_type == QuantType.QUInt8) else "s8" + model_uint8_path = "argmax_{}{}.onnx".format(activation_type_str, weight_type_str) + model_uint8_qdq_path = "argmax_{}{}_qdq.onnx".format(activation_type_str, weight_type_str) + model_uint8_qdq_trt_path = "argmax_{}{}_qdq_trt.onnx".format(activation_type_str, weight_type_str) # Verify QOperator mode - data_reader = self.input_feeds(1, {'input': [1, 256, 128, 128]}) - quantize_static(model_fp32_path, model_uint8_path, data_reader, quant_format=QuantFormat.QOperator, - activation_type = activation_type, weight_type = weight_type, extra_options = extra_options) + data_reader = self.input_feeds(1, {"input": [1, 256, 128, 128]}) + quantize_static( + model_fp32_path, + model_uint8_path, + data_reader, + quant_format=QuantFormat.QOperator, + activation_type=activation_type, + weight_type=weight_type, + extra_options=extra_options, + ) # make sure argmax become xint8 operator, its input name could tell that - check_op_nodes(self, model_uint8_path, lambda node: not(node.name == "argmax_node" and node.input[0] == 'conv_output')) - qnode_counts = {'QuantizeLinear': 1, 'QLinearConv': 1, 'ArgMax': 1} + check_op_nodes( + self, + model_uint8_path, + lambda node: not (node.name == "argmax_node" and node.input[0] == "conv_output"), + ) + qnode_counts = {"QuantizeLinear": 1, "QLinearConv": 1, "ArgMax": 1} check_op_type_count(self, model_uint8_path, **qnode_counts) - qnode_io_qtypes = {'QuantizeLinear' : [['i', 2, activation_proto_qtype], ['o', 0, activation_proto_qtype]]} + qnode_io_qtypes = { + "QuantizeLinear": [ + ["i", 2, activation_proto_qtype], + ["o", 0, activation_proto_qtype], + ] + } check_qtype_by_node_type(self, model_uint8_path, qnode_io_qtypes) data_reader.rewind() check_model_correctness(self, model_fp32_path, model_uint8_path, data_reader.get_next()) # Verify QDQ mode data_reader.rewind() - quantize_static(model_fp32_path, model_uint8_qdq_path, data_reader, quant_format=QuantFormat.QDQ, - activation_type = activation_type, weight_type = weight_type, extra_options = extra_options) - qdqnode_counts = {'QuantizeLinear': 2, 'DequantizeLinear': 3, 'ArgMax': 1} + quantize_static( + model_fp32_path, + model_uint8_qdq_path, + data_reader, + quant_format=QuantFormat.QDQ, + activation_type=activation_type, + weight_type=weight_type, + extra_options=extra_options, + ) + qdqnode_counts = {"QuantizeLinear": 2, "DequantizeLinear": 3, "ArgMax": 1} check_op_type_count(self, model_uint8_qdq_path, **qdqnode_counts) - qnode_io_qtypes = {'QuantizeLinear' : [['i', 2, activation_proto_qtype], ['o', 0, activation_proto_qtype]]} + qnode_io_qtypes = { + "QuantizeLinear": [ + ["i", 2, activation_proto_qtype], + ["o", 0, activation_proto_qtype], + ] + } check_qtype_by_node_type(self, model_uint8_qdq_path, qnode_io_qtypes) data_reader.rewind() check_model_correctness(self, model_fp32_path, model_uint8_qdq_path, data_reader.get_next()) # Verify QDQ mode for TensorRT data_reader.rewind() - quantize_static(model_fp32_path, model_uint8_qdq_trt_path, data_reader, quant_format=QuantFormat.QDQ, - activation_type=activation_type, weight_type=weight_type, extra_options=extra_options, - op_types_to_quantize=['ArgMax']) - qdqnode_counts = {'QuantizeLinear': 1, 'DequantizeLinear': 1, 'ArgMax': 1} + quantize_static( + model_fp32_path, + model_uint8_qdq_trt_path, + data_reader, + quant_format=QuantFormat.QDQ, + activation_type=activation_type, + weight_type=weight_type, + extra_options=extra_options, + op_types_to_quantize=["ArgMax"], + ) + qdqnode_counts = {"QuantizeLinear": 1, "DequantizeLinear": 1, "ArgMax": 1} check_op_type_count(self, model_uint8_qdq_trt_path, **qdqnode_counts) - qnode_io_qtypes = {'QuantizeLinear' : [['i', 2, activation_proto_qtype], ['o', 0, activation_proto_qtype]]} + qnode_io_qtypes = { + "QuantizeLinear": [ + ["i", 2, activation_proto_qtype], + ["o", 0, activation_proto_qtype], + ] + } check_qtype_by_node_type(self, model_uint8_qdq_trt_path, qnode_io_qtypes) data_reader.rewind() check_model_correctness(self, model_fp32_path, model_uint8_qdq_trt_path, data_reader.get_next()) @@ -122,7 +188,12 @@ def test_quantize_argmax(self): self.quantize_argmax_test(QuantType.QUInt8, QuantType.QUInt8) def test_quantize_argmax_s8s8(self): - self.quantize_argmax_test(QuantType.QInt8, QuantType.QInt8, extra_options = {'ActivationSymmetric' : True}) + self.quantize_argmax_test( + QuantType.QInt8, + QuantType.QInt8, + extra_options={"ActivationSymmetric": True}, + ) + -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/onnxruntime/test/python/quantization/test_op_concat.py b/onnxruntime/test/python/quantization/test_op_concat.py index 1025e4740fa22..515b8ad7db6bb 100644 --- a/onnxruntime/test/python/quantization/test_op_concat.py +++ b/onnxruntime/test/python/quantization/test_op_concat.py @@ -5,10 +5,17 @@ # -------------------------------------------------------------------------- import unittest + import numpy as np -from onnx import helper, TensorProto, numpy_helper, save -from onnxruntime.quantization import quantize_static, QuantFormat, QuantType -from op_test_utils import InputFeedsNegOneZeroOne, check_model_correctness, check_op_type_count, check_qtype_by_node_type +from onnx import TensorProto, helper, numpy_helper, save +from op_test_utils import ( + InputFeedsNegOneZeroOne, + check_model_correctness, + check_op_type_count, + check_qtype_by_node_type, +) + +from onnxruntime.quantization import QuantFormat, QuantType, quantize_static class TestONNXModel(unittest.TestCase): @@ -28,68 +35,129 @@ def construct_model(self, model_path): # | # (output) initializers = [] - input = helper.make_tensor_value_info('input', TensorProto.FLOAT, [1, 3, 15, 15]) - output = helper.make_tensor_value_info('output', TensorProto.FLOAT, [1, 13, 13, 13]) + input = helper.make_tensor_value_info("input", TensorProto.FLOAT, [1, 3, 15, 15]) + output = helper.make_tensor_value_info("output", TensorProto.FLOAT, [1, 13, 13, 13]) # Conv1 output [1, 2, 13, 13] conv1_weight_initializer = numpy_helper.from_array( - np.random.randint(-1, 2, [2, 3, 3, 3]).astype(np.float32), name='conv1_weight') - conv1_node = helper.make_node('Conv', ['input', 'conv1_weight'], ['conv1_output'], name='conv1_node') + np.random.randint(-1, 2, [2, 3, 3, 3]).astype(np.float32), + name="conv1_weight", + ) + conv1_node = helper.make_node("Conv", ["input", "conv1_weight"], ["conv1_output"], name="conv1_node") # Conv2 output [1, 5, 13, 13] conv2_weight_initializer = numpy_helper.from_array( - np.random.randint(-1, 2, [5, 3, 3, 3]).astype(np.float32), name='conv2_weight') - conv2_node = helper.make_node('Conv', ['input', 'conv2_weight'], ['conv2_output'], name='conv2_node') + np.random.randint(-1, 2, [5, 3, 3, 3]).astype(np.float32), + name="conv2_weight", + ) + conv2_node = helper.make_node("Conv", ["input", "conv2_weight"], ["conv2_output"], name="conv2_node") # Conv3 output [1, 6, 13, 13] conv3_weight_initializer = numpy_helper.from_array( - np.random.randint(-1, 2, [6, 3, 3, 3]).astype(np.float32), name='conv3_weight') - conv3_node = helper.make_node('Conv', ['input', 'conv3_weight'], ['conv3_output'], name='conv3_node') - - concat_node = helper.make_node('Concat', ['conv1_output', 'conv2_output', 'conv3_output'], [ - 'concat_output'], name='concat_node', axis=1) - - identity_node = helper.make_node('Identity', ['concat_output'], ['output'], name='identity_node') - - initializers = [conv1_weight_initializer, conv2_weight_initializer, conv3_weight_initializer] - graph = helper.make_graph([conv1_node, conv2_node, conv3_node, concat_node, identity_node], - 'qlinear_concat_op_test', [input], [output], initializer=initializers) + np.random.randint(-1, 2, [6, 3, 3, 3]).astype(np.float32), + name="conv3_weight", + ) + conv3_node = helper.make_node("Conv", ["input", "conv3_weight"], ["conv3_output"], name="conv3_node") + + concat_node = helper.make_node( + "Concat", + ["conv1_output", "conv2_output", "conv3_output"], + ["concat_output"], + name="concat_node", + axis=1, + ) + + identity_node = helper.make_node("Identity", ["concat_output"], ["output"], name="identity_node") + + initializers = [ + conv1_weight_initializer, + conv2_weight_initializer, + conv3_weight_initializer, + ] + graph = helper.make_graph( + [conv1_node, conv2_node, conv3_node, concat_node, identity_node], + "qlinear_concat_op_test", + [input], + [output], + initializer=initializers, + ) model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)]) save(model, model_path) def quantize_concat_test(self, activation_type, weight_type, extra_options={}): np.random.seed(1) - model_fp32_path = 'concat_fp32.onnx' + model_fp32_path = "concat_fp32.onnx" self.construct_model(model_fp32_path) - data_reader = InputFeedsNegOneZeroOne(1, {'input': [1, 3, 15, 15]}) + data_reader = InputFeedsNegOneZeroOne(1, {"input": [1, 3, 15, 15]}) activation_proto_qtype = TensorProto.UINT8 if activation_type == QuantType.QUInt8 else TensorProto.INT8 - activation_type_str = 'u8' if (activation_type == QuantType.QUInt8) else 's8' - weight_type_str = 'u8' if (weight_type == QuantType.QUInt8) else 's8' - model_q8_path = 'concat_{}{}.onnx'.format(activation_type_str, weight_type_str) - model_q8_qdq_path = 'concat_{}{}_qdq.onnx'.format(activation_type_str, weight_type_str) + activation_type_str = "u8" if (activation_type == QuantType.QUInt8) else "s8" + weight_type_str = "u8" if (weight_type == QuantType.QUInt8) else "s8" + model_q8_path = "concat_{}{}.onnx".format(activation_type_str, weight_type_str) + model_q8_qdq_path = "concat_{}{}_qdq.onnx".format(activation_type_str, weight_type_str) # Verify QOperator mode data_reader.rewind() - quantize_static(model_fp32_path, model_q8_path, data_reader, quant_format=QuantFormat.QOperator, - activation_type=activation_type, weight_type=weight_type, extra_options=extra_options) - - qnode_counts = {'QLinearConv': 3, 'QuantizeLinear': 1, 'DequantizeLinear': 1, 'QLinearConcat': 1} + quantize_static( + model_fp32_path, + model_q8_path, + data_reader, + quant_format=QuantFormat.QOperator, + activation_type=activation_type, + weight_type=weight_type, + extra_options=extra_options, + ) + + qnode_counts = { + "QLinearConv": 3, + "QuantizeLinear": 1, + "DequantizeLinear": 1, + "QLinearConcat": 1, + } check_op_type_count(self, model_q8_path, **qnode_counts) - qnode_io_qtypes = {'QuantizeLinear': [['i', 2, activation_proto_qtype], ['o', 0, activation_proto_qtype]]} - qnode_io_qtypes.update({'QLinearConcat': [['i', 1, activation_proto_qtype], [ - 'i', 4, activation_proto_qtype], ['i', 7, activation_proto_qtype]]}) + qnode_io_qtypes = { + "QuantizeLinear": [ + ["i", 2, activation_proto_qtype], + ["o", 0, activation_proto_qtype], + ] + } + qnode_io_qtypes.update( + { + "QLinearConcat": [ + ["i", 1, activation_proto_qtype], + ["i", 4, activation_proto_qtype], + ["i", 7, activation_proto_qtype], + ] + } + ) check_qtype_by_node_type(self, model_q8_path, qnode_io_qtypes) data_reader.rewind() check_model_correctness(self, model_fp32_path, model_q8_path, data_reader.get_next()) # Verify QDQ mode data_reader.rewind() - quantize_static(model_fp32_path, model_q8_qdq_path, data_reader, quant_format=QuantFormat.QDQ, - activation_type=activation_type, weight_type=weight_type, extra_options=extra_options) - qdqnode_counts = {'Conv': 3, 'QuantizeLinear': 5, 'DequantizeLinear': 8, 'Concat': 1} + quantize_static( + model_fp32_path, + model_q8_qdq_path, + data_reader, + quant_format=QuantFormat.QDQ, + activation_type=activation_type, + weight_type=weight_type, + extra_options=extra_options, + ) + qdqnode_counts = { + "Conv": 3, + "QuantizeLinear": 5, + "DequantizeLinear": 8, + "Concat": 1, + } check_op_type_count(self, model_q8_qdq_path, **qdqnode_counts) - qnode_io_qtypes = {'QuantizeLinear': [['i', 2, activation_proto_qtype], ['o', 0, activation_proto_qtype]]} + qnode_io_qtypes = { + "QuantizeLinear": [ + ["i", 2, activation_proto_qtype], + ["o", 0, activation_proto_qtype], + ] + } check_qtype_by_node_type(self, model_q8_qdq_path, qnode_io_qtypes) data_reader.rewind() check_model_correctness(self, model_fp32_path, model_q8_qdq_path, data_reader.get_next()) @@ -98,8 +166,12 @@ def test_quantize_concat(self): self.quantize_concat_test(QuantType.QUInt8, QuantType.QUInt8, extra_options={}) def test_quantize_concat_s8s8(self): - self.quantize_concat_test(QuantType.QInt8, QuantType.QInt8, extra_options={'ActivationSymmetric': True}) + self.quantize_concat_test( + QuantType.QInt8, + QuantType.QInt8, + extra_options={"ActivationSymmetric": True}, + ) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/onnxruntime/test/python/quantization/test_op_embed_layernorm.py b/onnxruntime/test/python/quantization/test_op_embed_layernorm.py index 07861b80fd183..297d1c1af6c06 100644 --- a/onnxruntime/test/python/quantization/test_op_embed_layernorm.py +++ b/onnxruntime/test/python/quantization/test_op_embed_layernorm.py @@ -7,12 +7,14 @@ # -------------------------------------------------------------------------- import unittest -import onnx + import numpy as np -from onnx import helper, TensorProto -from onnxruntime.quantization import quantize_dynamic +import onnx +from onnx import TensorProto, helper from op_test_utils import TestDataFeeds, check_model_correctness, check_op_type_count +from onnxruntime.quantization import quantize_dynamic + class TestOpEmbedLayerNormalization(unittest.TestCase): def input_feeds_int32(self, n, name2shape): @@ -35,61 +37,73 @@ def construct_model(self, batch, hidden_size, sequence_length, model_path): # Inputs to EmbedLayerNormalizationNode input_ids_shape = [batch, sequence_length] - input_ids_tensor = helper.make_tensor_value_info('input_ids', TensorProto.INT32, input_ids_shape) + input_ids_tensor = helper.make_tensor_value_info("input_ids", TensorProto.INT32, input_ids_shape) segment_ids_shape = [batch, sequence_length] - segment_ids_tensor = helper.make_tensor_value_info('segment_ids', TensorProto.INT32, segment_ids_shape) + segment_ids_tensor = helper.make_tensor_value_info("segment_ids", TensorProto.INT32, segment_ids_shape) # EmbedLayerNormalization Node Constants and Weights: word_embed_shape = [32, hidden_size] - word_embed_weights = np.random.random_sample(word_embed_shape).astype(dtype='float32') - word_embed_initializer = onnx.numpy_helper.from_array(word_embed_weights, name='word_embed') + word_embed_weights = np.random.random_sample(word_embed_shape).astype(dtype="float32") + word_embed_initializer = onnx.numpy_helper.from_array(word_embed_weights, name="word_embed") pos_embed_shape = [16, hidden_size] - pos_embed_weights = np.random.random_sample(pos_embed_shape).astype(dtype='float32') - pos_embed_initializer = onnx.numpy_helper.from_array(pos_embed_weights, name='pos_embed') + pos_embed_weights = np.random.random_sample(pos_embed_shape).astype(dtype="float32") + pos_embed_initializer = onnx.numpy_helper.from_array(pos_embed_weights, name="pos_embed") seg_embed_shape = [2, hidden_size] - seg_embed_weights = np.random.random_sample(seg_embed_shape).astype(dtype='float32') - seg_embed_initializer = onnx.numpy_helper.from_array(seg_embed_weights, name='seg_embed') + seg_embed_weights = np.random.random_sample(seg_embed_shape).astype(dtype="float32") + seg_embed_initializer = onnx.numpy_helper.from_array(seg_embed_weights, name="seg_embed") gamma_shape = [hidden_size] - gamma = np.random.random_sample(gamma_shape).astype(dtype='float32') - gamma_initializer = onnx.numpy_helper.from_array(gamma, name='gamma') + gamma = np.random.random_sample(gamma_shape).astype(dtype="float32") + gamma_initializer = onnx.numpy_helper.from_array(gamma, name="gamma") beta_shape = [hidden_size] - beta = np.random.random_sample(beta_shape).astype(dtype='float32') - beta_initializer = onnx.numpy_helper.from_array(beta, name='beta') + beta = np.random.random_sample(beta_shape).astype(dtype="float32") + beta_initializer = onnx.numpy_helper.from_array(beta, name="beta") # EmbedLayerNormalization Outputs: layernorm_out_shape = [batch, sequence_length, hidden_size] - layernorm_out_tensor = helper.make_tensor_value_info('layernorm_out', TensorProto.FLOAT, layernorm_out_shape) + layernorm_out_tensor = helper.make_tensor_value_info("layernorm_out", TensorProto.FLOAT, layernorm_out_shape) mask_index_out_shape = [batch] - mask_index_out_tensor = helper.make_tensor_value_info('mask_index_out', TensorProto.INT32, mask_index_out_shape) + mask_index_out_tensor = helper.make_tensor_value_info("mask_index_out", TensorProto.INT32, mask_index_out_shape) # EmbedLayerNormalization Node: embed_layer_norm_inputs = [ - 'input_ids', 'segment_ids', 'word_embed', 'pos_embed', 'seg_embed', 'gamma', 'beta' + "input_ids", + "segment_ids", + "word_embed", + "pos_embed", + "seg_embed", + "gamma", + "beta", ] - embed_layer_norm_outputs = ['layernorm_out', 'mask_index_out'] - embed_layer_norm_node = helper.make_node('EmbedLayerNormalization', - embed_layer_norm_inputs, - embed_layer_norm_outputs, - domain='com.microsoft') + embed_layer_norm_outputs = ["layernorm_out", "mask_index_out"] + embed_layer_norm_node = helper.make_node( + "EmbedLayerNormalization", + embed_layer_norm_inputs, + embed_layer_norm_outputs, + domain="com.microsoft", + ) # Construct the Graph and Model: nodes = [embed_layer_norm_node] - graph_name = 'embed_layernorm_graph' + graph_name = "embed_layernorm_graph" inputs = [input_ids_tensor, segment_ids_tensor] outputs = [layernorm_out_tensor, mask_index_out_tensor] initializers = [ - word_embed_initializer, pos_embed_initializer, seg_embed_initializer, gamma_initializer, beta_initializer + word_embed_initializer, + pos_embed_initializer, + seg_embed_initializer, + gamma_initializer, + beta_initializer, ] graph = helper.make_graph(nodes, graph_name, inputs, outputs, initializer=initializers) model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 14)]) - model.ir_version = 7 # use stable onnx ir version + model.ir_version = 7 # use stable onnx ir version onnx.save(model, model_path) def test_quantize_batch_size_1(self): @@ -97,20 +111,23 @@ def test_quantize_batch_size_1(self): hidden_size = 4 sequence_length = 4 - model_f32_path = 'test_embed_layer_norm_unit_test_batch1.onnx' - model_uint8_path = 'test_embed_layer_norm_unit_test_batch1_uint8.onnx' + model_f32_path = "test_embed_layer_norm_unit_test_batch1.onnx" + model_uint8_path = "test_embed_layer_norm_unit_test_batch1_uint8.onnx" self.construct_model(batch, hidden_size, sequence_length, model_f32_path) - data_reader = self.input_feeds_int32(1, { - 'input_ids': [batch, sequence_length], - 'segment_ids': [batch, sequence_length] - }) + data_reader = self.input_feeds_int32( + 1, + { + "input_ids": [batch, sequence_length], + "segment_ids": [batch, sequence_length], + }, + ) quantize_dynamic(model_f32_path, model_uint8_path) # Quantization should not have any DequantizeLinear nodes: - qnode_counts = {'DequantizeLinear': 0, 'QEmbedLayerNormalization': 1} + qnode_counts = {"DequantizeLinear": 0, "QEmbedLayerNormalization": 1} check_op_type_count(self, model_uint8_path, **qnode_counts) data_reader.rewind() @@ -121,25 +138,28 @@ def test_quantize_batch_size_2(self): hidden_size = 4 sequence_length = 4 - model_f32_path = 'test_embed_layer_norm_unit_test_batch2.onnx' - model_uint8_path = 'test_embed_layer_norm_unit_test_batch2_uint8.onnx' + model_f32_path = "test_embed_layer_norm_unit_test_batch2.onnx" + model_uint8_path = "test_embed_layer_norm_unit_test_batch2_uint8.onnx" self.construct_model(batch, hidden_size, sequence_length, model_f32_path) - data_reader = self.input_feeds_int32(1, { - 'input_ids': [batch, sequence_length], - 'segment_ids': [batch, sequence_length] - }) + data_reader = self.input_feeds_int32( + 1, + { + "input_ids": [batch, sequence_length], + "segment_ids": [batch, sequence_length], + }, + ) quantize_dynamic(model_f32_path, model_uint8_path) # Quantization should not have any DequantizeLinear nodes: - qnode_counts = {'DequantizeLinear': 0, 'QEmbedLayerNormalization': 1} + qnode_counts = {"DequantizeLinear": 0, "QEmbedLayerNormalization": 1} check_op_type_count(self, model_uint8_path, **qnode_counts) data_reader.rewind() check_model_correctness(self, model_f32_path, model_uint8_path, data_reader.get_next()) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/onnxruntime/test/python/quantization/test_op_gavgpool.py b/onnxruntime/test/python/quantization/test_op_gavgpool.py index 198fca361d0ff..a34c52f912ced 100644 --- a/onnxruntime/test/python/quantization/test_op_gavgpool.py +++ b/onnxruntime/test/python/quantization/test_op_gavgpool.py @@ -7,12 +7,14 @@ # -------------------------------------------------------------------------- import unittest -import onnx + import numpy as np -from onnx import helper, TensorProto -from onnxruntime.quantization import quantize_static, quantize_dynamic, QuantType, QuantFormat +import onnx +from onnx import TensorProto, helper from op_test_utils import TestDataFeeds, check_model_correctness, check_op_type_count, check_qtype_by_node_type +from onnxruntime.quantization import QuantFormat, QuantType, quantize_dynamic, quantize_static + class TestOpGlobalAveragePool(unittest.TestCase): def input_feeds(self, n, name2shape): @@ -37,37 +39,42 @@ def construct_model_gavgpool(self, output_model_path, input_shape, weight_shape, # GlobalAveragePool # | # (output) - input_name = 'input' - expand_input = 'expand_input' - conv_input = 'conv_input' - gavgpool_input_2nd = 'gavgpool_input' - output_name = 'output' + input_name = "input" + expand_input = "expand_input" + conv_input = "conv_input" + gavgpool_input_2nd = "gavgpool_input" + output_name = "output" initializers = [] # make 1st GlobalAveragePool node - gavgpool_node_1 = onnx.helper.make_node('GlobalAveragePool', [input_name], [expand_input]) + gavgpool_node_1 = onnx.helper.make_node("GlobalAveragePool", [input_name], [expand_input]) # make Expand node - expand_shape_name = 'expand_shape' + expand_shape_name = "expand_shape" initializers.append(onnx.numpy_helper.from_array(np.array(input_shape, dtype=np.int64), name=expand_shape_name)) - expand_node = onnx.helper.make_node('Expand', [expand_input, expand_shape_name], [conv_input]) + expand_node = onnx.helper.make_node("Expand", [expand_input, expand_shape_name], [conv_input]) # make Conv node - weight_name = 'conv_weight' - conv_name = 'conv_node' + weight_name = "conv_weight" + conv_name = "conv_node" conv_weight_data = np.random.normal(0, 0.1, weight_shape).astype(np.float32) initializers.append(onnx.numpy_helper.from_array(conv_weight_data, name=weight_name)) - conv_node = onnx.helper.make_node('Conv', [conv_input, weight_name], [gavgpool_input_2nd], name=conv_name) + conv_node = onnx.helper.make_node("Conv", [conv_input, weight_name], [gavgpool_input_2nd], name=conv_name) # make 1st GlobalAveragePool node - gavgpool_node_2 = onnx.helper.make_node('GlobalAveragePool', [gavgpool_input_2nd], [output_name]) + gavgpool_node_2 = onnx.helper.make_node("GlobalAveragePool", [gavgpool_input_2nd], [output_name]) # make graph input_tensor = helper.make_tensor_value_info(input_name, TensorProto.FLOAT, input_shape) output_tensor = helper.make_tensor_value_info(output_name, TensorProto.FLOAT, output_shape) - graph_name = 'GAveragePool_test' - graph = helper.make_graph([gavgpool_node_1, expand_node, conv_node, gavgpool_node_2], graph_name, - [input_tensor], [output_tensor], initializer=initializers) + graph_name = "GAveragePool_test" + graph = helper.make_graph( + [gavgpool_node_1, expand_node, conv_node, gavgpool_node_2], + graph_name, + [input_tensor], + [output_tensor], + initializer=initializers, + ) model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)]) model.ir_version = 7 # use stable onnx ir version @@ -75,27 +82,48 @@ def construct_model_gavgpool(self, output_model_path, input_shape, weight_shape, def quantize_gavgpool_test(self, activation_type, weight_type, extra_options={}): np.random.seed(1) - model_fp32_path = 'gavg_pool_fp32.onnx' - data_reader = self.input_feeds(1, {'input': [1, 8, 33, 33]}) - self.construct_model_gavgpool(model_fp32_path, - [1, 8, 33, 33], - [16, 8, 3, 3], - [1, 16, 1, 1]) + model_fp32_path = "gavg_pool_fp32.onnx" + data_reader = self.input_feeds(1, {"input": [1, 8, 33, 33]}) + self.construct_model_gavgpool(model_fp32_path, [1, 8, 33, 33], [16, 8, 3, 3], [1, 16, 1, 1]) activation_proto_qtype = TensorProto.UINT8 if activation_type == QuantType.QUInt8 else TensorProto.INT8 - activation_type_str = 'u8' if (activation_type == QuantType.QUInt8) else 's8' - weight_type_str = 'u8' if (weight_type == QuantType.QUInt8) else 's8' - model_q8_path = 'gavg_pool_{}{}.onnx'.format(activation_type_str, weight_type_str) + activation_type_str = "u8" if (activation_type == QuantType.QUInt8) else "s8" + weight_type_str = "u8" if (weight_type == QuantType.QUInt8) else "s8" + model_q8_path = "gavg_pool_{}{}.onnx".format(activation_type_str, weight_type_str) data_reader.rewind() - quantize_static(model_fp32_path, model_q8_path, data_reader, quant_format=QuantFormat.QOperator, - activation_type=activation_type, weight_type=weight_type, extra_options=extra_options) - - quant_nodes = {'QLinearConv': 1, 'GlobalAveragePool': 1, 'QLinearGlobalAveragePool': 1, - 'QuantizeLinear': 1, 'DequantizeLinear': 1} + quantize_static( + model_fp32_path, + model_q8_path, + data_reader, + quant_format=QuantFormat.QOperator, + activation_type=activation_type, + weight_type=weight_type, + extra_options=extra_options, + ) + + quant_nodes = { + "QLinearConv": 1, + "GlobalAveragePool": 1, + "QLinearGlobalAveragePool": 1, + "QuantizeLinear": 1, + "DequantizeLinear": 1, + } check_op_type_count(self, model_q8_path, **quant_nodes) - qnode_io_qtypes = {'QuantizeLinear': [['i', 2, activation_proto_qtype], ['o', 0, activation_proto_qtype]]} - qnode_io_qtypes.update({'QLinearGlobalAveragePool': [['i', 2, activation_proto_qtype], ['i', 4, activation_proto_qtype]]}) + qnode_io_qtypes = { + "QuantizeLinear": [ + ["i", 2, activation_proto_qtype], + ["o", 0, activation_proto_qtype], + ] + } + qnode_io_qtypes.update( + { + "QLinearGlobalAveragePool": [ + ["i", 2, activation_proto_qtype], + ["i", 4, activation_proto_qtype], + ] + } + ) check_qtype_by_node_type(self, model_q8_path, qnode_io_qtypes) data_reader.rewind() check_model_correctness(self, model_fp32_path, model_q8_path, data_reader.get_next()) @@ -104,8 +132,12 @@ def test_quantize_gavgpool(self): self.quantize_gavgpool_test(QuantType.QUInt8, QuantType.QUInt8, extra_options={}) def test_quantize_gavgpool_s8s8(self): - self.quantize_gavgpool_test(QuantType.QInt8, QuantType.QInt8, extra_options={'ActivationSymmetric': True}) + self.quantize_gavgpool_test( + QuantType.QInt8, + QuantType.QInt8, + extra_options={"ActivationSymmetric": True}, + ) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/onnxruntime/test/python/quantization/test_op_gemm.py b/onnxruntime/test/python/quantization/test_op_gemm.py index b969f437db7f5..692e4ebe9fb6c 100644 --- a/onnxruntime/test/python/quantization/test_op_gemm.py +++ b/onnxruntime/test/python/quantization/test_op_gemm.py @@ -7,12 +7,14 @@ # -------------------------------------------------------------------------- import unittest -import onnx + import numpy as np -from onnx import helper, TensorProto -from onnxruntime.quantization import quantize_static, quantize_dynamic, QuantFormat, QuantType +import onnx +from onnx import TensorProto, helper from op_test_utils import TestDataFeeds, check_model_correctness, check_op_type_count, check_qtype_by_node_type +from onnxruntime.quantization import QuantFormat, QuantType, quantize_dynamic, quantize_static + class TestOpGEMM(unittest.TestCase): def input_feeds(self, n, name2shape): @@ -35,8 +37,8 @@ def construct_model_gemm(self, output_model_path): # GEMM # | # (output) - input_name = 'input' - output_name = 'output' + input_name = "input" + output_name = "output" initializers = [] def make_gemm(input_name, weight_shape, weight_name, bias_shape, bias_name, output_name): @@ -46,30 +48,57 @@ def make_gemm(input_name, weight_shape, weight_name, bias_shape, bias_name, outp bias_data = np.random.normal(0, 0.1, bias_shape).astype(np.float32) initializers.append(onnx.numpy_helper.from_array(bias_data, name=bias_name)) - return onnx.helper.make_node('Gemm', [input_name, weight_name, bias_name], [output_name], alpha=1.0, beta=1.0, transB=1) + return onnx.helper.make_node( + "Gemm", + [input_name, weight_name, bias_name], + [output_name], + alpha=1.0, + beta=1.0, + transB=1, + ) + # make gemm1 node gemm1_output_name = "gemm1_output" - gemm1_node = make_gemm(input_name, [100, 10], 'linear1.weight', [100], 'linear1.bias', gemm1_output_name) + gemm1_node = make_gemm( + input_name, + [100, 10], + "linear1.weight", + [100], + "linear1.bias", + gemm1_output_name, + ) # make Clip - clip_min_name = 'clip_min' - clip_max_name = 'clip_max' - clip_output_name = 'clip_output' + clip_min_name = "clip_min" + clip_max_name = "clip_max" + clip_output_name = "clip_output" clip_inputs = [gemm1_output_name, clip_min_name, clip_max_name] clip_outputs = [clip_output_name] initializers.append(onnx.numpy_helper.from_array(np.array(-1.0, dtype=np.float32), name=clip_min_name)) initializers.append(onnx.numpy_helper.from_array(np.array(1.0, dtype=np.float32), name=clip_max_name)) - clip_node = onnx.helper.make_node('Clip', clip_inputs, clip_outputs) + clip_node = onnx.helper.make_node("Clip", clip_inputs, clip_outputs) # make gemm2 node - gemm2_node = make_gemm(clip_output_name, [10, 100], 'linear2.weight', [10], 'linear2.bias', output_name) + gemm2_node = make_gemm( + clip_output_name, + [10, 100], + "linear2.weight", + [10], + "linear2.bias", + output_name, + ) # make graph input_tensor = helper.make_tensor_value_info(input_name, TensorProto.FLOAT, [-1, 10]) output_tensor = helper.make_tensor_value_info(output_name, TensorProto.FLOAT, [-1, 10]) - graph_name = 'gemm_test' - graph = helper.make_graph([gemm1_node, clip_node, gemm2_node], graph_name, - [input_tensor], [output_tensor], initializer=initializers) + graph_name = "gemm_test" + graph = helper.make_graph( + [gemm1_node, clip_node, gemm2_node], + graph_name, + [input_tensor], + [output_tensor], + initializer=initializers, + ) model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)]) model.ir_version = 7 # use stable onnx ir version @@ -83,8 +112,8 @@ def construct_model_attention_and_matmul(self, output_model_path): # MatMul # | # (output) - input_name = 'input' - output_name = 'output' + input_name = "input" + output_name = "output" initializers = [] def make_attention_node(input_name, weight_shape, weight_name, bias_shape, bias_name, output_name): @@ -94,116 +123,213 @@ def make_attention_node(input_name, weight_shape, weight_name, bias_shape, bias_ bias_data = np.random.normal(0, 0.1, bias_shape).astype(np.float32) initializers.append(onnx.numpy_helper.from_array(bias_data, name=bias_name)) - return onnx.helper.make_node('Attention', [input_name, weight_name, bias_name], [output_name]) + return onnx.helper.make_node("Attention", [input_name, weight_name, bias_name], [output_name]) def make_matmul_node(input_name, weight_shape, weight_name, output_name): weight_data = np.random.normal(0, 0.1, weight_shape).astype(np.float32) initializers.append(onnx.numpy_helper.from_array(weight_data, name=weight_name)) - return onnx.helper.make_node('MatMul', [input_name, weight_name], [output_name]) + return onnx.helper.make_node("MatMul", [input_name, weight_name], [output_name]) + # make attention node attention_output_name = "attention_output" - attention_node = make_attention_node(input_name, [10, 30], 'qkv.weight', [30], 'qkv.bias', attention_output_name) + attention_node = make_attention_node( + input_name, [10, 30], "qkv.weight", [30], "qkv.bias", attention_output_name + ) attention_node.domain = "com.microsoft" attention_node.attribute.extend([helper.make_attribute("num_heads", 5)]) # make matmul node - matmul_node = make_matmul_node(attention_output_name, [10, 10], 'matmul.weight', output_name) + matmul_node = make_matmul_node(attention_output_name, [10, 10], "matmul.weight", output_name) # make graph input_tensor = helper.make_tensor_value_info(input_name, TensorProto.FLOAT, [1, -1, 10]) output_tensor = helper.make_tensor_value_info(output_name, TensorProto.FLOAT, [1, -1, 10]) - graph_name = 'attention_test' - graph = helper.make_graph([attention_node, matmul_node], graph_name, - [input_tensor], [output_tensor], initializer=initializers) + graph_name = "attention_test" + graph = helper.make_graph( + [attention_node, matmul_node], + graph_name, + [input_tensor], + [output_tensor], + initializer=initializers, + ) model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)]) model.ir_version = onnx.IR_VERSION onnx.save(model, output_model_path) - def static_quant_test(self, model_fp32_path, data_reader, activation_type, weight_type, extra_options={}): + def static_quant_test( + self, + model_fp32_path, + data_reader, + activation_type, + weight_type, + extra_options={}, + ): activation_proto_qtype = TensorProto.UINT8 if activation_type == QuantType.QUInt8 else TensorProto.INT8 - activation_type_str = 'u8' if (activation_type == QuantType.QUInt8) else 's8' - weight_type_str = 'u8' if (weight_type == QuantType.QUInt8) else 's8' - model_int8_path = 'gemm_fp32.quant_{}{}.onnx'.format(activation_type_str, weight_type_str) + activation_type_str = "u8" if (activation_type == QuantType.QUInt8) else "s8" + weight_type_str = "u8" if (weight_type == QuantType.QUInt8) else "s8" + model_int8_path = "gemm_fp32.quant_{}{}.onnx".format(activation_type_str, weight_type_str) data_reader.rewind() - quantize_static(model_fp32_path, model_int8_path, data_reader, quant_format=QuantFormat.QOperator, - activation_type=activation_type, weight_type=weight_type, extra_options=extra_options) - quant_nodes = {'QGemm': 2, 'QuantizeLinear': 1, 'DequantizeLinear': 1} + quantize_static( + model_fp32_path, + model_int8_path, + data_reader, + quant_format=QuantFormat.QOperator, + activation_type=activation_type, + weight_type=weight_type, + extra_options=extra_options, + ) + quant_nodes = {"QGemm": 2, "QuantizeLinear": 1, "DequantizeLinear": 1} check_op_type_count(self, model_int8_path, **quant_nodes) - qnode_io_qtypes = {'QuantizeLinear': [['i', 2, activation_proto_qtype], ['o', 0, activation_proto_qtype]]} - qnode_io_qtypes.update({'DequantizeLinear': [['i', 2, activation_proto_qtype]]}) + qnode_io_qtypes = { + "QuantizeLinear": [ + ["i", 2, activation_proto_qtype], + ["o", 0, activation_proto_qtype], + ] + } + qnode_io_qtypes.update({"DequantizeLinear": [["i", 2, activation_proto_qtype]]}) check_qtype_by_node_type(self, model_int8_path, qnode_io_qtypes) data_reader.rewind() check_model_correctness(self, model_fp32_path, model_int8_path, data_reader.get_next()) - def static_quant_test_qdq(self, model_fp32_path, data_reader, activation_type, weight_type, extra_options={}): + def static_quant_test_qdq( + self, + model_fp32_path, + data_reader, + activation_type, + weight_type, + extra_options={}, + ): activation_proto_qtype = TensorProto.UINT8 if activation_type == QuantType.QUInt8 else TensorProto.INT8 - activation_type_str = 'u8' if (activation_type == QuantType.QUInt8) else 's8' - weight_type_str = 'u8' if (weight_type == QuantType.QUInt8) else 's8' - model_int8_path = 'gemm_fp32.quant_dqd_{}{}.onnx'.format(activation_type_str, weight_type_str) + activation_type_str = "u8" if (activation_type == QuantType.QUInt8) else "s8" + weight_type_str = "u8" if (weight_type == QuantType.QUInt8) else "s8" + model_int8_path = "gemm_fp32.quant_dqd_{}{}.onnx".format(activation_type_str, weight_type_str) data_reader.rewind() - quantize_static(model_fp32_path, model_int8_path, data_reader, quant_format=QuantFormat.QDQ, - activation_type=activation_type, weight_type=weight_type, extra_options=extra_options) - quant_nodes = {'Gemm': 2, 'QuantizeLinear': 3, 'DequantizeLinear': 7} + quantize_static( + model_fp32_path, + model_int8_path, + data_reader, + quant_format=QuantFormat.QDQ, + activation_type=activation_type, + weight_type=weight_type, + extra_options=extra_options, + ) + quant_nodes = {"Gemm": 2, "QuantizeLinear": 3, "DequantizeLinear": 7} check_op_type_count(self, model_int8_path, **quant_nodes) - qnode_io_qtypes = {'QuantizeLinear': [['i', 2, activation_proto_qtype], ['o', 0, activation_proto_qtype]]} + qnode_io_qtypes = { + "QuantizeLinear": [ + ["i", 2, activation_proto_qtype], + ["o", 0, activation_proto_qtype], + ] + } check_qtype_by_node_type(self, model_int8_path, qnode_io_qtypes) data_reader.rewind() check_model_correctness(self, model_fp32_path, model_int8_path, data_reader.get_next()) - def dynamic_quant_test(self, model_fp32_path, data_reader, activation_type, weight_type, extra_options={}): + def dynamic_quant_test( + self, + model_fp32_path, + data_reader, + activation_type, + weight_type, + extra_options={}, + ): activation_proto_qtype = TensorProto.UINT8 if activation_type == QuantType.QUInt8 else TensorProto.INT8 - activation_type_str = 'u8' if (activation_type == QuantType.QUInt8) else 's8' - weight_type_str = 'u8' if (weight_type == QuantType.QUInt8) else 's8' - model_int8_path = 'gemm_fp32.quant_dynamic_{}{}.onnx'.format(activation_type_str, weight_type_str) - - quantize_dynamic(model_fp32_path, model_int8_path, - weight_type=weight_type, extra_options=extra_options) - quant_nodes = {'MatMulInteger': 2} + activation_type_str = "u8" if (activation_type == QuantType.QUInt8) else "s8" + weight_type_str = "u8" if (weight_type == QuantType.QUInt8) else "s8" + model_int8_path = "gemm_fp32.quant_dynamic_{}{}.onnx".format(activation_type_str, weight_type_str) + + quantize_dynamic( + model_fp32_path, + model_int8_path, + weight_type=weight_type, + extra_options=extra_options, + ) + quant_nodes = {"MatMulInteger": 2} check_op_type_count(self, model_int8_path, **quant_nodes) - qnode_io_qtypes = {'MatMulInteger': [['i', 2, activation_proto_qtype]]} + qnode_io_qtypes = {"MatMulInteger": [["i", 2, activation_proto_qtype]]} check_qtype_by_node_type(self, model_int8_path, qnode_io_qtypes) data_reader.rewind() - check_model_correctness(self, model_fp32_path, model_int8_path, {'input': np.random.rand(5, 10).astype(np.float32)}) + check_model_correctness( + self, + model_fp32_path, + model_int8_path, + {"input": np.random.rand(5, 10).astype(np.float32)}, + ) def dynamic_attention_quant_test(self, model_fp32_path, model_int8_path, per_channel, reduce_range): - quantize_dynamic(model_fp32_path, model_int8_path, per_channel=per_channel, reduce_range=reduce_range) - quant_nodes = {'QAttention': 1, 'MatMulInteger': 1} + quantize_dynamic( + model_fp32_path, + model_int8_path, + per_channel=per_channel, + reduce_range=reduce_range, + ) + quant_nodes = {"QAttention": 1, "MatMulInteger": 1} check_op_type_count(self, model_int8_path, **quant_nodes) - check_model_correctness(self, model_fp32_path, model_int8_path, {'input': np.random.rand(1, 5, 10).astype(np.float32)}) + check_model_correctness( + self, + model_fp32_path, + model_int8_path, + {"input": np.random.rand(1, 5, 10).astype(np.float32)}, + ) def test_quantize_gemm(self): np.random.seed(1) - model_fp32_path = 'gemm_fp32.onnx' + model_fp32_path = "gemm_fp32.onnx" self.construct_model_gemm(model_fp32_path) - data_reader = self.input_feeds(1, {'input': [5, 10]}) - - self.static_quant_test(model_fp32_path, data_reader, activation_type=QuantType.QUInt8, weight_type=QuantType.QUInt8) - self.static_quant_test_qdq(model_fp32_path, data_reader, activation_type=QuantType.QUInt8, weight_type=QuantType.QUInt8) - self.dynamic_quant_test(model_fp32_path, data_reader, activation_type=QuantType.QUInt8, weight_type=QuantType.QUInt8) + data_reader = self.input_feeds(1, {"input": [5, 10]}) + + self.static_quant_test( + model_fp32_path, + data_reader, + activation_type=QuantType.QUInt8, + weight_type=QuantType.QUInt8, + ) + self.static_quant_test_qdq( + model_fp32_path, + data_reader, + activation_type=QuantType.QUInt8, + weight_type=QuantType.QUInt8, + ) + self.dynamic_quant_test( + model_fp32_path, + data_reader, + activation_type=QuantType.QUInt8, + weight_type=QuantType.QUInt8, + ) def test_quantize_gemm_s8s8(self): np.random.seed(1) - model_fp32_path = 'gemm_fp32.onnx' + model_fp32_path = "gemm_fp32.onnx" self.construct_model_gemm(model_fp32_path) - data_reader = self.input_feeds(1, {'input': [5, 10]}) - - self.static_quant_test(model_fp32_path, data_reader, activation_type=QuantType.QInt8, weight_type=QuantType.QInt8, - extra_options={'ActivationSymmetric': True}) - self.static_quant_test_qdq(model_fp32_path, data_reader, activation_type=QuantType.QInt8, weight_type=QuantType.QInt8, - extra_options={'ActivationSymmetric': True}) + data_reader = self.input_feeds(1, {"input": [5, 10]}) + + self.static_quant_test( + model_fp32_path, + data_reader, + activation_type=QuantType.QInt8, + weight_type=QuantType.QInt8, + extra_options={"ActivationSymmetric": True}, + ) + self.static_quant_test_qdq( + model_fp32_path, + data_reader, + activation_type=QuantType.QInt8, + weight_type=QuantType.QInt8, + extra_options={"ActivationSymmetric": True}, + ) # dynamic quantization doesn't support activation:int8 - #self.dynamic_quant_test(model_fp32_path, data_reader, activation_type=QuantType.QInt8, weight_type=QuantType.QInt8, + # self.dynamic_quant_test(model_fp32_path, data_reader, activation_type=QuantType.QInt8, weight_type=QuantType.QInt8, # extra_options={'ActivationSymmetric': True}) def test_quantize_attention(self): np.random.seed(1) - model_fp32_path = 'attention_fp32.onnx' - model_int8_path = 'attention_fp32.quant.onnx' + model_fp32_path = "attention_fp32.onnx" + model_int8_path = "attention_fp32.quant.onnx" self.construct_model_attention_and_matmul(model_fp32_path) self.dynamic_attention_quant_test(model_fp32_path, model_int8_path, True, True) @@ -212,5 +338,5 @@ def test_quantize_attention(self): self.dynamic_attention_quant_test(model_fp32_path, model_int8_path, False, False) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/onnxruntime/test/python/quantization/test_op_maxpool.py b/onnxruntime/test/python/quantization/test_op_maxpool.py index 6e50bbbc2ae25..043aa2bca48bf 100644 --- a/onnxruntime/test/python/quantization/test_op_maxpool.py +++ b/onnxruntime/test/python/quantization/test_op_maxpool.py @@ -7,11 +7,19 @@ # -------------------------------------------------------------------------- import unittest -import onnx + import numpy as np -from onnx import helper, TensorProto -from onnxruntime.quantization import quantize_static, QuantFormat, QuantType -from op_test_utils import TestDataFeeds, check_model_correctness, check_op_type_count, check_op_nodes, check_qtype_by_node_type +import onnx +from onnx import TensorProto, helper +from op_test_utils import ( + TestDataFeeds, + check_model_correctness, + check_op_nodes, + check_op_type_count, + check_qtype_by_node_type, +) + +from onnxruntime.quantization import QuantFormat, QuantType, quantize_static class TestOpMaxPool(unittest.TestCase): @@ -25,11 +33,15 @@ def input_feeds(self, n, name2shape): dr = TestDataFeeds(input_data_list) return dr - def construct_model_conv_maxpool(self, output_model_path, - conv_input_shape, conv_weight_shape, - maxpool_input_shape, maxpool_attributes, - output_shape, - ): + def construct_model_conv_maxpool( + self, + output_model_path, + conv_input_shape, + conv_weight_shape, + maxpool_input_shape, + maxpool_attributes, + output_shape, + ): # (input) # \ # Conv @@ -37,72 +49,126 @@ def construct_model_conv_maxpool(self, output_model_path, # Identity MaxPool # / \ # (identity_out) (output) - input_tensor = helper.make_tensor_value_info('input', TensorProto.FLOAT, conv_input_shape) + input_tensor = helper.make_tensor_value_info("input", TensorProto.FLOAT, conv_input_shape) conv_weight_arr = np.random.randint(-1, 2, conv_weight_shape).astype(np.float32) - conv_weight_initializer = onnx.numpy_helper.from_array(conv_weight_arr, name='conv1_weight') - conv_node = onnx.helper.make_node('Conv', ['input', 'conv1_weight'], ['conv_output'], name='conv_node') + conv_weight_initializer = onnx.numpy_helper.from_array(conv_weight_arr, name="conv1_weight") + conv_node = onnx.helper.make_node("Conv", ["input", "conv1_weight"], ["conv_output"], name="conv_node") - identity_out = helper.make_tensor_value_info('identity_out', TensorProto.FLOAT, maxpool_input_shape) - identity_node = helper.make_node('Identity', ['conv_output'], ['identity_out'], name='IdentityNode') + identity_out = helper.make_tensor_value_info("identity_out", TensorProto.FLOAT, maxpool_input_shape) + identity_node = helper.make_node("Identity", ["conv_output"], ["identity_out"], name="IdentityNode") initializers = [conv_weight_initializer] - output_tensor = helper.make_tensor_value_info('output', TensorProto.FLOAT, output_shape) - maxpool_node = helper.make_node('MaxPool', ['conv_output'], ['output'], name='maxpool_node', **maxpool_attributes) - - graph = helper.make_graph([conv_node, identity_node, maxpool_node], 'TestOpQuantizerMaxPool_test_model', - [input_tensor], [identity_out, output_tensor], initializer=initializers) + output_tensor = helper.make_tensor_value_info("output", TensorProto.FLOAT, output_shape) + maxpool_node = helper.make_node( + "MaxPool", ["conv_output"], ["output"], name="maxpool_node", **maxpool_attributes + ) + + graph = helper.make_graph( + [conv_node, identity_node, maxpool_node], + "TestOpQuantizerMaxPool_test_model", + [input_tensor], + [identity_out, output_tensor], + initializer=initializers, + ) model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 14)]) model.ir_version = 7 # use stable onnx ir version onnx.save(model, output_model_path) def quantize_maxpool_test(self, activation_type, weight_type, extra_options={}): np.random.seed(1) - model_fp32_path = 'maxpool_fp32.onnx' - self.construct_model_conv_maxpool(model_fp32_path, - [1, 2, 26, 42], [3, 2, 3, 3], - [1, 3, 24, 40], {'kernel_shape': [3, 3]}, - [1, 3, 22, 38]) - data_reader = self.input_feeds(1, {'input': [1, 2, 26, 42]}) + model_fp32_path = "maxpool_fp32.onnx" + self.construct_model_conv_maxpool( + model_fp32_path, + [1, 2, 26, 42], + [3, 2, 3, 3], + [1, 3, 24, 40], + {"kernel_shape": [3, 3]}, + [1, 3, 22, 38], + ) + data_reader = self.input_feeds(1, {"input": [1, 2, 26, 42]}) activation_proto_qtype = TensorProto.UINT8 if activation_type == QuantType.QUInt8 else TensorProto.INT8 - activation_type_str = 'u8' if (activation_type == QuantType.QUInt8) else 's8' - weight_type_str = 'u8' if (weight_type == QuantType.QUInt8) else 's8' - model_q8_path = 'maxpool_{}{}.onnx'.format(activation_type_str, weight_type_str) - model_q8_qdq_path = 'maxpool_dqd_{}{}.onnx'.format(activation_type_str, weight_type_str) + activation_type_str = "u8" if (activation_type == QuantType.QUInt8) else "s8" + weight_type_str = "u8" if (weight_type == QuantType.QUInt8) else "s8" + model_q8_path = "maxpool_{}{}.onnx".format(activation_type_str, weight_type_str) + model_q8_qdq_path = "maxpool_dqd_{}{}.onnx".format(activation_type_str, weight_type_str) # Verify QOperator mode data_reader.rewind() - quantize_static(model_fp32_path, model_q8_path, data_reader, quant_format=QuantFormat.QOperator, - activation_type=activation_type, weight_type=weight_type, extra_options=extra_options) + quantize_static( + model_fp32_path, + model_q8_path, + data_reader, + quant_format=QuantFormat.QOperator, + activation_type=activation_type, + weight_type=weight_type, + extra_options=extra_options, + ) # make sure maxpool become xint8 operator, its input name could tell that - check_op_nodes(self, model_q8_path, lambda node: (node.name != "maxpool_node" or node.input[0] != 'conv_output')) - qnode_counts = {'QLinearConv': 1, 'QuantizeLinear': 1, 'DequantizeLinear': 2, 'MaxPool': 1} + check_op_nodes( + self, + model_q8_path, + lambda node: (node.name != "maxpool_node" or node.input[0] != "conv_output"), + ) + qnode_counts = { + "QLinearConv": 1, + "QuantizeLinear": 1, + "DequantizeLinear": 2, + "MaxPool": 1, + } check_op_type_count(self, model_q8_path, **qnode_counts) - qnode_io_qtypes = {'QuantizeLinear': [['i', 2, activation_proto_qtype], ['o', 0, activation_proto_qtype]]} - qnode_io_qtypes.update({'DequantizeLinear': [['i', 2, activation_proto_qtype]]}) + qnode_io_qtypes = { + "QuantizeLinear": [ + ["i", 2, activation_proto_qtype], + ["o", 0, activation_proto_qtype], + ] + } + qnode_io_qtypes.update({"DequantizeLinear": [["i", 2, activation_proto_qtype]]}) check_qtype_by_node_type(self, model_q8_path, qnode_io_qtypes) data_reader.rewind() check_model_correctness(self, model_fp32_path, model_q8_path, data_reader.get_next()) # Verify QDQ mode data_reader.rewind() - quantize_static(model_fp32_path, model_q8_qdq_path, data_reader, quant_format=QuantFormat.QDQ, - activation_type=activation_type, weight_type=weight_type, extra_options=extra_options) - qdqnode_counts = {'Conv': 1, 'QuantizeLinear': 3, 'DequantizeLinear': 4, 'MaxPool': 1} + quantize_static( + model_fp32_path, + model_q8_qdq_path, + data_reader, + quant_format=QuantFormat.QDQ, + activation_type=activation_type, + weight_type=weight_type, + extra_options=extra_options, + ) + qdqnode_counts = { + "Conv": 1, + "QuantizeLinear": 3, + "DequantizeLinear": 4, + "MaxPool": 1, + } check_op_type_count(self, model_q8_qdq_path, **qdqnode_counts) - qnode_io_qtypes = {'QuantizeLinear': [['i', 2, activation_proto_qtype], ['o', 0, activation_proto_qtype]]} - qnode_io_qtypes.update({'DequantizeLinear': [['i', 2, activation_proto_qtype]]}) + qnode_io_qtypes = { + "QuantizeLinear": [ + ["i", 2, activation_proto_qtype], + ["o", 0, activation_proto_qtype], + ] + } + qnode_io_qtypes.update({"DequantizeLinear": [["i", 2, activation_proto_qtype]]}) check_qtype_by_node_type(self, model_q8_qdq_path, qnode_io_qtypes) data_reader.rewind() check_model_correctness(self, model_fp32_path, model_q8_qdq_path, data_reader.get_next()) def test_quantize_maxpool(self): - self.quantize_maxpool_test(QuantType.QUInt8, QuantType.QUInt8, extra_options={ }) + self.quantize_maxpool_test(QuantType.QUInt8, QuantType.QUInt8, extra_options={}) def test_quantize_maxpool_s8s8(self): - self.quantize_maxpool_test(QuantType.QInt8, QuantType.QInt8, extra_options={'ActivationSymmetric': True}) + self.quantize_maxpool_test( + QuantType.QInt8, + QuantType.QInt8, + extra_options={"ActivationSymmetric": True}, + ) + -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/onnxruntime/test/python/quantization/test_op_pad.py b/onnxruntime/test/python/quantization/test_op_pad.py index 4e807582d37d2..df288826bed51 100644 --- a/onnxruntime/test/python/quantization/test_op_pad.py +++ b/onnxruntime/test/python/quantization/test_op_pad.py @@ -7,12 +7,14 @@ # -------------------------------------------------------------------------- import unittest -import onnx + import numpy as np -from onnx import helper, TensorProto -from onnxruntime.quantization import quantize_static, quantize_dynamic, QuantType, QuantFormat +import onnx +from onnx import TensorProto, helper from op_test_utils import TestDataFeeds, check_model_correctness, check_op_type_count, check_qtype_by_node_type +from onnxruntime.quantization import QuantFormat, QuantType, quantize_dynamic, quantize_static + class TestOpQuatizerPad(unittest.TestCase): def input_feeds(self, n, name2shape): @@ -25,7 +27,14 @@ def input_feeds(self, n, name2shape): dr = TestDataFeeds(input_data_list) return dr - def construct_model_pad(self, output_model_path, pad_mode, pad_input_shape, pad_dims, constant_value=None): + def construct_model_pad( + self, + output_model_path, + pad_mode, + pad_input_shape, + pad_dims, + constant_value=None, + ): # (input) # | # Pad @@ -34,29 +43,42 @@ def construct_model_pad(self, output_model_path, pad_mode, pad_input_shape, pad_ rank = len(pad_input_shape) self.assertEqual(rank * 2, len(pad_dims)) - input_tensor = helper.make_tensor_value_info('input', TensorProto.FLOAT, pad_input_shape) - pad_dims_initializer = helper.make_tensor('pad_dims', TensorProto.INT64, [2 * rank], pad_dims) + input_tensor = helper.make_tensor_value_info("input", TensorProto.FLOAT, pad_input_shape) + pad_dims_initializer = helper.make_tensor("pad_dims", TensorProto.INT64, [2 * rank], pad_dims) output_shape = [sum(e) for e in list(zip(pad_input_shape, pad_dims[:rank], pad_dims[rank:]))] - output_tensor = helper.make_tensor_value_info('output', TensorProto.FLOAT, output_shape) + output_tensor = helper.make_tensor_value_info("output", TensorProto.FLOAT, output_shape) - inputs = ['input', 'pad_dims'] + inputs = ["input", "pad_dims"] initializers = [pad_dims_initializer] - if (constant_value is not None) and (pad_mode is None or pad_mode == 'constant'): - constant_value_tensor = helper.make_tensor('padding_value', TensorProto.FLOAT, [], [constant_value]) - inputs.extend(['padding_value']) + if (constant_value is not None) and (pad_mode is None or pad_mode == "constant"): + constant_value_tensor = helper.make_tensor("padding_value", TensorProto.FLOAT, [], [constant_value]) + inputs.extend(["padding_value"]) initializers.extend([constant_value_tensor]) - kwargs = {'mode': pad_mode} if pad_mode is not None else {} - pad_node = helper.make_node('Pad', inputs, ['output'], name='PadNode', **kwargs) - - graph = helper.make_graph([pad_node], 'TestOpQuantizerPad_test_model', - [input_tensor], [output_tensor], initializer=initializers) + kwargs = {"mode": pad_mode} if pad_mode is not None else {} + pad_node = helper.make_node("Pad", inputs, ["output"], name="PadNode", **kwargs) + + graph = helper.make_graph( + [pad_node], + "TestOpQuantizerPad_test_model", + [input_tensor], + [output_tensor], + initializer=initializers, + ) model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)]) model.ir_version = 7 # use stable onnx ir version onnx.save(model, output_model_path) - def construct_model_conv_pad(self, output_model_path, conv_input_shape, conv_weight_shape, - pad_input_shape, pad_mode, pad_dims, constant_value=None): + def construct_model_conv_pad( + self, + output_model_path, + conv_input_shape, + conv_weight_shape, + pad_input_shape, + pad_mode, + pad_dims, + constant_value=None, + ): # (input) # \ # Conv @@ -67,133 +89,235 @@ def construct_model_conv_pad(self, output_model_path, conv_input_shape, conv_wei rank = len(pad_input_shape) self.assertEqual(rank * 2, len(pad_dims)) - input_tensor = helper.make_tensor_value_info('input', TensorProto.FLOAT, conv_input_shape) + input_tensor = helper.make_tensor_value_info("input", TensorProto.FLOAT, conv_input_shape) conv_weight_arr = np.random.randint(-1, 2, conv_weight_shape).astype(np.float32) - conv_weight_initializer = onnx.numpy_helper.from_array(conv_weight_arr, name='conv1_weight') - conv_node = onnx.helper.make_node('Conv', ['input', 'conv1_weight'], ['conv_output'], name='conv_node') + conv_weight_initializer = onnx.numpy_helper.from_array(conv_weight_arr, name="conv1_weight") + conv_node = onnx.helper.make_node("Conv", ["input", "conv1_weight"], ["conv_output"], name="conv_node") - identity_out = helper.make_tensor_value_info('identity_out', TensorProto.FLOAT, pad_input_shape) - identity_node = helper.make_node('Identity', ['conv_output'], ['identity_out'], name='IdentityNode') + identity_out = helper.make_tensor_value_info("identity_out", TensorProto.FLOAT, pad_input_shape) + identity_node = helper.make_node("Identity", ["conv_output"], ["identity_out"], name="IdentityNode") - pad_dims_initializer = helper.make_tensor('pad_dims', TensorProto.INT64, [2 * rank], pad_dims) + pad_dims_initializer = helper.make_tensor("pad_dims", TensorProto.INT64, [2 * rank], pad_dims) output_shape = [sum(e) for e in list(zip(pad_input_shape, pad_dims[:rank], pad_dims[rank:]))] - output_tensor = helper.make_tensor_value_info('output', TensorProto.FLOAT, output_shape) - pad_inputs = ['conv_output', 'pad_dims'] + output_tensor = helper.make_tensor_value_info("output", TensorProto.FLOAT, output_shape) + pad_inputs = ["conv_output", "pad_dims"] initializers = [conv_weight_initializer, pad_dims_initializer] - if (constant_value is not None) and (pad_mode is None or pad_mode == 'constant'): - constant_value_tensor = helper.make_tensor('padding_value', TensorProto.FLOAT, [], [constant_value]) - pad_inputs.extend(['padding_value']) + if (constant_value is not None) and (pad_mode is None or pad_mode == "constant"): + constant_value_tensor = helper.make_tensor("padding_value", TensorProto.FLOAT, [], [constant_value]) + pad_inputs.extend(["padding_value"]) initializers.extend([constant_value_tensor]) - kwargs = {'mode': pad_mode} if pad_mode is not None else {} - pad_node = helper.make_node('Pad', pad_inputs, ['output'], name='pad_node', **kwargs) - - graph = helper.make_graph([conv_node, identity_node, pad_node], 'TestOpQuantizerPad_test_model', - [input_tensor], [identity_out, output_tensor], initializer=initializers) + kwargs = {"mode": pad_mode} if pad_mode is not None else {} + pad_node = helper.make_node("Pad", pad_inputs, ["output"], name="pad_node", **kwargs) + + graph = helper.make_graph( + [conv_node, identity_node, pad_node], + "TestOpQuantizerPad_test_model", + [input_tensor], + [identity_out, output_tensor], + initializer=initializers, + ) model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)]) model.ir_version = 7 # use stable onnx ir version onnx.save(model, output_model_path) - def quantize_model(self, model_fp32_path, model_i8_path, data_reader=None, - activation_type=QuantType.QUInt8, weight_type=QuantType.QUInt8, extra_options={}): + def quantize_model( + self, + model_fp32_path, + model_i8_path, + data_reader=None, + activation_type=QuantType.QUInt8, + weight_type=QuantType.QUInt8, + extra_options={}, + ): if data_reader is not None: - quantize_static(model_fp32_path, model_i8_path, data_reader, reduce_range=True, quant_format=QuantFormat.QOperator, - activation_type=activation_type, weight_type=weight_type, extra_options=extra_options) + quantize_static( + model_fp32_path, + model_i8_path, + data_reader, + reduce_range=True, + quant_format=QuantFormat.QOperator, + activation_type=activation_type, + weight_type=weight_type, + extra_options=extra_options, + ) else: - quantize_dynamic(model_fp32_path, model_i8_path, reduce_range=True, - weight_type=weight_type, extra_options=extra_options) - - def verify_should_not_trigger(self, quantize_mode='static'): + quantize_dynamic( + model_fp32_path, + model_i8_path, + reduce_range=True, + weight_type=weight_type, + extra_options=extra_options, + ) + + def verify_should_not_trigger(self, quantize_mode="static"): np.random.seed(108) - model_fp32_path = 'qop_pad_notrigger_fp32_{}.onnx'.format(quantize_mode) - model_i8_path = 'qop_pad_notrigger_i8_{}.onnx'.format(quantize_mode) - data_reader = self.input_feeds(1, {'input': [1, 16, 31, 31]}) - self.construct_model_pad(model_fp32_path, 'constant', [1, 16, 31, 31], [0, 0, 1, 2, 0, 0, 3, 4]) - self.quantize_model(model_fp32_path, model_i8_path, None if quantize_mode != 'static' else data_reader) + model_fp32_path = "qop_pad_notrigger_fp32_{}.onnx".format(quantize_mode) + model_i8_path = "qop_pad_notrigger_i8_{}.onnx".format(quantize_mode) + data_reader = self.input_feeds(1, {"input": [1, 16, 31, 31]}) + self.construct_model_pad(model_fp32_path, "constant", [1, 16, 31, 31], [0, 0, 1, 2, 0, 0, 3, 4]) + self.quantize_model( + model_fp32_path, + model_i8_path, + None if quantize_mode != "static" else data_reader, + ) data_reader.rewind() # DequantizeLinear=0 pad node is not been quantized as input is not quantized. - check_op_type_count(self, model_i8_path, DynamicQuantizeLinear=0, QuantizeLinear=0, DequantizeLinear=0) + check_op_type_count( + self, + model_i8_path, + DynamicQuantizeLinear=0, + QuantizeLinear=0, + DequantizeLinear=0, + ) check_model_correctness(self, model_fp32_path, model_i8_path, data_reader.get_next()) def test_static_quantize_no_trigger(self): - self.verify_should_not_trigger(quantize_mode='static') + self.verify_should_not_trigger(quantize_mode="static") def test_dynamic_quantize_no_trigger(self): - self.verify_should_not_trigger(quantize_mode='dynamic') - - def verify_quantize_with_pad_mode(self, pad_mode, constant_value=None, quantize_mode='static', rtol=0.01, atol=0.05, - activation_type=QuantType.QUInt8, weight_type=QuantType.QUInt8, extra_options={}): + self.verify_should_not_trigger(quantize_mode="dynamic") + + def verify_quantize_with_pad_mode( + self, + pad_mode, + constant_value=None, + quantize_mode="static", + rtol=0.01, + atol=0.05, + activation_type=QuantType.QUInt8, + weight_type=QuantType.QUInt8, + extra_options={}, + ): np.random.seed(108) - tag_pad_mode = pad_mode if pad_mode is not None else 'none' - tag_constant_value = '' if constant_value is None else '_value' - model_fp32_path = 'qop_pad_{}_fp32_{}{}.onnx'.format(quantize_mode, tag_pad_mode, tag_constant_value) - data_reader = self.input_feeds(1, {'input': [1, 8, 33, 33]}) - self.construct_model_conv_pad(model_fp32_path, [1, 8, 33, 33], [16, 8, 3, 3], [1, 16, 31, 31], - pad_mode, [0, 0, 1, 2, 0, 0, 3, 4], constant_value=constant_value) + tag_pad_mode = pad_mode if pad_mode is not None else "none" + tag_constant_value = "" if constant_value is None else "_value" + model_fp32_path = "qop_pad_{}_fp32_{}{}.onnx".format(quantize_mode, tag_pad_mode, tag_constant_value) + data_reader = self.input_feeds(1, {"input": [1, 8, 33, 33]}) + self.construct_model_conv_pad( + model_fp32_path, + [1, 8, 33, 33], + [16, 8, 3, 3], + [1, 16, 31, 31], + pad_mode, + [0, 0, 1, 2, 0, 0, 3, 4], + constant_value=constant_value, + ) activation_proto_qtype = TensorProto.UINT8 if activation_type == QuantType.QUInt8 else TensorProto.INT8 - activation_type_str = 'u8' if (activation_type == QuantType.QUInt8) else 's8' - weight_type_str = 'u8' if (weight_type == QuantType.QUInt8) else 's8' - model_i8_path = 'qop_pad_{}_i8_{}{}_{}{}.onnx'.format( - quantize_mode, tag_pad_mode, tag_constant_value, activation_type_str, weight_type_str) + activation_type_str = "u8" if (activation_type == QuantType.QUInt8) else "s8" + weight_type_str = "u8" if (weight_type == QuantType.QUInt8) else "s8" + model_i8_path = "qop_pad_{}_i8_{}{}_{}{}.onnx".format( + quantize_mode, + tag_pad_mode, + tag_constant_value, + activation_type_str, + weight_type_str, + ) data_reader.rewind() - self.quantize_model(model_fp32_path, model_i8_path, None if quantize_mode != 'static' else data_reader, - activation_type=activation_type, weight_type=weight_type, extra_options=extra_options) + self.quantize_model( + model_fp32_path, + model_i8_path, + None if quantize_mode != "static" else data_reader, + activation_type=activation_type, + weight_type=weight_type, + extra_options=extra_options, + ) # DequantizeLinear=2 means there are one DequantizeLinear Node aftr both conv and pad, # which means pad node is running in quantized semantic. # In dynamic quantize mode, pad operator in fact not quantized as input is fp32. - if quantize_mode != 'static': - kwargs = {'DynamicQuantizeLinear': 1} if activation_type == QuantType.QUInt8 else {'QuantizeLinear': 1} + if quantize_mode != "static": + kwargs = {"DynamicQuantizeLinear": 1} if activation_type == QuantType.QUInt8 else {"QuantizeLinear": 1} else: - kwargs = {'DequantizeLinear': 2, 'QuantizeLinear': 1} + kwargs = {"DequantizeLinear": 2, "QuantizeLinear": 1} check_op_type_count(self, model_i8_path, **kwargs) # check node input/output type if such node exists in the graph - qnode_io_qtypes = {'QuantizeLinear': [['i', 2, activation_proto_qtype], ['o', 0, activation_proto_qtype]]} - qnode_io_qtypes.update({'DequantizeLinear': [['i', 2, activation_proto_qtype]]}) - qnode_io_qtypes.update({'ConvInteger': [['i', 2, activation_proto_qtype]]}) + qnode_io_qtypes = { + "QuantizeLinear": [ + ["i", 2, activation_proto_qtype], + ["o", 0, activation_proto_qtype], + ] + } + qnode_io_qtypes.update({"DequantizeLinear": [["i", 2, activation_proto_qtype]]}) + qnode_io_qtypes.update({"ConvInteger": [["i", 2, activation_proto_qtype]]}) check_qtype_by_node_type(self, model_i8_path, qnode_io_qtypes) data_reader.rewind() - check_model_correctness(self, model_fp32_path, model_i8_path, data_reader.get_next(), rtol=rtol, atol=atol) + check_model_correctness( + self, + model_fp32_path, + model_i8_path, + data_reader.get_next(), + rtol=rtol, + atol=atol, + ) def test_static_mode_edge(self): - self.verify_quantize_with_pad_mode('edge', constant_value=None) + self.verify_quantize_with_pad_mode("edge", constant_value=None) def test_static_mode_reflect(self): - self.verify_quantize_with_pad_mode('reflect', constant_value=None) + self.verify_quantize_with_pad_mode("reflect", constant_value=None) def test_static_mode_constant_default(self): - self.verify_quantize_with_pad_mode('constant', constant_value=None) + self.verify_quantize_with_pad_mode("constant", constant_value=None) def test_static_mode_constant_value(self): - self.verify_quantize_with_pad_mode('constant', constant_value=3.75) + self.verify_quantize_with_pad_mode("constant", constant_value=3.75) def test_static_mode_edge_s8s8(self): - self.verify_quantize_with_pad_mode('edge', constant_value=None, rtol=0.1, atol=0.1, activation_type=QuantType.QInt8, - weight_type=QuantType.QInt8, extra_options={'ActivationSymmetric': True}) + self.verify_quantize_with_pad_mode( + "edge", + constant_value=None, + rtol=0.1, + atol=0.1, + activation_type=QuantType.QInt8, + weight_type=QuantType.QInt8, + extra_options={"ActivationSymmetric": True}, + ) def test_static_mode_reflect_s8s8(self): - self.verify_quantize_with_pad_mode('reflect', constant_value=None, rtol=0.1, atol=0.1, activation_type=QuantType.QInt8, - weight_type=QuantType.QInt8, extra_options={'ActivationSymmetric': True}) + self.verify_quantize_with_pad_mode( + "reflect", + constant_value=None, + rtol=0.1, + atol=0.1, + activation_type=QuantType.QInt8, + weight_type=QuantType.QInt8, + extra_options={"ActivationSymmetric": True}, + ) def test_static_mode_constant_default_s8s8(self): - self.verify_quantize_with_pad_mode('constant', constant_value=None, rtol=0.1, atol=0.1, activation_type=QuantType.QInt8, - weight_type=QuantType.QInt8, extra_options={'ActivationSymmetric': True}) + self.verify_quantize_with_pad_mode( + "constant", + constant_value=None, + rtol=0.1, + atol=0.1, + activation_type=QuantType.QInt8, + weight_type=QuantType.QInt8, + extra_options={"ActivationSymmetric": True}, + ) def test_static_mode_constant_value_s8s8(self): - self.verify_quantize_with_pad_mode('constant', constant_value=3.75, rtol=0.1, atol=0.1, activation_type=QuantType.QInt8, - weight_type=QuantType.QInt8, extra_options={'ActivationSymmetric': True}) + self.verify_quantize_with_pad_mode( + "constant", + constant_value=3.75, + rtol=0.1, + atol=0.1, + activation_type=QuantType.QInt8, + weight_type=QuantType.QInt8, + extra_options={"ActivationSymmetric": True}, + ) def test_dynamic_mode_edge(self): - self.verify_quantize_with_pad_mode('edge', constant_value=None, quantize_mode='dynamic') + self.verify_quantize_with_pad_mode("edge", constant_value=None, quantize_mode="dynamic") def test_dynamic_mode_reflect(self): - self.verify_quantize_with_pad_mode('reflect', constant_value=None, quantize_mode='dynamic') + self.verify_quantize_with_pad_mode("reflect", constant_value=None, quantize_mode="dynamic") def test_dynamic_mode_constant_default(self): - self.verify_quantize_with_pad_mode('constant', constant_value=None, quantize_mode='dynamic') + self.verify_quantize_with_pad_mode("constant", constant_value=None, quantize_mode="dynamic") def test_dynamic_mode_constant_value(self): - self.verify_quantize_with_pad_mode('constant', constant_value=3.75, quantize_mode='dynamic') + self.verify_quantize_with_pad_mode("constant", constant_value=3.75, quantize_mode="dynamic") # TODO: uncomment following after ConvInteger s8 supported # def test_dynamic_mode_edge_s8s8(self): @@ -213,5 +337,5 @@ def test_dynamic_mode_constant_value(self): # weight_type=QuantType.QInt8, extra_options={'ActivationSymmetric': True}) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/onnxruntime/test/python/quantization/test_op_pooling.py b/onnxruntime/test/python/quantization/test_op_pooling.py index fdc66f945da78..b0561bd79f8e1 100644 --- a/onnxruntime/test/python/quantization/test_op_pooling.py +++ b/onnxruntime/test/python/quantization/test_op_pooling.py @@ -7,11 +7,19 @@ # -------------------------------------------------------------------------- import unittest -import onnx + import numpy as np -from onnx import helper, TensorProto -from onnxruntime.quantization import quantize_static, QuantFormat, QuantType -from op_test_utils import TestDataFeeds, check_model_correctness, check_op_type_count, check_op_nodes, check_qtype_by_node_type +import onnx +from onnx import TensorProto, helper +from op_test_utils import ( + TestDataFeeds, + check_model_correctness, + check_op_nodes, + check_op_type_count, + check_qtype_by_node_type, +) + +from onnxruntime.quantization import QuantFormat, QuantType, quantize_static class TestOpAveragePool(unittest.TestCase): @@ -25,11 +33,15 @@ def input_feeds(self, n, name2shape): dr = TestDataFeeds(input_data_list) return dr - def construct_model_conv_avgpool(self, output_model_path, - conv_input_shape, conv_weight_shape, - avgpool_input_shape, avgpool_attributes, - output_shape, - ): + def construct_model_conv_avgpool( + self, + output_model_path, + conv_input_shape, + conv_weight_shape, + avgpool_input_shape, + avgpool_attributes, + output_shape, + ): # (input) # \ # Conv @@ -37,61 +49,116 @@ def construct_model_conv_avgpool(self, output_model_path, # Identity AveragePool # / \ # (identity_out) (output) - input_tensor = helper.make_tensor_value_info('input', TensorProto.FLOAT, conv_input_shape) + input_tensor = helper.make_tensor_value_info("input", TensorProto.FLOAT, conv_input_shape) conv_weight_arr = np.random.randint(-1, 2, conv_weight_shape).astype(np.float32) - conv_weight_initializer = onnx.numpy_helper.from_array(conv_weight_arr, name='conv1_weight') - conv_node = onnx.helper.make_node('Conv', ['input', 'conv1_weight'], ['conv_output'], name='conv_node') + conv_weight_initializer = onnx.numpy_helper.from_array(conv_weight_arr, name="conv1_weight") + conv_node = onnx.helper.make_node("Conv", ["input", "conv1_weight"], ["conv_output"], name="conv_node") - identity_out = helper.make_tensor_value_info('identity_out', TensorProto.FLOAT, avgpool_input_shape) - identity_node = helper.make_node('Identity', ['conv_output'], ['identity_out'], name='IdentityNode') + identity_out = helper.make_tensor_value_info("identity_out", TensorProto.FLOAT, avgpool_input_shape) + identity_node = helper.make_node("Identity", ["conv_output"], ["identity_out"], name="IdentityNode") initializers = [conv_weight_initializer] - output_tensor = helper.make_tensor_value_info('output', TensorProto.FLOAT, output_shape) - avgpool_node = helper.make_node('AveragePool', ['conv_output'], ['output'], name='avgpool_node', **avgpool_attributes) - - graph = helper.make_graph([conv_node, identity_node, avgpool_node], 'TestOpQuantizerAveragePool_test_model', - [input_tensor], [identity_out, output_tensor], initializer=initializers) + output_tensor = helper.make_tensor_value_info("output", TensorProto.FLOAT, output_shape) + avgpool_node = helper.make_node( + "AveragePool", ["conv_output"], ["output"], name="avgpool_node", **avgpool_attributes + ) + + graph = helper.make_graph( + [conv_node, identity_node, avgpool_node], + "TestOpQuantizerAveragePool_test_model", + [input_tensor], + [identity_out, output_tensor], + initializer=initializers, + ) model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 12)]) - model.ir_version = 7 # use stable onnx ir version + model.ir_version = 7 # use stable onnx ir version onnx.save(model, output_model_path) - def quantize_avgpool_test(self, activation_type, weight_type, extra_options = {}): + def quantize_avgpool_test(self, activation_type, weight_type, extra_options={}): np.random.seed(1) - model_fp32_path = 'avgpool_fp32.onnx' - self.construct_model_conv_avgpool(model_fp32_path, - [1, 2, 26, 42], [3, 2, 3, 3], - [1, 3, 24, 40], {'kernel_shape': [3, 3]}, - [1, 3, 22, 38]) - data_reader = self.input_feeds(1, {'input': [1, 2, 26, 42]}) + model_fp32_path = "avgpool_fp32.onnx" + self.construct_model_conv_avgpool( + model_fp32_path, + [1, 2, 26, 42], + [3, 2, 3, 3], + [1, 3, 24, 40], + {"kernel_shape": [3, 3]}, + [1, 3, 22, 38], + ) + data_reader = self.input_feeds(1, {"input": [1, 2, 26, 42]}) activation_proto_qtype = TensorProto.UINT8 if activation_type == QuantType.QUInt8 else TensorProto.INT8 - activation_type_str = 'u8' if (activation_type == QuantType.QUInt8) else 's8' - weight_type_str = 'u8' if (weight_type == QuantType.QUInt8) else 's8' - model_q8_path = 'avgpool_{}{}.onnx'.format(activation_type_str, weight_type_str) - model_q8_qdq_path = 'avgpool_qdq_{}{}.onnx'.format(activation_type_str, weight_type_str) + activation_type_str = "u8" if (activation_type == QuantType.QUInt8) else "s8" + weight_type_str = "u8" if (weight_type == QuantType.QUInt8) else "s8" + model_q8_path = "avgpool_{}{}.onnx".format(activation_type_str, weight_type_str) + model_q8_qdq_path = "avgpool_qdq_{}{}.onnx".format(activation_type_str, weight_type_str) # Verify QOperator mode data_reader.rewind() - quantize_static(model_fp32_path, model_q8_path, data_reader, quant_format=QuantFormat.QOperator, - activation_type = activation_type, weight_type = weight_type, extra_options = extra_options) - qnode_counts = {'QLinearConv': 1, 'QuantizeLinear': 1, 'DequantizeLinear': 2, 'QLinearAveragePool': 1} + quantize_static( + model_fp32_path, + model_q8_path, + data_reader, + quant_format=QuantFormat.QOperator, + activation_type=activation_type, + weight_type=weight_type, + extra_options=extra_options, + ) + qnode_counts = { + "QLinearConv": 1, + "QuantizeLinear": 1, + "DequantizeLinear": 2, + "QLinearAveragePool": 1, + } check_op_type_count(self, model_q8_path, **qnode_counts) - qnode_io_qtypes = {'QuantizeLinear' : [['i', 2, activation_proto_qtype], ['o', 0, activation_proto_qtype]]} - qnode_io_qtypes.update({'QLinearConv' : [['i', 2, activation_proto_qtype], ['i', 7, activation_proto_qtype], ['o', 0, activation_proto_qtype]]}) - qnode_io_qtypes.update({'QLinearAveragePool' : [['i', 4, activation_proto_qtype]]}) # shape info note workig on custome ops + qnode_io_qtypes = { + "QuantizeLinear": [ + ["i", 2, activation_proto_qtype], + ["o", 0, activation_proto_qtype], + ] + } + qnode_io_qtypes.update( + { + "QLinearConv": [ + ["i", 2, activation_proto_qtype], + ["i", 7, activation_proto_qtype], + ["o", 0, activation_proto_qtype], + ] + } + ) + qnode_io_qtypes.update( + {"QLinearAveragePool": [["i", 4, activation_proto_qtype]]} + ) # shape info note workig on custome ops check_qtype_by_node_type(self, model_q8_path, qnode_io_qtypes) data_reader.rewind() check_model_correctness(self, model_fp32_path, model_q8_path, data_reader.get_next()) # Verify QDQ mode data_reader.rewind() - quantize_static(model_fp32_path, model_q8_qdq_path, data_reader, quant_format=QuantFormat.QDQ, - activation_type = activation_type, weight_type = weight_type, extra_options = extra_options) - qdqnode_counts = {'Conv': 1, 'QuantizeLinear': 3, 'DequantizeLinear': 4, 'AveragePool': 1} + quantize_static( + model_fp32_path, + model_q8_qdq_path, + data_reader, + quant_format=QuantFormat.QDQ, + activation_type=activation_type, + weight_type=weight_type, + extra_options=extra_options, + ) + qdqnode_counts = { + "Conv": 1, + "QuantizeLinear": 3, + "DequantizeLinear": 4, + "AveragePool": 1, + } check_op_type_count(self, model_q8_qdq_path, **qdqnode_counts) - qnode_io_qtypes = {'QuantizeLinear' : [['i', 2, activation_proto_qtype], ['o', 0, activation_proto_qtype]]} + qnode_io_qtypes = { + "QuantizeLinear": [ + ["i", 2, activation_proto_qtype], + ["o", 0, activation_proto_qtype], + ] + } check_qtype_by_node_type(self, model_q8_qdq_path, qnode_io_qtypes) data_reader.rewind() check_model_correctness(self, model_fp32_path, model_q8_qdq_path, data_reader.get_next()) @@ -100,7 +167,12 @@ def test_quantize_avgpool(self): self.quantize_avgpool_test(QuantType.QUInt8, QuantType.QUInt8) def test_quantize_avgpool_s8s8(self): - self.quantize_avgpool_test(QuantType.QInt8, QuantType.QInt8, extra_options = {'ActivationSymmetric' : True}) + self.quantize_avgpool_test( + QuantType.QInt8, + QuantType.QInt8, + extra_options={"ActivationSymmetric": True}, + ) + -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/onnxruntime/test/python/quantization/test_op_reshape.py b/onnxruntime/test/python/quantization/test_op_reshape.py index 9d05ad1ef3d21..aeedda6821245 100644 --- a/onnxruntime/test/python/quantization/test_op_reshape.py +++ b/onnxruntime/test/python/quantization/test_op_reshape.py @@ -7,11 +7,19 @@ # -------------------------------------------------------------------------- import unittest -import onnx + import numpy as np -from onnx import helper, TensorProto -from onnxruntime.quantization import quantize_static, QuantFormat, QuantType -from op_test_utils import TestDataFeeds, check_model_correctness, check_op_type_count, check_op_nodes, check_qtype_by_node_type +import onnx +from onnx import TensorProto, helper +from op_test_utils import ( + TestDataFeeds, + check_model_correctness, + check_op_nodes, + check_op_type_count, + check_qtype_by_node_type, +) + +from onnxruntime.quantization import QuantFormat, QuantType, quantize_static class TestOpReshape(unittest.TestCase): @@ -33,76 +41,116 @@ def construct_model_matmul_reshape(self, output_model_path, input_shape, weight_ # Reshape # | # (output) - input_name = 'input' - output_name = 'output' + input_name = "input" + output_name = "output" initializers = [] # make MatMul node - weight_name = 'matmul_weight' - matmul_output_name = 'matmul_output' + weight_name = "matmul_weight" + matmul_output_name = "matmul_output" matmul_inputs = [input_name, weight_name] matmul_outputs = [matmul_output_name] - matmul_name = 'matmul_node' + matmul_name = "matmul_node" matmul_weight_data = np.random.normal(0, 0.1, weight_shape).astype(np.float32) initializers.append(onnx.numpy_helper.from_array(matmul_weight_data, name=weight_name)) - matmul_node = onnx.helper.make_node('MatMul', matmul_inputs, matmul_outputs, name=matmul_name) + matmul_node = onnx.helper.make_node("MatMul", matmul_inputs, matmul_outputs, name=matmul_name) # make Reshape node - reshape_shape = 'reshape_shape' + reshape_shape = "reshape_shape" reshape_inputs = [matmul_output_name, reshape_shape] reshape_output = [output_name] - reshape_name = 'reshape_node' + reshape_name = "reshape_node" initializers.append(onnx.numpy_helper.from_array(np.array(output_shape, dtype=np.int64), name=reshape_shape)) - reshape_node = onnx.helper.make_node('Reshape', reshape_inputs, reshape_output, name=reshape_name) + reshape_node = onnx.helper.make_node("Reshape", reshape_inputs, reshape_output, name=reshape_name) # make graph input_tensor = helper.make_tensor_value_info(input_name, TensorProto.FLOAT, input_shape) output_tensor = helper.make_tensor_value_info(output_name, TensorProto.FLOAT, output_shape) - graph_name = 'Reshape_Quant_Test' - graph = helper.make_graph([matmul_node, reshape_node], graph_name, - [input_tensor], [output_tensor], initializer=initializers) + graph_name = "Reshape_Quant_Test" + graph = helper.make_graph( + [matmul_node, reshape_node], + graph_name, + [input_tensor], + [output_tensor], + initializer=initializers, + ) model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 11)]) - model.ir_version = 7 # use stable onnx ir version + model.ir_version = 7 # use stable onnx ir version onnx.save(model, output_model_path) - def quantize_reshape_test(self, activation_type, weight_type, extra_options = {}): + def quantize_reshape_test(self, activation_type, weight_type, extra_options={}): np.random.seed(1) - model_fp32_path = 'reshape_fp32.onnx' + model_fp32_path = "reshape_fp32.onnx" - self.construct_model_matmul_reshape(model_fp32_path, - [3, 7], - [7, 3], - [1, 9]) + self.construct_model_matmul_reshape(model_fp32_path, [3, 7], [7, 3], [1, 9]) activation_proto_qtype = TensorProto.UINT8 if activation_type == QuantType.QUInt8 else TensorProto.INT8 - activation_type_str = 'u8' if (activation_type == QuantType.QUInt8) else 's8' - weight_type_str = 'u8' if (weight_type == QuantType.QUInt8) else 's8' - model_uint8_path = 'reshape_{}{}.onnx'.format(activation_type_str, weight_type_str) - model_uint8_qdq_path = 'reshape_{}{}_qdq.onnx'.format(activation_type_str, weight_type_str) + activation_type_str = "u8" if (activation_type == QuantType.QUInt8) else "s8" + weight_type_str = "u8" if (weight_type == QuantType.QUInt8) else "s8" + model_uint8_path = "reshape_{}{}.onnx".format(activation_type_str, weight_type_str) + model_uint8_qdq_path = "reshape_{}{}_qdq.onnx".format(activation_type_str, weight_type_str) # Verify QOperator mode - data_reader = self.input_feeds(1, {'input': [3, 7]}) - quantize_static(model_fp32_path, model_uint8_path, data_reader, quant_format=QuantFormat.QOperator, - activation_type = activation_type, weight_type = weight_type, extra_options = extra_options) + data_reader = self.input_feeds(1, {"input": [3, 7]}) + quantize_static( + model_fp32_path, + model_uint8_path, + data_reader, + quant_format=QuantFormat.QOperator, + activation_type=activation_type, + weight_type=weight_type, + extra_options=extra_options, + ) # make sure transpose become xint8 operator, its input name could tell that - check_op_nodes(self, model_uint8_path, lambda node: (node.name != "reshape_node" or node.input[0] != 'matmul_output')) - qnode_counts = {'QLinearMatMul': 1, 'QuantizeLinear': 1, 'DequantizeLinear': 1, 'Reshape': 1} + check_op_nodes( + self, + model_uint8_path, + lambda node: (node.name != "reshape_node" or node.input[0] != "matmul_output"), + ) + qnode_counts = { + "QLinearMatMul": 1, + "QuantizeLinear": 1, + "DequantizeLinear": 1, + "Reshape": 1, + } check_op_type_count(self, model_uint8_path, **qnode_counts) - qnode_io_qtypes = {'QuantizeLinear' : [['i', 2, activation_proto_qtype], ['o', 0, activation_proto_qtype]]} - qnode_io_qtypes.update({'DequantizeLinear' : [['i', 2, activation_proto_qtype]]}) + qnode_io_qtypes = { + "QuantizeLinear": [ + ["i", 2, activation_proto_qtype], + ["o", 0, activation_proto_qtype], + ] + } + qnode_io_qtypes.update({"DequantizeLinear": [["i", 2, activation_proto_qtype]]}) check_qtype_by_node_type(self, model_uint8_path, qnode_io_qtypes) data_reader.rewind() check_model_correctness(self, model_fp32_path, model_uint8_path, data_reader.get_next()) # Verify QDQ mode data_reader.rewind() - quantize_static(model_fp32_path, model_uint8_qdq_path, data_reader, quant_format=QuantFormat.QDQ, - activation_type = activation_type, weight_type = weight_type, extra_options = extra_options) - qdqnode_counts = {'MatMul': 1, 'QuantizeLinear': 3, 'DequantizeLinear': 4, 'Reshape': 1} + quantize_static( + model_fp32_path, + model_uint8_qdq_path, + data_reader, + quant_format=QuantFormat.QDQ, + activation_type=activation_type, + weight_type=weight_type, + extra_options=extra_options, + ) + qdqnode_counts = { + "MatMul": 1, + "QuantizeLinear": 3, + "DequantizeLinear": 4, + "Reshape": 1, + } check_op_type_count(self, model_uint8_qdq_path, **qdqnode_counts) - qnode_io_qtypes = {'QuantizeLinear' : [['i', 2, activation_proto_qtype], ['o', 0, activation_proto_qtype]]} + qnode_io_qtypes = { + "QuantizeLinear": [ + ["i", 2, activation_proto_qtype], + ["o", 0, activation_proto_qtype], + ] + } check_qtype_by_node_type(self, model_uint8_qdq_path, qnode_io_qtypes) data_reader.rewind() check_model_correctness(self, model_fp32_path, model_uint8_qdq_path, data_reader.get_next()) @@ -111,7 +159,12 @@ def test_quantize_reshape(self): self.quantize_reshape_test(QuantType.QUInt8, QuantType.QUInt8) def test_quantize_reshape_s8s8(self): - self.quantize_reshape_test(QuantType.QInt8, QuantType.QInt8, extra_options = {'ActivationSymmetric' : True}) + self.quantize_reshape_test( + QuantType.QInt8, + QuantType.QInt8, + extra_options={"ActivationSymmetric": True}, + ) + -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/onnxruntime/test/python/quantization/test_op_resize.py b/onnxruntime/test/python/quantization/test_op_resize.py index 987bf8cb5984b..1f9703b26f0fe 100644 --- a/onnxruntime/test/python/quantization/test_op_resize.py +++ b/onnxruntime/test/python/quantization/test_op_resize.py @@ -7,11 +7,19 @@ # -------------------------------------------------------------------------- import unittest -import onnx + import numpy as np -from onnx import helper, TensorProto -from onnxruntime.quantization import quantize_static, QuantFormat, QuantType -from op_test_utils import TestDataFeeds, check_model_correctness, check_op_type_count, check_op_nodes, check_qtype_by_node_type +import onnx +from onnx import TensorProto, helper +from op_test_utils import ( + TestDataFeeds, + check_model_correctness, + check_op_nodes, + check_op_type_count, + check_qtype_by_node_type, +) + +from onnxruntime.quantization import QuantFormat, QuantType, quantize_static class TestOpResize(unittest.TestCase): @@ -25,11 +33,18 @@ def input_feeds(self, n, name2shape): dr = TestDataFeeds(input_data_list) return dr - def construct_model_conv_resize(self, output_model_path, - conv_input_shape, conv_weight_shape, - resize_input_shape, resize_output_shape, - resize_attrs, - resize_roi, resize_scales, resize_sizes): + def construct_model_conv_resize( + self, + output_model_path, + conv_input_shape, + conv_weight_shape, + resize_input_shape, + resize_output_shape, + resize_attrs, + resize_roi, + resize_scales, + resize_sizes, + ): # (input) # \ # Conv @@ -37,88 +52,149 @@ def construct_model_conv_resize(self, output_model_path, # Identity Resize # / \ # (identity_out) (output) - input_tensor = helper.make_tensor_value_info('input', TensorProto.FLOAT, conv_input_shape) + input_tensor = helper.make_tensor_value_info("input", TensorProto.FLOAT, conv_input_shape) conv_weight_arr = np.random.randint(-1, 2, conv_weight_shape).astype(np.float32) - conv_weight_initializer = onnx.numpy_helper.from_array(conv_weight_arr, name='conv1_weight') - conv_node = onnx.helper.make_node('Conv', ['input', 'conv1_weight'], ['conv_output'], name='conv_node') + conv_weight_initializer = onnx.numpy_helper.from_array(conv_weight_arr, name="conv1_weight") + conv_node = onnx.helper.make_node("Conv", ["input", "conv1_weight"], ["conv_output"], name="conv_node") - identity_out = helper.make_tensor_value_info('identity_out', TensorProto.FLOAT, resize_input_shape) - identity_node = helper.make_node('Identity', ['conv_output'], ['identity_out'], name='IdentityNode') + identity_out = helper.make_tensor_value_info("identity_out", TensorProto.FLOAT, resize_input_shape) + identity_node = helper.make_node("Identity", ["conv_output"], ["identity_out"], name="IdentityNode") initializers = [conv_weight_initializer] - output_tensor = helper.make_tensor_value_info('output', TensorProto.FLOAT, resize_output_shape) - resize_inputs = ['conv_output'] # resize_roi_name, resize_scales_name, resize_sizes_name] - resize_node = helper.make_node('Resize', resize_inputs, ['output'], name='resize_node', **resize_attrs) + output_tensor = helper.make_tensor_value_info("output", TensorProto.FLOAT, resize_output_shape) + resize_inputs = ["conv_output"] # resize_roi_name, resize_scales_name, resize_sizes_name] + resize_node = helper.make_node("Resize", resize_inputs, ["output"], name="resize_node", **resize_attrs) - if (resize_roi is not None): - resize_roi_name = 'resize_roi' - resize_roi_initializer = helper.make_tensor(resize_roi_name, TensorProto.FLOAT, [len(resize_roi)], resize_roi) + if resize_roi is not None: + resize_roi_name = "resize_roi" + resize_roi_initializer = helper.make_tensor( + resize_roi_name, TensorProto.FLOAT, [len(resize_roi)], resize_roi + ) initializers.extend([resize_roi_initializer]) resize_node.input.extend([resize_roi_name]) else: - resize_node.input.extend(['']) - - if (resize_scales is not None): - resize_scales_name = 'resize_scales' - resize_scales_initializer = helper.make_tensor(resize_scales_name, TensorProto.FLOAT, [ - len(resize_scales)], resize_scales) + resize_node.input.extend([""]) + + if resize_scales is not None: + resize_scales_name = "resize_scales" + resize_scales_initializer = helper.make_tensor( + resize_scales_name, + TensorProto.FLOAT, + [len(resize_scales)], + resize_scales, + ) initializers.extend([resize_scales_initializer]) resize_node.input.extend([resize_scales_name]) else: - resize_node.input.extend(['']) + resize_node.input.extend([""]) - if (resize_sizes is not None): - resize_sizes_name = 'resize_sizes' - resize_sizes_initializer = helper.make_tensor(resize_sizes_name, TensorProto.INT64, [len(resize_sizes)], resize_sizes) + if resize_sizes is not None: + resize_sizes_name = "resize_sizes" + resize_sizes_initializer = helper.make_tensor( + resize_sizes_name, TensorProto.INT64, [len(resize_sizes)], resize_sizes + ) initializers.extend([resize_sizes_initializer]) resize_node.input.extend([resize_sizes_name]) - graph = helper.make_graph([conv_node, identity_node, resize_node], 'TestOpQuantizerResize_test_model', - [input_tensor], [identity_out, output_tensor], initializer=initializers) + graph = helper.make_graph( + [conv_node, identity_node, resize_node], + "TestOpQuantizerResize_test_model", + [input_tensor], + [identity_out, output_tensor], + initializer=initializers, + ) model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)]) - model.ir_version = 7 # use stable onnx ir version + model.ir_version = 7 # use stable onnx ir version onnx.save(model, output_model_path) - def quantize_resize_test(self, activation_type, weight_type, extra_options = {}): + def quantize_resize_test(self, activation_type, weight_type, extra_options={}): np.random.seed(1) - model_fp32_path = 'resize_fp32.onnx' - - kwargs = {'coordinate_transformation_mode': 'asymmetric', 'mode': 'nearest', 'nearest_mode': 'floor'} - self.construct_model_conv_resize(model_fp32_path, - [1, 2, 26, 42], [3, 2, 3, 3], - [1, 3, 24, 40], [1, 3, 48, 80], - kwargs, - [0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0], [1.0, 1.0, 2.0, 2.0], None) + model_fp32_path = "resize_fp32.onnx" + + kwargs = { + "coordinate_transformation_mode": "asymmetric", + "mode": "nearest", + "nearest_mode": "floor", + } + self.construct_model_conv_resize( + model_fp32_path, + [1, 2, 26, 42], + [3, 2, 3, 3], + [1, 3, 24, 40], + [1, 3, 48, 80], + kwargs, + [0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0], + [1.0, 1.0, 2.0, 2.0], + None, + ) activation_proto_qtype = TensorProto.UINT8 if activation_type == QuantType.QUInt8 else TensorProto.INT8 - activation_type_str = 'u8' if (activation_type == QuantType.QUInt8) else 's8' - weight_type_str = 'u8' if (weight_type == QuantType.QUInt8) else 's8' - model_uint8_path = 'resize_{}{}.onnx'.format(activation_type_str, weight_type_str) - model_uint8_qdq_path = 'resize_{}{}_qdq.onnx'.format(activation_type_str, weight_type_str) + activation_type_str = "u8" if (activation_type == QuantType.QUInt8) else "s8" + weight_type_str = "u8" if (weight_type == QuantType.QUInt8) else "s8" + model_uint8_path = "resize_{}{}.onnx".format(activation_type_str, weight_type_str) + model_uint8_qdq_path = "resize_{}{}_qdq.onnx".format(activation_type_str, weight_type_str) # Verify QOperator mode - data_reader = self.input_feeds(1, {'input': [1, 2, 26, 42]}) - quantize_static(model_fp32_path, model_uint8_path, data_reader, quant_format=QuantFormat.QOperator, - activation_type = activation_type, weight_type = weight_type, extra_options = extra_options) + data_reader = self.input_feeds(1, {"input": [1, 2, 26, 42]}) + quantize_static( + model_fp32_path, + model_uint8_path, + data_reader, + quant_format=QuantFormat.QOperator, + activation_type=activation_type, + weight_type=weight_type, + extra_options=extra_options, + ) # make sure resize become xint8 operator, its input name could tell that - check_op_nodes(self, model_uint8_path, lambda node: (node.name != "resize_node" or node.input[0] != 'conv_output')) - qnode_counts = {'QLinearConv': 1, 'QuantizeLinear': 1, 'DequantizeLinear': 2, 'Resize': 1} + check_op_nodes( + self, + model_uint8_path, + lambda node: (node.name != "resize_node" or node.input[0] != "conv_output"), + ) + qnode_counts = { + "QLinearConv": 1, + "QuantizeLinear": 1, + "DequantizeLinear": 2, + "Resize": 1, + } check_op_type_count(self, model_uint8_path, **qnode_counts) - qnode_io_qtypes = {'QuantizeLinear' : [['i', 2, activation_proto_qtype], ['o', 0, activation_proto_qtype]]} - qnode_io_qtypes.update({'DequantizeLinear' : [['i', 2, activation_proto_qtype]]}) + qnode_io_qtypes = { + "QuantizeLinear": [ + ["i", 2, activation_proto_qtype], + ["o", 0, activation_proto_qtype], + ] + } + qnode_io_qtypes.update({"DequantizeLinear": [["i", 2, activation_proto_qtype]]}) check_qtype_by_node_type(self, model_uint8_path, qnode_io_qtypes) data_reader.rewind() check_model_correctness(self, model_fp32_path, model_uint8_path, data_reader.get_next()) # Verify QDQ mode data_reader.rewind() - quantize_static(model_fp32_path, model_uint8_qdq_path, data_reader, quant_format=QuantFormat.QDQ, - activation_type = activation_type, weight_type = weight_type, extra_options = extra_options) - qdqnode_counts = {'Conv': 1, 'QuantizeLinear': 3, 'DequantizeLinear': 4, 'Resize': 1} + quantize_static( + model_fp32_path, + model_uint8_qdq_path, + data_reader, + quant_format=QuantFormat.QDQ, + activation_type=activation_type, + weight_type=weight_type, + extra_options=extra_options, + ) + qdqnode_counts = { + "Conv": 1, + "QuantizeLinear": 3, + "DequantizeLinear": 4, + "Resize": 1, + } check_op_type_count(self, model_uint8_qdq_path, **qdqnode_counts) - qnode_io_qtypes = {'QuantizeLinear' : [['i', 2, activation_proto_qtype], ['o', 0, activation_proto_qtype]]} + qnode_io_qtypes = { + "QuantizeLinear": [ + ["i", 2, activation_proto_qtype], + ["o", 0, activation_proto_qtype], + ] + } check_qtype_by_node_type(self, model_uint8_qdq_path, qnode_io_qtypes) data_reader.rewind() check_model_correctness(self, model_fp32_path, model_uint8_qdq_path, data_reader.get_next()) @@ -130,5 +206,6 @@ def test_quantize_resize(self): # def test_quantize_resize_s8s8(self): # self.quantize_resize_test(QuantType.QInt8, QuantType.QInt8, extra_options = {'ActivationSymmetric' : True}) -if __name__ == '__main__': + +if __name__ == "__main__": unittest.main() diff --git a/onnxruntime/test/python/quantization/test_op_squeeze_unsqueeze.py b/onnxruntime/test/python/quantization/test_op_squeeze_unsqueeze.py index a4759cb628c05..6794ef4acc787 100644 --- a/onnxruntime/test/python/quantization/test_op_squeeze_unsqueeze.py +++ b/onnxruntime/test/python/quantization/test_op_squeeze_unsqueeze.py @@ -7,12 +7,14 @@ # -------------------------------------------------------------------------- import unittest -import onnx + import numpy as np -from onnx import helper, TensorProto -from onnxruntime.quantization import quantize_static, QuantFormat, QuantType +import onnx +from onnx import TensorProto, helper from op_test_utils import TestDataFeeds, check_model_correctness, check_op_type_count, check_qtype_by_node_type +from onnxruntime.quantization import QuantFormat, QuantType, quantize_static + class TestOpSqueezeUnsqueeze(unittest.TestCase): def input_feeds(self, n, name2shape): @@ -25,9 +27,14 @@ def input_feeds(self, n, name2shape): dr = TestDataFeeds(input_data_list) return dr - def construct_model_conv_squeezes(self, output_model_path, - conv_input_shape, conv_weight_shape, conv_output_shape, - opset=13): + def construct_model_conv_squeezes( + self, + output_model_path, + conv_input_shape, + conv_weight_shape, + conv_output_shape, + opset=13, + ): # (input) # / | \ # Conv1 conv2 conv3 @@ -41,92 +48,201 @@ def construct_model_conv_squeezes(self, output_model_path, # add2 # | # (output) - input_tensor = helper.make_tensor_value_info('input', TensorProto.FLOAT, conv_input_shape) + input_tensor = helper.make_tensor_value_info("input", TensorProto.FLOAT, conv_input_shape) conv1_weight_arr = np.random.randint(-1, 2, conv_weight_shape).astype(np.float32) - conv1_weight_initializer = onnx.numpy_helper.from_array(conv1_weight_arr, name='conv1_weight') - conv1_node = onnx.helper.make_node('Conv', ['input', 'conv1_weight'], ['conv1_output'], name='conv1_node') + conv1_weight_initializer = onnx.numpy_helper.from_array(conv1_weight_arr, name="conv1_weight") + conv1_node = onnx.helper.make_node("Conv", ["input", "conv1_weight"], ["conv1_output"], name="conv1_node") conv2_weight_arr = np.random.randint(-1, 2, conv_weight_shape).astype(np.float32) - conv2_weight_initializer = onnx.numpy_helper.from_array(conv2_weight_arr, name='conv2_weight') - conv2_node = onnx.helper.make_node('Conv', ['input', 'conv2_weight'], ['conv2_output'], name='conv2_node') + conv2_weight_initializer = onnx.numpy_helper.from_array(conv2_weight_arr, name="conv2_weight") + conv2_node = onnx.helper.make_node("Conv", ["input", "conv2_weight"], ["conv2_output"], name="conv2_node") conv3_weight_arr = np.random.randint(-1, 2, conv_weight_shape).astype(np.float32) - conv3_weight_initializer = onnx.numpy_helper.from_array(conv3_weight_arr, name='conv3_weight') - conv3_node = onnx.helper.make_node('Conv', ['input', 'conv3_weight'], ['conv3_output'], name='conv3_node') - - if (opset >= 13): - squeeze_axes_initializer = onnx.numpy_helper.from_array(np.array([0], dtype=np.int64), name='squeeze_axes') - squeeze1_node = helper.make_node('Squeeze', ['conv1_output', 'squeeze_axes'], ['squeeze1_output'], name='suqeeze1_node') - squeeze2_node = helper.make_node('Squeeze', ['conv2_output', 'squeeze_axes'], ['squeeze2_output'], name='suqeeze2_node') + conv3_weight_initializer = onnx.numpy_helper.from_array(conv3_weight_arr, name="conv3_weight") + conv3_node = onnx.helper.make_node("Conv", ["input", "conv3_weight"], ["conv3_output"], name="conv3_node") + + if opset >= 13: + squeeze_axes_initializer = onnx.numpy_helper.from_array(np.array([0], dtype=np.int64), name="squeeze_axes") + squeeze1_node = helper.make_node( + "Squeeze", + ["conv1_output", "squeeze_axes"], + ["squeeze1_output"], + name="suqeeze1_node", + ) + squeeze2_node = helper.make_node( + "Squeeze", + ["conv2_output", "squeeze_axes"], + ["squeeze2_output"], + name="suqeeze2_node", + ) else: - squeeze1_node = helper.make_node('Squeeze', ['conv1_output'], ['squeeze1_output'], name='suqeeze1_node', axes=[0]) - squeeze2_node = helper.make_node('Squeeze', ['conv2_output'], ['squeeze2_output'], name='suqeeze2_node', axes=[0]) - - add1_node = helper.make_node('Add', ['squeeze1_output', 'squeeze2_output'], ['add1_output'], name='add1_node') - if (opset >= 13): - unsqueeze_node = helper.make_node('Unsqueeze', ['add1_output', 'squeeze_axes'], [ - 'unsqueeze_output'], name='unsqueeze_node') + squeeze1_node = helper.make_node( + "Squeeze", + ["conv1_output"], + ["squeeze1_output"], + name="suqeeze1_node", + axes=[0], + ) + squeeze2_node = helper.make_node( + "Squeeze", + ["conv2_output"], + ["squeeze2_output"], + name="suqeeze2_node", + axes=[0], + ) + + add1_node = helper.make_node( + "Add", + ["squeeze1_output", "squeeze2_output"], + ["add1_output"], + name="add1_node", + ) + if opset >= 13: + unsqueeze_node = helper.make_node( + "Unsqueeze", + ["add1_output", "squeeze_axes"], + ["unsqueeze_output"], + name="unsqueeze_node", + ) else: - unsqueeze_node = helper.make_node('Unsqueeze', ['add1_output'], ['unsqueeze_output'], name='unsqueeze_node', axes=[0]) - - output_tensor = helper.make_tensor_value_info('output', TensorProto.FLOAT, conv_output_shape) - add2_node = helper.make_node('Add', ['unsqueeze_output', 'conv3_output'], ['output'], name='add2_node') - - initializers = [conv1_weight_initializer, conv2_weight_initializer, conv3_weight_initializer] - if (opset >= 13): + unsqueeze_node = helper.make_node( + "Unsqueeze", + ["add1_output"], + ["unsqueeze_output"], + name="unsqueeze_node", + axes=[0], + ) + + output_tensor = helper.make_tensor_value_info("output", TensorProto.FLOAT, conv_output_shape) + add2_node = helper.make_node("Add", ["unsqueeze_output", "conv3_output"], ["output"], name="add2_node") + + initializers = [ + conv1_weight_initializer, + conv2_weight_initializer, + conv3_weight_initializer, + ] + if opset >= 13: initializers.append(squeeze_axes_initializer) - graph = helper.make_graph([conv1_node, conv2_node, conv3_node, squeeze1_node, squeeze2_node, add1_node, unsqueeze_node, add2_node], - 'TestOpSuqeezes_test_model', [input_tensor], [output_tensor], initializer=initializers) + graph = helper.make_graph( + [ + conv1_node, + conv2_node, + conv3_node, + squeeze1_node, + squeeze2_node, + add1_node, + unsqueeze_node, + add2_node, + ], + "TestOpSuqeezes_test_model", + [input_tensor], + [output_tensor], + initializer=initializers, + ) model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", opset)]) model.ir_version = 7 # use stable onnx ir version onnx.save(model, output_model_path) - def run_quantize_squeezes_of_opset(self, opset=13, activation_type=QuantType.QUInt8, weight_type=QuantType.QUInt8, extra_options={}): + def run_quantize_squeezes_of_opset( + self, + opset=13, + activation_type=QuantType.QUInt8, + weight_type=QuantType.QUInt8, + extra_options={}, + ): np.random.seed(1) - model_fp32_path = 'squeezes_opset{}_fp32.onnx'.format(opset) + model_fp32_path = "squeezes_opset{}_fp32.onnx".format(opset) self.construct_model_conv_squeezes(model_fp32_path, [1, 2, 26, 42], [3, 2, 3, 3], [1, 3, 24, 40], opset=opset) activation_proto_qtype = TensorProto.UINT8 if activation_type == QuantType.QUInt8 else TensorProto.INT8 - activation_type_str = 'u8' if (activation_type == QuantType.QUInt8) else 's8' - weight_type_str = 'u8' if (weight_type == QuantType.QUInt8) else 's8' - model_uint8_path = 'squeezes_opset{}_{}{}.onnx'.format(opset, activation_type_str, weight_type_str) - model_uint8_qdq_path = 'squeezes_opset{}_{}{}_qdq.onnx'.format(opset, activation_type_str, weight_type_str) + activation_type_str = "u8" if (activation_type == QuantType.QUInt8) else "s8" + weight_type_str = "u8" if (weight_type == QuantType.QUInt8) else "s8" + model_uint8_path = "squeezes_opset{}_{}{}.onnx".format(opset, activation_type_str, weight_type_str) + model_uint8_qdq_path = "squeezes_opset{}_{}{}_qdq.onnx".format(opset, activation_type_str, weight_type_str) # Verify QOperator mode - data_reader = self.input_feeds(1, {'input': [1, 2, 26, 42]}) - quantize_static(model_fp32_path, model_uint8_path, data_reader, quant_format=QuantFormat.QOperator, - activation_type=activation_type, weight_type=weight_type, extra_options=extra_options) + data_reader = self.input_feeds(1, {"input": [1, 2, 26, 42]}) + quantize_static( + model_fp32_path, + model_uint8_path, + data_reader, + quant_format=QuantFormat.QOperator, + activation_type=activation_type, + weight_type=weight_type, + extra_options=extra_options, + ) # make sure squeezes become xint8 operator, its input name could tell that - qnode_counts = {'QuantizeLinear': 1, 'DequantizeLinear': 1} + qnode_counts = {"QuantizeLinear": 1, "DequantizeLinear": 1} check_op_type_count(self, model_uint8_path, **qnode_counts) - qnode_io_qtypes = {'QuantizeLinear': [['i', 2, activation_proto_qtype], ['o', 0, activation_proto_qtype]]} - qnode_io_qtypes.update({'DequantizeLinear': [['i', 2, activation_proto_qtype]]}) + qnode_io_qtypes = { + "QuantizeLinear": [ + ["i", 2, activation_proto_qtype], + ["o", 0, activation_proto_qtype], + ] + } + qnode_io_qtypes.update({"DequantizeLinear": [["i", 2, activation_proto_qtype]]}) check_qtype_by_node_type(self, model_uint8_path, qnode_io_qtypes) data_reader.rewind() - check_model_correctness(self, model_fp32_path, model_uint8_path, data_reader.get_next(), rtol=0.01, atol=0.5) + check_model_correctness( + self, + model_fp32_path, + model_uint8_path, + data_reader.get_next(), + rtol=0.01, + atol=0.5, + ) # Verify QDQ mode data_reader.rewind() - quantize_static(model_fp32_path, model_uint8_qdq_path, data_reader, quant_format=QuantFormat.QDQ, - activation_type=activation_type, weight_type=weight_type, extra_options=extra_options) - qdqnode_counts = {'Conv': 3, 'QuantizeLinear': 9, 'DequantizeLinear': 12} + quantize_static( + model_fp32_path, + model_uint8_qdq_path, + data_reader, + quant_format=QuantFormat.QDQ, + activation_type=activation_type, + weight_type=weight_type, + extra_options=extra_options, + ) + qdqnode_counts = {"Conv": 3, "QuantizeLinear": 9, "DequantizeLinear": 12} check_op_type_count(self, model_uint8_qdq_path, **qdqnode_counts) - qnode_io_qtypes = {'QuantizeLinear': [['i', 2, activation_proto_qtype], ['o', 0, activation_proto_qtype]]} + qnode_io_qtypes = { + "QuantizeLinear": [ + ["i", 2, activation_proto_qtype], + ["o", 0, activation_proto_qtype], + ] + } check_qtype_by_node_type(self, model_uint8_qdq_path, qnode_io_qtypes) data_reader.rewind() - check_model_correctness(self, model_fp32_path, model_uint8_qdq_path, data_reader.get_next(), rtol=0.01, atol=0.5) + check_model_correctness( + self, + model_fp32_path, + model_uint8_qdq_path, + data_reader.get_next(), + rtol=0.01, + atol=0.5, + ) def test_quantize_squeeze_unsqueeze(self): self.run_quantize_squeezes_of_opset(11) self.run_quantize_squeezes_of_opset(13) def test_quantize_squeeze_unsqueeze_s8s8(self): - self.run_quantize_squeezes_of_opset(11, QuantType.QInt8, QuantType.QInt8, extra_options={'ActivationSymmetric': True}) - self.run_quantize_squeezes_of_opset(13, QuantType.QInt8, QuantType.QInt8, extra_options={'ActivationSymmetric': True}) - - -if __name__ == '__main__': + self.run_quantize_squeezes_of_opset( + 11, + QuantType.QInt8, + QuantType.QInt8, + extra_options={"ActivationSymmetric": True}, + ) + self.run_quantize_squeezes_of_opset( + 13, + QuantType.QInt8, + QuantType.QInt8, + extra_options={"ActivationSymmetric": True}, + ) + + +if __name__ == "__main__": unittest.main() diff --git a/onnxruntime/test/python/quantization/test_op_transpose.py b/onnxruntime/test/python/quantization/test_op_transpose.py index 6428bdabec079..e34df1072976d 100644 --- a/onnxruntime/test/python/quantization/test_op_transpose.py +++ b/onnxruntime/test/python/quantization/test_op_transpose.py @@ -7,11 +7,19 @@ # -------------------------------------------------------------------------- import unittest -import onnx + import numpy as np -from onnx import helper, TensorProto -from onnxruntime.quantization import quantize_static, QuantFormat, QuantType -from op_test_utils import TestDataFeeds, check_model_correctness, check_op_type_count, check_op_nodes, check_qtype_by_node_type +import onnx +from onnx import TensorProto, helper +from op_test_utils import ( + TestDataFeeds, + check_model_correctness, + check_op_nodes, + check_op_type_count, + check_qtype_by_node_type, +) + +from onnxruntime.quantization import QuantFormat, QuantType, quantize_static class TestOpTranspose(unittest.TestCase): @@ -33,67 +41,112 @@ def construct_model_matmul_transpose(self, output_model_path, input_shape, weigh # Transpose # | # (output) - input_name = 'input' - output_name = 'output' + input_name = "input" + output_name = "output" initializers = [] # make MatMul node - weight_name = 'matmul_weight' - matmul_output_name = 'matmul_output' + weight_name = "matmul_weight" + matmul_output_name = "matmul_output" matmul_inputs = [input_name, weight_name] matmul_outputs = [matmul_output_name] - matmul_name = 'matmul_node' + matmul_name = "matmul_node" matmul_weight_data = np.random.normal(0, 0.1, weight_shape).astype(np.float32) initializers.append(onnx.numpy_helper.from_array(matmul_weight_data, name=weight_name)) - matmul_node = onnx.helper.make_node('MatMul', matmul_inputs, matmul_outputs, name=matmul_name) + matmul_node = onnx.helper.make_node("MatMul", matmul_inputs, matmul_outputs, name=matmul_name) # make Transpose node - kwargs = {'perm': (1, 0)} - transpose_node = onnx.helper.make_node('Transpose', [matmul_output_name], [output_name], name="transpose_node", **kwargs) + kwargs = {"perm": (1, 0)} + transpose_node = onnx.helper.make_node( + "Transpose", [matmul_output_name], [output_name], name="transpose_node", **kwargs + ) # make graph input_tensor = helper.make_tensor_value_info(input_name, TensorProto.FLOAT, input_shape) output_tensor = helper.make_tensor_value_info(output_name, TensorProto.FLOAT, output_shape) - graph_name = 'Transpose_Quant_Test' - graph = helper.make_graph([matmul_node, transpose_node], graph_name, - [input_tensor], [output_tensor], initializer=initializers) + graph_name = "Transpose_Quant_Test" + graph = helper.make_graph( + [matmul_node, transpose_node], + graph_name, + [input_tensor], + [output_tensor], + initializer=initializers, + ) model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 11)]) - model.ir_version = 7 # use stable onnx ir version + model.ir_version = 7 # use stable onnx ir version onnx.save(model, output_model_path) - def quantize_transpose_test(self, activation_type, weight_type, extra_options = {}): + def quantize_transpose_test(self, activation_type, weight_type, extra_options={}): np.random.seed(1) - model_fp32_path = 'transpose_fp32.onnx' + model_fp32_path = "transpose_fp32.onnx" self.construct_model_matmul_transpose(model_fp32_path, [3, 7], [7, 5], [5, 3]) activation_proto_qtype = TensorProto.UINT8 if activation_type == QuantType.QUInt8 else TensorProto.INT8 - activation_type_str = 'u8' if (activation_type == QuantType.QUInt8) else 's8' - weight_type_str = 'u8' if (weight_type == QuantType.QUInt8) else 's8' - model_uint8_path = 'transpose_{}{}.onnx'.format(activation_type_str, weight_type_str) - model_uint8_qdq_path = 'transpose_{}{}_qdq.onnx'.format(activation_type_str, weight_type_str) + activation_type_str = "u8" if (activation_type == QuantType.QUInt8) else "s8" + weight_type_str = "u8" if (weight_type == QuantType.QUInt8) else "s8" + model_uint8_path = "transpose_{}{}.onnx".format(activation_type_str, weight_type_str) + model_uint8_qdq_path = "transpose_{}{}_qdq.onnx".format(activation_type_str, weight_type_str) # Verify QOperator model - data_reader = self.input_feeds(1, {'input': [3, 7]}) - quantize_static(model_fp32_path, model_uint8_path, data_reader, quant_format=QuantFormat.QOperator, - activation_type = activation_type, weight_type = weight_type, extra_options = extra_options) + data_reader = self.input_feeds(1, {"input": [3, 7]}) + quantize_static( + model_fp32_path, + model_uint8_path, + data_reader, + quant_format=QuantFormat.QOperator, + activation_type=activation_type, + weight_type=weight_type, + extra_options=extra_options, + ) # make sure transpose become xint8 operator, its input name could tell that - check_op_nodes(self, model_uint8_path, lambda node: (node.name != "transpose_node" or node.input[0] != 'matmul_output')) - qnode_counts = {'QLinearMatMul': 1, 'QuantizeLinear': 1, 'DequantizeLinear': 1, 'Transpose': 1} + check_op_nodes( + self, + model_uint8_path, + lambda node: (node.name != "transpose_node" or node.input[0] != "matmul_output"), + ) + qnode_counts = { + "QLinearMatMul": 1, + "QuantizeLinear": 1, + "DequantizeLinear": 1, + "Transpose": 1, + } check_op_type_count(self, model_uint8_path, **qnode_counts) - qnode_io_qtypes = {'QuantizeLinear' : [['i', 2, activation_proto_qtype], ['o', 0, activation_proto_qtype]]} - qnode_io_qtypes.update({'DequantizeLinear' : [['i', 2, activation_proto_qtype]]}) + qnode_io_qtypes = { + "QuantizeLinear": [ + ["i", 2, activation_proto_qtype], + ["o", 0, activation_proto_qtype], + ] + } + qnode_io_qtypes.update({"DequantizeLinear": [["i", 2, activation_proto_qtype]]}) check_qtype_by_node_type(self, model_uint8_path, qnode_io_qtypes) data_reader.rewind() check_model_correctness(self, model_fp32_path, model_uint8_path, data_reader.get_next()) # Verify QDQ model data_reader.rewind() - quantize_static(model_fp32_path, model_uint8_qdq_path, data_reader, quant_format=QuantFormat.QDQ, - activation_type = activation_type, weight_type = weight_type, extra_options = extra_options) - qdqnode_counts = {'MatMul': 1, 'QuantizeLinear': 3, 'DequantizeLinear': 4, 'Transpose': 1} + quantize_static( + model_fp32_path, + model_uint8_qdq_path, + data_reader, + quant_format=QuantFormat.QDQ, + activation_type=activation_type, + weight_type=weight_type, + extra_options=extra_options, + ) + qdqnode_counts = { + "MatMul": 1, + "QuantizeLinear": 3, + "DequantizeLinear": 4, + "Transpose": 1, + } check_op_type_count(self, model_uint8_qdq_path, **qdqnode_counts) - qnode_io_qtypes = {'QuantizeLinear' : [['i', 2, activation_proto_qtype], ['o', 0, activation_proto_qtype]]} + qnode_io_qtypes = { + "QuantizeLinear": [ + ["i", 2, activation_proto_qtype], + ["o", 0, activation_proto_qtype], + ] + } check_qtype_by_node_type(self, model_uint8_qdq_path, qnode_io_qtypes) data_reader.rewind() check_model_correctness(self, model_fp32_path, model_uint8_qdq_path, data_reader.get_next()) @@ -102,7 +155,12 @@ def test_quantize_transpose(self): self.quantize_transpose_test(QuantType.QUInt8, QuantType.QUInt8) def test_quantize_transpose_s8s8(self): - self.quantize_transpose_test(QuantType.QInt8, QuantType.QInt8, extra_options = {'ActivationSymmetric' : True}) + self.quantize_transpose_test( + QuantType.QInt8, + QuantType.QInt8, + extra_options={"ActivationSymmetric": True}, + ) + -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/onnxruntime/test/python/quantization/test_qat.py b/onnxruntime/test/python/quantization/test_qat.py index 048a9c6c7958f..34d36df30db73 100644 --- a/onnxruntime/test/python/quantization/test_qat.py +++ b/onnxruntime/test/python/quantization/test_qat.py @@ -3,25 +3,25 @@ # -*- coding: UTF-8 -*- -import numpy as np -import onnx -from onnx import helper, numpy_helper, TensorProto, ValueInfoProto -from onnx import shape_inference -import onnxruntime -from pathlib import Path import unittest import urllib.request +from pathlib import Path -from onnxruntime.quantization.quantize import ONNXQuantizer +import numpy as np +import onnx +from onnx import TensorProto, ValueInfoProto, helper, numpy_helper +from onnx import onnx_pb as onnx_proto +from onnx import shape_inference +import onnxruntime from onnxruntime.quantization.quant_utils import QuantizationMode -from onnx import onnx_pb as onnx_proto +from onnxruntime.quantization.quantize import ONNXQuantizer def generate_input_initializer(tensor_shape, tensor_dtype, input_name): - ''' - Helper function to generate initializers for inputs - ''' + """ + Helper function to generate initializers for inputs + """ tensor = np.random.ranf(tensor_shape).astype(tensor_dtype) init = numpy_helper.from_array(tensor, input_name) return init @@ -31,9 +31,9 @@ def generate_qat_model(model_names): test_models = [] test_initializers = [] - ''' + """ TEST_MODEL_CONFIG_1 - ''' + """ # Main graph: # # [A] [input_bias] @@ -54,42 +54,47 @@ def generate_qat_model(model_names): graph = helper.make_graph( [ - #nodes + # nodes helper.make_node("Add", ["A", "input_bias"], ["add_out"], "add0"), - helper.make_node("QuantizeLinear", ["add_out", "quant0_scale_const", "quant0_zp_const"], ["quant0_out"], - "qlinear0"), - helper.make_node("DequantizeLinear", ["quant0_out", "dequant0_scale_const", "dequant0_zp_const"], - ["dequant0_out"], "dqlinear0"), + helper.make_node( + "QuantizeLinear", + ["add_out", "quant0_scale_const", "quant0_zp_const"], + ["quant0_out"], + "qlinear0", + ), + helper.make_node( + "DequantizeLinear", + ["quant0_out", "dequant0_scale_const", "dequant0_zp_const"], + ["dequant0_out"], + "dqlinear0", + ), helper.make_node("MatMul", ["dequant0_out", "trans_out"], ["B"], "matmul"), ], - "QAT_model_1", #name - [ #input - helper.make_tensor_value_info('A', TensorProto.FLOAT, ['unk_1']) - ], - [ #output - helper.make_tensor_value_info('B', TensorProto.FLOAT, [1024]) + "QAT_model_1", # name + [helper.make_tensor_value_info("A", TensorProto.FLOAT, ["unk_1"])], # input + [helper.make_tensor_value_info("B", TensorProto.FLOAT, [1024])], # output + [ # initializers + helper.make_tensor("quant0_scale_const", TensorProto.FLOAT, [], [0.01961481384932995]), + helper.make_tensor("quant0_zp_const", TensorProto.INT8, [], [0]), + helper.make_tensor("dequant0_scale_const", TensorProto.FLOAT, [], [0.01961481384932995]), + helper.make_tensor("dequant0_zp_const", TensorProto.INT8, [], [0]), ], - [ #initializers - helper.make_tensor('quant0_scale_const', TensorProto.FLOAT, [], [0.01961481384932995]), - helper.make_tensor('quant0_zp_const', TensorProto.INT8, [], [0]), - helper.make_tensor('dequant0_scale_const', TensorProto.FLOAT, [], [0.01961481384932995]), - helper.make_tensor('dequant0_zp_const', TensorProto.INT8, [], [0]), - ]) - input_weight_1 = generate_input_initializer([1024, 1024], np.float32, 'trans_out') - input_bias_1 = generate_input_initializer([1024], np.float32, 'input_bias') + ) + input_weight_1 = generate_input_initializer([1024, 1024], np.float32, "trans_out") + input_bias_1 = generate_input_initializer([1024], np.float32, "input_bias") graph.initializer.add().CopyFrom(input_weight_1) graph.initializer.add().CopyFrom(input_bias_1) model_1 = onnx.helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)]) - model_1.ir_version = 7 # use stable onnx ir version + model_1.ir_version = 7 # use stable onnx ir version onnx.save(model_1, model_names[0]) test_models.extend([model_1]) initiazliers_1 = [input_weight_1, input_bias_1] test_initializers.append(initiazliers_1) - ''' + """ TEST_MODEL_CONFIG_2 - ''' + """ # Main graph: # @@ -110,50 +115,73 @@ def generate_qat_model(model_names): graph = helper.make_graph( [ - #nodes - helper.make_node("MaxPool", ["A"], ["maxpool_out"], "maxpool", kernel_shape = [1, 1]), - helper.make_node("QuantizeLinear", ["maxpool_out", "quant0_scale_const", "quant0_zp_const"], ["quant0_out"], - "qlinear0"), - helper.make_node("DequantizeLinear", ["quant0_out", "dequant0_scale_const", "dequant0_zp_const"], - ["dequant0_out"], "dqlinear0"), - helper.make_node("Conv", ["dequant0_out", "conv_weight_0", "conv_bias_0"], ["conv0_out"], "conv0"), - helper.make_node("QuantizeLinear", ["maxpool_out", "quant1_scale_const", "quant1_zp_const"], ["quant1_out"], - "qlinear1"), - helper.make_node("DequantizeLinear", ["quant1_out", "dequant1_scale_const", "dequant1_zp_const"], - ["dequant1_out"], "dqlinear1"), - helper.make_node("Conv", ["dequant1_out", "conv_weight_1", "conv_bias_1"], ["conv1_out"], "conv1"), + # nodes + helper.make_node("MaxPool", ["A"], ["maxpool_out"], "maxpool", kernel_shape=[1, 1]), + helper.make_node( + "QuantizeLinear", + ["maxpool_out", "quant0_scale_const", "quant0_zp_const"], + ["quant0_out"], + "qlinear0", + ), + helper.make_node( + "DequantizeLinear", + ["quant0_out", "dequant0_scale_const", "dequant0_zp_const"], + ["dequant0_out"], + "dqlinear0", + ), + helper.make_node( + "Conv", + ["dequant0_out", "conv_weight_0", "conv_bias_0"], + ["conv0_out"], + "conv0", + ), + helper.make_node( + "QuantizeLinear", + ["maxpool_out", "quant1_scale_const", "quant1_zp_const"], + ["quant1_out"], + "qlinear1", + ), + helper.make_node( + "DequantizeLinear", + ["quant1_out", "dequant1_scale_const", "dequant1_zp_const"], + ["dequant1_out"], + "dqlinear1", + ), + helper.make_node( + "Conv", + ["dequant1_out", "conv_weight_1", "conv_bias_1"], + ["conv1_out"], + "conv1", + ), helper.make_node("Add", ["conv0_out", "conv1_out"], ["B"], "add"), ], - "QAT_model_2", #name - [ #input - helper.make_tensor_value_info('A', TensorProto.FLOAT, [1, 64, 256, 256]) + "QAT_model_2", # name + [helper.make_tensor_value_info("A", TensorProto.FLOAT, [1, 64, 256, 256])], # input + [helper.make_tensor_value_info("B", TensorProto.FLOAT, [1, 256, 256, 256])], # output + [ # initializers + helper.make_tensor("quant0_scale_const", TensorProto.FLOAT, [], [0.2062656134366989]), + helper.make_tensor("quant0_zp_const", TensorProto.UINT8, [], [165]), + helper.make_tensor("dequant0_scale_const", TensorProto.FLOAT, [], [0.2062656134366989]), + helper.make_tensor("dequant0_zp_const", TensorProto.UINT8, [], [165]), + helper.make_tensor("quant1_scale_const", TensorProto.FLOAT, [], [0.10088317096233368]), + helper.make_tensor("quant1_zp_const", TensorProto.UINT8, [], [132]), + helper.make_tensor("dequant1_scale_const", TensorProto.FLOAT, [], [0.10088317096233368]), + helper.make_tensor("dequant1_zp_const", TensorProto.UINT8, [], [132]), ], - [ #output - helper.make_tensor_value_info('B', TensorProto.FLOAT, [1, 256, 256, 256]) - ], - [ #initializers - helper.make_tensor('quant0_scale_const', TensorProto.FLOAT, [], [0.2062656134366989]), - helper.make_tensor('quant0_zp_const', TensorProto.UINT8, [], [165]), - helper.make_tensor('dequant0_scale_const', TensorProto.FLOAT, [], [0.2062656134366989]), - helper.make_tensor('dequant0_zp_const', TensorProto.UINT8, [], [165]), - helper.make_tensor('quant1_scale_const', TensorProto.FLOAT, [], [0.10088317096233368]), - helper.make_tensor('quant1_zp_const', TensorProto.UINT8, [], [132]), - helper.make_tensor('dequant1_scale_const', TensorProto.FLOAT, [], [0.10088317096233368]), - helper.make_tensor('dequant1_zp_const', TensorProto.UINT8, [], [132]), - ]) - - conv_weight_0 = generate_input_initializer([256, 64, 1, 1], np.float32, 'conv_weight_0') - conv_bias_0 = generate_input_initializer([256], np.float32, 'conv_bias_0') + ) + + conv_weight_0 = generate_input_initializer([256, 64, 1, 1], np.float32, "conv_weight_0") + conv_bias_0 = generate_input_initializer([256], np.float32, "conv_bias_0") graph.initializer.add().CopyFrom(conv_weight_0) graph.initializer.add().CopyFrom(conv_bias_0) - conv_weight_1 = generate_input_initializer([256, 64, 1, 1], np.float32, 'conv_weight_1') - conv_bias_1 = generate_input_initializer([256], np.float32, 'conv_bias_1') + conv_weight_1 = generate_input_initializer([256, 64, 1, 1], np.float32, "conv_weight_1") + conv_bias_1 = generate_input_initializer([256], np.float32, "conv_bias_1") graph.initializer.add().CopyFrom(conv_weight_1) graph.initializer.add().CopyFrom(conv_bias_1) model_2 = onnx.helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)]) - model_2.ir_version = 7 # use stable onnx ir version + model_2.ir_version = 7 # use stable onnx ir version onnx.save(model_2, model_names[1]) test_models.extend([model_2]) @@ -164,9 +192,9 @@ def generate_qat_model(model_names): def generate_qat_support_model(model_names, test_initializers): - ''' - EXPECTED_TEST_RESULT_CONFIG_1 - ''' + """ + EXPECTED_TEST_RESULT_CONFIG_1 + """ test_qat_support_models = [] @@ -182,34 +210,35 @@ def generate_qat_support_model(model_names, test_initializers): # | # [B] graph = helper.make_graph( - [ #nodes + [ # nodes helper.make_node("Add", ["A", "input_bias"], ["add_out"], "add0"), helper.make_node("MatMul", ["add_out", "trans_out"], ["B"], "matmul"), ], - "QAT_support_model_1", #name + "QAT_support_model_1", # name [ - #input - helper.make_tensor_value_info('A', TensorProto.FLOAT, ['unk_1']) + # input + helper.make_tensor_value_info("A", TensorProto.FLOAT, ["unk_1"]) ], [ - #output - helper.make_tensor_value_info('B', TensorProto.FLOAT, [1024]) - ]) + # output + helper.make_tensor_value_info("B", TensorProto.FLOAT, [1024]) + ], + ) - #initializers + # initializers init_1 = test_initializers[0] for init in init_1: graph.initializer.add().CopyFrom(init) model_1 = onnx.ModelProto() - model_1.ir_version = 7 # use stable onnx ir version + model_1.ir_version = 7 # use stable onnx ir version model_1 = onnx.helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)]) onnx.save(model_1, model_names[0]) test_qat_support_models.extend([model_1]) - ''' + """ EXPECTED_TEST_RESULT_CONFIG_2 - ''' + """ # Main graph: # [A] @@ -223,27 +252,34 @@ def generate_qat_support_model(model_names, test_initializers): # | # [B] graph = helper.make_graph( - [ #nodes + [ # nodes helper.make_node("MaxPool", ["A"], ["maxpool_out"], "maxpool"), - helper.make_node("Conv", ["maxpool_out", "conv_weight_0", "conv_bias_0"], ["conv0_out"], "conv0"), - helper.make_node("Conv", ["maxpool_out", "conv_weight_1", "conv_bias_1"], ["conv1_out"], "conv1"), + helper.make_node( + "Conv", + ["maxpool_out", "conv_weight_0", "conv_bias_0"], + ["conv0_out"], + "conv0", + ), + helper.make_node( + "Conv", + ["maxpool_out", "conv_weight_1", "conv_bias_1"], + ["conv1_out"], + "conv1", + ), helper.make_node("Add", ["conv0_out", "conv1_out"], ["B"], "add"), ], - "QAT_support_model_2", #name - [ #input - helper.make_tensor_value_info('A', TensorProto.FLOAT, [1, 64, 256, 256]) - ], - [ #output - helper.make_tensor_value_info('B', TensorProto.FLOAT, [1, 256, 256, 256]) - ]) + "QAT_support_model_2", # name + [helper.make_tensor_value_info("A", TensorProto.FLOAT, [1, 64, 256, 256])], # input + [helper.make_tensor_value_info("B", TensorProto.FLOAT, [1, 256, 256, 256])], # output + ) - #initializers + # initializers init_2 = test_initializers[1] for init in init_2: graph.initializer.add().CopyFrom(init) model_2 = onnx.ModelProto() - model_2.ir_version = 7 # use stable onnx ir version + model_2.ir_version = 7 # use stable onnx ir version model_2 = onnx.helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)]) onnx.save(model_1, model_names[1]) @@ -253,15 +289,15 @@ def generate_qat_support_model(model_names, test_initializers): def compare_two_models(model_1, model_2): - ''' + """ Helper function to check if two models are the same :param: model_1 - expected model :param: model_2 - actual model Return true if two models are the same. Otherwise return false. - ''' + """ check_1, check_2 = True, True - #check nodes + # check nodes for node_1 in model_1.graph.node: node_found = False for node_2 in model_2.graph.node: @@ -277,7 +313,7 @@ def compare_two_models(model_1, model_2): print("Error:Node {} in the expected model not found in test model.".format(node_1.name)) break - #check initializers: + # check initializers: for init_1 in model_1.graph.initializer: init1_arr = numpy_helper.to_array(init_1) init_found = False @@ -287,8 +323,9 @@ def compare_two_models(model_1, model_2): init2_arr = numpy_helper.to_array(init_2) if not np.array_equal(init1_arr, init2_arr): check_2 = False - print("Error: Initializer {} in test model dismatches with the expected model.".format( - init_2.name)) + print( + "Error: Initializer {} in test model dismatches with the expected model.".format(init_2.name) + ) break if not init_found: @@ -303,20 +340,34 @@ class TestQAT(unittest.TestCase): def test_remove_fakequant_nodes(self): model_names = ["qat_model_1.onnx", "qat_model_2.onnx"] - qat_support_model_names = ["qat_support_model_1.onnx", "qat_support_model_2.onnx"] + qat_support_model_names = [ + "qat_support_model_1.onnx", + "qat_support_model_2.onnx", + ] test_models, test_initializers = generate_qat_model(model_names) qat_support_models_expected = generate_qat_support_model(qat_support_model_names, test_initializers) for i in range(len(test_models)): - quantizer = ONNXQuantizer(test_models[i], False, False,QuantizationMode.IntegerOps, False, TensorProto.INT8, - TensorProto.INT8, None, None, None, ['Conv', 'MatMul', 'MaxPool']) - #test remove editting to the graph + quantizer = ONNXQuantizer( + test_models[i], + False, + False, + QuantizationMode.IntegerOps, + False, + TensorProto.INT8, + TensorProto.INT8, + None, + None, + None, + ["Conv", "MatMul", "MaxPool"], + ) + # test remove editting to the graph qat_support_model_actual = quantizer.remove_fake_quantized_nodes() assert compare_two_models(qat_support_models_expected[i], qat_support_model_actual) print("TEST_MODEL {} finished: ".format(i) + qat_support_model_names[i]) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/onnxruntime/test/python/quantization/test_qdq.py b/onnxruntime/test/python/quantization/test_qdq.py index 3d055e62f3d08..a5721cdf83c1e 100644 --- a/onnxruntime/test/python/quantization/test_qdq.py +++ b/onnxruntime/test/python/quantization/test_qdq.py @@ -7,12 +7,15 @@ # -------------------------------------------------------------------------- import unittest -import onnx + import numpy as np -from onnx import helper, TensorProto -from onnxruntime.quantization import quantize_static, QuantType, QuantFormat, QuantizationMode, QDQQuantizer +import onnx +from onnx import TensorProto, helper from op_test_utils import TestDataFeeds, check_model_correctness, check_op_type_count, check_op_type_order +from onnxruntime.quantization import QDQQuantizer, QuantFormat, QuantizationMode, QuantType, quantize_static + + class TestQDQFormat(unittest.TestCase): def input_feeds(self, n, name2shape): input_data_list = [] @@ -24,106 +27,117 @@ def input_feeds(self, n, name2shape): dr = TestDataFeeds(input_data_list) return dr + class TestQDQExtraOptions(unittest.TestCase): def test_qdq_extra_options(self): - # (input) - # | - # Add + # (input) + # | + # Add # | - # ReduceMean + # ReduceMean # | - # Add + # Add # | # (output) initializers = [] - input_tensor = helper.make_tensor_value_info('L', TensorProto.FLOAT, [5, 5]) - output_tensor = helper.make_tensor_value_info('O', TensorProto.FLOAT, [5, 5]) + input_tensor = helper.make_tensor_value_info("L", TensorProto.FLOAT, [5, 5]) + output_tensor = helper.make_tensor_value_info("O", TensorProto.FLOAT, [5, 5]) add_weight_data_1 = np.random.normal(0, 0.1, [5, 5]).astype(np.float32) initializers.append(onnx.numpy_helper.from_array(add_weight_data_1, name="M")) add_weight_data_2 = np.random.normal(0, 0.1, [5, 5]).astype(np.float32) initializers.append(onnx.numpy_helper.from_array(add_weight_data_2, name="N")) - add_node_1 = onnx.helper.make_node('Add', ['L', 'M'], ['P'], name='Add1') - reduce_mean_node = onnx.helper.make_node('ReduceMean', ['P'], ['Q'], keepdims=1, name='ReduceMean') - add_node_2 = onnx.helper.make_node('Add', ['Q', 'N'], ['O'], name='Add2') - - graph = helper.make_graph([add_node_1, reduce_mean_node, add_node_2], 'QDQ_Test_Finetune', [input_tensor], [output_tensor], initializer=initializers) + add_node_1 = onnx.helper.make_node("Add", ["L", "M"], ["P"], name="Add1") + reduce_mean_node = onnx.helper.make_node("ReduceMean", ["P"], ["Q"], keepdims=1, name="ReduceMean") + add_node_2 = onnx.helper.make_node("Add", ["Q", "N"], ["O"], name="Add2") + + graph = helper.make_graph( + [add_node_1, reduce_mean_node, add_node_2], + "QDQ_Test_Finetune", + [input_tensor], + [output_tensor], + initializer=initializers, + ) model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)]) - test_model_path = './test_qdq_finetune.onnx' + test_model_path = "./test_qdq_finetune.onnx" onnx.save(model, test_model_path) compute_range = { - 'P': [0.1, 0.1], - 'Q': [0.1, 0.1], - 'M': [0.1, 0.1], - 'N': [0.1, 0.1], - 'L': [0.1, 0.1], - 'O': [0.1, 0.1], + "P": [0.1, 0.1], + "Q": [0.1, 0.1], + "M": [0.1, 0.1], + "N": [0.1, 0.1], + "L": [0.1, 0.1], + "O": [0.1, 0.1], } - op_types_to_quantize = ['Add'] + op_types_to_quantize = ["Add"] mode = QuantizationMode.QLinearOps model = onnx.load_model(test_model_path, False) quantizer = QDQQuantizer( model, - True, #per_channel - False, #reduce_range + True, # per_channel + False, # reduce_range mode, - True, #static - QuantType.QInt8, #weight_type - QuantType.QInt8, #activation_type + True, # static + QuantType.QInt8, # weight_type + QuantType.QInt8, # activation_type compute_range, - [], #nodes_to_quantize - ['Add2'], #nodes_to_exclude + [], # nodes_to_quantize + ["Add2"], # nodes_to_exclude op_types_to_quantize, - {'ActivationSymmetric' : True, 'AddQDQPairToWeight' : True, 'OpTypesToExcludeOutputQuantizatioin': []}) #extra_options + { + "ActivationSymmetric": True, + "AddQDQPairToWeight": True, + "OpTypesToExcludeOutputQuantizatioin": [], + }, + ) # extra_options quantizer.quantize_model() - qdq_model_path = './test_qdq_finetune_qdq.onnx' + qdq_model_path = "./test_qdq_finetune_qdq.onnx" quantizer.model.save_model_to_file(qdq_model_path, False) # QDQ pair should be added to Add1 but not Add2 # QDQ pair shoud be added to Add1 output as well. - qdq_added_to_node_output_flag = False + qdq_added_to_node_output_flag = False for node in quantizer.model.nodes(): - if node.name == 'Add1': + if node.name == "Add1": for input in node.input: self.assertTrue("DequantizeLinear" in input) for output in node.output: self.assertTrue("QuantizeLinear" not in output) - if node.name == 'Add2': + if node.name == "Add2": for input in node.input: self.assertTrue("DequantizeLinear" not in input) for output in node.output: self.assertTrue("QuantizeLinear" not in output) # This QuantizeLinear node should be followed by Add1 - if node.name == 'P_QuantizeLinear': + if node.name == "P_QuantizeLinear": qdq_added_to_node_output_flag = True - self.assertTrue(node.input[0] == 'P') + self.assertTrue(node.input[0] == "P") self.assertTrue(qdq_added_to_node_output_flag) - def test_qdq_extra_options_2(self): - # (input) - # | - # Add + # (input) + # | + # Add # / | \ - # MatMul MatMul MatMul + # MatMul MatMul MatMul # | | | # (output)(output)(output) initializers = [] - input_tensor = helper.make_tensor_value_info('L', TensorProto.FLOAT, [5, 5]) - output_tensor1 = helper.make_tensor_value_info('M', TensorProto.FLOAT, [5, 5]) - output_tensor2 = helper.make_tensor_value_info('N', TensorProto.FLOAT, [5, 5]) - output_tensor3 = helper.make_tensor_value_info('O', TensorProto.FLOAT, [5, 5]) + input_tensor = helper.make_tensor_value_info("L", TensorProto.FLOAT, [5, 5]) + output_tensor1 = helper.make_tensor_value_info("M", TensorProto.FLOAT, [5, 5]) + output_tensor2 = helper.make_tensor_value_info("N", TensorProto.FLOAT, [5, 5]) + output_tensor3 = helper.make_tensor_value_info("O", TensorProto.FLOAT, [5, 5]) add_weight_data = np.random.normal(0, 0.1, [5, 5]).astype(np.float32) initializers.append(onnx.numpy_helper.from_array(add_weight_data, name="P")) @@ -134,66 +148,86 @@ def test_qdq_extra_options_2(self): matmul_weight_data_3 = np.random.normal(0, 0.1, [5, 5]).astype(np.float32) initializers.append(onnx.numpy_helper.from_array(matmul_weight_data_2, name="S")) - add_node = onnx.helper.make_node('Add', ['L', 'P'], ['T'], name='Add') - matmul_node_1 = onnx.helper.make_node('MatMul', ['T', 'Q'], ['M'], name='MatMul1') - matmul_node_2 = onnx.helper.make_node('MatMul', ['T', 'R'], ['N'], name='MatMul2') - matmul_node_3 = onnx.helper.make_node('MatMul', ['T', 'S'], ['O'], name='MatMul3') - - graph = helper.make_graph([add_node, matmul_node_1, matmul_node_2, matmul_node_3], 'QDQ_Test_Finetune_2', [input_tensor], [output_tensor1, output_tensor2, output_tensor3], initializer=initializers) + add_node = onnx.helper.make_node("Add", ["L", "P"], ["T"], name="Add") + matmul_node_1 = onnx.helper.make_node("MatMul", ["T", "Q"], ["M"], name="MatMul1") + matmul_node_2 = onnx.helper.make_node("MatMul", ["T", "R"], ["N"], name="MatMul2") + matmul_node_3 = onnx.helper.make_node("MatMul", ["T", "S"], ["O"], name="MatMul3") + + graph = helper.make_graph( + [add_node, matmul_node_1, matmul_node_2, matmul_node_3], + "QDQ_Test_Finetune_2", + [input_tensor], + [output_tensor1, output_tensor2, output_tensor3], + initializer=initializers, + ) model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)]) - test_model_path = './test_qdq_finetune_2.onnx' + test_model_path = "./test_qdq_finetune_2.onnx" onnx.save(model, test_model_path) compute_range = { - 'L': [0.1, 0.1], - 'M': [0.1, 0.1], - 'N': [0.1, 0.1], - 'O': [0.1, 0.1], - 'P': [0.1, 0.1], - 'Q': [0.1, 0.1], - 'R': [0.1, 0.1], - 'S': [0.1, 0.1], - 'T': [0.1, 0.1], + "L": [0.1, 0.1], + "M": [0.1, 0.1], + "N": [0.1, 0.1], + "O": [0.1, 0.1], + "P": [0.1, 0.1], + "Q": [0.1, 0.1], + "R": [0.1, 0.1], + "S": [0.1, 0.1], + "T": [0.1, 0.1], } - op_types_to_quantize = ['Add', 'MatMul'] + op_types_to_quantize = ["Add", "MatMul"] mode = QuantizationMode.QLinearOps model = onnx.load_model(test_model_path, False) quantizer = QDQQuantizer( model, - True, #per_channel - False, #reduce_range + True, # per_channel + False, # reduce_range mode, - True, #static - QuantType.QInt8, #weight_type - QuantType.QInt8, #activation_type + True, # static + QuantType.QInt8, # weight_type + QuantType.QInt8, # activation_type compute_range, - [], #nodes_to_quantize - ['Add'], #nodes_to_exclude + [], # nodes_to_quantize + ["Add"], # nodes_to_exclude op_types_to_quantize, - {'ActivationSymmetric' : True, 'AddQDQPairToWeight' : True, 'OpTypesToExcludeOutputQuantizatioin': op_types_to_quantize, 'DedicatedQDQPair': True}) #extra_options + { + "ActivationSymmetric": True, + "AddQDQPairToWeight": True, + "OpTypesToExcludeOutputQuantizatioin": op_types_to_quantize, + "DedicatedQDQPair": True, + }, + ) # extra_options quantizer.quantize_model() - qdq_model_path = './test_qdq_finetune_qdq_2.onnx' + qdq_model_path = "./test_qdq_finetune_qdq_2.onnx" quantizer.model.save_model_to_file(qdq_model_path, False) # Three dedicated QDQ pair should be generated and feed into each MatMul node - # Also QDQ pair should not be added to Add node + # Also QDQ pair should not be added to Add node # QDQ pair shoud not be added to node's output for node in quantizer.model.nodes(): - if node.name == 'MatMul1': + if node.name == "MatMul1": self.assertTrue("T_DequantizeLinear_1" in node.input) - if node.name == 'MatMul2': + if node.name == "MatMul2": self.assertTrue("T_DequantizeLinear_2" in node.input) - if node.name == 'MatMul3': + if node.name == "MatMul3": self.assertTrue("T_DequantizeLinear_3" in node.input) - if node.name == 'Add': + if node.name == "Add": for input in node.input: self.assertTrue("DequantizeLinear" not in input) # QDQ pair shoud not be added to MatMul's output - if node.op_type == 'QuantizeLinear': - self.assertTrue(node.input[0] not in ['M_QuantizeLinearInput', 'N_QuantizeLinearInput', 'O_QuantizeLinearInput']) + if node.op_type == "QuantizeLinear": + self.assertTrue( + node.input[0] + not in [ + "M_QuantizeLinearInput", + "N_QuantizeLinearInput", + "O_QuantizeLinearInput", + ] + ) + class TestQDQFormatConv(TestQDQFormat): def construct_model_conv(self, output_model_path, input_shape, weight_shape, output_shape, has_bias): @@ -202,84 +236,92 @@ def construct_model_conv(self, output_model_path, input_shape, weight_shape, out # Conv # | # (output) - input_name = 'input' - output_name = 'output' + input_name = "input" + output_name = "output" initializers = [] # make Conv node - weight_name = 'conv_weight' - bias_name = 'conv_bias' + weight_name = "conv_weight" + bias_name = "conv_bias" conv_inputs = [input_name, weight_name] conv_outputs = [output_name] - conv_name = 'conv_node' + conv_name = "conv_node" conv_weight_data = np.random.normal(0, 0.1, weight_shape).astype(np.float32) initializers.append(onnx.numpy_helper.from_array(conv_weight_data, name=weight_name)) if has_bias: conv_inputs.append(bias_name) bias_data = np.random.normal(0, 0.05, (weight_shape[0])).astype(np.float32) initializers.append(onnx.numpy_helper.from_array(bias_data, name=bias_name)) - conv_node = onnx.helper.make_node('Conv', conv_inputs, conv_outputs, name=conv_name) + conv_node = onnx.helper.make_node("Conv", conv_inputs, conv_outputs, name=conv_name) # make graph input_tensor = helper.make_tensor_value_info(input_name, TensorProto.FLOAT, input_shape) output_tensor = helper.make_tensor_value_info(output_name, TensorProto.FLOAT, output_shape) - graph_name = 'QDQ_Test_Conv' - graph = helper.make_graph([conv_node], graph_name, - [input_tensor], [output_tensor], initializer=initializers) + graph_name = "QDQ_Test_Conv" + graph = helper.make_graph( + [conv_node], + graph_name, + [input_tensor], + [output_tensor], + initializer=initializers, + ) model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)]) - model.ir_version = 7 # use stable onnx ir version + model.ir_version = 7 # use stable onnx ir version onnx.save(model, output_model_path) - def verify_quantize_conv(self, has_bias, per_channel, is_quant_type_int8 = False): + def verify_quantize_conv(self, has_bias, per_channel, is_quant_type_int8=False): np.random.seed(1) - model_fp32_path = 'conv_fp32.{}.{}.onnx'.format(has_bias, per_channel) - model_int8_qdq_path = 'conv_quant_qdq.{}.{}.onnx'.format(has_bias, per_channel) - model_int8_qop_path = 'conv_quant_qop.{}.{}.onnx'.format(has_bias, per_channel) - data_reader = self.input_feeds(1, {'input': [1, 8, 33, 33]}) - self.construct_model_conv(model_fp32_path, - [1, 8, 33, 33], - [16, 8, 3, 3], - [1, 16, 31, 31], - has_bias) - quantize_static(model_fp32_path, - model_int8_qdq_path, - data_reader, - quant_format=QuantFormat.QDQ, - per_channel = per_channel, - reduce_range = per_channel, - activation_type = QuantType.QInt8 if is_quant_type_int8 else QuantType.QUInt8, - weight_type = QuantType.QInt8 if is_quant_type_int8 else QuantType.QUInt8 - ) + model_fp32_path = "conv_fp32.{}.{}.onnx".format(has_bias, per_channel) + model_int8_qdq_path = "conv_quant_qdq.{}.{}.onnx".format(has_bias, per_channel) + model_int8_qop_path = "conv_quant_qop.{}.{}.onnx".format(has_bias, per_channel) + data_reader = self.input_feeds(1, {"input": [1, 8, 33, 33]}) + self.construct_model_conv(model_fp32_path, [1, 8, 33, 33], [16, 8, 3, 3], [1, 16, 31, 31], has_bias) + quantize_static( + model_fp32_path, + model_int8_qdq_path, + data_reader, + quant_format=QuantFormat.QDQ, + per_channel=per_channel, + reduce_range=per_channel, + activation_type=QuantType.QInt8 if is_quant_type_int8 else QuantType.QUInt8, + weight_type=QuantType.QInt8 if is_quant_type_int8 else QuantType.QUInt8, + ) data_reader.rewind() - qdq_nodes = {'Conv': 1, 'QuantizeLinear': 2, 'DequantizeLinear': 4 if has_bias else 3} + qdq_nodes = { + "Conv": 1, + "QuantizeLinear": 2, + "DequantizeLinear": 4 if has_bias else 3, + } check_op_type_count(self, model_int8_qdq_path, **qdq_nodes) check_model_correctness(self, model_fp32_path, model_int8_qdq_path, data_reader.get_next()) data_reader.rewind() - quantize_static(model_fp32_path, - model_int8_qop_path, - data_reader, - quant_format=QuantFormat.QOperator, - per_channel = per_channel, - reduce_range = per_channel, - activation_type = QuantType.QInt8 if is_quant_type_int8 else QuantType.QUInt8, - weight_type = QuantType.QInt8 if is_quant_type_int8 else QuantType.QUInt8 - ) + quantize_static( + model_fp32_path, + model_int8_qop_path, + data_reader, + quant_format=QuantFormat.QOperator, + per_channel=per_channel, + reduce_range=per_channel, + activation_type=QuantType.QInt8 if is_quant_type_int8 else QuantType.QUInt8, + weight_type=QuantType.QInt8 if is_quant_type_int8 else QuantType.QUInt8, + ) data_reader.rewind() - qop_nodes = {'QLinearConv': 1, 'QuantizeLinear': 1, 'DequantizeLinear': 1} + qop_nodes = {"QLinearConv": 1, "QuantizeLinear": 1, "DequantizeLinear": 1} check_op_type_count(self, model_int8_qop_path, **qop_nodes) check_model_correctness(self, model_fp32_path, model_int8_qop_path, data_reader.get_next()) def test_quantize_conv_without_bias(self): # only test cases per_channel=True and reduce_range=True to avoid saturation on avx2 and avx512 for weight type int8 - self.verify_quantize_conv(False, True, True) # has_bias:False, per_channel:True, is_quant_type_int8:True - self.verify_quantize_conv(True, True, True) # has_bias:True, per_channel:True, is_quant_type_int8:True + self.verify_quantize_conv(False, True, True) # has_bias:False, per_channel:True, is_quant_type_int8:True + self.verify_quantize_conv(True, True, True) # has_bias:True, per_channel:True, is_quant_type_int8:True + + self.verify_quantize_conv(False, False, False) # has_bias:False, per_channel:False, is_quant_type_int8:False + self.verify_quantize_conv(True, False, False) # has_bias:True, per_channel:False, is_quant_type_int8:False + self.verify_quantize_conv(False, True, False) # has_bias:False, per_channel:True, is_quant_type_int8:False + self.verify_quantize_conv(True, True, False) # has_bias:True, per_channel:True, is_quant_type_int8:False - self.verify_quantize_conv(False, False, False) # has_bias:False, per_channel:False, is_quant_type_int8:False - self.verify_quantize_conv(True, False, False) # has_bias:True, per_channel:False, is_quant_type_int8:False - self.verify_quantize_conv(False, True, False) # has_bias:False, per_channel:True, is_quant_type_int8:False - self.verify_quantize_conv(True, True, False) # has_bias:True, per_channel:True, is_quant_type_int8:False class TestQDQFormatConvClip(TestQDQFormat): def construct_model_conv_clip(self, output_model_path, input_shape, weight_shape, output_shape): @@ -292,91 +334,111 @@ def construct_model_conv_clip(self, output_model_path, input_shape, weight_shape # Reshape # | # (output) - input_name = 'input' - output_name = 'output' + input_name = "input" + output_name = "output" initializers = [] # make Conv node - weight_name = 'conv_weight' + weight_name = "conv_weight" conv_inputs = [input_name, weight_name] - conv_outputs = ['conv_output'] - conv_name = 'conv_node' + conv_outputs = ["conv_output"] + conv_name = "conv_node" conv_weight_data = np.random.normal(0, 0.1, weight_shape).astype(np.float32) initializers.append(onnx.numpy_helper.from_array(conv_weight_data, name=weight_name)) - conv_node = onnx.helper.make_node('Conv', conv_inputs, conv_outputs, name=conv_name) + conv_node = onnx.helper.make_node("Conv", conv_inputs, conv_outputs, name=conv_name) # make Clip node - clip_min_name = 'clip_min' - clip_max_name = 'clip_max' + clip_min_name = "clip_min" + clip_max_name = "clip_max" clip_inputs = [conv_outputs[0], clip_min_name, clip_max_name] - clip_outputs = ['clip_output'] - clip_name = 'clip_node' + clip_outputs = ["clip_output"] + clip_name = "clip_node" initializers.append(onnx.numpy_helper.from_array(np.array(-1.0, dtype=np.float32), name=clip_min_name)) initializers.append(onnx.numpy_helper.from_array(np.array(1.0, dtype=np.float32), name=clip_max_name)) - clip_node = onnx.helper.make_node('Clip', clip_inputs, clip_outputs, name=clip_name) + clip_node = onnx.helper.make_node("Clip", clip_inputs, clip_outputs, name=clip_name) # make Identity node - reshape_name = 'reshape_node' - reshape_shape = 'reshape_shape' + reshape_name = "reshape_node" + reshape_shape = "reshape_shape" initializers.append(onnx.numpy_helper.from_array(np.array([-1], dtype=np.int64), name=reshape_shape)) - reshape_node = onnx.helper.make_node('Reshape', ['clip_output', reshape_shape], [output_name], name=reshape_name) + reshape_node = onnx.helper.make_node( + "Reshape", ["clip_output", reshape_shape], [output_name], name=reshape_name + ) # make graph input_tensor = helper.make_tensor_value_info(input_name, TensorProto.FLOAT, input_shape) output_tensor = helper.make_tensor_value_info(output_name, TensorProto.FLOAT, output_shape) - graph_name = 'QDQ_Test_Conv_clip' - graph = helper.make_graph([conv_node, clip_node, reshape_node], graph_name, - [input_tensor], [output_tensor], initializer=initializers) + graph_name = "QDQ_Test_Conv_clip" + graph = helper.make_graph( + [conv_node, clip_node, reshape_node], + graph_name, + [input_tensor], + [output_tensor], + initializer=initializers, + ) model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)]) - model.ir_version = 7 # use stable onnx ir version + model.ir_version = 7 # use stable onnx ir version onnx.save(model, output_model_path) def verify(self, per_channel, is_quant_type_int8): np.random.seed(1) - model_fp32_path = 'conv_clip_fp32.{}.onnx'.format(per_channel) - model_int8_qdq_path = 'conv_clip_quant_qdq.{}.onnx'.format(per_channel) - model_int8_qop_path = 'conv_clip_quant_qop.{}.onnx'.format(per_channel) - data_reader = self.input_feeds(1, {'input': [1, 8, 33, 33]}) - self.construct_model_conv_clip(model_fp32_path, - [1, 8, 33, 33], - [16, 8, 3, 3], - [15376]) - quantize_static(model_fp32_path, - model_int8_qdq_path, - data_reader, - quant_format=QuantFormat.QDQ, - per_channel = per_channel, - reduce_range = per_channel, - activation_type = QuantType.QInt8 if is_quant_type_int8 else QuantType.QUInt8, - weight_type = QuantType.QInt8 if is_quant_type_int8 else QuantType.QUInt8 - ) + model_fp32_path = "conv_clip_fp32.{}.onnx".format(per_channel) + model_int8_qdq_path = "conv_clip_quant_qdq.{}.onnx".format(per_channel) + model_int8_qop_path = "conv_clip_quant_qop.{}.onnx".format(per_channel) + data_reader = self.input_feeds(1, {"input": [1, 8, 33, 33]}) + self.construct_model_conv_clip(model_fp32_path, [1, 8, 33, 33], [16, 8, 3, 3], [15376]) + quantize_static( + model_fp32_path, + model_int8_qdq_path, + data_reader, + quant_format=QuantFormat.QDQ, + per_channel=per_channel, + reduce_range=per_channel, + activation_type=QuantType.QInt8 if is_quant_type_int8 else QuantType.QUInt8, + weight_type=QuantType.QInt8 if is_quant_type_int8 else QuantType.QUInt8, + ) data_reader.rewind() - #topo sort check - check_op_type_order(self, model_int8_qdq_path, ['DequantizeLinear', 'QuantizeLinear', 'DequantizeLinear', 'Conv', 'QuantizeLinear', 'DequantizeLinear', 'Reshape', 'QuantizeLinear', 'DequantizeLinear']) + # topo sort check + check_op_type_order( + self, + model_int8_qdq_path, + [ + "DequantizeLinear", + "QuantizeLinear", + "DequantizeLinear", + "Conv", + "QuantizeLinear", + "DequantizeLinear", + "Reshape", + "QuantizeLinear", + "DequantizeLinear", + ], + ) check_model_correctness(self, model_fp32_path, model_int8_qdq_path, data_reader.get_next()) data_reader.rewind() - quantize_static(model_fp32_path, - model_int8_qop_path, - data_reader, - quant_format=QuantFormat.QOperator, - per_channel = per_channel, - reduce_range = per_channel, - activation_type = QuantType.QInt8 if is_quant_type_int8 else QuantType.QUInt8, - weight_type = QuantType.QInt8 if is_quant_type_int8 else QuantType.QUInt8 - ) + quantize_static( + model_fp32_path, + model_int8_qop_path, + data_reader, + quant_format=QuantFormat.QOperator, + per_channel=per_channel, + reduce_range=per_channel, + activation_type=QuantType.QInt8 if is_quant_type_int8 else QuantType.QUInt8, + weight_type=QuantType.QInt8 if is_quant_type_int8 else QuantType.QUInt8, + ) data_reader.rewind() - qop_nodes = {'QLinearConv': 1, 'QuantizeLinear': 1, 'DequantizeLinear': 1} + qop_nodes = {"QLinearConv": 1, "QuantizeLinear": 1, "DequantizeLinear": 1} check_op_type_count(self, model_int8_qop_path, **qop_nodes) check_model_correctness(self, model_fp32_path, model_int8_qop_path, data_reader.get_next()) def test_quantize_conv_without_bias(self): # only test cases per_channel=True and reduce_range=True to avoid saturation on avx2 and avx512 for weight type int8 - self.verify(True, True) # per_channel:False, is_quant_type_int8:True + self.verify(True, True) # per_channel:False, is_quant_type_int8:True - self.verify(False, False) # per_channel:False, is_quant_type_int8:False - self.verify(True, False) # per_channel:True, is_quant_type_int8:False + self.verify(False, False) # per_channel:False, is_quant_type_int8:False + self.verify(True, False) # per_channel:True, is_quant_type_int8:False class TestQDQFormatConvRelu(TestQDQFormat): @@ -388,78 +450,94 @@ def construct_model_conv_relu(self, output_model_path, input_shape, weight_shape # Relu # | # (output) - input_name = 'input' - output_name = 'output' + input_name = "input" + output_name = "output" initializers = [] # make Conv node - weight_name = 'conv_weight' + weight_name = "conv_weight" conv_inputs = [input_name, weight_name] - conv_outputs = ['conv_output'] - conv_name = 'conv_node' + conv_outputs = ["conv_output"] + conv_name = "conv_node" conv_weight_data = np.random.normal(0, 0.1, weight_shape).astype(np.float32) initializers.append(onnx.numpy_helper.from_array(conv_weight_data, name=weight_name)) - conv_node = onnx.helper.make_node('Conv', conv_inputs, conv_outputs, name=conv_name) + conv_node = onnx.helper.make_node("Conv", conv_inputs, conv_outputs, name=conv_name) # make Clip node - relu_node = onnx.helper.make_node('Relu', conv_outputs, [output_name], name='Relu') + relu_node = onnx.helper.make_node("Relu", conv_outputs, [output_name], name="Relu") # make graph input_tensor = helper.make_tensor_value_info(input_name, TensorProto.FLOAT, input_shape) output_tensor = helper.make_tensor_value_info(output_name, TensorProto.FLOAT, output_shape) - graph_name = 'QDQ_Test_Conv_clip' - graph = helper.make_graph([conv_node, relu_node], graph_name, - [input_tensor], [output_tensor], initializer=initializers) + graph_name = "QDQ_Test_Conv_clip" + graph = helper.make_graph( + [conv_node, relu_node], + graph_name, + [input_tensor], + [output_tensor], + initializer=initializers, + ) model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)]) - model.ir_version = 7 # use stable onnx ir version + model.ir_version = 7 # use stable onnx ir version onnx.save(model, output_model_path) def verify(self, per_channel, is_quant_type_int8): np.random.seed(1) - model_fp32_path = 'conv_relu_fp32.{}.onnx'.format(per_channel) - model_int8_qdq_path = 'conv_relu_quant_qdq.{}.onnx'.format(per_channel) - model_int8_qop_path = 'conv_relu_quant_qop.{}.onnx'.format(per_channel) - data_reader = self.input_feeds(1, {'input': [1, 8, 33, 33]}) - self.construct_model_conv_relu(model_fp32_path, - [1, 8, 33, 33], - [16, 8, 3, 3], - [1, 16, 31, 31]) - quantize_static(model_fp32_path, - model_int8_qdq_path, - data_reader, - quant_format=QuantFormat.QDQ, - per_channel = per_channel, - reduce_range = per_channel, - activation_type = QuantType.QInt8 if is_quant_type_int8 else QuantType.QUInt8, - weight_type = QuantType.QInt8 if is_quant_type_int8 else QuantType.QUInt8 - ) + model_fp32_path = "conv_relu_fp32.{}.onnx".format(per_channel) + model_int8_qdq_path = "conv_relu_quant_qdq.{}.onnx".format(per_channel) + model_int8_qop_path = "conv_relu_quant_qop.{}.onnx".format(per_channel) + data_reader = self.input_feeds(1, {"input": [1, 8, 33, 33]}) + self.construct_model_conv_relu(model_fp32_path, [1, 8, 33, 33], [16, 8, 3, 3], [1, 16, 31, 31]) + quantize_static( + model_fp32_path, + model_int8_qdq_path, + data_reader, + quant_format=QuantFormat.QDQ, + per_channel=per_channel, + reduce_range=per_channel, + activation_type=QuantType.QInt8 if is_quant_type_int8 else QuantType.QUInt8, + weight_type=QuantType.QInt8 if is_quant_type_int8 else QuantType.QUInt8, + ) data_reader.rewind() - #topo sort check - check_op_type_order(self, model_int8_qdq_path, ['DequantizeLinear', 'QuantizeLinear', 'DequantizeLinear', 'Conv', 'QuantizeLinear', 'DequantizeLinear']) + # topo sort check + check_op_type_order( + self, + model_int8_qdq_path, + [ + "DequantizeLinear", + "QuantizeLinear", + "DequantizeLinear", + "Conv", + "QuantizeLinear", + "DequantizeLinear", + ], + ) check_model_correctness(self, model_fp32_path, model_int8_qdq_path, data_reader.get_next()) data_reader.rewind() - quantize_static(model_fp32_path, - model_int8_qop_path, - data_reader, - quant_format=QuantFormat.QOperator, - per_channel = per_channel, - reduce_range = per_channel, - activation_type = QuantType.QInt8 if is_quant_type_int8 else QuantType.QUInt8, - weight_type = QuantType.QInt8 if is_quant_type_int8 else QuantType.QUInt8 - ) + quantize_static( + model_fp32_path, + model_int8_qop_path, + data_reader, + quant_format=QuantFormat.QOperator, + per_channel=per_channel, + reduce_range=per_channel, + activation_type=QuantType.QInt8 if is_quant_type_int8 else QuantType.QUInt8, + weight_type=QuantType.QInt8 if is_quant_type_int8 else QuantType.QUInt8, + ) data_reader.rewind() - qop_nodes = {'QLinearConv': 1, 'QuantizeLinear': 1, 'DequantizeLinear': 1} + qop_nodes = {"QLinearConv": 1, "QuantizeLinear": 1, "DequantizeLinear": 1} check_op_type_count(self, model_int8_qop_path, **qop_nodes) check_model_correctness(self, model_fp32_path, model_int8_qop_path, data_reader.get_next()) def test_quantize_conv_without_bias(self): # only test cases per_channel=True and reduce_range=True to avoid saturation on avx2 and avx512 for weight type int8 - self.verify(True, True) # per_channel:False, is_quant_type_int8:True + self.verify(True, True) # per_channel:False, is_quant_type_int8:True + + self.verify(False, False) # per_channel:False, is_quant_type_int8:False + self.verify(True, False) # per_channel:True, is_quant_type_int8:False - self.verify(False, False) # per_channel:False, is_quant_type_int8:False - self.verify(True, False) # per_channel:True, is_quant_type_int8:False -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/onnxruntime/test/python/quantization/test_symmetric_flag.py b/onnxruntime/test/python/quantization/test_symmetric_flag.py index 62d632faec65e..26f7ba6ce59b3 100644 --- a/onnxruntime/test/python/quantization/test_symmetric_flag.py +++ b/onnxruntime/test/python/quantization/test_symmetric_flag.py @@ -7,127 +7,154 @@ # -------------------------------------------------------------------------- import unittest + +import numpy as np import onnx +from onnx import TensorProto, helper, numpy_helper + from onnxruntime import quantization -import numpy as np -from onnx import helper, TensorProto, numpy_helper + class TestSymmetricFlag(unittest.TestCase): + def setUp(self): + + # Set up symmetrically and asymmetrically disributed values for activations + self.symmetric_activations = [ + -1 * np.ones([1, 2, 32, 32], dtype="float32"), + +1 * np.ones([1, 2, 32, 32], dtype="float32"), + ] + self.asymmetric_activations = [ + -1 * np.ones([1, 2, 32, 32], dtype="float32"), + +2 * np.ones([1, 2, 32, 32], dtype="float32"), + ] + + # Set up symmetrically and asymmetrically disributed values for weights + self.symmetric_weights = np.concatenate( + ( + -1 * np.ones([1, 1, 2, 2], dtype="float32"), + +1 * np.ones([1, 1, 2, 2], dtype="float32"), + ), + axis=1, + ) + self.asymmetric_weights = np.concatenate( + ( + -1 * np.ones([1, 1, 2, 2], dtype="float32"), + +2 * np.ones([1, 1, 2, 2], dtype="float32"), + ), + axis=1, + ) + + def perform_quantization(self, activations, weight, act_sym, wgt_sym): + + # One-layer convolution model + act = helper.make_tensor_value_info("ACT", TensorProto.FLOAT, activations[0].shape) + wgt = helper.make_tensor_value_info("WGT", TensorProto.FLOAT, weight.shape) + res = helper.make_tensor_value_info("RES", TensorProto.FLOAT, [None, None, None, None]) + wgt_init = numpy_helper.from_array(weight, "WGT") + conv_node = onnx.helper.make_node("Conv", ["ACT", "WGT"], ["RES"]) + graph = helper.make_graph([conv_node], "test", [act], [res], initializer=[wgt_init]) + model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 11)]) + onnx.save(model, "model.onnx") + + # Quantize model + class DummyDataReader(quantization.CalibrationDataReader): + def __init__(self): + self.iterator = ({"ACT": act} for act in activations) + + def get_next(self): + return next(self.iterator, None) + + quantization.quantize_static( + model_input="model.onnx", + model_output="quantized-model.onnx", + calibration_data_reader=DummyDataReader(), + quant_format=quantization.QuantFormat.QOperator, + activation_type=quantization.QuantType.QInt8, + weight_type=quantization.QuantType.QInt8, + op_types_to_quantize=["Conv", "MatMul"], + extra_options={"WeightSymmetric": wgt_sym, "ActivationSymmetric": act_sym}, + ) + + # Extract quantization parameters: scales and zero points for activations, weights, and results + model = onnx.load("quantized-model.onnx") + act_zp = [init for init in model.graph.initializer if init.name == "ACT_zero_point"][0].int32_data[0] + act_sc = [init for init in model.graph.initializer if init.name == "ACT_scale"][0].float_data[0] + wgt_zp = [init for init in model.graph.initializer if init.name == "WGT_zero_point"][0].int32_data[0] + wgt_sc = [init for init in model.graph.initializer if init.name == "WGT_scale"][0].float_data[0] + + # Return quantization parameters + return act_zp, act_sc, wgt_zp, wgt_sc + + def test_0(self): + + act_zp, act_sc, wgt_zp, wgt_sc = self.perform_quantization( + self.asymmetric_activations, + self.asymmetric_weights, + act_sym=True, + wgt_sym=True, + ) + + # Calibration activations are asymmetric, but activation + # symmetrization flag is set to True, hence expect activation zero + # point = 0 + self.assertEqual(act_zp, 0) + + # Weights are asymmetric, but weight symmetrization flag is set to + # True, hence expect weight zero point = 0 + self.assertEqual(wgt_zp, 0) + + def test_1(self): + + act_zp, act_sc, wgt_zp, wgt_sc = self.perform_quantization( + self.asymmetric_activations, + self.asymmetric_weights, + act_sym=False, + wgt_sym=False, + ) + + # Calibration activations are asymmetric, symmetrization flag not + # set, hence expect activation zero point != 0 + self.assertNotEqual(act_zp, 0) + + # Weights are asymmetric, weight symmetrization flag is set to + # False, hence expect weight zero point != 0 + self.assertNotEqual(wgt_zp, 0) + + def test_2(self): + + act_zp, act_sc, wgt_zp, wgt_sc = self.perform_quantization( + self.symmetric_activations, + self.symmetric_weights, + act_sym=True, + wgt_sym=True, + ) + + # Calibration activations are symmetric, hence expect activation + # zero point == 0 (regardless of flag) + self.assertEqual(act_zp, 0) + + # Weights are symmetric, hence expect weight + # zero point == 0 (regardless of flag) + self.assertEqual(wgt_zp, 0) + + def test_3(self): + + act_zp, act_sc, wgt_zp, wgt_sc = self.perform_quantization( + self.symmetric_activations, + self.symmetric_weights, + act_sym=False, + wgt_sym=False, + ) + + # Calibration activations are symmetric, hence expect activation + # zero point == 0 (regardless of flag) + self.assertEqual(act_zp, 0) + + # Weights are symmetric, hence expect weight + # zero point == 0 (regardless of flag) + self.assertEqual(wgt_zp, 0) + + +if __name__ == "__main__": - def setUp(self): - - # Set up symmetrically and asymmetrically disributed values for activations - self.symmetric_activations = [-1*np.ones([1, 2, 32, 32], dtype="float32"), - +1*np.ones([1, 2, 32, 32], dtype="float32")] - self.asymmetric_activations = [-1*np.ones([1, 2, 32, 32], dtype="float32"), - +2*np.ones([1, 2, 32, 32], dtype="float32")] - - # Set up symmetrically and asymmetrically disributed values for weights - self.symmetric_weights = np.concatenate((-1*np.ones([1, 1, 2, 2], dtype="float32"), - +1*np.ones([1, 1, 2, 2], dtype="float32")), axis = 1) - self.asymmetric_weights = np.concatenate((-1*np.ones([1, 1, 2, 2], dtype="float32"), - +2*np.ones([1, 1, 2, 2], dtype="float32")), axis = 1) - - - def perform_quantization(self, activations, weight, act_sym, wgt_sym): - - # One-layer convolution model - act = helper.make_tensor_value_info("ACT", TensorProto.FLOAT, activations[0].shape) - wgt = helper.make_tensor_value_info("WGT", TensorProto.FLOAT, weight.shape) - res = helper.make_tensor_value_info("RES", TensorProto.FLOAT, [None, None, None, None]) - wgt_init = numpy_helper.from_array(weight, "WGT") - conv_node = onnx.helper.make_node("Conv", ["ACT", "WGT"], ["RES"]) - graph = helper.make_graph([conv_node], "test", [act], [res], initializer=[wgt_init]) - model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 11)]) - onnx.save(model, "model.onnx") - - # Quantize model - class DummyDataReader(quantization.CalibrationDataReader): - def __init__(self): - self.iterator = ({"ACT": act} for act in activations) - def get_next(self): - return next(self.iterator, None) - quantization.quantize_static(model_input="model.onnx", - model_output="quantized-model.onnx", - calibration_data_reader=DummyDataReader(), - quant_format=quantization.QuantFormat.QOperator, - activation_type=quantization.QuantType.QInt8, - weight_type=quantization.QuantType.QInt8, - op_types_to_quantize=["Conv", "MatMul"], - extra_options = {"WeightSymmetric": wgt_sym, - "ActivationSymmetric": act_sym}) - - # Extract quantization parameters: scales and zero points for activations, weights, and results - model = onnx.load("quantized-model.onnx") - act_zp = [init for init in model.graph.initializer if init.name=="ACT_zero_point"][0].int32_data[0] - act_sc = [init for init in model.graph.initializer if init.name=="ACT_scale"][0].float_data[0] - wgt_zp = [init for init in model.graph.initializer if init.name=="WGT_zero_point"][0].int32_data[0] - wgt_sc = [init for init in model.graph.initializer if init.name=="WGT_scale"][0].float_data[0] - - # Return quantization parameters - return act_zp, act_sc, wgt_zp, wgt_sc - - def test_0(self): - - act_zp, act_sc, wgt_zp, wgt_sc = self.perform_quantization(self.asymmetric_activations, - self.asymmetric_weights, - act_sym = True, - wgt_sym = True) - - # Calibration activations are asymmetric, but activation - # symmetrization flag is set to True, hence expect activation zero - # point = 0 - self.assertEqual(act_zp, 0) - - # Weights are asymmetric, but weight symmetrization flag is set to - # True, hence expect weight zero point = 0 - self.assertEqual(wgt_zp, 0) - - def test_1(self): - - act_zp, act_sc, wgt_zp, wgt_sc = self.perform_quantization(self.asymmetric_activations, - self.asymmetric_weights, - act_sym = False, - wgt_sym = False) - - # Calibration activations are asymmetric, symmetrization flag not - # set, hence expect activation zero point != 0 - self.assertNotEqual(act_zp, 0) - - # Weights are asymmetric, weight symmetrization flag is set to - # False, hence expect weight zero point != 0 - self.assertNotEqual(wgt_zp, 0) - - def test_2(self): - - act_zp, act_sc, wgt_zp, wgt_sc = self.perform_quantization(self.symmetric_activations, - self.symmetric_weights, - act_sym = True, - wgt_sym = True) - - # Calibration activations are symmetric, hence expect activation - # zero point == 0 (regardless of flag) - self.assertEqual(act_zp, 0) - - # Weights are symmetric, hence expect weight - # zero point == 0 (regardless of flag) - self.assertEqual(wgt_zp, 0) - - def test_3(self): - - act_zp, act_sc, wgt_zp, wgt_sc = self.perform_quantization(self.symmetric_activations, - self.symmetric_weights, - act_sym = False, - wgt_sym = False) - - # Calibration activations are symmetric, hence expect activation - # zero point == 0 (regardless of flag) - self.assertEqual(act_zp, 0) - - # Weights are symmetric, hence expect weight - # zero point == 0 (regardless of flag) - self.assertEqual(wgt_zp, 0) - -if __name__ == '__main__': - unittest.main() diff --git a/onnxruntime/test/python/test_pytorch_export_contrib_ops.py b/onnxruntime/test/python/test_pytorch_export_contrib_ops.py index b40ecd73acaf2..bfb6dc48a617b 100644 --- a/onnxruntime/test/python/test_pytorch_export_contrib_ops.py +++ b/onnxruntime/test/python/test_pytorch_export_contrib_ops.py @@ -3,13 +3,15 @@ """Test export of PyTorch operators using ONNX Runtime contrib ops.""" +import copy +import io +import unittest + +import numpy as np import torch + import onnxruntime from onnxruntime.tools import pytorch_export_contrib_ops -import numpy as np -import unittest -import io -import copy def ort_test_with_input(ort_sess, input, output, rtol, atol): @@ -39,6 +41,7 @@ def to_numpy(tensor): # PyTorch and ORT. class ONNXExporterTest(unittest.TestCase): from torch.onnx.symbolic_helper import _export_onnx_opset_version + opset_version = _export_onnx_opset_version keep_initializers_as_inputs = True # For IR version 3 type export. @@ -46,13 +49,20 @@ def setUp(self): torch.manual_seed(0) pytorch_export_contrib_ops.register() - def run_test(self, model, input=None, - custom_opsets=None, - batch_size=2, - rtol=0.001, atol=1e-7, - do_constant_folding=True, - dynamic_axes=None, test_with_inputs=None, - input_names=None, output_names=None): + def run_test( + self, + model, + input=None, + custom_opsets=None, + batch_size=2, + rtol=0.001, + atol=1e-7, + do_constant_folding=True, + dynamic_axes=None, + test_with_inputs=None, + input_names=None, + output_names=None, + ): model.eval() if input is None: @@ -70,17 +80,21 @@ def run_test(self, model, input=None, # export the model to ONNX f = io.BytesIO() - torch.onnx.export(model, input_copy, f, - opset_version=self.opset_version, - do_constant_folding=do_constant_folding, - keep_initializers_as_inputs=self.keep_initializers_as_inputs, - dynamic_axes=dynamic_axes, - input_names=input_names, output_names=output_names, - custom_opsets=custom_opsets) + torch.onnx.export( + model, + input_copy, + f, + opset_version=self.opset_version, + do_constant_folding=do_constant_folding, + keep_initializers_as_inputs=self.keep_initializers_as_inputs, + dynamic_axes=dynamic_axes, + input_names=input_names, + output_names=output_names, + custom_opsets=custom_opsets, + ) # compute onnxruntime output prediction - ort_sess = onnxruntime.InferenceSession(f.getvalue(), - providers=onnxruntime.get_available_providers()) + ort_sess = onnxruntime.InferenceSession(f.getvalue(), providers=onnxruntime.get_available_providers()) input_copy = copy.deepcopy(input) ort_test_with_input(ort_sess, input_copy, output, rtol, atol) @@ -111,6 +125,7 @@ def test_gelu(self): def test_triu(self): for i in range(-5, 5): + class Module(torch.nn.Module): def forward(self, input): return input.triu(diagonal=i) @@ -126,6 +141,7 @@ def forward(self, input): self.run_test(model, x, custom_opsets={"com.microsoft": 1}) for i in range(-5, 5): + class Module2D(torch.nn.Module): def forward(self, input): return input.triu(diagonal=i) @@ -142,6 +158,7 @@ def forward(self, input): def test_tril(self): for i in range(-5, 5): + class Module(torch.nn.Module): def forward(self, input): return input.tril(diagonal=i) @@ -157,6 +174,7 @@ def forward(self, input): self.run_test(model, x, custom_opsets={"com.microsoft": 1}) for i in range(-5, 5): + class Module2D(torch.nn.Module): def forward(self, input): return input.tril(diagonal=i) @@ -174,10 +192,11 @@ def forward(self, input): # opset 9 tests, with keep_initializers_as_inputs=False for # IR version 4 style export. -ONNXExporterTest_opset9_IRv4 = type(str("TestONNXRuntime_opset9_IRv4"), - (unittest.TestCase,), - dict(ONNXExporterTest.__dict__, - keep_initializers_as_inputs=False)) +ONNXExporterTest_opset9_IRv4 = type( + str("TestONNXRuntime_opset9_IRv4"), + (unittest.TestCase,), + dict(ONNXExporterTest.__dict__, keep_initializers_as_inputs=False), +) if __name__ == "__main__": diff --git a/onnxruntime/test/python/transformers/bert_model_generator.py b/onnxruntime/test/python/transformers/bert_model_generator.py index 4764334adab5f..6534aedba3978 100644 --- a/onnxruntime/test/python/transformers/bert_model_generator.py +++ b/onnxruntime/test/python/transformers/bert_model_generator.py @@ -4,12 +4,13 @@ # license information. # -------------------------------------------------------------------------- -import onnx import math -import numpy as np from typing import List + +import numpy as np +import onnx +from onnx import TensorProto, helper from packaging import version -from onnx import helper, TensorProto def float_tensor(name: str, shape: List[int], random=False): @@ -28,115 +29,226 @@ def reverse_if(inputs, reverse=False): return inputs -def create_bert_attention(input_hidden_size=16, - num_heads=2, - pruned_qk_hidden_size=16, - pruned_v_hidden_size=16, - use_float_mask=False, - switch_add_inputs=False): +def create_bert_attention( + input_hidden_size=16, + num_heads=2, + pruned_qk_hidden_size=16, + pruned_v_hidden_size=16, + use_float_mask=False, + switch_add_inputs=False, +): # unsqueeze in opset version 13 has two inputs (axis is moved from attribute to input). - has_unsqueeze_two_inputs = (version.parse(onnx.__version__) >= version.parse('1.8.0')) + has_unsqueeze_two_inputs = version.parse(onnx.__version__) >= version.parse("1.8.0") # nodes in attention subgraph nodes = [ helper.make_node("Add", ["input_1", "input_2"], ["layernorm_input"], "add_layernorm"), - helper.make_node("LayerNormalization", ["layernorm_input", "layer_norm_weight", "layer_norm_bias"], - ["layernorm_out"], - "layernorm", - axis=-1, - epsion=0.000009999999747378752), - + helper.make_node( + "LayerNormalization", + ["layernorm_input", "layer_norm_weight", "layer_norm_bias"], + ["layernorm_out"], + "layernorm", + axis=-1, + epsion=0.000009999999747378752, + ), # q nodes helper.make_node("MatMul", ["layernorm_out", "matmul_q_weight"], ["matmul_q_out"], "matmul_q"), - helper.make_node("Add", reverse_if(["matmul_q_out", "add_q_weight"], switch_add_inputs), ["add_q_out"], "add_q"), - helper.make_node("Reshape", ["add_q_out", "reshape_weight_qk"], ["reshape_q_out"], "reshape_q"), - helper.make_node("Transpose", ["reshape_q_out"], ["transpose_q_out"], "transpose_q", perm=[0, 2, 1, 3]), - + helper.make_node( + "Add", + reverse_if(["matmul_q_out", "add_q_weight"], switch_add_inputs), + ["add_q_out"], + "add_q", + ), + helper.make_node( + "Reshape", + ["add_q_out", "reshape_weight_qk"], + ["reshape_q_out"], + "reshape_q", + ), + helper.make_node( + "Transpose", + ["reshape_q_out"], + ["transpose_q_out"], + "transpose_q", + perm=[0, 2, 1, 3], + ), # k nodes helper.make_node("MatMul", ["layernorm_out", "matmul_k_weight"], ["matmul_k_out"], "matmul_k"), - helper.make_node("Add", reverse_if(["matmul_k_out", "add_k_weight"], switch_add_inputs), ["add_k_out"], "add_k"), - helper.make_node("Reshape", ["add_k_out", "reshape_weight_qk"], ["reshape_k_out"], "reshape_k"), - helper.make_node("Transpose", ["reshape_k_out"], ["transpose_k_out"], "transpose_k", perm=[0, 2, 3, 1]), - + helper.make_node( + "Add", + reverse_if(["matmul_k_out", "add_k_weight"], switch_add_inputs), + ["add_k_out"], + "add_k", + ), + helper.make_node( + "Reshape", + ["add_k_out", "reshape_weight_qk"], + ["reshape_k_out"], + "reshape_k", + ), + helper.make_node( + "Transpose", + ["reshape_k_out"], + ["transpose_k_out"], + "transpose_k", + perm=[0, 2, 3, 1], + ), # mask nodes - helper.make_node("Unsqueeze", ["input_mask", "axes_1"], ["unsqueeze0_out"], "unsqueeze0") if has_unsqueeze_two_inputs \ - else helper.make_node("Unsqueeze", ["input_mask"], ["unsqueeze0_out"], "unsqueeze0", axes=[1]), - helper.make_node("Unsqueeze", ["unsqueeze0_out", "axes_2"], ["unsqueeze1_out"], "unsqueeze1") if has_unsqueeze_two_inputs \ - else helper.make_node("Unsqueeze", ["unsqueeze0_out"], ["unsqueeze1_out"], "unsqueeze1", axes=[2]), - + helper.make_node("Unsqueeze", ["input_mask", "axes_1"], ["unsqueeze0_out"], "unsqueeze0") + if has_unsqueeze_two_inputs + else helper.make_node("Unsqueeze", ["input_mask"], ["unsqueeze0_out"], "unsqueeze0", axes=[1]), + helper.make_node("Unsqueeze", ["unsqueeze0_out", "axes_2"], ["unsqueeze1_out"], "unsqueeze1") + if has_unsqueeze_two_inputs + else helper.make_node("Unsqueeze", ["unsqueeze0_out"], ["unsqueeze1_out"], "unsqueeze1", axes=[2]), # when attention_mask is float type, no need to cast helper.make_node("Cast", ["unsqueeze1_out"], ["cast_out"], "cast", to=1) if not use_float_mask else None, - helper.make_node("Sub", ["sub_weight", "unsqueeze1_out" if use_float_mask else "cast_out"], ["sub_out"], "sub"), + helper.make_node( + "Sub", + ["sub_weight", "unsqueeze1_out" if use_float_mask else "cast_out"], + ["sub_out"], + "sub", + ), helper.make_node("Mul", ["sub_out", "mul_weight"], ["mul_mask_out"], "mul_mask"), - # qk nodes - helper.make_node("MatMul", ["transpose_q_out", "transpose_k_out"], ["matmul_qk_out"], "matmul_qk"), + helper.make_node( + "MatMul", + ["transpose_q_out", "transpose_k_out"], + ["matmul_qk_out"], + "matmul_qk", + ), helper.make_node("Div", ["matmul_qk_out", "div_weight"], ["div_qk_out"], "div_qk"), - helper.make_node("Add", reverse_if(["div_qk_out", "mul_mask_out"], switch_add_inputs), ["add_qk_out"], "add_qk"), + helper.make_node( + "Add", + reverse_if(["div_qk_out", "mul_mask_out"], switch_add_inputs), + ["add_qk_out"], + "add_qk", + ), helper.make_node("Softmax", ["add_qk_out"], ["softmax_qk_out"], "softmax_qk", axis=3), - # v nodes helper.make_node("MatMul", ["layernorm_out", "matmul_v_weight"], ["matmul_v_out"], "matmul_v"), helper.make_node("Add", ["matmul_v_out", "add_v_weight"], ["add_v_out"], "add_v"), helper.make_node("Reshape", ["add_v_out", "reshape_weight_v"], ["reshape_v_out"], "reshape_v"), - helper.make_node("Transpose", ["reshape_v_out"], ["transpose_v_out"], "transpose_v", perm=[0, 2, 1, 3]), - + helper.make_node( + "Transpose", + ["reshape_v_out"], + ["transpose_v_out"], + "transpose_v", + perm=[0, 2, 1, 3], + ), # qkv nodes - helper.make_node("MatMul", ["softmax_qk_out", "transpose_v_out"], ["matmul_qkv_1_out"], "matmul_qkv_1"), - helper.make_node("Transpose", ["matmul_qkv_1_out"], ["transpose_qkv_out"], "transpose_qkv", perm=[0, 2, 1, 3]), - helper.make_node("Reshape", ["transpose_qkv_out", "reshape_weight_qkv"], ["reshape_qkv_out"], "reshape_qkv"), - helper.make_node("MatMul", ["reshape_qkv_out", "matmul_qkv_weight"], ["matmul_qkv_2_out"], "matmul_qkv_2"), - helper.make_node("Add", reverse_if(["matmul_qkv_2_out", "add_qkv_weight"], switch_add_inputs), ["add_qkv_out"], "add_qkv"), - helper.make_node("Add", reverse_if(["add_qkv_out", "layernorm_out"], switch_add_inputs), ["skip_output"], "add_skip"), - helper.make_node("LayerNormalization", ["skip_output", "layer_norm_weight", "layer_norm_bias"], ["output"], - "layernorm2", - axis=-1, - epsion=0.000009999999747378752), + helper.make_node( + "MatMul", + ["softmax_qk_out", "transpose_v_out"], + ["matmul_qkv_1_out"], + "matmul_qkv_1", + ), + helper.make_node( + "Transpose", + ["matmul_qkv_1_out"], + ["transpose_qkv_out"], + "transpose_qkv", + perm=[0, 2, 1, 3], + ), + helper.make_node( + "Reshape", + ["transpose_qkv_out", "reshape_weight_qkv"], + ["reshape_qkv_out"], + "reshape_qkv", + ), + helper.make_node( + "MatMul", + ["reshape_qkv_out", "matmul_qkv_weight"], + ["matmul_qkv_2_out"], + "matmul_qkv_2", + ), + helper.make_node( + "Add", + reverse_if(["matmul_qkv_2_out", "add_qkv_weight"], switch_add_inputs), + ["add_qkv_out"], + "add_qkv", + ), + helper.make_node( + "Add", + reverse_if(["add_qkv_out", "layernorm_out"], switch_add_inputs), + ["skip_output"], + "add_skip", + ), + helper.make_node( + "LayerNormalization", + ["skip_output", "layer_norm_weight", "layer_norm_bias"], + ["output"], + "layernorm2", + axis=-1, + epsion=0.000009999999747378752, + ), ] pruned_qk_head_size = int(pruned_qk_hidden_size / num_heads) pruned_v_head_size = int(pruned_v_hidden_size / num_heads) initializers = [ # initializers - float_tensor('layer_norm_weight', [input_hidden_size]), - float_tensor('layer_norm_bias', [input_hidden_size]), - float_tensor('matmul_q_weight', [input_hidden_size, pruned_qk_hidden_size]), - float_tensor('matmul_k_weight', [input_hidden_size, pruned_qk_hidden_size]), - float_tensor('matmul_v_weight', [input_hidden_size, pruned_v_hidden_size]), - float_tensor('matmul_qkv_weight', [pruned_v_hidden_size, input_hidden_size]), - float_tensor('add_q_weight', [pruned_qk_hidden_size]), - float_tensor('add_k_weight', [pruned_qk_hidden_size]), - float_tensor('add_v_weight', [pruned_v_hidden_size]), - float_tensor('add_qkv_weight', [input_hidden_size]), - helper.make_tensor('div_weight', TensorProto.FLOAT, [1], [math.sqrt(pruned_qk_head_size)]), - helper.make_tensor('sub_weight', TensorProto.FLOAT, [1], [1.0]), - helper.make_tensor('mul_weight', TensorProto.FLOAT, [1], [-10000]), - helper.make_tensor('reshape_weight_qk', TensorProto.INT64, [4], [0, 0, num_heads, pruned_qk_head_size]), - helper.make_tensor('reshape_weight_v', TensorProto.INT64, [4], [0, 0, num_heads, pruned_v_head_size]), - helper.make_tensor('reshape_weight_qkv', TensorProto.INT64, [3], [0, 0, pruned_v_hidden_size]), + float_tensor("layer_norm_weight", [input_hidden_size]), + float_tensor("layer_norm_bias", [input_hidden_size]), + float_tensor("matmul_q_weight", [input_hidden_size, pruned_qk_hidden_size]), + float_tensor("matmul_k_weight", [input_hidden_size, pruned_qk_hidden_size]), + float_tensor("matmul_v_weight", [input_hidden_size, pruned_v_hidden_size]), + float_tensor("matmul_qkv_weight", [pruned_v_hidden_size, input_hidden_size]), + float_tensor("add_q_weight", [pruned_qk_hidden_size]), + float_tensor("add_k_weight", [pruned_qk_hidden_size]), + float_tensor("add_v_weight", [pruned_v_hidden_size]), + float_tensor("add_qkv_weight", [input_hidden_size]), + helper.make_tensor("div_weight", TensorProto.FLOAT, [1], [math.sqrt(pruned_qk_head_size)]), + helper.make_tensor("sub_weight", TensorProto.FLOAT, [1], [1.0]), + helper.make_tensor("mul_weight", TensorProto.FLOAT, [1], [-10000]), + helper.make_tensor( + "reshape_weight_qk", + TensorProto.INT64, + [4], + [0, 0, num_heads, pruned_qk_head_size], + ), + helper.make_tensor( + "reshape_weight_v", + TensorProto.INT64, + [4], + [0, 0, num_heads, pruned_v_head_size], + ), + helper.make_tensor("reshape_weight_qkv", TensorProto.INT64, [3], [0, 0, pruned_v_hidden_size]), ] if has_unsqueeze_two_inputs: - initializers.append(helper.make_tensor('axes_1', TensorProto.INT64, [1], [1])) - initializers.append(helper.make_tensor('axes_2', TensorProto.INT64, [1], [2])) + initializers.append(helper.make_tensor("axes_1", TensorProto.INT64, [1], [1])) + initializers.append(helper.make_tensor("axes_2", TensorProto.INT64, [1], [2])) batch_size = 1 sequence_length = 3 graph = helper.make_graph( [node for node in nodes if node], - "AttentionFusionPrunedModel", #name + "AttentionFusionPrunedModel", # name [ # inputs - helper.make_tensor_value_info('input_1', TensorProto.FLOAT, - [batch_size, sequence_length, input_hidden_size]), - helper.make_tensor_value_info('input_2', TensorProto.FLOAT, - [batch_size, sequence_length, input_hidden_size]), - helper.make_tensor_value_info('input_mask', TensorProto.FLOAT if use_float_mask else TensorProto.INT64, - [batch_size, sequence_length]) + helper.make_tensor_value_info( + "input_1", + TensorProto.FLOAT, + [batch_size, sequence_length, input_hidden_size], + ), + helper.make_tensor_value_info( + "input_2", + TensorProto.FLOAT, + [batch_size, sequence_length, input_hidden_size], + ), + helper.make_tensor_value_info( + "input_mask", + TensorProto.FLOAT if use_float_mask else TensorProto.INT64, + [batch_size, sequence_length], + ), ], [ # outputs - helper.make_tensor_value_info('output', TensorProto.FLOAT, - [batch_size, sequence_length, input_hidden_size]), + helper.make_tensor_value_info( + "output", + TensorProto.FLOAT, + [batch_size, sequence_length, input_hidden_size], + ), ], - initializers) + initializers, + ) model = helper.make_model(graph) return model @@ -144,98 +256,163 @@ def create_bert_attention(input_hidden_size=16, def create_tf2onnx_attention_3d(input_hidden_size=16, num_heads=4, head_size=4, use_float_mask=False): # unsqueeze in opset version 13 has two inputs (axis is moved from attribute to input). - has_unsqueeze_two_inputs = (version.parse(onnx.__version__) >= version.parse('1.8.0')) + has_unsqueeze_two_inputs = version.parse(onnx.__version__) >= version.parse("1.8.0") # nodes in attention subgraph nodes = [ helper.make_node("Add", ["input_1", "input_2"], ["layernorm_input"], "add_layernorm"), - helper.make_node("LayerNormalization", ["layernorm_input", "layer_norm_weight", "layer_norm_bias"], - ["layernorm_out"], - "layernorm", - axis=-1, - epsion=0.000009999999747378752), - + helper.make_node( + "LayerNormalization", + ["layernorm_input", "layer_norm_weight", "layer_norm_bias"], + ["layernorm_out"], + "layernorm", + axis=-1, + epsion=0.000009999999747378752, + ), # q nodes - helper.make_node("Einsum", ["layernorm_out", "einsum_q_weight"], ["einsum_q_out"], "einsum_q", equation="abc,cde->abde"), + helper.make_node( + "Einsum", + ["layernorm_out", "einsum_q_weight"], + ["einsum_q_out"], + "einsum_q", + equation="abc,cde->abde", + ), helper.make_node("Add", ["einsum_q_out", "add_q_weight"], ["add_q_out"], "add_q"), - # k nodes - helper.make_node("Einsum", ["layernorm_out", "einsum_k_weight"], ["einsum_k_out"], "einsum_k", equation="abc,cde->abde"), + helper.make_node( + "Einsum", + ["layernorm_out", "einsum_k_weight"], + ["einsum_k_out"], + "einsum_k", + equation="abc,cde->abde", + ), helper.make_node("Add", ["einsum_k_out", "add_k_weight"], ["add_k_out"], "add_k"), helper.make_node("Mul", ["add_k_out", "mul_weight_1"], ["mul_k_out"], "mul_k"), - # mask nodes - helper.make_node("Unsqueeze", ["input_mask", "axes_1"], ["unsqueeze0_out"], "unsqueeze0") if has_unsqueeze_two_inputs \ - else helper.make_node("Unsqueeze", ["input_mask"], ["unsqueeze0_out"], "unsqueeze0", axes=[1, 2]), - helper.make_node("Slice", ["unsqueeze0_out", "slice_start", "slice_end", "slice_axes", "slice_steps"], ["slice_out"], "slice"), - + helper.make_node("Unsqueeze", ["input_mask", "axes_1"], ["unsqueeze0_out"], "unsqueeze0") + if has_unsqueeze_two_inputs + else helper.make_node("Unsqueeze", ["input_mask"], ["unsqueeze0_out"], "unsqueeze0", axes=[1, 2]), + helper.make_node( + "Slice", + ["unsqueeze0_out", "slice_start", "slice_end", "slice_axes", "slice_steps"], + ["slice_out"], + "slice", + ), # when attention_mask is float type, no need to cast helper.make_node("Cast", ["slice_out"], ["cast_out"], "cast", to=1) if not use_float_mask else None, - helper.make_node("Sub", ["sub_weight", "unsqueeze1_out" if use_float_mask else "cast_out"], ["sub_out"], "sub"), + helper.make_node( + "Sub", + ["sub_weight", "unsqueeze1_out" if use_float_mask else "cast_out"], + ["sub_out"], + "sub", + ), helper.make_node("Mul", ["sub_out", "mul_weight_2"], ["mul_mask_out"], "mul_mask"), - # qk nodes - helper.make_node("Einsum", ["add_q_out", "mul_k_out"], ["einsum_qk_out"], "einsum_qk", equation="aecd,abcd->acbe"), + helper.make_node( + "Einsum", + ["add_q_out", "mul_k_out"], + ["einsum_qk_out"], + "einsum_qk", + equation="aecd,abcd->acbe", + ), helper.make_node("Add", ["einsum_qk_out", "mul_mask_out"], ["add_qk_out"], "add_qk"), helper.make_node("Softmax", ["add_qk_out"], ["softmax_qk_out"], "softmax_qk", axis=3), - # v nodes - helper.make_node("Einsum", ["layernorm_out", "einsum_v_weight"], ["einsum_v_out"], "einsum_v", equation="abc,cde->abde"), + helper.make_node( + "Einsum", + ["layernorm_out", "einsum_v_weight"], + ["einsum_v_out"], + "einsum_v", + equation="abc,cde->abde", + ), helper.make_node("Add", ["einsum_v_out", "add_v_weight"], ["add_v_out"], "add_v"), - # qkv nodes - helper.make_node("Einsum", ["softmax_qk_out", "add_v_out"], ["einsum_qkv_1_out"], "einsum_qkv_1", equation="acbe,aecd->abcd"), - helper.make_node("Einsum", ["einsum_qkv_1_out", "einsum_qkv_weight"], ["einsum_qkv_2_out"], "einsum_qkv_2", equation="abcd,cde->abe"), + helper.make_node( + "Einsum", + ["softmax_qk_out", "add_v_out"], + ["einsum_qkv_1_out"], + "einsum_qkv_1", + equation="acbe,aecd->abcd", + ), + helper.make_node( + "Einsum", + ["einsum_qkv_1_out", "einsum_qkv_weight"], + ["einsum_qkv_2_out"], + "einsum_qkv_2", + equation="abcd,cde->abe", + ), helper.make_node("Add", ["einsum_qkv_2_out", "add_qkv_weight"], ["add_qkv_out"], "add_qkv"), helper.make_node("Add", ["add_qkv_out", "layernorm_out"], ["skip_output"], "add_skip"), - helper.make_node("LayerNormalization", ["skip_output", "layer_norm_weight", "layer_norm_bias"], ["output"], - "layernorm2", - axis=-1, - epsion=0.000009999999747378752), + helper.make_node( + "LayerNormalization", + ["skip_output", "layer_norm_weight", "layer_norm_bias"], + ["output"], + "layernorm2", + axis=-1, + epsion=0.000009999999747378752, + ), ] initializers = [ # initializers - float_tensor('layer_norm_weight', [input_hidden_size]), - float_tensor('layer_norm_bias', [input_hidden_size]), - float_tensor('einsum_q_weight', [input_hidden_size, num_heads, head_size]), - float_tensor('einsum_k_weight', [input_hidden_size, num_heads, head_size]), - float_tensor('einsum_v_weight', [input_hidden_size, num_heads, head_size]), - float_tensor('einsum_qkv_weight', [num_heads, head_size, input_hidden_size]), - float_tensor('add_q_weight', [num_heads, head_size]), - float_tensor('add_k_weight', [num_heads, head_size]), - float_tensor('add_v_weight', [num_heads, head_size]), - float_tensor('add_qkv_weight', [input_hidden_size]), - helper.make_tensor('sub_weight', TensorProto.FLOAT, [1], [1.0]), - helper.make_tensor('mul_weight_1', TensorProto.FLOAT, [1], [-10000]), - helper.make_tensor('mul_weight_2', TensorProto.FLOAT, [1], [0.125]), - helper.make_tensor('reshape_weight_1', TensorProto.INT64, [4], [0, 0, num_heads, head_size]), - helper.make_tensor('slice_start', TensorProto.INT32, [4], [0, 0, 0, 0]), - helper.make_tensor('slice_end', TensorProto.INT32, [4], [1000000000, 1000000000, 1000000000, 1000000000]), - helper.make_tensor('slice_axes', TensorProto.INT32, [4], [0, 1, 2, 3]), - helper.make_tensor('slice_steps', TensorProto.INT32, [4], [1, 1, 1, 1]) + float_tensor("layer_norm_weight", [input_hidden_size]), + float_tensor("layer_norm_bias", [input_hidden_size]), + float_tensor("einsum_q_weight", [input_hidden_size, num_heads, head_size]), + float_tensor("einsum_k_weight", [input_hidden_size, num_heads, head_size]), + float_tensor("einsum_v_weight", [input_hidden_size, num_heads, head_size]), + float_tensor("einsum_qkv_weight", [num_heads, head_size, input_hidden_size]), + float_tensor("add_q_weight", [num_heads, head_size]), + float_tensor("add_k_weight", [num_heads, head_size]), + float_tensor("add_v_weight", [num_heads, head_size]), + float_tensor("add_qkv_weight", [input_hidden_size]), + helper.make_tensor("sub_weight", TensorProto.FLOAT, [1], [1.0]), + helper.make_tensor("mul_weight_1", TensorProto.FLOAT, [1], [-10000]), + helper.make_tensor("mul_weight_2", TensorProto.FLOAT, [1], [0.125]), + helper.make_tensor("reshape_weight_1", TensorProto.INT64, [4], [0, 0, num_heads, head_size]), + helper.make_tensor("slice_start", TensorProto.INT32, [4], [0, 0, 0, 0]), + helper.make_tensor( + "slice_end", + TensorProto.INT32, + [4], + [1000000000, 1000000000, 1000000000, 1000000000], + ), + helper.make_tensor("slice_axes", TensorProto.INT32, [4], [0, 1, 2, 3]), + helper.make_tensor("slice_steps", TensorProto.INT32, [4], [1, 1, 1, 1]), ] if has_unsqueeze_two_inputs: - initializers.append(helper.make_tensor('axes_1', TensorProto.INT64, [2], [1, 2])) + initializers.append(helper.make_tensor("axes_1", TensorProto.INT64, [2], [1, 2])) batch_size = 1 sequence_length = 3 graph = helper.make_graph( [node for node in nodes if node], - "AttentionFusionPrunedModel", #name + "AttentionFusionPrunedModel", # name [ # inputs - helper.make_tensor_value_info('input_1', TensorProto.FLOAT, - [batch_size, sequence_length, input_hidden_size]), - helper.make_tensor_value_info('input_2', TensorProto.FLOAT, - [batch_size, sequence_length, input_hidden_size]), - helper.make_tensor_value_info('input_mask', TensorProto.FLOAT if use_float_mask else TensorProto.INT64, - [batch_size, sequence_length]) + helper.make_tensor_value_info( + "input_1", + TensorProto.FLOAT, + [batch_size, sequence_length, input_hidden_size], + ), + helper.make_tensor_value_info( + "input_2", + TensorProto.FLOAT, + [batch_size, sequence_length, input_hidden_size], + ), + helper.make_tensor_value_info( + "input_mask", + TensorProto.FLOAT if use_float_mask else TensorProto.INT64, + [batch_size, sequence_length], + ), ], [ # outputs - helper.make_tensor_value_info('output', TensorProto.FLOAT, - [batch_size, sequence_length, input_hidden_size]), + helper.make_tensor_value_info( + "output", + TensorProto.FLOAT, + [batch_size, sequence_length, input_hidden_size], + ), ], - initializers) + initializers, + ) model = helper.make_model(graph) return model diff --git a/onnxruntime/test/python/transformers/gpt2_model_generator.py b/onnxruntime/test/python/transformers/gpt2_model_generator.py index ce8bc3d0cddc8..9a34744b42e74 100644 --- a/onnxruntime/test/python/transformers/gpt2_model_generator.py +++ b/onnxruntime/test/python/transformers/gpt2_model_generator.py @@ -4,214 +4,544 @@ # license information. # -------------------------------------------------------------------------- -import onnx import math -import numpy from typing import List -from packaging import version -from onnx import helper, TensorProto + +import numpy +import onnx from bert_model_generator import float_tensor, reverse_if +from onnx import TensorProto, helper +from packaging import version def create_gpt2_attention(hidden_size=64, num_heads=4, max_seq_len=32, switch_add_inputs=False): # unsqueeze in opset version 13 has two inputs (axis is moved from attribute to input). - is_opset_13_or_newer = (version.parse(onnx.__version__) >= version.parse('1.8.0')) + is_opset_13_or_newer = version.parse(onnx.__version__) >= version.parse("1.8.0") # nodes in attention subgraph nodes = [ helper.make_node("Add", ["input_1", "input_2"], ["layernorm_input"], "add_layernorm"), - helper.make_node("LayerNormalization", ["layernorm_input", "layer_norm_weight", "layer_norm_bias"], - ["layernorm_out"], - "layernorm", - epsion=0.000009999999747378752), - + helper.make_node( + "LayerNormalization", + ["layernorm_input", "layer_norm_weight", "layer_norm_bias"], + ["layernorm_out"], + "layernorm", + epsion=0.000009999999747378752, + ), # fully connection nodes - helper.make_node("MatMul", ["layernorm_out", "matmul_fc_weight"], ["matmul_fc_out"], "matmul_fc"), - helper.make_node("Add", reverse_if(["matmul_fc_out", "add_fc_weight"], switch_add_inputs), ["fc_out"], "add_fc"), - - helper.make_node("Split", ["fc_out", "split_q_k_v"], ["q", "k", "v"], "split_qkv", axis=2) if is_opset_13_or_newer \ - else helper.make_node("Split", ["fc_out"], ["q", "k", "v"], "split_qkv", axis=2, split=[hidden_size, hidden_size, hidden_size]), - + helper.make_node( + "MatMul", + ["layernorm_out", "matmul_fc_weight"], + ["matmul_fc_out"], + "matmul_fc", + ), + helper.make_node( + "Add", + reverse_if(["matmul_fc_out", "add_fc_weight"], switch_add_inputs), + ["fc_out"], + "add_fc", + ), + helper.make_node("Split", ["fc_out", "split_q_k_v"], ["q", "k", "v"], "split_qkv", axis=2) + if is_opset_13_or_newer + else helper.make_node( + "Split", + ["fc_out"], + ["q", "k", "v"], + "split_qkv", + axis=2, + split=[hidden_size, hidden_size, hidden_size], + ), # q nodes helper.make_node("Reshape", ["q", "reshape_x_shape"], ["reshape_q_out"], "reshape_q"), - helper.make_node("Transpose", ["reshape_q_out"], ["transpose_q_out"], "transpose_q", perm=[0, 2, 1, 3]), - + helper.make_node( + "Transpose", + ["reshape_q_out"], + ["transpose_q_out"], + "transpose_q", + perm=[0, 2, 1, 3], + ), # k nodes helper.make_node("Reshape", ["k", "reshape_x_shape"], ["reshape_k_out"], "reshape_k"), - helper.make_node("Transpose", ["reshape_k_out"], ["transpose_k_out"], "transpose_k", perm=[0, 2, 1, 3]), - + helper.make_node( + "Transpose", + ["reshape_k_out"], + ["transpose_k_out"], + "transpose_k", + perm=[0, 2, 1, 3], + ), # v nodes helper.make_node("Reshape", ["v", "reshape_x_shape"], ["reshape_v_out"], "reshape_v"), - helper.make_node("Transpose", ["reshape_v_out"], ["transpose_v_out"], "transpose_v", perm=[0, 2, 1, 3]), - + helper.make_node( + "Transpose", + ["reshape_v_out"], + ["transpose_v_out"], + "transpose_v", + perm=[0, 2, 1, 3], + ), # past - helper.make_node("Split", ["past", "split_1_1"], ["split_k", "split_v"], "split_past", axis=0) if is_opset_13_or_newer \ - else helper.make_node("Split", ["past"], ["split_k", "split_v"], "split_past", axis=0, split=[1, 1]), - - helper.make_node("Squeeze", ["split_k", "axes_0"], ["past_k"], "squeeze_past_k") if is_opset_13_or_newer \ - else helper.make_node("Squeeze", ["split_k"], ["past_k"], "squeeze_past_k", axes=[0]), - helper.make_node("Concat", ["past_k", "transpose_k_out"], ["concat_k_out"], "concat_k", axis=-2), - - helper.make_node("Transpose", ["concat_k_out"], ["concat_k_transpose_out"], "transpose_concat_k", perm=[0, 1, 3, 2]), - - helper.make_node("Squeeze", ["split_v", "axes_0"], ["past_v"], "squeeze_past_v") if is_opset_13_or_newer \ - else helper.make_node("Squeeze", ["split_v"], ["past_v"], "squeeze_past_v", axes=[0]), - helper.make_node("Concat", ["past_v", "transpose_v_out"], ["concat_v_out"], "concat_v", axis=-2), - + helper.make_node("Split", ["past", "split_1_1"], ["split_k", "split_v"], "split_past", axis=0) + if is_opset_13_or_newer + else helper.make_node( + "Split", + ["past"], + ["split_k", "split_v"], + "split_past", + axis=0, + split=[1, 1], + ), + helper.make_node("Squeeze", ["split_k", "axes_0"], ["past_k"], "squeeze_past_k") + if is_opset_13_or_newer + else helper.make_node("Squeeze", ["split_k"], ["past_k"], "squeeze_past_k", axes=[0]), + helper.make_node( + "Concat", + ["past_k", "transpose_k_out"], + ["concat_k_out"], + "concat_k", + axis=-2, + ), + helper.make_node( + "Transpose", + ["concat_k_out"], + ["concat_k_transpose_out"], + "transpose_concat_k", + perm=[0, 1, 3, 2], + ), + helper.make_node("Squeeze", ["split_v", "axes_0"], ["past_v"], "squeeze_past_v") + if is_opset_13_or_newer + else helper.make_node("Squeeze", ["split_v"], ["past_v"], "squeeze_past_v", axes=[0]), + helper.make_node( + "Concat", + ["past_v", "transpose_v_out"], + ["concat_v_out"], + "concat_v", + axis=-2, + ), # present - helper.make_node("Unsqueeze", ["concat_k_out", "axes_0"], ["concat_k_unsqueeze_out"], "concat_k_unsqueeze") if is_opset_13_or_newer \ - else helper.make_node("Unsqueeze", ["concat_k_out"], ["concat_k_unsqueeze_out"], "concat_k_unsqueeze", axes=[0]), - - helper.make_node("Unsqueeze", ["concat_v_out", "axes_0"], ["concat_v_unsqueeze_out"], "concat_v_unsqueeze") if is_opset_13_or_newer \ - else helper.make_node("Unsqueeze", ["concat_v_out"], ["concat_v_unsqueeze_out"], "concat_v_unsqueeze", axes=[0]), - - helper.make_node("Concat", ["concat_k_unsqueeze_out", "concat_v_unsqueeze_out"], ["present"], "concat_present", axis=0), - + helper.make_node( + "Unsqueeze", + ["concat_k_out", "axes_0"], + ["concat_k_unsqueeze_out"], + "concat_k_unsqueeze", + ) + if is_opset_13_or_newer + else helper.make_node( + "Unsqueeze", + ["concat_k_out"], + ["concat_k_unsqueeze_out"], + "concat_k_unsqueeze", + axes=[0], + ), + helper.make_node( + "Unsqueeze", + ["concat_v_out", "axes_0"], + ["concat_v_unsqueeze_out"], + "concat_v_unsqueeze", + ) + if is_opset_13_or_newer + else helper.make_node( + "Unsqueeze", + ["concat_v_out"], + ["concat_v_unsqueeze_out"], + "concat_v_unsqueeze", + axes=[0], + ), + helper.make_node( + "Concat", + ["concat_k_unsqueeze_out", "concat_v_unsqueeze_out"], + ["present"], + "concat_present", + axis=0, + ), helper.make_node("Shape", ["transpose_q_out"], ["transpose_q_shape_out"], "transpose_q_shape"), - helper.make_node("Slice", ["transpose_q_shape_out", "starts_n2", "ends_n1", "axes_0"], ["transpose_q_shape_slice_out"], "transpose_q_shape_slice"), - - helper.make_node("Squeeze", ["transpose_q_shape_slice_out", "axes_0"], ["transpose_q_shape_slice_squeeze_out"], "transpose_q_shape_slice_squeeze") if is_opset_13_or_newer \ - else helper.make_node("Squeeze", ["transpose_q_shape_slice_out"], ["transpose_q_shape_slice_squeeze_out"], "transpose_q_shape_slice_squeeze", axes=[0]), - + helper.make_node( + "Slice", + ["transpose_q_shape_out", "starts_n2", "ends_n1", "axes_0"], + ["transpose_q_shape_slice_out"], + "transpose_q_shape_slice", + ), + helper.make_node( + "Squeeze", + ["transpose_q_shape_slice_out", "axes_0"], + ["transpose_q_shape_slice_squeeze_out"], + "transpose_q_shape_slice_squeeze", + ) + if is_opset_13_or_newer + else helper.make_node( + "Squeeze", + ["transpose_q_shape_slice_out"], + ["transpose_q_shape_slice_squeeze_out"], + "transpose_q_shape_slice_squeeze", + axes=[0], + ), helper.make_node("Shape", ["concat_k_out"], ["concat_k_shape_out"], "concat_k_shape"), - helper.make_node("Slice", ["concat_k_shape_out", "starts_n2", "ends_n1", "axes_0"], ["concat_k_shape_slice_out"], "concat_k_shape_slice"), - - helper.make_node("Squeeze", ["concat_k_shape_slice_out", "axes_0"], ["concat_k_shape_slice_squeeze_out"], "concat_k_shape_slice_squeeze") if is_opset_13_or_newer \ - else helper.make_node("Squeeze", ["concat_k_shape_slice_out"], ["concat_k_shape_slice_squeeze_out"], "concat_k_shape_slice_squeeze", axes=[0]), - - helper.make_node("Sub", ["concat_k_shape_slice_squeeze_out", "transpose_q_shape_slice_squeeze_out"], ["sub_out"], "sub"), - - helper.make_node("Unsqueeze", ["sub_out", "axes_0"], ["sub_unsqueeze_out"], "sub_unsqueeze") if is_opset_13_or_newer \ - else helper.make_node("Unsqueeze", ["sub_out"], ["sub_unsqueeze_out"], "sub_unsqueeze", axes=[0]), - - helper.make_node("Unsqueeze", ["concat_k_shape_slice_squeeze_out", "axes_0"], ["concat_k_shape_slice_squeeze_unsqueeze_out"], "concat_k_shape_slice_squeeze_unsqueeze") if is_opset_13_or_newer \ - else helper.make_node("Unsqueeze", ["concat_k_shape_slice_squeeze_out"], ["concat_k_shape_slice_squeeze_unsqueeze_out"], "concat_k_shape_slice_squeeze_unsqueeze", axes=[0]), - - helper.make_node("Slice", ["undir_mask", "sub_unsqueeze_out", "concat_k_shape_slice_squeeze_unsqueeze_out", "axes_2", "steps_1"], ["undir_mask_slice_out"], "undir_mask_slice"), - helper.make_node("Slice", ["undir_mask_slice_out", "starts_0", "concat_k_shape_slice_squeeze_unsqueeze_out", "axes_3", "steps_1"], ["mask_slice_slice_out"], "mask_slice_slice"), - helper.make_node("Cast", ["mask_slice_slice_out"], ["undir_mask_out"], "undir_mask_cast", to=9), - + helper.make_node( + "Slice", + ["concat_k_shape_out", "starts_n2", "ends_n1", "axes_0"], + ["concat_k_shape_slice_out"], + "concat_k_shape_slice", + ), + helper.make_node( + "Squeeze", + ["concat_k_shape_slice_out", "axes_0"], + ["concat_k_shape_slice_squeeze_out"], + "concat_k_shape_slice_squeeze", + ) + if is_opset_13_or_newer + else helper.make_node( + "Squeeze", + ["concat_k_shape_slice_out"], + ["concat_k_shape_slice_squeeze_out"], + "concat_k_shape_slice_squeeze", + axes=[0], + ), + helper.make_node( + "Sub", + ["concat_k_shape_slice_squeeze_out", "transpose_q_shape_slice_squeeze_out"], + ["sub_out"], + "sub", + ), + helper.make_node("Unsqueeze", ["sub_out", "axes_0"], ["sub_unsqueeze_out"], "sub_unsqueeze") + if is_opset_13_or_newer + else helper.make_node("Unsqueeze", ["sub_out"], ["sub_unsqueeze_out"], "sub_unsqueeze", axes=[0]), + helper.make_node( + "Unsqueeze", + ["concat_k_shape_slice_squeeze_out", "axes_0"], + ["concat_k_shape_slice_squeeze_unsqueeze_out"], + "concat_k_shape_slice_squeeze_unsqueeze", + ) + if is_opset_13_or_newer + else helper.make_node( + "Unsqueeze", + ["concat_k_shape_slice_squeeze_out"], + ["concat_k_shape_slice_squeeze_unsqueeze_out"], + "concat_k_shape_slice_squeeze_unsqueeze", + axes=[0], + ), + helper.make_node( + "Slice", + [ + "undir_mask", + "sub_unsqueeze_out", + "concat_k_shape_slice_squeeze_unsqueeze_out", + "axes_2", + "steps_1", + ], + ["undir_mask_slice_out"], + "undir_mask_slice", + ), + helper.make_node( + "Slice", + [ + "undir_mask_slice_out", + "starts_0", + "concat_k_shape_slice_squeeze_unsqueeze_out", + "axes_3", + "steps_1", + ], + ["mask_slice_slice_out"], + "mask_slice_slice", + ), + helper.make_node( + "Cast", + ["mask_slice_slice_out"], + ["undir_mask_out"], + "undir_mask_cast", + to=9, + ), # mask nodes - helper.make_node("Reshape", ["input_mask", "input_mask_shape"], ["input_mask_reshape_out"], "input_mask_reshape"), - - helper.make_node("Unsqueeze", ["input_mask_reshape_out", "axes_1"], ["unsqueeze0_out"], "unsqueeze0") if is_opset_13_or_newer \ - else helper.make_node("Unsqueeze", ["input_mask_reshape_out"], ["unsqueeze0_out"], "unsqueeze0", axes=[1]), - - helper.make_node("Unsqueeze", ["unsqueeze0_out", "axes_2"], ["unsqueeze1_out"], "unsqueeze1") if is_opset_13_or_newer \ - else helper.make_node("Unsqueeze", ["unsqueeze0_out"], ["unsqueeze1_out"], "unsqueeze1", axes=[2]), - + helper.make_node( + "Reshape", + ["input_mask", "input_mask_shape"], + ["input_mask_reshape_out"], + "input_mask_reshape", + ), + helper.make_node( + "Unsqueeze", + ["input_mask_reshape_out", "axes_1"], + ["unsqueeze0_out"], + "unsqueeze0", + ) + if is_opset_13_or_newer + else helper.make_node( + "Unsqueeze", + ["input_mask_reshape_out"], + ["unsqueeze0_out"], + "unsqueeze0", + axes=[1], + ), + helper.make_node("Unsqueeze", ["unsqueeze0_out", "axes_2"], ["unsqueeze1_out"], "unsqueeze1") + if is_opset_13_or_newer + else helper.make_node("Unsqueeze", ["unsqueeze0_out"], ["unsqueeze1_out"], "unsqueeze1", axes=[2]), helper.make_node("Sub", ["sub_weight", "unsqueeze1_out"], ["mask_sub_out"], "sub_mask"), helper.make_node("Mul", ["mask_sub_out", "mul_weight"], ["mul_mask_out"], "mul_mask"), - - # qk nodes - helper.make_node("MatMul", ["transpose_q_out", "concat_k_transpose_out"], ["qk_out"], "matmul_qk"), + helper.make_node( + "MatMul", + ["transpose_q_out", "concat_k_transpose_out"], + ["qk_out"], + "matmul_qk", + ), helper.make_node("Div", ["qk_out", "div_weight"], ["qk_norm_out"], "qk_norm"), - - helper.make_node("Where", ["undir_mask_out", "qk_norm_out", "where_weight"], ["where_out"], "where"), - - helper.make_node("Add", reverse_if(["where_out", "mul_mask_out"], switch_add_inputs), ["add_mask_out"], "add_mask"), - + helper.make_node( + "Where", + ["undir_mask_out", "qk_norm_out", "where_weight"], + ["where_out"], + "where", + ), + helper.make_node( + "Add", + reverse_if(["where_out", "mul_mask_out"], switch_add_inputs), + ["add_mask_out"], + "add_mask", + ), helper.make_node("Softmax", ["add_mask_out"], ["softmax_out"], "softmax", axis=3), - # qkv nodes - helper.make_node("MatMul", ["softmax_out", "concat_v_out"], ["matmul_qkv_1_out"], "matmul_qk_v"), - helper.make_node("Transpose", ["matmul_qkv_1_out"], ["transpose_qkv_out"], "transpose_qkv", perm=[0, 2, 1, 3]), - helper.make_node("Reshape", ["transpose_qkv_out", "reshape_weight_qkv"], ["reshape_qkv_out"], "reshape_qkv"), + helper.make_node( + "MatMul", + ["softmax_out", "concat_v_out"], + ["matmul_qkv_1_out"], + "matmul_qk_v", + ), + helper.make_node( + "Transpose", + ["matmul_qkv_1_out"], + ["transpose_qkv_out"], + "transpose_qkv", + perm=[0, 2, 1, 3], + ), + helper.make_node( + "Reshape", + ["transpose_qkv_out", "reshape_weight_qkv"], + ["reshape_qkv_out"], + "reshape_qkv", + ), helper.make_node("Shape", ["reshape_qkv_out"], ["qkv_shape"], "shape_qkv"), - - helper.make_node("Slice", ["qkv_shape", "starts_n1", "ends_inf", "axes_0"], ["qkv_shape_slice_out"], "qkv_shape_slice"), - helper.make_node("Squeeze", ["qkv_shape_slice_out", "axes_0"], ["qkv_shape_slice_squeeze_out"], "qkv_shape_slice_squeeze") if is_opset_13_or_newer \ - else helper.make_node("Squeeze", ["qkv_shape_slice_out"], ["qkv_shape_slice_squeeze_out"], "qkv_shape_slice_squeeze", axes=[0]), - - helper.make_node("Unsqueeze", ["qkv_shape_slice_squeeze_out", "axes_0"], ["qkv_shape_slice_squeeze_unsqueeze_out"], "qkv_shape_slice_squeeze_unsqueeze") if is_opset_13_or_newer \ - else helper.make_node("Unsqueeze", ["qkv_shape_slice_squeeze_out"], ["qkv_shape_slice_squeeze_unsqueeze_out"], "qkv_shape_slice_squeeze_unsqueeze", axes=[0]), - - helper.make_node("Concat", ["concat_n1", "qkv_shape_slice_squeeze_unsqueeze_out"], ["qkv_shape_slice_squeeze_unsqueeze_concat_out"], "qkv_shape_slice_squeeze_unsqueeze_concat", axis=0), - - helper.make_node("Reshape", ["reshape_qkv_out", "qkv_shape_slice_squeeze_unsqueeze_concat_out"], ["qkv_reshape_out"], "qkv_reshape"), - helper.make_node("Gemm", ["qkv_reshape_out", "gemm_weight", "gemm_bias"], ["gemm_out"], "gemm", alpha=1.0, beta=1.0, transA=0, transB=0), - - helper.make_node("Gather", ["qkv_shape", "indices_1"], ["qkv_shape_1"], "shape_qkv_gather_1", axis=0), - helper.make_node("Gather", ["qkv_shape", "indices_0"], ["qkv_shape_0"], "shape_qkv_gather_0", axis=0), - - helper.make_node("Unsqueeze", ["qkv_shape_1", "axes_0"], ["qkv_shape_1_unsqueeze_out"], "qkv_shape_1_unsqueeze") if is_opset_13_or_newer \ - else helper.make_node("Unsqueeze", ["qkv_shape_1"], ["qkv_shape_1_unsqueeze_out"], "qkv_shape_1_unsqueeze", axes=[0]), - - helper.make_node("Unsqueeze", ["qkv_shape_0", "axes_0"], ["qkv_shape_0_unsqueeze_out"], "qkv_shape_0_unsqueeze") if is_opset_13_or_newer \ - else helper.make_node("Unsqueeze", ["qkv_shape_0"], ["qkv_shape_0_unsqueeze_out"], "qkv_shape_0_unsqueeze", axes=[0]), - - helper.make_node("Concat", ["qkv_shape_0_unsqueeze_out", "qkv_shape_1_unsqueeze_out", "qkv_hidden"], ["shape_qkv_concat_out"], "shape_qkv_concat", axis=0), - - helper.make_node("Reshape", ["gemm_out", "shape_qkv_concat_out"], ["gemm_reshape_out"], "gemm_reshape"), - - - helper.make_node("Add", reverse_if(["gemm_reshape_out", "layernorm_input"], switch_add_inputs), ["skip_output"], "add_skip"), - helper.make_node("LayerNormalization", ["skip_output", "layer_norm_weight", "layer_norm_bias"], ["output"], - "layernorm2", - epsion=0.000009999999747378752), + helper.make_node( + "Slice", + ["qkv_shape", "starts_n1", "ends_inf", "axes_0"], + ["qkv_shape_slice_out"], + "qkv_shape_slice", + ), + helper.make_node( + "Squeeze", + ["qkv_shape_slice_out", "axes_0"], + ["qkv_shape_slice_squeeze_out"], + "qkv_shape_slice_squeeze", + ) + if is_opset_13_or_newer + else helper.make_node( + "Squeeze", + ["qkv_shape_slice_out"], + ["qkv_shape_slice_squeeze_out"], + "qkv_shape_slice_squeeze", + axes=[0], + ), + helper.make_node( + "Unsqueeze", + ["qkv_shape_slice_squeeze_out", "axes_0"], + ["qkv_shape_slice_squeeze_unsqueeze_out"], + "qkv_shape_slice_squeeze_unsqueeze", + ) + if is_opset_13_or_newer + else helper.make_node( + "Unsqueeze", + ["qkv_shape_slice_squeeze_out"], + ["qkv_shape_slice_squeeze_unsqueeze_out"], + "qkv_shape_slice_squeeze_unsqueeze", + axes=[0], + ), + helper.make_node( + "Concat", + ["concat_n1", "qkv_shape_slice_squeeze_unsqueeze_out"], + ["qkv_shape_slice_squeeze_unsqueeze_concat_out"], + "qkv_shape_slice_squeeze_unsqueeze_concat", + axis=0, + ), + helper.make_node( + "Reshape", + ["reshape_qkv_out", "qkv_shape_slice_squeeze_unsqueeze_concat_out"], + ["qkv_reshape_out"], + "qkv_reshape", + ), + helper.make_node( + "Gemm", + ["qkv_reshape_out", "gemm_weight", "gemm_bias"], + ["gemm_out"], + "gemm", + alpha=1.0, + beta=1.0, + transA=0, + transB=0, + ), + helper.make_node( + "Gather", + ["qkv_shape", "indices_1"], + ["qkv_shape_1"], + "shape_qkv_gather_1", + axis=0, + ), + helper.make_node( + "Gather", + ["qkv_shape", "indices_0"], + ["qkv_shape_0"], + "shape_qkv_gather_0", + axis=0, + ), + helper.make_node( + "Unsqueeze", + ["qkv_shape_1", "axes_0"], + ["qkv_shape_1_unsqueeze_out"], + "qkv_shape_1_unsqueeze", + ) + if is_opset_13_or_newer + else helper.make_node( + "Unsqueeze", + ["qkv_shape_1"], + ["qkv_shape_1_unsqueeze_out"], + "qkv_shape_1_unsqueeze", + axes=[0], + ), + helper.make_node( + "Unsqueeze", + ["qkv_shape_0", "axes_0"], + ["qkv_shape_0_unsqueeze_out"], + "qkv_shape_0_unsqueeze", + ) + if is_opset_13_or_newer + else helper.make_node( + "Unsqueeze", + ["qkv_shape_0"], + ["qkv_shape_0_unsqueeze_out"], + "qkv_shape_0_unsqueeze", + axes=[0], + ), + helper.make_node( + "Concat", + ["qkv_shape_0_unsqueeze_out", "qkv_shape_1_unsqueeze_out", "qkv_hidden"], + ["shape_qkv_concat_out"], + "shape_qkv_concat", + axis=0, + ), + helper.make_node( + "Reshape", + ["gemm_out", "shape_qkv_concat_out"], + ["gemm_reshape_out"], + "gemm_reshape", + ), + helper.make_node( + "Add", + reverse_if(["gemm_reshape_out", "layernorm_input"], switch_add_inputs), + ["skip_output"], + "add_skip", + ), + helper.make_node( + "LayerNormalization", + ["skip_output", "layer_norm_weight", "layer_norm_bias"], + ["output"], + "layernorm2", + epsion=0.000009999999747378752, + ), ] head_size = int(hidden_size // num_heads) - unidir_mask = numpy.tril(numpy.ones( - (max_seq_len, max_seq_len))).reshape([max_seq_len * max_seq_len]).astype(numpy.uint8) + unidir_mask = ( + numpy.tril(numpy.ones((max_seq_len, max_seq_len))).reshape([max_seq_len * max_seq_len]).astype(numpy.uint8) + ) initializers = [ # initializers - float_tensor('layer_norm_weight', [hidden_size]), - float_tensor('layer_norm_bias', [hidden_size]), - float_tensor('matmul_fc_weight', [hidden_size, 3 * hidden_size]), - float_tensor('add_fc_weight', [3 * hidden_size]), - float_tensor('gemm_weight', [hidden_size, hidden_size]), - float_tensor('gemm_bias', [hidden_size]), - helper.make_tensor('undir_mask', TensorProto.UINT8, [1, 1, max_seq_len, max_seq_len], unidir_mask.tolist()), - helper.make_tensor('div_weight', TensorProto.FLOAT, [], [math.sqrt(head_size)]), - helper.make_tensor('sub_weight', TensorProto.FLOAT, [], [1.0]), - helper.make_tensor('where_weight', TensorProto.FLOAT, [], [-10000.]), - helper.make_tensor('mul_weight', TensorProto.FLOAT, [], [-10000]), - helper.make_tensor('input_mask_shape', TensorProto.INT64, [2], [0, -1]), - helper.make_tensor('starts_0', TensorProto.INT64, [1], [0]), - helper.make_tensor('concat_n1', TensorProto.INT64, [1], [-1]), - helper.make_tensor('starts_n1', TensorProto.INT64, [1], [-1]), - helper.make_tensor('ends_inf', TensorProto.INT64, [1], [9223372036854775807]), - helper.make_tensor('starts_n2', TensorProto.INT64, [1], [-2]), - helper.make_tensor('ends_n1', TensorProto.INT64, [1], [-1]), - helper.make_tensor('axes_0', TensorProto.INT64, [1], [0]), - helper.make_tensor('axes_2', TensorProto.INT64, [1], [2]), - helper.make_tensor('axes_3', TensorProto.INT64, [1], [3]), - helper.make_tensor('steps_1', TensorProto.INT64, [1], [1]), - helper.make_tensor('indices_0', TensorProto.INT64, [], [0]), - helper.make_tensor('indices_1', TensorProto.INT64, [], [1]), - helper.make_tensor('qkv_hidden', TensorProto.INT64, [1], [hidden_size]), - helper.make_tensor('reshape_x_shape', TensorProto.INT64, [4], [0, 0, num_heads, head_size]), - helper.make_tensor('reshape_weight_qkv', TensorProto.INT64, [3], [0, 0, hidden_size]), + float_tensor("layer_norm_weight", [hidden_size]), + float_tensor("layer_norm_bias", [hidden_size]), + float_tensor("matmul_fc_weight", [hidden_size, 3 * hidden_size]), + float_tensor("add_fc_weight", [3 * hidden_size]), + float_tensor("gemm_weight", [hidden_size, hidden_size]), + float_tensor("gemm_bias", [hidden_size]), + helper.make_tensor( + "undir_mask", + TensorProto.UINT8, + [1, 1, max_seq_len, max_seq_len], + unidir_mask.tolist(), + ), + helper.make_tensor("div_weight", TensorProto.FLOAT, [], [math.sqrt(head_size)]), + helper.make_tensor("sub_weight", TensorProto.FLOAT, [], [1.0]), + helper.make_tensor("where_weight", TensorProto.FLOAT, [], [-10000.0]), + helper.make_tensor("mul_weight", TensorProto.FLOAT, [], [-10000]), + helper.make_tensor("input_mask_shape", TensorProto.INT64, [2], [0, -1]), + helper.make_tensor("starts_0", TensorProto.INT64, [1], [0]), + helper.make_tensor("concat_n1", TensorProto.INT64, [1], [-1]), + helper.make_tensor("starts_n1", TensorProto.INT64, [1], [-1]), + helper.make_tensor("ends_inf", TensorProto.INT64, [1], [9223372036854775807]), + helper.make_tensor("starts_n2", TensorProto.INT64, [1], [-2]), + helper.make_tensor("ends_n1", TensorProto.INT64, [1], [-1]), + helper.make_tensor("axes_0", TensorProto.INT64, [1], [0]), + helper.make_tensor("axes_2", TensorProto.INT64, [1], [2]), + helper.make_tensor("axes_3", TensorProto.INT64, [1], [3]), + helper.make_tensor("steps_1", TensorProto.INT64, [1], [1]), + helper.make_tensor("indices_0", TensorProto.INT64, [], [0]), + helper.make_tensor("indices_1", TensorProto.INT64, [], [1]), + helper.make_tensor("qkv_hidden", TensorProto.INT64, [1], [hidden_size]), + helper.make_tensor("reshape_x_shape", TensorProto.INT64, [4], [0, 0, num_heads, head_size]), + helper.make_tensor("reshape_weight_qkv", TensorProto.INT64, [3], [0, 0, hidden_size]), ] if is_opset_13_or_newer: - initializers.append(helper.make_tensor('split_1_1', TensorProto.INT64, [2], [1, 1])) + initializers.append(helper.make_tensor("split_1_1", TensorProto.INT64, [2], [1, 1])) initializers.append( - helper.make_tensor('split_q_k_v', TensorProto.INT64, [3], [hidden_size, hidden_size, hidden_size])) - initializers.append(helper.make_tensor('axes_1', TensorProto.INT64, [1], [1])) + helper.make_tensor( + "split_q_k_v", + TensorProto.INT64, + [3], + [hidden_size, hidden_size, hidden_size], + ) + ) + initializers.append(helper.make_tensor("axes_1", TensorProto.INT64, [1], [1])) batch_size = 1 sequence_length = 3 past_sequence_length = 2 graph = helper.make_graph( [node for node in nodes if node], - "GPT2", #name + "GPT2", # name [ # inputs - helper.make_tensor_value_info('input_1', TensorProto.FLOAT, ['batch_size', 'sequence_length', hidden_size]), - helper.make_tensor_value_info('input_2', TensorProto.FLOAT, ['batch_size', 'sequence_length', hidden_size]), - helper.make_tensor_value_info('input_mask', TensorProto.FLOAT, - ['batch_size', 'past_sequence_length + sequence_length']), - helper.make_tensor_value_info('past', TensorProto.FLOAT, - [2, 'batch_size', num_heads, 'past_sequence_length', head_size]) + helper.make_tensor_value_info( + "input_1", + TensorProto.FLOAT, + ["batch_size", "sequence_length", hidden_size], + ), + helper.make_tensor_value_info( + "input_2", + TensorProto.FLOAT, + ["batch_size", "sequence_length", hidden_size], + ), + helper.make_tensor_value_info( + "input_mask", + TensorProto.FLOAT, + ["batch_size", "past_sequence_length + sequence_length"], + ), + helper.make_tensor_value_info( + "past", + TensorProto.FLOAT, + [2, "batch_size", num_heads, "past_sequence_length", head_size], + ), ], [ # outputs - helper.make_tensor_value_info('output', TensorProto.FLOAT, ['batch_size', 'sequence_length', hidden_size]), helper.make_tensor_value_info( - 'present', TensorProto.FLOAT, - [2, 'batch_size', num_heads, 'past_sequence_length + sequence_length', head_size]), + "output", + TensorProto.FLOAT, + ["batch_size", "sequence_length", hidden_size], + ), + helper.make_tensor_value_info( + "present", + TensorProto.FLOAT, + [ + 2, + "batch_size", + num_heads, + "past_sequence_length + sequence_length", + head_size, + ], + ), ], - initializers) + initializers, + ) model = helper.make_model(graph) return model diff --git a/onnxruntime/test/python/transformers/model_loader.py b/onnxruntime/test/python/transformers/model_loader.py index 0ea5b6411d93e..126df89240c70 100644 --- a/onnxruntime/test/python/transformers/model_loader.py +++ b/onnxruntime/test/python/transformers/model_loader.py @@ -4,11 +4,12 @@ # license information. # -------------------------------------------------------------------------- -import unittest import os -from onnx import ModelProto, TensorProto, numpy_helper, external_data_helper, load_model +import unittest +from onnx import ModelProto, TensorProto, external_data_helper, load_model, numpy_helper from parity_utilities import find_transformers_source + if find_transformers_source(): from fusion_utils import NumpyHelper else: @@ -56,14 +57,14 @@ def load_model_with_dummy_external_data(path: str) -> ModelProto: def get_test_data_path(sub_dir: str, file: str): - relative_path = os.path.join(os.path.dirname(__file__), 'test_data', sub_dir, file) - if (os.path.exists(relative_path)): + relative_path = os.path.join(os.path.dirname(__file__), "test_data", sub_dir, file) + if os.path.exists(relative_path): return relative_path - return os.path.join('.', 'transformers', 'test_data', sub_dir, file) + return os.path.join(".", "transformers", "test_data", sub_dir, file) def get_fusion_test_model(file: str): - relative_path = os.path.join(os.path.dirname(__file__), '..', '..', 'testdata', 'transform', 'fusion', file) - if (os.path.exists(relative_path)): + relative_path = os.path.join(os.path.dirname(__file__), "..", "..", "testdata", "transform", "fusion", file) + if os.path.exists(relative_path): return relative_path - return os.path.join('.', 'testdata', 'transform', 'fusion', file) + return os.path.join(".", "testdata", "transform", "fusion", file) diff --git a/onnxruntime/test/python/transformers/parity_utilities.py b/onnxruntime/test/python/transformers/parity_utilities.py index d450700581eda..0fe61d9b600fd 100644 --- a/onnxruntime/test/python/transformers/parity_utilities.py +++ b/onnxruntime/test/python/transformers/parity_utilities.py @@ -6,20 +6,36 @@ import os import sys + import numpy import torch def find_transformers_source(sub_dir_paths=[]): - source_dir = os.path.join(os.path.dirname(__file__), '..', '..', '..', 'python', 'tools', 'transformers', *sub_dir_paths) - if (os.path.exists(source_dir)): + source_dir = os.path.join( + os.path.dirname(__file__), + "..", + "..", + "..", + "python", + "tools", + "transformers", + *sub_dir_paths, + ) + if os.path.exists(source_dir): if source_dir not in sys.path: sys.path.append(source_dir) return True - return False + return False -def create_inputs(batch_size=1, sequence_length=1, hidden_size=768, float16=False, device=torch.device('cuda')): +def create_inputs( + batch_size=1, + sequence_length=1, + hidden_size=768, + float16=False, + device=torch.device("cuda"), +): float_type = torch.float16 if float16 else torch.float32 input = torch.normal(mean=0.0, std=10.0, size=(batch_size, sequence_length, hidden_size)).to(float_type).to(device) return input @@ -27,42 +43,54 @@ def create_inputs(batch_size=1, sequence_length=1, hidden_size=768, float16=Fals def export_onnx(model, onnx_model_path, float16, hidden_size, device): from pathlib import Path + Path(onnx_model_path).parent.mkdir(parents=True, exist_ok=True) input_hidden_states = create_inputs(hidden_size=hidden_size, float16=float16, device=device) with torch.no_grad(): outputs = model(input_hidden_states) - dynamic_axes = {'input': {0: 'batch_size', 1: 'seq_len'}, "output": {0: 'batch_size', 1: 'seq_len'}} - - torch.onnx.export(model, - args=(input_hidden_states), - f=onnx_model_path, - input_names=['input'], - output_names=["output"], - dynamic_axes=dynamic_axes, - opset_version=11, - do_constant_folding=True) + dynamic_axes = { + "input": {0: "batch_size", 1: "seq_len"}, + "output": {0: "batch_size", 1: "seq_len"}, + } + + torch.onnx.export( + model, + args=(input_hidden_states), + f=onnx_model_path, + input_names=["input"], + output_names=["output"], + dynamic_axes=dynamic_axes, + opset_version=11, + do_constant_folding=True, + ) print("exported:", onnx_model_path) -def optimize_onnx(input_onnx_path, optimized_onnx_path, expected_op=None, use_gpu=False, opt_level=None): +def optimize_onnx( + input_onnx_path, + optimized_onnx_path, + expected_op=None, + use_gpu=False, + opt_level=None, +): if find_transformers_source(): from optimizer import optimize_model else: from onnxruntime.transformers.optimizer import optimize_model - onnx_model = optimize_model(input_onnx_path, model_type='gpt2', use_gpu=use_gpu, opt_level=opt_level) + onnx_model = optimize_model(input_onnx_path, model_type="gpt2", use_gpu=use_gpu, opt_level=opt_level) onnx_model.save_model_to_file(optimized_onnx_path) if expected_op is not None: - assert len(onnx_model.get_nodes_by_op_type(expected_op)) == 1, \ - f"Expected {expected_op} node not found in the optimized model {optimized_onnx_path}" + assert ( + len(onnx_model.get_nodes_by_op_type(expected_op)) == 1 + ), f"Expected {expected_op} node not found in the optimized model {optimized_onnx_path}" def diff_outputs(torch_outputs, ort_outputs, index): - """ Returns the maximum difference between PyTorch and OnnxRuntime outputs. - """ + """Returns the maximum difference between PyTorch and OnnxRuntime outputs.""" expected_outputs = torch_outputs[index].cpu().numpy() diff = numpy.abs(expected_outputs - ort_outputs[index]) return numpy.amax(diff) @@ -81,10 +109,12 @@ def compare_outputs(torch_outputs, ort_outputs, atol=1e-06, verbose=True): is_all_close(bool): whether all elements are close. max_abs_diff(float): maximum absolute difference. """ - same = numpy.asarray([ - numpy.allclose(ort_outputs[i], torch_outputs[i].cpu().numpy(), atol=atol, rtol=0) - for i in range(len(ort_outputs)) - ]) + same = numpy.asarray( + [ + numpy.allclose(ort_outputs[i], torch_outputs[i].cpu().numpy(), atol=atol, rtol=0) + for i in range(len(ort_outputs)) + ] + ) max_abs_diff = [diff_outputs(torch_outputs, ort_outputs, i) for i in range(len(ort_outputs))] @@ -94,43 +124,47 @@ def compare_outputs(torch_outputs, ort_outputs, atol=1e-06, verbose=True): diff = numpy.fabs(ort_outputs[i] - torch_outputs[i].cpu().numpy()) idx = numpy.unravel_index(diff.argmax(), diff.shape) print( - f'Output {i}, diff={diff[idx]:.9f} index={idx} ort={ort_outputs[i][idx]:.9f} torch={float(torch_outputs[i][idx]):.9f}' + f"Output {i}, diff={diff[idx]:.9f} index={idx} ort={ort_outputs[i][idx]:.9f} torch={float(torch_outputs[i][idx]):.9f}" ) return is_all_close, max(max_abs_diff) def create_ort_session(onnx_model_path, use_gpu=True): - from onnxruntime import SessionOptions, InferenceSession, GraphOptimizationLevel, __version__ as onnxruntime_version + from onnxruntime import GraphOptimizationLevel, InferenceSession, SessionOptions + from onnxruntime import __version__ as onnxruntime_version + sess_options = SessionOptions() sess_options.graph_optimization_level = GraphOptimizationLevel.ORT_DISABLE_ALL sess_options.intra_op_num_threads = 2 sess_options.log_severity_level = 2 - execution_providers = ['CPUExecutionProvider'] if not use_gpu else ['CUDAExecutionProvider', 'CPUExecutionProvider'] + execution_providers = ["CPUExecutionProvider"] if not use_gpu else ["CUDAExecutionProvider", "CPUExecutionProvider"] return InferenceSession(onnx_model_path, sess_options, providers=execution_providers) def onnxruntime_inference(ort_session, input): - ort_inputs = {'input': numpy.ascontiguousarray(input.cpu().numpy())} + ort_inputs = {"input": numpy.ascontiguousarray(input.cpu().numpy())} ort_outputs = ort_session.run(None, ort_inputs) return ort_outputs -def run_parity(model, - onnx_model_path, - batch_size, - hidden_size, - sequence_length, - float16, - device, - optimized, - test_cases=100, - verbose=False, - tolerance=None): +def run_parity( + model, + onnx_model_path, + batch_size, + hidden_size, + sequence_length, + float16, + device, + optimized, + test_cases=100, + verbose=False, + tolerance=None, +): passed_cases = 0 max_diffs = [] printed = False # print only one sample - ort_session = create_ort_session(onnx_model_path, device.type == 'cuda') + ort_session = create_ort_session(onnx_model_path, device.type == "cuda") for i in range(test_cases): input_hidden_states = create_inputs(batch_size, sequence_length, hidden_size, float16, device) @@ -147,7 +181,7 @@ def run_parity(model, passed_cases += 1 elif verbose and not printed: printed = True - numpy.set_printoptions(precision=10, floatmode='fixed') + numpy.set_printoptions(precision=10, floatmode="fixed") torch.set_printoptions(precision=10) print("input", input_hidden_states) print("torch_outputs", torch_outputs) diff --git a/onnxruntime/test/python/transformers/test_attention_fusion.py b/onnxruntime/test/python/transformers/test_attention_fusion.py index 5e593704c7032..f9d64bb36c2b8 100644 --- a/onnxruntime/test/python/transformers/test_attention_fusion.py +++ b/onnxruntime/test/python/transformers/test_attention_fusion.py @@ -4,27 +4,28 @@ # license information. # -------------------------------------------------------------------------- -import unittest import os +import unittest + import onnx from bert_model_generator import create_bert_attention, create_tf2onnx_attention_3d from gpt2_model_generator import create_gpt2_attention from model_loader import get_test_data_path - from parity_utilities import find_transformers_source + if find_transformers_source(): - from optimizer import optimize_model, optimize_by_fusion from onnx_model import OnnxModel + from optimizer import optimize_by_fusion, optimize_model else: - from onnxruntime.transformers.optimizer import optimize_model, optimize_by_fusion from onnxruntime.transformers.onnx_model import OnnxModel + from onnxruntime.transformers.optimizer import optimize_by_fusion, optimize_model class TestFusion(unittest.TestCase): def verify_fusion(self, optimized_model, expected_model_filename): optimized_model.topological_sort() - expected_model_path = os.path.join(os.path.dirname(__file__), 'test_data', 'models', expected_model_filename) + expected_model_path = os.path.join(os.path.dirname(__file__), "test_data", "models", expected_model_filename) expected_model = OnnxModel(onnx.load(expected_model_path)) expected_model.topological_sort() @@ -32,93 +33,107 @@ def verify_fusion(self, optimized_model, expected_model_filename): def test_attention_fusion(self): model = create_bert_attention() - dir = '.' + dir = "." model_path = os.path.join(dir, "attention.onnx") onnx.save(model, model_path) optimized_model = optimize_model(model_path) os.remove(model_path) - self.verify_fusion(optimized_model, 'attention_opt.onnx') + self.verify_fusion(optimized_model, "attention_opt.onnx") def test_attention_fusion_pruned_model(self): - model = create_bert_attention(input_hidden_size=16, - num_heads=2, - pruned_qk_hidden_size=8, - pruned_v_hidden_size=8) - dir = '.' + model = create_bert_attention( + input_hidden_size=16, + num_heads=2, + pruned_qk_hidden_size=8, + pruned_v_hidden_size=8, + ) + dir = "." model_path = os.path.join(dir, "pruned_attention.onnx") onnx.save(model, model_path) optimized_model = optimize_model(model_path) os.remove(model_path) - self.verify_fusion(optimized_model, 'pruned_attention_opt.onnx') + self.verify_fusion(optimized_model, "pruned_attention_opt.onnx") def test_attention_fusion_reverse_add_order(self): - model = create_bert_attention(input_hidden_size=16, - num_heads=2, - pruned_qk_hidden_size=8, - pruned_v_hidden_size=8, - switch_add_inputs=True) - dir = '.' + model = create_bert_attention( + input_hidden_size=16, + num_heads=2, + pruned_qk_hidden_size=8, + pruned_v_hidden_size=8, + switch_add_inputs=True, + ) + dir = "." model_path = os.path.join(dir, "bert_attention_reverse_add_order.onnx") onnx.save(model, model_path) optimized_model = optimize_model(model_path) os.remove(model_path) # reverse add input order will get same optimized model - self.verify_fusion(optimized_model, 'pruned_attention_opt.onnx') + self.verify_fusion(optimized_model, "pruned_attention_opt.onnx") def test_attention_fusion_for_varied_qkv_dimensions(self): - model = create_bert_attention(input_hidden_size=16, - num_heads=2, - pruned_qk_hidden_size=24, - pruned_v_hidden_size=16) - dir = '.' + model = create_bert_attention( + input_hidden_size=16, + num_heads=2, + pruned_qk_hidden_size=24, + pruned_v_hidden_size=16, + ) + dir = "." model_path = os.path.join(dir, "attention_with_varied_qkv.onnx") onnx.save(model, model_path) optimized_model = optimize_model(model_path) os.remove(model_path) - self.verify_fusion(optimized_model, 'attention_with_varied_qkv_opt.onnx') + self.verify_fusion(optimized_model, "attention_with_varied_qkv_opt.onnx") def test_attention_fusion_for_varied_qkv_dimensions_with_wrong_opt_parameters(self): - model = create_bert_attention(input_hidden_size=16, - num_heads=2, - pruned_qk_hidden_size=24, - pruned_v_hidden_size=16) - dir = '.' + model = create_bert_attention( + input_hidden_size=16, + num_heads=2, + pruned_qk_hidden_size=24, + pruned_v_hidden_size=16, + ) + dir = "." model_path = os.path.join(dir, "attention_with_varied_qkv.onnx") onnx.save(model, model_path) - #wrong num_heads and hidden_size - optimized_model = optimize_model(model_path, 'bert', num_heads=8, hidden_size=8) + # wrong num_heads and hidden_size + optimized_model = optimize_model(model_path, "bert", num_heads=8, hidden_size=8) os.remove(model_path) - self.verify_fusion(optimized_model, 'attention_with_varied_qkv_opt.onnx') + self.verify_fusion(optimized_model, "attention_with_varied_qkv_opt.onnx") def test_3d_attention_fusion_tf2onnx_model(self): model = create_tf2onnx_attention_3d() - dir = '.' - model_path = os.path.join(dir, 'bert_3d_attention.onnx') + dir = "." + model_path = os.path.join(dir, "bert_3d_attention.onnx") onnx.save(model, model_path) - optimized_model = optimize_model(model_path, model_type='bert_tf', num_heads=4, hidden_size=16) + optimized_model = optimize_model(model_path, model_type="bert_tf", num_heads=4, hidden_size=16) os.remove(model_path) - self.verify_fusion(optimized_model, 'bert_3d_attention_opt.onnx') + self.verify_fusion(optimized_model, "bert_3d_attention_opt.onnx") def test_gpt2_attention_fusion(self): hidden_size = 64 num_heads = 4 for add_order in [False, True]: - model = create_gpt2_attention(hidden_size=hidden_size, num_heads=num_heads, switch_add_inputs=add_order) - dir = '.' + model = create_gpt2_attention( + hidden_size=hidden_size, + num_heads=num_heads, + switch_add_inputs=add_order, + ) + dir = "." model_path = os.path.join(dir, "gpt2_attention.onnx") onnx.save(model, model_path) - optimized_model = optimize_model(model_path, - model_type='gpt2', - num_heads=num_heads, - hidden_size=hidden_size) + optimized_model = optimize_model( + model_path, + model_type="gpt2", + num_heads=num_heads, + hidden_size=hidden_size, + ) optimized_model.topological_sort() os.remove(model_path) @@ -128,10 +143,10 @@ def test_gpt2_attention_fusion(self): def test_megatron_gpt2_attention_fusion(self): path = get_test_data_path("models", "gpt2_megatron.onnx") model = onnx.load(path) - optimized_model = optimize_by_fusion(model, model_type='gpt2') + optimized_model = optimize_by_fusion(model, model_type="gpt2") self.verify_fusion(optimized_model, "gpt2_megatron_opt.onnx") -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/onnxruntime/test/python/transformers/test_beam_search.py b/onnxruntime/test/python/transformers/test_beam_search.py index 64e57626372fc..56bf6a3b218e7 100644 --- a/onnxruntime/test/python/transformers/test_beam_search.py +++ b/onnxruntime/test/python/transformers/test_beam_search.py @@ -6,11 +6,12 @@ # license information. # -------------------------------------------------------------------------- -import unittest import os -import pytest +import unittest +import pytest from parity_utilities import find_transformers_source + if find_transformers_source(): from convert_beam_search import main as run else: @@ -18,21 +19,22 @@ class TestBeamSearch(unittest.TestCase): - def setUp(self): - #TODO: use a smaller model and enable tests in CI pipeline + # TODO: use a smaller model and enable tests in CI pipeline self.model_name = "gpt2" - self.gpt2_onnx_path = os.path.join('.', 'onnx_models', 'gpt2_past_fp32_shape.onnx') - self.beam_search_onnx_path = os.path.join('.', 'onnx_models', 'gpt2_beam_search.onnx') - self.cpu_params = f'-m {self.model_name} --decoder_onnx {self.gpt2_onnx_path} --output {self.beam_search_onnx_path} --output_sequences_score --repetition_penalty 2.0' + self.gpt2_onnx_path = os.path.join(".", "onnx_models", "gpt2_past_fp32_shape.onnx") + self.beam_search_onnx_path = os.path.join(".", "onnx_models", "gpt2_beam_search.onnx") + self.cpu_params = f"-m {self.model_name} --decoder_onnx {self.gpt2_onnx_path} --output {self.beam_search_onnx_path} --output_sequences_score --repetition_penalty 2.0" def run_beam_search(self, arguments: str, sentences=None): return run(arguments.split(), sentences=sentences) @pytest.mark.slow def test_cpu(self): - result = self.run_beam_search(self.cpu_params + " --num_return_sequences 2", - sentences=["The product is released"]) + result = self.run_beam_search( + self.cpu_params + " --num_return_sequences 2", + sentences=["The product is released"], + ) os.remove(self.gpt2_onnx_path) os.remove(self.beam_search_onnx_path) self.assertTrue(result["parity"], "ORT and PyTorch result is different") @@ -61,11 +63,11 @@ def test_length_penalty(self): @pytest.mark.slow def test_no_repeat_ngram(self): for ngram_size in [1, 2]: - result = self.run_beam_search(self.cpu_params + f' --no_repeat_ngram_size {ngram_size}') + result = self.run_beam_search(self.cpu_params + f" --no_repeat_ngram_size {ngram_size}") os.remove(self.gpt2_onnx_path) os.remove(self.beam_search_onnx_path) self.assertTrue(result["parity"], "ORT and PyTorch result is different") -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/onnxruntime/test/python/transformers/test_data/bert_squad_tensorflow2.1_keras2onnx_opset11/generate_tiny_keras2onnx_bert_models.py b/onnxruntime/test/python/transformers/test_data/bert_squad_tensorflow2.1_keras2onnx_opset11/generate_tiny_keras2onnx_bert_models.py index 6276fa83385b5..47145fc213a0d 100644 --- a/onnxruntime/test/python/transformers/test_data/bert_squad_tensorflow2.1_keras2onnx_opset11/generate_tiny_keras2onnx_bert_models.py +++ b/onnxruntime/test/python/transformers/test_data/bert_squad_tensorflow2.1_keras2onnx_opset11/generate_tiny_keras2onnx_bert_models.py @@ -1,7 +1,7 @@ -#------------------------------------------------------------------------- +# ------------------------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. -#-------------------------------------------------------------------------- +# -------------------------------------------------------------------------- """ Convert a Bert large model exported by Keras2Onnx to a tiny model for test purpose. The input model is generated like the following (need install keras2onnx from source): @@ -23,18 +23,20 @@ keras2onnx.save_model(onnx_model, output_model_path) """ -import onnx -import onnx.utils -import sys import argparse +import os +import random +import sys +import timeit +from pathlib import Path + import numpy as np +import onnx +import onnx.utils from onnx import ModelProto, TensorProto, numpy_helper from onnxruntime_tools.transformers.onnx_model import OnnxModel -import os + import onnxruntime -import random -from pathlib import Path -import timeit DICT_SIZE = 20 SEQ_LEN = 7 @@ -52,18 +54,20 @@ def resize_weight(self, initializer_name, target_shape): target_w = w if len(target_shape) == 1: - target_w = w[:target_shape[0]] + target_w = w[: target_shape[0]] elif len(target_shape) == 2: - target_w = w[:target_shape[0], :target_shape[1]] + target_w = w[: target_shape[0], : target_shape[1]] elif len(target_shape) == 3: - target_w = w[:target_shape[0], :target_shape[1], :target_shape[2]] + target_w = w[: target_shape[0], : target_shape[1], : target_shape[2]] else: print("at most 3 dimensions") - tensor = onnx.helper.make_tensor(name=initializer_name + '_resize', - data_type=TensorProto.FLOAT, - dims=target_shape, - vals=target_w.flatten().tolist()) + tensor = onnx.helper.make_tensor( + name=initializer_name + "_resize", + data_type=TensorProto.FLOAT, + dims=target_shape, + vals=target_w.flatten().tolist(), + ) return tensor @@ -78,7 +82,7 @@ def resize_model(self): "num_heads": 16, "size_per_head": 64, "word_dict_size": [28996, 30522], # list of supported dictionary size. - "max_word_position": 512 + "max_word_position": 512, } # parameters of output tiny model. @@ -88,11 +92,11 @@ def resize_model(self): "num_heads": 2, "size_per_head": 4, "word_dict_size": DICT_SIZE, - "max_word_position": 10 + "max_word_position": 10, } for input in graph.input: - if (input.type.tensor_type.shape.dim[1].dim_value == old_parameters["seq_len"]): + if input.type.tensor_type.shape.dim[1].dim_value == old_parameters["seq_len"]: print("input", input.name, input.type.tensor_type.shape) input.type.tensor_type.shape.dim[1].dim_value = new_parameters["seq_len"] print("=>", input.type.tensor_type.shape) @@ -103,84 +107,182 @@ def resize_model(self): dtype = np.float32 if initializer.data_type == 1 else np.int32 if len(tensor.shape) == 1 and tensor.shape[0] == 1: if tensor == old_parameters["num_heads"]: - print("initializer type={}".format(initializer.data_type), initializer.name, - old_parameters["num_heads"], "=>[", new_parameters["num_heads"], "]") + print( + "initializer type={}".format(initializer.data_type), + initializer.name, + old_parameters["num_heads"], + "=>[", + new_parameters["num_heads"], + "]", + ) initializer.CopyFrom( - numpy_helper.from_array(np.asarray([new_parameters["num_heads"]], dtype=dtype), - initializer.name)) + numpy_helper.from_array( + np.asarray([new_parameters["num_heads"]], dtype=dtype), + initializer.name, + ) + ) elif tensor == old_parameters["seq_len"]: - print("initializer type={}".format(initializer.data_type), initializer.name, - old_parameters["seq_len"], "=>[", new_parameters["seq_len"], "]") + print( + "initializer type={}".format(initializer.data_type), + initializer.name, + old_parameters["seq_len"], + "=>[", + new_parameters["seq_len"], + "]", + ) initializer.CopyFrom( - numpy_helper.from_array(np.asarray([new_parameters["seq_len"]], dtype=dtype), initializer.name)) + numpy_helper.from_array( + np.asarray([new_parameters["seq_len"]], dtype=dtype), + initializer.name, + ) + ) elif tensor == old_parameters["size_per_head"]: - print("initializer type={}".format(initializer.data_type), initializer.name, - old_parameters["size_per_head"], "=>[", new_parameters["size_per_head"], "]") + print( + "initializer type={}".format(initializer.data_type), + initializer.name, + old_parameters["size_per_head"], + "=>[", + new_parameters["size_per_head"], + "]", + ) initializer.CopyFrom( - numpy_helper.from_array(np.asarray([new_parameters["size_per_head"]], dtype=dtype), - initializer.name)) + numpy_helper.from_array( + np.asarray([new_parameters["size_per_head"]], dtype=dtype), + initializer.name, + ) + ) elif tensor == old_parameters["hidden_size"]: - print("initializer type={}".format(initializer.data_type), initializer.name, - old_parameters["hidden_size"], "=>[", new_parameters["hidden_size"], "]") + print( + "initializer type={}".format(initializer.data_type), + initializer.name, + old_parameters["hidden_size"], + "=>[", + new_parameters["hidden_size"], + "]", + ) initializer.CopyFrom( - numpy_helper.from_array(np.asarray([new_parameters["hidden_size"]], dtype=dtype), - initializer.name)) + numpy_helper.from_array( + np.asarray([new_parameters["hidden_size"]], dtype=dtype), + initializer.name, + ) + ) elif tensor == 4 * old_parameters["hidden_size"]: - print("initializer type={}".format(initializer.data_type), initializer.name, - 4 * old_parameters["hidden_size"], "=>[", 4 * new_parameters["hidden_size"], "]") + print( + "initializer type={}".format(initializer.data_type), + initializer.name, + 4 * old_parameters["hidden_size"], + "=>[", + 4 * new_parameters["hidden_size"], + "]", + ) initializer.CopyFrom( - numpy_helper.from_array(np.asarray([4 * new_parameters["hidden_size"]], dtype=dtype), - initializer.name)) + numpy_helper.from_array( + np.asarray([4 * new_parameters["hidden_size"]], dtype=dtype), + initializer.name, + ) + ) elif len(tensor.shape) == 0: if tensor == old_parameters["num_heads"]: - print("initializer type={}".format(initializer.data_type), initializer.name, - old_parameters["num_heads"], "=>", new_parameters["num_heads"]) + print( + "initializer type={}".format(initializer.data_type), + initializer.name, + old_parameters["num_heads"], + "=>", + new_parameters["num_heads"], + ) initializer.CopyFrom( - numpy_helper.from_array(np.asarray(new_parameters["num_heads"], dtype=dtype), initializer.name)) + numpy_helper.from_array( + np.asarray(new_parameters["num_heads"], dtype=dtype), + initializer.name, + ) + ) elif tensor == old_parameters["seq_len"]: - print("initializer type={}".format(initializer.data_type), initializer.name, - old_parameters["seq_len"], "=>", new_parameters["seq_len"]) + print( + "initializer type={}".format(initializer.data_type), + initializer.name, + old_parameters["seq_len"], + "=>", + new_parameters["seq_len"], + ) initializer.CopyFrom( - numpy_helper.from_array(np.asarray(new_parameters["seq_len"], dtype=dtype), initializer.name)) + numpy_helper.from_array( + np.asarray(new_parameters["seq_len"], dtype=dtype), + initializer.name, + ) + ) elif tensor == old_parameters["size_per_head"]: - print("initializer type={}".format(initializer.data_type), initializer.name, - old_parameters["size_per_head"], "=>", new_parameters["size_per_head"]) + print( + "initializer type={}".format(initializer.data_type), + initializer.name, + old_parameters["size_per_head"], + "=>", + new_parameters["size_per_head"], + ) initializer.CopyFrom( - numpy_helper.from_array(np.asarray(new_parameters["size_per_head"], dtype=dtype), - initializer.name)) + numpy_helper.from_array( + np.asarray(new_parameters["size_per_head"], dtype=dtype), + initializer.name, + ) + ) elif tensor == old_parameters["hidden_size"]: - print("initializer type={}".format(initializer.data_type), initializer.name, - old_parameters["hidden_size"], "=>", new_parameters["hidden_size"]) + print( + "initializer type={}".format(initializer.data_type), + initializer.name, + old_parameters["hidden_size"], + "=>", + new_parameters["hidden_size"], + ) initializer.CopyFrom( - numpy_helper.from_array(np.asarray(new_parameters["hidden_size"], dtype=dtype), - initializer.name)) + numpy_helper.from_array( + np.asarray(new_parameters["hidden_size"], dtype=dtype), + initializer.name, + ) + ) elif tensor == 4 * old_parameters["hidden_size"]: - print("initializer type={}".format(initializer.data_type), initializer.name, - 4 * old_parameters["hidden_size"], "=>", 4 * new_parameters["hidden_size"]) + print( + "initializer type={}".format(initializer.data_type), + initializer.name, + 4 * old_parameters["hidden_size"], + "=>", + 4 * new_parameters["hidden_size"], + ) initializer.CopyFrom( - numpy_helper.from_array(np.asarray(4 * new_parameters["hidden_size"], dtype=dtype), - initializer.name)) + numpy_helper.from_array( + np.asarray(4 * new_parameters["hidden_size"], dtype=dtype), + initializer.name, + ) + ) elif tensor == 1.0 / np.sqrt(old_parameters["size_per_head"]): - print("initializer type={}".format(initializer.data_type), initializer.name, - 1.0 / np.sqrt(old_parameters["size_per_head"]), "=>", - 1.0 / np.sqrt(new_parameters["size_per_head"])) + print( + "initializer type={}".format(initializer.data_type), + initializer.name, + 1.0 / np.sqrt(old_parameters["size_per_head"]), + "=>", + 1.0 / np.sqrt(new_parameters["size_per_head"]), + ) initializer.CopyFrom( - numpy_helper.from_array(np.asarray(1.0 / np.sqrt(new_parameters["size_per_head"]), dtype=dtype), - initializer.name)) + numpy_helper.from_array( + np.asarray( + 1.0 / np.sqrt(new_parameters["size_per_head"]), + dtype=dtype, + ), + initializer.name, + ) + ) new_shape = [] shape_changed = False for dim in tensor.shape: - if (dim == old_parameters["hidden_size"]): + if dim == old_parameters["hidden_size"]: new_shape.append(new_parameters["hidden_size"]) shape_changed = True - elif (dim == 4 * old_parameters["hidden_size"]): + elif dim == 4 * old_parameters["hidden_size"]: new_shape.append(4 * new_parameters["hidden_size"]) shape_changed = True - elif (dim in old_parameters["word_dict_size"]): + elif dim in old_parameters["word_dict_size"]: new_shape.append(new_parameters["word_dict_size"]) shape_changed = True - elif (dim == old_parameters["max_word_position"]): + elif dim == old_parameters["max_word_position"]: new_shape.append(new_parameters["max_word_position"]) shape_changed = True else: @@ -190,13 +292,13 @@ def resize_model(self): print("initializer", initializer.name, tensor.shape, "=>", new_shape) for initializer_name in reshapes: - self.replace_input_of_all_nodes(initializer_name, initializer_name + '_resize') + self.replace_input_of_all_nodes(initializer_name, initializer_name + "_resize") tensor = self.resize_weight(initializer_name, reshapes[initializer_name]) self.model.graph.initializer.extend([tensor]) self.use_dynamic_axes() - def use_dynamic_axes(self, dynamic_batch_dim='batch_size', seq_len=7): + def use_dynamic_axes(self, dynamic_batch_dim="batch_size", seq_len=7): """ Update input and output shape to use dynamic axes. """ @@ -214,29 +316,31 @@ def use_dynamic_axes(self, dynamic_batch_dim='batch_size', seq_len=7): dim_proto.dim_value = seq_len -def generate_test_data(onnx_file, - output_path, - batch_size, - sequence_length, - use_cpu=True, - input_tensor_only=False, - dictionary_size=DICT_SIZE, - test_cases=3): +def generate_test_data( + onnx_file, + output_path, + batch_size, + sequence_length, + use_cpu=True, + input_tensor_only=False, + dictionary_size=DICT_SIZE, + test_cases=3, +): input_data_type = np.int32 for test_case in range(test_cases): input_1 = np.random.randint(dictionary_size, size=(batch_size, sequence_length), dtype=input_data_type) - tensor_1 = numpy_helper.from_array(input_1, 'input_ids') + tensor_1 = numpy_helper.from_array(input_1, "input_ids") actual_seq_len = random.randint(sequence_length - 3, sequence_length) input_2 = np.zeros((batch_size, sequence_length), dtype=input_data_type) temp = np.ones((batch_size, actual_seq_len), dtype=input_data_type) - input_2[:temp.shape[0], :temp.shape[1]] = temp - tensor_2 = numpy_helper.from_array(input_2, 'attention_mask') + input_2[: temp.shape[0], : temp.shape[1]] = temp + tensor_2 = numpy_helper.from_array(input_2, "attention_mask") input_3 = np.zeros((batch_size, sequence_length), dtype=input_data_type) - tensor_3 = numpy_helper.from_array(input_3, 'token_type_ids') + tensor_3 = numpy_helper.from_array(input_3, "token_type_ids") - path = os.path.join(output_path, 'test_data_set_' + str(test_case)) + path = os.path.join(output_path, "test_data_set_" + str(test_case)) try: os.mkdir(path) except OSError: @@ -249,58 +353,70 @@ def generate_test_data(onnx_file, sess_options = onnxruntime.SessionOptions() sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_DISABLE_ALL - sess = onnxruntime.InferenceSession(onnx_file, sess_options, providers=['CPUExecutionProvider']) + sess = onnxruntime.InferenceSession(onnx_file, sess_options, providers=["CPUExecutionProvider"]) input1_name = sess.get_inputs()[0].name output_names = [output.name for output in sess.get_outputs()] - inputs = {'input_ids': input_1, 'attention_mask': input_2, 'token_type_ids': input_3} + inputs = { + "input_ids": input_1, + "attention_mask": input_2, + "token_type_ids": input_3, + } print("inputs", inputs) result = sess.run(output_names, inputs) - with open(os.path.join(path, 'input_{}.pb'.format(0)), 'wb') as f: + with open(os.path.join(path, "input_{}.pb".format(0)), "wb") as f: f.write(tensor_1.SerializeToString()) - with open(os.path.join(path, 'input_{}.pb'.format(1)), 'wb') as f: + with open(os.path.join(path, "input_{}.pb".format(1)), "wb") as f: f.write(tensor_2.SerializeToString()) - with open(os.path.join(path, 'input_{}.pb'.format(2)), 'wb') as f: + with open(os.path.join(path, "input_{}.pb".format(2)), "wb") as f: f.write(tensor_3.SerializeToString()) for i, output_name in enumerate(output_names): tensor_result = numpy_helper.from_array( - np.asarray(result[i]).reshape((batch_size, sequence_length)), output_names[i]) - with open(os.path.join(path, 'output_{}.pb'.format(i)), 'wb') as f: + np.asarray(result[i]).reshape((batch_size, sequence_length)), + output_names[i], + ) + with open(os.path.join(path, "output_{}.pb".format(i)), "wb") as f: f.write(tensor_result.SerializeToString()) start_time = timeit.default_timer() sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_EXTENDED - path_prefix = onnx_file[:-5] #remove .onnx suffix + path_prefix = onnx_file[:-5] # remove .onnx suffix if use_cpu: sess_options.optimized_model_filepath = path_prefix + "_optimized_cpu.onnx" else: sess_options.optimized_model_filepath = path_prefix + "_optimized_gpu.onnx" - session = onnxruntime.InferenceSession(onnx_file, sess_options=sess_options, - providers=onnxruntime.get_available_providers()) + session = onnxruntime.InferenceSession( + onnx_file, + sess_options=sess_options, + providers=onnxruntime.get_available_providers(), + ) if use_cpu: - session.set_providers(['CPUExecutionProvider']) # use cpu + session.set_providers(["CPUExecutionProvider"]) # use cpu else: - if 'CUDAExecutionProvider' not in session.get_providers(): + if "CUDAExecutionProvider" not in session.get_providers(): print("Warning: GPU not found") continue outputs = session.run(None, inputs) evalTime = timeit.default_timer() - start_time if outputs[0].tolist() != result[0].tolist(): - print("Error: not same result after optimization. use_cpu={}, no_opt_output={}, opt_output={}".format( - use_cpu, result[0].tolist(), outputs[1].tolist())) + print( + "Error: not same result after optimization. use_cpu={}, no_opt_output={}, opt_output={}".format( + use_cpu, result[0].tolist(), outputs[1].tolist() + ) + ) print("** Evaluation done in total {} secs".format(evalTime)) def main(): parser = argparse.ArgumentParser() - parser.add_argument('--input', required=True, type=str) - parser.add_argument('--output', required=True, type=str) - parser.add_argument('--float16', required=False, action='store_true') + parser.add_argument("--input", required=True, type=str) + parser.add_argument("--output", required=True, type=str) + parser.add_argument("--float16", required=False, action="store_true") parser.set_defaults(float16=False) args = parser.parse_args() diff --git a/onnxruntime/test/python/transformers/test_data/gpt2_pytorch1.5_opset11/generate_tiny_gpt2_model.py b/onnxruntime/test/python/transformers/test_data/gpt2_pytorch1.5_opset11/generate_tiny_gpt2_model.py index bbeabcfd64112..7f613a8674989 100644 --- a/onnxruntime/test/python/transformers/test_data/gpt2_pytorch1.5_opset11/generate_tiny_gpt2_model.py +++ b/onnxruntime/test/python/transformers/test_data/gpt2_pytorch1.5_opset11/generate_tiny_gpt2_model.py @@ -1,22 +1,24 @@ -#------------------------------------------------------------------------- +# ------------------------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. -#-------------------------------------------------------------------------- +# -------------------------------------------------------------------------- # This tool generates a tiny GPT2 model for testing fusion script. # You can use benchmark_gpt2.py to get a gpt2 ONNX model as input of this tool. -import onnx -import onnx.utils -import sys import argparse +import os +import random +import sys +import timeit +from pathlib import Path + import numpy as np +import onnx +import onnx.utils from onnx import ModelProto, TensorProto, numpy_helper from onnxruntime_tools.transformers.onnx_model import OnnxModel -import os + import onnxruntime -import random -from pathlib import Path -import timeit DICT_SIZE = 20 SEQ_LEN = 5 @@ -29,7 +31,7 @@ "num_heads": 12, "size_per_head": 64, "word_dict_size": [50257], # list of supported dictionary size. - "max_word_position": 1024 + "max_word_position": 1024, } # parameters of output tiny model. @@ -39,7 +41,7 @@ "num_heads": 2, "size_per_head": 2, "word_dict_size": DICT_SIZE, - "max_word_position": 8 + "max_word_position": 8, } @@ -54,20 +56,27 @@ def resize_weight(self, initializer_name, target_shape): target_w = w if len(target_shape) == 1: - target_w = w[:target_shape[0]] + target_w = w[: target_shape[0]] elif len(target_shape) == 2: - target_w = w[:target_shape[0], :target_shape[1]] + target_w = w[: target_shape[0], : target_shape[1]] elif len(target_shape) == 3: - target_w = w[:target_shape[0], :target_shape[1], :target_shape[2]] + target_w = w[: target_shape[0], : target_shape[1], : target_shape[2]] elif len(target_shape) == 4: - target_w = w[:target_shape[0], :target_shape[1], :target_shape[2], :target_shape[3]] + target_w = w[ + : target_shape[0], + : target_shape[1], + : target_shape[2], + : target_shape[3], + ] else: print("at most 3 dimensions") - tensor = onnx.helper.make_tensor(name=initializer_name + '_resize', - data_type=TensorProto.FLOAT, - dims=target_shape, - vals=target_w.flatten().tolist()) + tensor = onnx.helper.make_tensor( + name=initializer_name + "_resize", + data_type=TensorProto.FLOAT, + dims=target_shape, + vals=target_w.flatten().tolist(), + ) return tensor @@ -76,7 +85,7 @@ def resize_model(self): initializers = graph.initializer for input in graph.input: - if (input.type.tensor_type.shape.dim[1].dim_value == old_parameters["seq_len"]): + if input.type.tensor_type.shape.dim[1].dim_value == old_parameters["seq_len"]: print("input", input.name, input.type.tensor_type.shape) input.type.tensor_type.shape.dim[1].dim_value = new_parameters["seq_len"] print("=>", input.type.tensor_type.shape) @@ -95,105 +104,228 @@ def resize_model(self): if len(tensor.shape) == 1 and tensor.shape[0] == 1: if tensor == old_parameters["num_heads"]: - print("initializer type={}".format(initializer.data_type), initializer.name, - old_parameters["num_heads"], "=>[", new_parameters["num_heads"], "]") + print( + "initializer type={}".format(initializer.data_type), + initializer.name, + old_parameters["num_heads"], + "=>[", + new_parameters["num_heads"], + "]", + ) initializer.CopyFrom( - numpy_helper.from_array(np.asarray([new_parameters["num_heads"]], dtype=dtype), - initializer.name)) + numpy_helper.from_array( + np.asarray([new_parameters["num_heads"]], dtype=dtype), + initializer.name, + ) + ) elif tensor == old_parameters["seq_len"]: - print("initializer type={}".format(initializer.data_type), initializer.name, - old_parameters["seq_len"], "=>[", new_parameters["seq_len"], "]") + print( + "initializer type={}".format(initializer.data_type), + initializer.name, + old_parameters["seq_len"], + "=>[", + new_parameters["seq_len"], + "]", + ) initializer.CopyFrom( - numpy_helper.from_array(np.asarray([new_parameters["seq_len"]], dtype=dtype), initializer.name)) + numpy_helper.from_array( + np.asarray([new_parameters["seq_len"]], dtype=dtype), + initializer.name, + ) + ) elif tensor == old_parameters["size_per_head"]: - print("initializer type={}".format(initializer.data_type), initializer.name, - old_parameters["size_per_head"], "=>[", new_parameters["size_per_head"], "]") + print( + "initializer type={}".format(initializer.data_type), + initializer.name, + old_parameters["size_per_head"], + "=>[", + new_parameters["size_per_head"], + "]", + ) initializer.CopyFrom( - numpy_helper.from_array(np.asarray([new_parameters["size_per_head"]], dtype=dtype), - initializer.name)) + numpy_helper.from_array( + np.asarray([new_parameters["size_per_head"]], dtype=dtype), + initializer.name, + ) + ) elif tensor == old_parameters["hidden_size"]: - print("initializer type={}".format(initializer.data_type), initializer.name, - old_parameters["hidden_size"], "=>[", new_parameters["hidden_size"], "]") + print( + "initializer type={}".format(initializer.data_type), + initializer.name, + old_parameters["hidden_size"], + "=>[", + new_parameters["hidden_size"], + "]", + ) initializer.CopyFrom( - numpy_helper.from_array(np.asarray([new_parameters["hidden_size"]], dtype=dtype), - initializer.name)) + numpy_helper.from_array( + np.asarray([new_parameters["hidden_size"]], dtype=dtype), + initializer.name, + ) + ) elif tensor == 4 * old_parameters["hidden_size"]: - print("initializer type={}".format(initializer.data_type), initializer.name, - 4 * old_parameters["hidden_size"], "=>[", 4 * new_parameters["hidden_size"], "]") + print( + "initializer type={}".format(initializer.data_type), + initializer.name, + 4 * old_parameters["hidden_size"], + "=>[", + 4 * new_parameters["hidden_size"], + "]", + ) initializer.CopyFrom( - numpy_helper.from_array(np.asarray([4 * new_parameters["hidden_size"]], dtype=dtype), - initializer.name)) + numpy_helper.from_array( + np.asarray([4 * new_parameters["hidden_size"]], dtype=dtype), + initializer.name, + ) + ) elif tensor == 3 * old_parameters["hidden_size"]: - print("initializer type={}".format(initializer.data_type), initializer.name, - 3 * old_parameters["hidden_size"], "=>[", 3 * new_parameters["hidden_size"], "]") + print( + "initializer type={}".format(initializer.data_type), + initializer.name, + 3 * old_parameters["hidden_size"], + "=>[", + 3 * new_parameters["hidden_size"], + "]", + ) initializer.CopyFrom( - numpy_helper.from_array(np.asarray([3 * new_parameters["hidden_size"]], dtype=dtype), - initializer.name)) + numpy_helper.from_array( + np.asarray([3 * new_parameters["hidden_size"]], dtype=dtype), + initializer.name, + ) + ) elif len(tensor.shape) == 0: if tensor == old_parameters["num_heads"]: - print("initializer type={}".format(initializer.data_type), initializer.name, - old_parameters["num_heads"], "=>", new_parameters["num_heads"]) + print( + "initializer type={}".format(initializer.data_type), + initializer.name, + old_parameters["num_heads"], + "=>", + new_parameters["num_heads"], + ) initializer.CopyFrom( - numpy_helper.from_array(np.asarray(new_parameters["num_heads"], dtype=dtype), initializer.name)) + numpy_helper.from_array( + np.asarray(new_parameters["num_heads"], dtype=dtype), + initializer.name, + ) + ) elif tensor == old_parameters["seq_len"]: - print("initializer type={}".format(initializer.data_type), initializer.name, - old_parameters["seq_len"], "=>", new_parameters["seq_len"]) + print( + "initializer type={}".format(initializer.data_type), + initializer.name, + old_parameters["seq_len"], + "=>", + new_parameters["seq_len"], + ) initializer.CopyFrom( - numpy_helper.from_array(np.asarray(new_parameters["seq_len"], dtype=dtype), initializer.name)) + numpy_helper.from_array( + np.asarray(new_parameters["seq_len"], dtype=dtype), + initializer.name, + ) + ) elif tensor == old_parameters["size_per_head"]: - print("initializer type={}".format(initializer.data_type), initializer.name, - old_parameters["size_per_head"], "=>", new_parameters["size_per_head"]) + print( + "initializer type={}".format(initializer.data_type), + initializer.name, + old_parameters["size_per_head"], + "=>", + new_parameters["size_per_head"], + ) initializer.CopyFrom( - numpy_helper.from_array(np.asarray(new_parameters["size_per_head"], dtype=dtype), - initializer.name)) + numpy_helper.from_array( + np.asarray(new_parameters["size_per_head"], dtype=dtype), + initializer.name, + ) + ) elif tensor == old_parameters["hidden_size"]: - print("initializer type={}".format(initializer.data_type), initializer.name, - old_parameters["hidden_size"], "=>", new_parameters["hidden_size"]) + print( + "initializer type={}".format(initializer.data_type), + initializer.name, + old_parameters["hidden_size"], + "=>", + new_parameters["hidden_size"], + ) initializer.CopyFrom( - numpy_helper.from_array(np.asarray(new_parameters["hidden_size"], dtype=dtype), - initializer.name)) + numpy_helper.from_array( + np.asarray(new_parameters["hidden_size"], dtype=dtype), + initializer.name, + ) + ) elif tensor == 4 * old_parameters["hidden_size"]: - print("initializer type={}".format(initializer.data_type), initializer.name, - 4 * old_parameters["hidden_size"], "=>", 4 * new_parameters["hidden_size"]) + print( + "initializer type={}".format(initializer.data_type), + initializer.name, + 4 * old_parameters["hidden_size"], + "=>", + 4 * new_parameters["hidden_size"], + ) initializer.CopyFrom( - numpy_helper.from_array(np.asarray(4 * new_parameters["hidden_size"], dtype=dtype), - initializer.name)) + numpy_helper.from_array( + np.asarray(4 * new_parameters["hidden_size"], dtype=dtype), + initializer.name, + ) + ) elif tensor == 3 * old_parameters["hidden_size"]: - print("initializer type={}".format(initializer.data_type), initializer.name, - 3 * old_parameters["hidden_size"], "=>", 3 * new_parameters["hidden_size"]) + print( + "initializer type={}".format(initializer.data_type), + initializer.name, + 3 * old_parameters["hidden_size"], + "=>", + 3 * new_parameters["hidden_size"], + ) initializer.CopyFrom( - numpy_helper.from_array(np.asarray(3 * new_parameters["hidden_size"], dtype=dtype), - initializer.name)) + numpy_helper.from_array( + np.asarray(3 * new_parameters["hidden_size"], dtype=dtype), + initializer.name, + ) + ) elif tensor == 1.0 / np.sqrt(old_parameters["size_per_head"]): - print("initializer type={}".format(initializer.data_type), initializer.name, - 1.0 / np.sqrt(old_parameters["size_per_head"]), "=>", - 1.0 / np.sqrt(new_parameters["size_per_head"])) + print( + "initializer type={}".format(initializer.data_type), + initializer.name, + 1.0 / np.sqrt(old_parameters["size_per_head"]), + "=>", + 1.0 / np.sqrt(new_parameters["size_per_head"]), + ) initializer.CopyFrom( - numpy_helper.from_array(np.asarray(1.0 / np.sqrt(new_parameters["size_per_head"]), dtype=dtype), - initializer.name)) + numpy_helper.from_array( + np.asarray( + 1.0 / np.sqrt(new_parameters["size_per_head"]), + dtype=dtype, + ), + initializer.name, + ) + ) elif tensor == np.sqrt(old_parameters["size_per_head"]): - print("initializer type={}".format(initializer.data_type), initializer.name, - np.sqrt(old_parameters["size_per_head"]), "=>", np.sqrt(new_parameters["size_per_head"])) + print( + "initializer type={}".format(initializer.data_type), + initializer.name, + np.sqrt(old_parameters["size_per_head"]), + "=>", + np.sqrt(new_parameters["size_per_head"]), + ) initializer.CopyFrom( - numpy_helper.from_array(np.asarray(np.sqrt(new_parameters["size_per_head"]), dtype=dtype), - initializer.name)) + numpy_helper.from_array( + np.asarray(np.sqrt(new_parameters["size_per_head"]), dtype=dtype), + initializer.name, + ) + ) new_shape = [] shape_changed = False for dim in tensor.shape: - if (dim == old_parameters["hidden_size"]): + if dim == old_parameters["hidden_size"]: new_shape.append(new_parameters["hidden_size"]) shape_changed = True - elif (dim == 4 * old_parameters["hidden_size"]): + elif dim == 4 * old_parameters["hidden_size"]: new_shape.append(4 * new_parameters["hidden_size"]) shape_changed = True - elif (dim == 3 * old_parameters["hidden_size"]): + elif dim == 3 * old_parameters["hidden_size"]: new_shape.append(3 * new_parameters["hidden_size"]) shape_changed = True - elif (dim in old_parameters["word_dict_size"]): + elif dim in old_parameters["word_dict_size"]: new_shape.append(new_parameters["word_dict_size"]) shape_changed = True - elif (dim == old_parameters["max_word_position"]): + elif dim == old_parameters["max_word_position"]: new_shape.append(new_parameters["max_word_position"]) shape_changed = True else: @@ -203,7 +335,7 @@ def resize_model(self): print("initializer", initializer.name, tensor.shape, "=>", new_shape) for initializer_name in reshapes: - self.replace_input_of_all_nodes(initializer_name, initializer_name + '_resize') + self.replace_input_of_all_nodes(initializer_name, initializer_name + "_resize") tensor = self.resize_weight(initializer_name, reshapes[initializer_name]) self.model.graph.initializer.extend([tensor]) @@ -213,45 +345,73 @@ def resize_model(self): for i, node in enumerate(graph.node): if node.op_type == "Split": nodes_to_add.append( - onnx.helper.make_node('Split', - node.input, - node.output, - name="Split_{}".format(i), - axis=2, - split=[ - new_parameters["hidden_size"], new_parameters["hidden_size"], - new_parameters["hidden_size"] - ])) + onnx.helper.make_node( + "Split", + node.input, + node.output, + name="Split_{}".format(i), + axis=2, + split=[ + new_parameters["hidden_size"], + new_parameters["hidden_size"], + new_parameters["hidden_size"], + ], + ) + ) nodes_to_remove.append(node) - print("update split", - [new_parameters["hidden_size"], new_parameters["hidden_size"], new_parameters["hidden_size"]]) + print( + "update split", + [ + new_parameters["hidden_size"], + new_parameters["hidden_size"], + new_parameters["hidden_size"], + ], + ) if node.op_type == "Constant": for att in node.attribute: - if att.name == 'value': + if att.name == "value": if numpy_helper.to_array(att.t) == old_parameters["num_heads"]: nodes_to_add.append( - onnx.helper.make_node('Constant', - inputs=node.input, - outputs=node.output, - value=onnx.helper.make_tensor(name=att.t.name, - data_type=TensorProto.INT64, - dims=[], - vals=[new_parameters["num_heads"] - ]))) - print("constant", att.t.name, old_parameters["num_heads"], "=>", - new_parameters["num_heads"]) + onnx.helper.make_node( + "Constant", + inputs=node.input, + outputs=node.output, + value=onnx.helper.make_tensor( + name=att.t.name, + data_type=TensorProto.INT64, + dims=[], + vals=[new_parameters["num_heads"]], + ), + ) + ) + print( + "constant", + att.t.name, + old_parameters["num_heads"], + "=>", + new_parameters["num_heads"], + ) if numpy_helper.to_array(att.t) == np.sqrt(old_parameters["size_per_head"]): nodes_to_add.append( - onnx.helper.make_node('Constant', - inputs=node.input, - outputs=node.output, - value=onnx.helper.make_tensor( - name=att.t.name, - data_type=TensorProto.FLOAT, - dims=[], - vals=[np.sqrt(new_parameters["size_per_head"])]))) - print("constant", att.t.name, np.sqrt(old_parameters["size_per_head"]), "=>", - np.sqrt(new_parameters["size_per_head"])) + onnx.helper.make_node( + "Constant", + inputs=node.input, + outputs=node.output, + value=onnx.helper.make_tensor( + name=att.t.name, + data_type=TensorProto.FLOAT, + dims=[], + vals=[np.sqrt(new_parameters["size_per_head"])], + ), + ) + ) + print( + "constant", + att.t.name, + np.sqrt(old_parameters["size_per_head"]), + "=>", + np.sqrt(new_parameters["size_per_head"]), + ) else: node.name = node.op_type + "_" + str(i) for node in nodes_to_remove: @@ -276,21 +436,23 @@ def resize_model(self): dim_proto.dim_value = new_parameters["size_per_head"] -def generate_test_data(onnx_file, - output_path, - batch_size=1, - use_cpu=True, - input_tensor_only=False, - dictionary_size=DICT_SIZE, - test_cases=1, - output_optimized_model=False): +def generate_test_data( + onnx_file, + output_path, + batch_size=1, + use_cpu=True, + input_tensor_only=False, + dictionary_size=DICT_SIZE, + test_cases=1, + output_optimized_model=False, +): for test_case in range(test_cases): sequence_length = 3 input_1 = np.random.randint(dictionary_size, size=(batch_size, 1), dtype=np.int64) - tensor_1 = numpy_helper.from_array(input_1, 'input_ids') + tensor_1 = numpy_helper.from_array(input_1, "input_ids") - path = os.path.join(output_path, 'test_data_set_' + str(test_case)) + path = os.path.join(output_path, "test_data_set_" + str(test_case)) try: os.mkdir(path) except OSError: @@ -300,23 +462,28 @@ def generate_test_data(onnx_file, sess_options = onnxruntime.SessionOptions() sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_DISABLE_ALL - sess = onnxruntime.InferenceSession(onnx_file, sess_options, providers=['CPUExecutionProvider']) + sess = onnxruntime.InferenceSession(onnx_file, sess_options, providers=["CPUExecutionProvider"]) input1_name = sess.get_inputs()[0].name output_names = [output.name for output in sess.get_outputs()] inputs = {input1_name: input_1} - with open(os.path.join(path, 'input_{}.pb'.format(0)), 'wb') as f: + with open(os.path.join(path, "input_{}.pb".format(0)), "wb") as f: f.write(tensor_1.SerializeToString()) for i in range(12): input_name = f"past_{i}" - input = np.random.rand(2, batch_size, new_parameters["num_heads"], sequence_length, - new_parameters["size_per_head"]).astype(np.float32) + input = np.random.rand( + 2, + batch_size, + new_parameters["num_heads"], + sequence_length, + new_parameters["size_per_head"], + ).astype(np.float32) tensor = numpy_helper.from_array(input, input_name) inputs.update({input_name: input}) - with open(os.path.join(path, 'input_{}.pb'.format(1 + i)), 'wb') as f: + with open(os.path.join(path, "input_{}.pb".format(1 + i)), "wb") as f: f.write(tensor.SerializeToString()) if input_tensor_only: @@ -329,9 +496,9 @@ def generate_test_data(onnx_file, def main(): parser = argparse.ArgumentParser() - parser.add_argument('--input', required=True, type=str) - parser.add_argument('--output', required=True, type=str) - parser.add_argument('--output_optimized_model', required=False, action='store_true') + parser.add_argument("--input", required=True, type=str) + parser.add_argument("--output", required=True, type=str) + parser.add_argument("--output_optimized_model", required=False, action="store_true") parser.set_defaults(output_optimized_model=False) args = parser.parse_args() @@ -352,11 +519,13 @@ def main(): p = Path(args.output) data_path = p.parent - generate_test_data(args.output, - data_path, - batch_size=1, - use_cpu=True, - output_optimized_model=args.output_optimized_model) + generate_test_data( + args.output, + data_path, + batch_size=1, + use_cpu=True, + output_optimized_model=args.output_optimized_model, + ) if __name__ == "__main__": diff --git a/onnxruntime/test/python/transformers/test_gelu_fusions.py b/onnxruntime/test/python/transformers/test_gelu_fusions.py index 6aa591baedb0d..feaba40ad88b1 100644 --- a/onnxruntime/test/python/transformers/test_gelu_fusions.py +++ b/onnxruntime/test/python/transformers/test_gelu_fusions.py @@ -1,9 +1,10 @@ +import math import os import unittest -import math -import torch +import torch from parity_utilities import find_transformers_source + if find_transformers_source(): from optimizer import optimize_model else: @@ -32,8 +33,12 @@ def forward(self, x): return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * x * (1.0 + 0.044715 * x * x))) -test_cases = [('huggingface', 'Gelu', HuggingfaceGelu), ('huggingface', 'FastGelu', HuggingfaceFastGelu), - ('megatron', 'Gelu', MegatronGelu), ('megatron', 'FastGelu', MegatronFastGelu)] +test_cases = [ + ("huggingface", "Gelu", HuggingfaceGelu), + ("huggingface", "FastGelu", HuggingfaceFastGelu), + ("megatron", "Gelu", MegatronGelu), + ("megatron", "FastGelu", MegatronFastGelu), +] class TestGeluFusions(unittest.TestCase): @@ -52,13 +57,19 @@ def test_fusions(self): dummy_input = torch.ones(3, dtype=torch.float32) test_name = f"{operator}_{source}" onnx_path = f"{test_name}.onnx" - torch.onnx.export(model, (dummy_input), onnx_path, input_names=['input'], output_names=['output']) - optimizer = optimize_model(onnx_path, 'bert') + torch.onnx.export( + model, + (dummy_input), + onnx_path, + input_names=["input"], + output_names=["output"], + ) + optimizer = optimize_model(onnx_path, "bert") # optimizer.save_model_to_file(f"{operator}_{source}_opt.onnx") os.remove(onnx_path) expected_node_count = {operator: 1} self.verify_node_count(optimizer, expected_node_count, test_name) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/onnxruntime/test/python/transformers/test_gpt2.py b/onnxruntime/test/python/transformers/test_gpt2.py index 449393c5e998d..974c1f5a9765e 100644 --- a/onnxruntime/test/python/transformers/test_gpt2.py +++ b/onnxruntime/test/python/transformers/test_gpt2.py @@ -6,25 +6,25 @@ # license information. # -------------------------------------------------------------------------- -import unittest -import os import logging +import os +import unittest + import coloredlogs import pytest - from parity_utilities import find_transformers_source -if find_transformers_source(sub_dir_paths=['models', 'gpt2']): - from benchmark_gpt2 import parse_arguments, main +if find_transformers_source(sub_dir_paths=["models", "gpt2"]): + from benchmark_gpt2 import main, parse_arguments else: - from onnxruntime.transformers.models.gpt2.benchmark_gpt2 import parse_arguments, main + from onnxruntime.transformers.models.gpt2.benchmark_gpt2 import main, parse_arguments class TestGpt2(unittest.TestCase): - def setUp(self): from onnxruntime import get_available_providers - self.test_cuda = 'CUDAExecutionProvider' in get_available_providers() + + self.test_cuda = "CUDAExecutionProvider" in get_available_providers() def run_benchmark_gpt2(self, arguments: str): args = parse_arguments(arguments.split()) @@ -34,21 +34,22 @@ def run_benchmark_gpt2(self, arguments: str): @pytest.mark.slow def test_gpt2_fp32(self): - self.run_benchmark_gpt2('-m gpt2 --precision fp32 -v -b 1 --sequence_lengths 2 -s 3') + self.run_benchmark_gpt2("-m gpt2 --precision fp32 -v -b 1 --sequence_lengths 2 -s 3") @pytest.mark.slow def test_gpt2_fp16(self): if self.test_cuda: - self.run_benchmark_gpt2('-m gpt2 --precision fp16 -o -b 1 --sequence_lengths 2 -s 3 --use_gpu') + self.run_benchmark_gpt2("-m gpt2 --precision fp16 -o -b 1 --sequence_lengths 2 -s 3 --use_gpu") @pytest.mark.slow def test_gpt2_int8(self): - self.run_benchmark_gpt2('-m gpt2 --precision int8 -o -b 1 --sequence_lengths 2 -s 3') + self.run_benchmark_gpt2("-m gpt2 --precision int8 -o -b 1 --sequence_lengths 2 -s 3") @pytest.mark.slow def test_gpt2_beam_search_step_fp32(self): self.run_benchmark_gpt2( - '-m gpt2 --model_class=GPT2LMHeadModel_BeamSearchStep --precision fp32 -v -b 1 --sequence_lengths 5 -s 3') + "-m gpt2 --model_class=GPT2LMHeadModel_BeamSearchStep --precision fp32 -v -b 1 --sequence_lengths 5 -s 3" + ) # @pytest.mark.slow # def test_gpt2_beam_search_step_fp16(self): @@ -59,29 +60,30 @@ def test_gpt2_beam_search_step_fp32(self): @pytest.mark.slow def test_gpt2_beam_search_step_int8(self): self.run_benchmark_gpt2( - '-m gpt2 --model_class=GPT2LMHeadModel_BeamSearchStep --precision int8 -o -b 1 --sequence_lengths 5 -s 3') + "-m gpt2 --model_class=GPT2LMHeadModel_BeamSearchStep --precision int8 -o -b 1 --sequence_lengths 5 -s 3" + ) @pytest.mark.slow def test_gpt2_configurable_one_step_search_fp32(self): self.run_benchmark_gpt2( - '-m gpt2 --model_class=GPT2LMHeadModel_ConfigurableOneStepSearch --precision fp32 -v -b 1 --sequence_lengths 5 --past_sequence_lengths 3 --use_gpu' + "-m gpt2 --model_class=GPT2LMHeadModel_ConfigurableOneStepSearch --precision fp32 -v -b 1 --sequence_lengths 5 --past_sequence_lengths 3 --use_gpu" ) @pytest.mark.slow def test_gpt2_configurable_one_step_search_fp16(self): if self.test_cuda: self.run_benchmark_gpt2( - '-m gpt2 --model_class=GPT2LMHeadModel_ConfigurableOneStepSearch --precision fp16 -o -b 1 --sequence_lengths 5 -s 3 --use_gpu' + "-m gpt2 --model_class=GPT2LMHeadModel_ConfigurableOneStepSearch --precision fp16 -o -b 1 --sequence_lengths 5 -s 3 --use_gpu" ) @pytest.mark.slow def test_gpt2_configurable_one_step_search_int8(self): self.run_benchmark_gpt2( - '-m gpt2 --model_class=GPT2LMHeadModel_ConfigurableOneStepSearch --precision int8 -o -b 1 --sequence_lengths 5 -s 3' + "-m gpt2 --model_class=GPT2LMHeadModel_ConfigurableOneStepSearch --precision int8 -o -b 1 --sequence_lengths 5 -s 3" ) -if __name__ == '__main__': - coloredlogs.install(fmt='%(message)s') +if __name__ == "__main__": + coloredlogs.install(fmt="%(message)s") logging.getLogger("transformers").setLevel(logging.ERROR) unittest.main() diff --git a/onnxruntime/test/python/transformers/test_optimizer.py b/onnxruntime/test/python/transformers/test_optimizer.py index b4d944caf998f..979158f512723 100644 --- a/onnxruntime/test/python/transformers/test_optimizer.py +++ b/onnxruntime/test/python/transformers/test_optimizer.py @@ -8,33 +8,40 @@ # For live logging, use the command: pytest -o log_cli=true --log-cli-level=DEBUG -import unittest import os +import unittest + import pytest +from model_loader import get_fusion_test_model, get_test_data_path from onnx import TensorProto, load_model -from model_loader import get_test_data_path, get_fusion_test_model - from parity_utilities import find_transformers_source + if find_transformers_source(): - from optimizer import optimize_model - from onnx_model import OnnxModel - from onnx_exporter import export_onnx_model_from_tf, export_onnx_model_from_pt + from benchmark_helper import OptimizerInfo, Precision from huggingface_models import MODELS - from benchmark_helper import Precision, OptimizerInfo + from onnx_exporter import export_onnx_model_from_pt, export_onnx_model_from_tf + from onnx_model import OnnxModel + from optimizer import optimize_model else: - from onnxruntime.transformers.optimizer import optimize_model - from onnxruntime.transformers.onnx_model import OnnxModel - from onnxruntime.transformers.onnx_exporter import export_onnx_model_from_tf, export_onnx_model_from_pt + from onnxruntime.transformers.benchmark_helper import OptimizerInfo, Precision from onnxruntime.transformers.huggingface_models import MODELS - from onnxruntime.transformers.benchmark_helper import Precision, OptimizerInfo + from onnxruntime.transformers.onnx_exporter import export_onnx_model_from_pt, export_onnx_model_from_tf + from onnxruntime.transformers.onnx_model import OnnxModel + from onnxruntime.transformers.optimizer import optimize_model BERT_TEST_MODELS = { - "bert_keras_0": ('models', 'TFBertForSequenceClassification_1.onnx'), # bert_mrpc_tensorflow2.1_opset10 - "bert_keras_squad": ('models', 'TFBertForQuestionAnswering.onnx'), # bert_squad_tensorflow2.1_keras2onnx_opset11 - "gpt2_past": ('models', 'gpt2_past.onnx'), # gpt2_pytorch1.5_opset11 - "gpt2_past_mask": ('FUSION', 'gpt2_past_mask_one_layer.onnx'), - "multiple_embed": ('FUSION', 'embed_layer_norm_multiple.onnx'), - "bert_tf2onnx_0": ('models', 'bert_tf2onnx_0.onnx') + "bert_keras_0": ( + "models", + "TFBertForSequenceClassification_1.onnx", + ), # bert_mrpc_tensorflow2.1_opset10 + "bert_keras_squad": ( + "models", + "TFBertForQuestionAnswering.onnx", + ), # bert_squad_tensorflow2.1_keras2onnx_opset11 + "gpt2_past": ("models", "gpt2_past.onnx"), # gpt2_pytorch1.5_opset11 + "gpt2_past_mask": ("FUSION", "gpt2_past_mask_one_layer.onnx"), + "multiple_embed": ("FUSION", "embed_layer_norm_multiple.onnx"), + "bert_tf2onnx_0": ("models", "bert_tf2onnx_0.onnx"), } @@ -57,15 +64,18 @@ def verify_node_count(self, bert_model, expected_node_count, test_name): self.assertEqual(len(bert_model.get_nodes_by_op_type(op_type)), count) # add test function for huggingface pytorch model - def _test_optimizer_on_huggingface_model(self, - model_name, - expected_fusion_result_list, - inputs_count=1, - validate_model=True): + def _test_optimizer_on_huggingface_model( + self, + model_name, + expected_fusion_result_list, + inputs_count=1, + validate_model=True, + ): # Remove cached model so that CI machine will have space import shutil - shutil.rmtree('./cache_models', ignore_errors=True) - shutil.rmtree('./onnx_models', ignore_errors=True) + + shutil.rmtree("./cache_models", ignore_errors=True) + shutil.rmtree("./onnx_models", ignore_errors=True) # expect fusion result list have the following keys # EmbedLayerNormalization, Attention, Gelu, FastGelu, BiasGelu, LayerNormalization, SkipLayerNormalization model_fusion_statistics = {} @@ -73,13 +83,25 @@ def _test_optimizer_on_huggingface_model(self, input_names = MODELS[model_name][0] import torch + with torch.no_grad(): - _, is_valid_onnx_model, _, _ = export_onnx_model_from_pt(model_name, MODELS[model_name][1], - MODELS[model_name][2], MODELS[model_name][3], None, - './cache_models', './onnx_models', - input_names[:inputs_count], False, - Precision.FLOAT32, OptimizerInfo.BYSCRIPT, True, True, True, - model_fusion_statistics) + _, is_valid_onnx_model, _, _ = export_onnx_model_from_pt( + model_name, + MODELS[model_name][1], + MODELS[model_name][2], + MODELS[model_name][3], + None, + "./cache_models", + "./onnx_models", + input_names[:inputs_count], + False, + Precision.FLOAT32, + OptimizerInfo.BYSCRIPT, + True, + True, + True, + model_fusion_statistics, + ) onnx_model = list(model_fusion_statistics.keys())[0] fusion_result_list = list(model_fusion_statistics[onnx_model].values()) @@ -91,8 +113,9 @@ def _test_optimizer_on_huggingface_model(self, def _test_optimizer_on_tf_model(self, model_name, expected_fusion_result_list, inputs_count, validate_model=True): # Remove cached model so that CI machine will have space import shutil - shutil.rmtree('./cache_models', ignore_errors=True) - shutil.rmtree('./onnx_models', ignore_errors=True) + + shutil.rmtree("./cache_models", ignore_errors=True) + shutil.rmtree("./onnx_models", ignore_errors=True) # expect fusion result list have the following keys # EmbedLayerNormalization, Attention, Gelu, FastGelu, BiasGelu, LayerNormalization, SkipLayerNormalization @@ -102,13 +125,25 @@ def _test_optimizer_on_tf_model(self, model_name, expected_fusion_result_list, i input_names = MODELS[model_name][0] import torch + with torch.no_grad(): - _, is_valid_onnx_model, _, _ = export_onnx_model_from_tf(model_name, MODELS[model_name][1], - MODELS[model_name][2], MODELS[model_name][3], None, - './cache_models', './onnx_models', - input_names[:inputs_count], False, - Precision.FLOAT32, True, True, True, True, - model_fusion_statistics) + _, is_valid_onnx_model, _, _ = export_onnx_model_from_tf( + model_name, + MODELS[model_name][1], + MODELS[model_name][2], + MODELS[model_name][3], + None, + "./cache_models", + "./onnx_models", + input_names[:inputs_count], + False, + Precision.FLOAT32, + True, + True, + True, + True, + model_fusion_statistics, + ) onnx_model = list(model_fusion_statistics.keys())[0] fusion_result_list = list(model_fusion_statistics[onnx_model].values()) @@ -143,22 +178,22 @@ def _test_optimizer_on_tf_model(self, model_name, expected_fusion_result_list, i # self.assertTrue(bert_model.is_fully_optimized()) def test_gpt2_past(self): - input = _get_test_model_path('gpt2_past') - model = optimize_model(input, 'gpt2', num_heads=2, hidden_size=4) + input = _get_test_model_path("gpt2_past") + model = optimize_model(input, "gpt2", num_heads=2, hidden_size=4) expected_node_count = { - 'EmbedLayerNormalization': 0, - 'Attention': 12, - 'Gelu': 0, - 'FastGelu': 12, - 'BiasGelu': 0, - 'LayerNormalization': 25, - 'SkipLayerNormalization': 0 + "EmbedLayerNormalization": 0, + "Attention": 12, + "Gelu": 0, + "FastGelu": 12, + "BiasGelu": 0, + "LayerNormalization": 25, + "SkipLayerNormalization": 0, } - self.verify_node_count(model, expected_node_count, 'test_gpt2_past') + self.verify_node_count(model, expected_node_count, "test_gpt2_past") def test_gpt2_past_fp16(self): - input_model_path = _get_test_model_path('gpt2_past') + input_model_path = _get_test_model_path("gpt2_past") model = OnnxModel(load_model(input_model_path, format=None, load_external_data=True)) model.convert_float_to_float16(keep_io_types=False, use_symbolic_shape_infer=False) for input in model.graph().input[1:]: @@ -167,45 +202,49 @@ def test_gpt2_past_fp16(self): self.assertEqual(output.type.tensor_type.elem_type, TensorProto.FLOAT16) def test_gpt2_past_mask(self): - input = _get_test_model_path('gpt2_past_mask') - model = optimize_model(input, 'gpt2', num_heads=2, hidden_size=4) + input = _get_test_model_path("gpt2_past_mask") + model = optimize_model(input, "gpt2", num_heads=2, hidden_size=4) expected_node_count = { - 'EmbedLayerNormalization': 1, - 'Attention': 1, - 'Gelu': 0, - 'FastGelu': 1, - 'BiasGelu': 0, - 'LayerNormalization': 1, - 'SkipLayerNormalization': 0 + "EmbedLayerNormalization": 1, + "Attention": 1, + "Gelu": 0, + "FastGelu": 1, + "BiasGelu": 0, + "LayerNormalization": 1, + "SkipLayerNormalization": 0, } - self.verify_node_count(model, expected_node_count, 'test_gpt2_past_mask') + self.verify_node_count(model, expected_node_count, "test_gpt2_past_mask") def test_multiple_embed(self): - input_model_path = _get_test_model_path('multiple_embed') - model = optimize_model(input_model_path, 'bert', num_heads=2, hidden_size=4) + input_model_path = _get_test_model_path("multiple_embed") + model = optimize_model(input_model_path, "bert", num_heads=2, hidden_size=4) expected_node_count = { - 'EmbedLayerNormalization': 2, - 'Attention': 2, - 'Gelu': 0, - 'FastGelu': 0, - 'BiasGelu': 0, - 'LayerNormalization': 0, - 'SkipLayerNormalization': 0 + "EmbedLayerNormalization": 2, + "Attention": 2, + "Gelu": 0, + "FastGelu": 0, + "BiasGelu": 0, + "LayerNormalization": 0, + "SkipLayerNormalization": 0, } - self.verify_node_count(model, expected_node_count, 'test_multiple_embed') + self.verify_node_count(model, expected_node_count, "test_multiple_embed") def test_embed_layer_norm_fusion(self): onnx_files = [] for i in [3, 8, 9]: onnx_files.append(f"embed_layer_norm_format{i}.onnx") onnx_files.append(f"embed_layer_norm_format{i}_opset13.onnx") - onnx_files.append('embed_layer_norm_format3_no_cast.onnx') - onnx_files.append('embed_layer_norm_format3_no_cast_opset13.onnx') + onnx_files.append("embed_layer_norm_format3_no_cast.onnx") + onnx_files.append("embed_layer_norm_format3_no_cast_opset13.onnx") for file in onnx_files: input_model_path = get_fusion_test_model(file) - model = optimize_model(input_model_path, 'bert') - expected_node_count = {'EmbedLayerNormalization': 1, 'Attention': 1, 'ReduceSum': 0} + model = optimize_model(input_model_path, "bert") + expected_node_count = { + "EmbedLayerNormalization": 1, + "Attention": 1, + "ReduceSum": 0, + } self.verify_node_count(model, expected_node_count, file) # def test_bert_tf2onnx_0(self): @@ -275,10 +314,16 @@ def test_huggingface_xlmroberta_fusion(self): @pytest.mark.slow def test_huggingface_flaubert_fusion(self): # output not close issue - self._test_optimizer_on_huggingface_model("flaubert/flaubert_base_cased", [0, 12, 0, 0, 12, 0, 25], - validate_model=False) - self._test_optimizer_on_huggingface_model("flaubert/flaubert_small_cased", [0, 6, 0, 0, 6, 12, 1], - validate_model=False) + self._test_optimizer_on_huggingface_model( + "flaubert/flaubert_base_cased", + [0, 12, 0, 0, 12, 0, 25], + validate_model=False, + ) + self._test_optimizer_on_huggingface_model( + "flaubert/flaubert_small_cased", + [0, 6, 0, 0, 6, 12, 1], + validate_model=False, + ) # @pytest.mark.slow # def test_huggingface_dialogpt_fusion(self): @@ -325,5 +370,5 @@ def test_huggingface_xlm_from_tf2onnx(self): self._test_optimizer_on_tf_model("xlm-mlm-ende-1024", [0, 0, 0, 0, 0, 1, 12], 1, validate_model=False) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/onnxruntime/test/python/transformers/test_parity_decoder_attention.py b/onnxruntime/test/python/transformers/test_parity_decoder_attention.py index 79934fbd944c0..6c05e321f7618 100644 --- a/onnxruntime/test/python/transformers/test_parity_decoder_attention.py +++ b/onnxruntime/test/python/transformers/test_parity_decoder_attention.py @@ -11,12 +11,13 @@ # ------------------------------------------------------------------------- import math +import os +from typing import Dict, List, Optional, Tuple + import numpy import torch from torch import Tensor, nn from torch.nn import functional as F -from typing import Dict, List, Optional, Tuple -import os torch.manual_seed(0) @@ -83,6 +84,7 @@ def my_bart_attention_forward( return attn_output, None, layer_state """ + class Config: batch_size = 0 sequence_length = 0 @@ -99,6 +101,7 @@ def __init__(self, b, s, s2, n, h): self.head_size = h self.embed_dim = self.num_heads * self.head_size + class AttentionProjection(nn.Module): def __init__(self, num_heads, head_dim, embed_dim, bias=True): super().__init__() @@ -148,6 +151,7 @@ def forward( return k, v + class AttentionForONNX(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" @@ -165,7 +169,7 @@ def __init__( self.dropout = dropout self.head_dim = embed_dim // num_heads assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads" - self.scaling = self.head_dim ** -0.5 + self.scaling = self.head_dim**-0.5 self.encoder_decoder_attention = encoder_decoder_attention self.k_v_proj = torch.jit.script(AttentionProjection(num_heads, self.head_dim, embed_dim, bias)) @@ -183,9 +187,9 @@ def forward( key_padding_mask: Optional[Tensor] = None, layer_state: Optional[List[Tensor]] = None, attn_mask: Optional[Tensor] = None, - output_attentions: bool=False, + output_attentions: bool = False, use_past=torch.tensor(False), - has_key_padding_mask: bool=False + has_key_padding_mask: bool = False, ) -> Tuple[Tensor, Optional[Tensor]]: """Input shape: Time(SeqLen) x Batch x Channel""" static_kv: bool = self.encoder_decoder_attention @@ -198,7 +202,12 @@ def forward( # Update cache if layer_state is not None: - cached_shape = (bsz, self.num_heads, -1, self.head_dim) # bsz must be first for reorder_cache + cached_shape = ( + bsz, + self.num_heads, + -1, + self.head_dim, + ) # bsz must be first for reorder_cache if static_kv: # cross-attn new_key_cache = k.view(*cached_shape) @@ -237,9 +246,9 @@ def ORT_forward( key_padding_mask: Optional[Tensor] = None, layer_state: Optional[List[Tensor]] = None, attn_mask: Optional[Tensor] = None, - output_attentions: bool=False, + output_attentions: bool = False, use_past=torch.tensor(False), - has_key_padding_mask: bool=False + has_key_padding_mask: bool = False, ) -> Tuple[Tensor, Optional[Tensor]]: """Input shape: Time(SeqLen) x Batch x Channel""" # For readability @@ -247,16 +256,36 @@ def ORT_forward( has_layer_state = True if layer_state is not None else False use_past_cache = True if use_past else False - q_weight = self.q_proj.weight.transpose(0,1) + q_weight = self.q_proj.weight.transpose(0, 1) q_weight = q_weight.reshape(self.embed_dim, self.embed_dim) - kv_weight = torch.stack((self.k_v_proj.k_proj.weight.transpose(0,1), self.k_v_proj.v_proj.weight.transpose(0,1)), dim=1) + kv_weight = torch.stack( + ( + self.k_v_proj.k_proj.weight.transpose(0, 1), + self.k_v_proj.v_proj.weight.transpose(0, 1), + ), + dim=1, + ) kv_weight = kv_weight.reshape(self.embed_dim, 2 * self.embed_dim) - bias = torch.stack((self.q_proj.bias, self.k_v_proj.k_proj.bias, self.k_v_proj.v_proj.bias), dim=0) + bias = torch.stack( + (self.q_proj.bias, self.k_v_proj.k_proj.bias, self.k_v_proj.v_proj.bias), + dim=0, + ) bias = bias.reshape(3 * self.embed_dim) - onnx_model_str = create_decoder_attention_graph(query, key, q_weight, kv_weight, bias, self.num_heads, static_kv, use_past_cache, has_layer_state, has_key_padding_mask) + onnx_model_str = create_decoder_attention_graph( + query, + key, + q_weight, + kv_weight, + bias, + self.num_heads, + static_kv, + use_past_cache, + has_layer_state, + has_key_padding_mask, + ) self_p_k, self_p_v, enc_dec_p_k, enc_dec_p_v = layer_state if self.encoder_decoder_attention: @@ -265,16 +294,17 @@ def ORT_forward( key_cache, value_cache = self_p_k, self_p_v ort_inputs = { - 'query': numpy.ascontiguousarray(query.cpu().numpy()), - 'key': numpy.ascontiguousarray(key.cpu().numpy()), - 'key_padding_mask': numpy.ascontiguousarray(key_padding_mask.cpu().numpy()), - 'key_cache': numpy.ascontiguousarray(key_cache.detach().cpu().numpy()), - 'value_cache': numpy.ascontiguousarray(value_cache.detach().cpu().numpy()) + "query": numpy.ascontiguousarray(query.cpu().numpy()), + "key": numpy.ascontiguousarray(key.cpu().numpy()), + "key_padding_mask": numpy.ascontiguousarray(key_padding_mask.cpu().numpy()), + "key_cache": numpy.ascontiguousarray(key_cache.detach().cpu().numpy()), + "value_cache": numpy.ascontiguousarray(value_cache.detach().cpu().numpy()), } - from onnxruntime import SessionOptions, InferenceSession + from onnxruntime import InferenceSession, SessionOptions + sess_options = SessionOptions() - ort_session = InferenceSession(onnx_model_str, sess_options, providers=['CUDAExecutionProvider']) + ort_session = InferenceSession(onnx_model_str, sess_options, providers=["CUDAExecutionProvider"]) ort_output = ort_session.run(None, ort_inputs) output, new_key_cache, new_value_cache = ort_output @@ -284,8 +314,19 @@ def ORT_forward( return attn_output, torch.tensor(new_key_cache), torch.tensor(new_value_cache) -def create_decoder_attention_graph(query, key, q_weight, kv_weight, bias, num_heads_, static_kv, use_past, has_layer_state, has_key_padding_mask): - from onnx import helper, TensorProto +def create_decoder_attention_graph( + query, + key, + q_weight, + kv_weight, + bias, + num_heads_, + static_kv, + use_past, + has_layer_state, + has_key_padding_mask, +): + from onnx import TensorProto, helper S, B, NH = query.size() S2 = key.size()[0] @@ -293,60 +334,77 @@ def create_decoder_attention_graph(query, key, q_weight, kv_weight, bias, num_he H = int(NH / N) nodes = [ - helper.make_node("DecoderAttention", - ["query", "key", "q_weight", "kv_weight", "bias", "key_padding_mask", "key_cache", "value_cache", "static_kv", "use_past", "has_layer_state", "has_key_padding_mask"], - ["output", "new_key_cache", "new_value_cache"], - "DecoderAttention_0", - num_heads=num_heads_, - domain="com.microsoft"), + helper.make_node( + "DecoderAttention", + [ + "query", + "key", + "q_weight", + "kv_weight", + "bias", + "key_padding_mask", + "key_cache", + "value_cache", + "static_kv", + "use_past", + "has_layer_state", + "has_key_padding_mask", + ], + ["output", "new_key_cache", "new_value_cache"], + "DecoderAttention_0", + num_heads=num_heads_, + domain="com.microsoft", + ), ] initializers = [ - helper.make_tensor('q_weight', TensorProto.FLOAT, [NH, NH], - q_weight.flatten().tolist()), - helper.make_tensor('kv_weight', TensorProto.FLOAT, [NH, 2 * NH], - kv_weight.flatten().tolist()), - helper.make_tensor('bias', TensorProto.FLOAT, [3 * NH], - bias.flatten().tolist()), - helper.make_tensor('static_kv', TensorProto.BOOL, [1], - [static_kv]), - helper.make_tensor('use_past', TensorProto.BOOL, [1], - [use_past]), - helper.make_tensor('has_layer_state', TensorProto.BOOL, [1], - [has_layer_state]), - helper.make_tensor('has_key_padding_mask', TensorProto.BOOL, [1], - [has_key_padding_mask]), + helper.make_tensor("q_weight", TensorProto.FLOAT, [NH, NH], q_weight.flatten().tolist()), + helper.make_tensor("kv_weight", TensorProto.FLOAT, [NH, 2 * NH], kv_weight.flatten().tolist()), + helper.make_tensor("bias", TensorProto.FLOAT, [3 * NH], bias.flatten().tolist()), + helper.make_tensor("static_kv", TensorProto.BOOL, [1], [static_kv]), + helper.make_tensor("use_past", TensorProto.BOOL, [1], [use_past]), + helper.make_tensor("has_layer_state", TensorProto.BOOL, [1], [has_layer_state]), + helper.make_tensor("has_key_padding_mask", TensorProto.BOOL, [1], [has_key_padding_mask]), ] - graph = helper.make_graph(nodes, "DecoderAttention_Graph", [ - helper.make_tensor_value_info('query', TensorProto.FLOAT, [S, B, NH]), - helper.make_tensor_value_info('key', TensorProto.FLOAT, [S2, B, NH]), - helper.make_tensor_value_info('key_padding_mask', TensorProto.BOOL, [B, "mask_len"]), - helper.make_tensor_value_info('key_cache', TensorProto.FLOAT, [B, N, "cache_len", H]), - helper.make_tensor_value_info('value_cache', TensorProto.FLOAT, [B, N, "cache_len", H]), - ], [ - helper.make_tensor_value_info('output', TensorProto.FLOAT, [S, B, NH]), - helper.make_tensor_value_info('new_key_cache', TensorProto.FLOAT, [B, N, "new_cache_len", H]), - helper.make_tensor_value_info('new_value_cache', TensorProto.FLOAT, [B, N, "new_cache_len", H]), - ], initializers) + graph = helper.make_graph( + nodes, + "DecoderAttention_Graph", + [ + helper.make_tensor_value_info("query", TensorProto.FLOAT, [S, B, NH]), + helper.make_tensor_value_info("key", TensorProto.FLOAT, [S2, B, NH]), + helper.make_tensor_value_info("key_padding_mask", TensorProto.BOOL, [B, "mask_len"]), + helper.make_tensor_value_info("key_cache", TensorProto.FLOAT, [B, N, "cache_len", H]), + helper.make_tensor_value_info("value_cache", TensorProto.FLOAT, [B, N, "cache_len", H]), + ], + [ + helper.make_tensor_value_info("output", TensorProto.FLOAT, [S, B, NH]), + helper.make_tensor_value_info("new_key_cache", TensorProto.FLOAT, [B, N, "new_cache_len", H]), + helper.make_tensor_value_info("new_value_cache", TensorProto.FLOAT, [B, N, "new_cache_len", H]), + ], + initializers, + ) model = helper.make_model(graph) return model.SerializeToString() -def create_inputs(config: Config, has_layer_state: bool, use_past: bool, encoder_decoder_attention:bool): - query = torch.normal(mean=0.0, - std=0.1, - size=(config.sequence_length, - config.batch_size, - config.embed_dim) - ).to(torch.float32) - key = torch.normal(mean=0.0, - std=0.1, - size=(config.kv_sequence_length, - config.batch_size, - config.embed_dim) - ).to(torch.float32) +def create_inputs( + config: Config, + has_layer_state: bool, + use_past: bool, + encoder_decoder_attention: bool, +): + query = torch.normal( + mean=0.0, + std=0.1, + size=(config.sequence_length, config.batch_size, config.embed_dim), + ).to(torch.float32) + key = torch.normal( + mean=0.0, + std=0.1, + size=(config.kv_sequence_length, config.batch_size, config.embed_dim), + ).to(torch.float32) key_length = None if not has_layer_state or not use_past: @@ -360,64 +418,175 @@ def create_inputs(config: Config, has_layer_state: bool, use_past: bool, encoder else: key_length = config.kv_sequence_length - key_padding_mask = torch.normal(mean=0.0, - std=0.1, - size=(config.batch_size, - key_length) - ) > 0 + key_padding_mask = torch.normal(mean=0.0, std=0.1, size=(config.batch_size, key_length)) > 0 # The following line ensure not all the mask are true key_padding_mask[0][0] = False - cache = torch.normal(mean=0.0, - std=0.1, - size=(config.batch_size, - config.num_heads, - config.kv_sequence_length, - config.head_size) - ).to(torch.float32) + cache = torch.normal( + mean=0.0, + std=0.1, + size=( + config.batch_size, + config.num_heads, + config.kv_sequence_length, + config.head_size, + ), + ).to(torch.float32) layer_state = [cache, cache, cache, cache] return query, key, key_padding_mask, layer_state, torch.tensor(use_past) -def parity_check(config, has_layer_state, use_past, static_kv, has_key_padding_mask, rtol = 1e-4, atol = 1e-4): - query, key, key_padding_mask, layer_state, use_past = create_inputs(config, - has_layer_state, - use_past, - static_kv) - attn = AttentionForONNX(config.embed_dim, - config.num_heads, - encoder_decoder_attention = static_kv) - attn_output, new_key_cache, new_value_cache = attn.forward(query, key, key_padding_mask, layer_state, None, False, use_past, has_key_padding_mask) - attn_output_ort, new_key_cache_ort, new_value_cache_ort = attn.ORT_forward(query, key, key_padding_mask, layer_state, None, False, use_past, has_key_padding_mask) - attn_output_ort_1, _, _ = attn.ORT_forward(query, key, key_padding_mask, layer_state, None, False, use_past, has_key_padding_mask) - print(" B:", config.batch_size, - " S:", config.sequence_length, - " S*:", config.kv_sequence_length, - " h:", config.embed_dim, - " has_layer_state:", has_layer_state, - " use_past:", use_past, - " static_kv:", static_kv, - " has_key_padding_mask:", has_key_padding_mask, - "[attn_output, randomness, key, value] parity:", - numpy.allclose(attn_output.detach().numpy(), attn_output_ort.detach().numpy(), rtol = rtol, atol = atol, equal_nan = True), - numpy.allclose(attn_output_ort_1.detach().numpy(), attn_output_ort.detach().numpy(), rtol = rtol, atol = atol, equal_nan = True), - numpy.allclose(new_key_cache.detach().numpy(), new_key_cache_ort.detach().numpy(), rtol = rtol, atol = atol, equal_nan = True), - numpy.allclose(new_value_cache.detach().numpy(), new_value_cache_ort.detach().numpy(), rtol = rtol, atol = atol, equal_nan = True)) - - -if __name__ == '__main__': +def parity_check( + config, + has_layer_state, + use_past, + static_kv, + has_key_padding_mask, + rtol=1e-4, + atol=1e-4, +): + query, key, key_padding_mask, layer_state, use_past = create_inputs(config, has_layer_state, use_past, static_kv) + attn = AttentionForONNX(config.embed_dim, config.num_heads, encoder_decoder_attention=static_kv) + attn_output, new_key_cache, new_value_cache = attn.forward( + query, + key, + key_padding_mask, + layer_state, + None, + False, + use_past, + has_key_padding_mask, + ) + attn_output_ort, new_key_cache_ort, new_value_cache_ort = attn.ORT_forward( + query, + key, + key_padding_mask, + layer_state, + None, + False, + use_past, + has_key_padding_mask, + ) + attn_output_ort_1, _, _ = attn.ORT_forward( + query, + key, + key_padding_mask, + layer_state, + None, + False, + use_past, + has_key_padding_mask, + ) + print( + " B:", + config.batch_size, + " S:", + config.sequence_length, + " S*:", + config.kv_sequence_length, + " h:", + config.embed_dim, + " has_layer_state:", + has_layer_state, + " use_past:", + use_past, + " static_kv:", + static_kv, + " has_key_padding_mask:", + has_key_padding_mask, + "[attn_output, randomness, key, value] parity:", + numpy.allclose( + attn_output.detach().numpy(), + attn_output_ort.detach().numpy(), + rtol=rtol, + atol=atol, + equal_nan=True, + ), + numpy.allclose( + attn_output_ort_1.detach().numpy(), + attn_output_ort.detach().numpy(), + rtol=rtol, + atol=atol, + equal_nan=True, + ), + numpy.allclose( + new_key_cache.detach().numpy(), + new_key_cache_ort.detach().numpy(), + rtol=rtol, + atol=atol, + equal_nan=True, + ), + numpy.allclose( + new_value_cache.detach().numpy(), + new_value_cache_ort.detach().numpy(), + rtol=rtol, + atol=atol, + equal_nan=True, + ), + ) + + +if __name__ == "__main__": for b in [1, 32, 128]: for s in [1, 2, 128]: for s2 in [1, 64, 256]: for n in [8]: for h in [64]: config = Config(b, s, s2, n, h) - parity_check(config, has_layer_state = True, use_past = True, static_kv = True, has_key_padding_mask = False) - parity_check(config, has_layer_state = True, use_past = True, static_kv = False, has_key_padding_mask = False) - parity_check(config, has_layer_state = True, use_past = False, static_kv = True, has_key_padding_mask = False) - parity_check(config, has_layer_state = True, use_past = False, static_kv = False, has_key_padding_mask = False) - parity_check(config, has_layer_state = True, use_past = True, static_kv = True, has_key_padding_mask = True) - parity_check(config, has_layer_state = True, use_past = True, static_kv = False, has_key_padding_mask = True) - parity_check(config, has_layer_state = True, use_past = False, static_kv = True, has_key_padding_mask = True) - parity_check(config, has_layer_state = True, use_past = False, static_kv = False, has_key_padding_mask = True) \ No newline at end of file + parity_check( + config, + has_layer_state=True, + use_past=True, + static_kv=True, + has_key_padding_mask=False, + ) + parity_check( + config, + has_layer_state=True, + use_past=True, + static_kv=False, + has_key_padding_mask=False, + ) + parity_check( + config, + has_layer_state=True, + use_past=False, + static_kv=True, + has_key_padding_mask=False, + ) + parity_check( + config, + has_layer_state=True, + use_past=False, + static_kv=False, + has_key_padding_mask=False, + ) + parity_check( + config, + has_layer_state=True, + use_past=True, + static_kv=True, + has_key_padding_mask=True, + ) + parity_check( + config, + has_layer_state=True, + use_past=True, + static_kv=False, + has_key_padding_mask=True, + ) + parity_check( + config, + has_layer_state=True, + use_past=False, + static_kv=True, + has_key_padding_mask=True, + ) + parity_check( + config, + has_layer_state=True, + use_past=False, + static_kv=False, + has_key_padding_mask=True, + ) diff --git a/onnxruntime/test/python/transformers/test_parity_gelu.py b/onnxruntime/test/python/transformers/test_parity_gelu.py index e7e4d04da9b30..7fe42dc76f193 100644 --- a/onnxruntime/test/python/transformers/test_parity_gelu.py +++ b/onnxruntime/test/python/transformers/test_parity_gelu.py @@ -23,12 +23,13 @@ For comparison, CPU has MaxDiff=4.77E-07 for each formula. """ -import unittest -import torch -from torch import nn import math import os +import unittest + +import torch from parity_utilities import * +from torch import nn class Gelu(nn.Module): @@ -63,10 +64,10 @@ def forward(self, x): if self.fp32_gelu_op and x.dtype == torch.float16: # This test only evaluates FP32 kernels so add data type cast for input and output. casted_output = self.gelu(x.to(torch.float32)).to(torch.float16) - return (casted_output, ) + return (casted_output,) else: output = self.gelu(x) - return (output, ) + return (output,) def get_output_names(): @@ -74,15 +75,17 @@ def get_output_names(): return outputs -def run(batch_size, - float16, - optimized, - hidden_size, - device, - test_cases, - formula=0, - sequence_length=2, - fp32_gelu_op=True): +def run( + batch_size, + float16, + optimized, + hidden_size, + device, + test_cases, + formula=0, + sequence_length=2, + fp32_gelu_op=True, +): test_name = f"device={device}, float16={float16}, optimized={optimized}, batch_size={batch_size}, sequence_length={sequence_length}, hidden_size={hidden_size}, formula={formula}, fp32_gelu_op={fp32_gelu_op}" print(f"\nTesting: {test_name}") @@ -93,31 +96,35 @@ def run(batch_size, model.half() # Do not re-use onnx file from previous test since weights of model are random. - onnx_model_path = './temp/gelu_{}_{}.onnx'.format(formula, "fp16" if float16 else "fp32") + onnx_model_path = "./temp/gelu_{}_{}.onnx".format(formula, "fp16" if float16 else "fp32") export_onnx(model, onnx_model_path, float16, hidden_size, device) if optimized: - optimized_onnx_path = './temp/gelu_{}_opt_{}.onnx'.format(formula, "fp16" if float16 else "fp32") + optimized_onnx_path = "./temp/gelu_{}_opt_{}.onnx".format(formula, "fp16" if float16 else "fp32") use_gpu = float16 and not fp32_gelu_op - optimize_onnx(onnx_model_path, - optimized_onnx_path, - Gelu.get_fused_op(formula), - use_gpu=use_gpu, - opt_level=2 if use_gpu else None) + optimize_onnx( + onnx_model_path, + optimized_onnx_path, + Gelu.get_fused_op(formula), + use_gpu=use_gpu, + opt_level=2 if use_gpu else None, + ) onnx_path = optimized_onnx_path else: onnx_path = onnx_model_path - num_failure = run_parity(model, - onnx_path, - batch_size, - hidden_size, - sequence_length, - float16, - device, - optimized, - test_cases, - verbose=False) + num_failure = run_parity( + model, + onnx_path, + batch_size, + hidden_size, + sequence_length, + float16, + device, + optimized, + test_cases, + verbose=False, + ) # clean up onnx file os.remove(onnx_model_path) @@ -134,66 +141,90 @@ def setUp(self): self.sequence_length = 2 self.hidden_size = 768 self.formula_to_test = [0, 1, 2, 3, 4, 5] - self.formula_must_pass = [0, 1, 3, 4, 5] # formula 2 cannot pass precision test. - - def run_test(self, - batch_size, - float16, - optimized, - hidden_size, - device, - formula, - enable_assert=True, - fp32_gelu_op=True): - if float16 and device.type == 'cpu': # CPU does not support FP16 + self.formula_must_pass = [ + 0, + 1, + 3, + 4, + 5, + ] # formula 2 cannot pass precision test. + + def run_test( + self, + batch_size, + float16, + optimized, + hidden_size, + device, + formula, + enable_assert=True, + fp32_gelu_op=True, + ): + if float16 and device.type == "cpu": # CPU does not support FP16 return - num_failure, test_name = run(batch_size, float16, optimized, hidden_size, device, self.test_cases, formula, - self.sequence_length, fp32_gelu_op) + num_failure, test_name = run( + batch_size, + float16, + optimized, + hidden_size, + device, + self.test_cases, + formula, + self.sequence_length, + fp32_gelu_op, + ) if enable_assert: self.assertTrue(num_failure == 0, "Failed: " + test_name) def run_one(self, optimized, device, hidden_size=768, formula=0): for batch_size in [4]: - self.run_test(batch_size, - float16=False, - optimized=optimized, - hidden_size=hidden_size, - device=device, - formula=formula, - enable_assert=formula in self.formula_must_pass) - - self.run_test(batch_size, - float16=True, - optimized=optimized, - hidden_size=hidden_size, - device=device, - formula=formula, - enable_assert=formula in self.formula_must_pass, - fp32_gelu_op=True) - - self.run_test(batch_size, - float16=True, - optimized=optimized, - hidden_size=hidden_size, - device=device, - formula=formula, - enable_assert=formula in self.formula_must_pass, - fp32_gelu_op=False) + self.run_test( + batch_size, + float16=False, + optimized=optimized, + hidden_size=hidden_size, + device=device, + formula=formula, + enable_assert=formula in self.formula_must_pass, + ) + + self.run_test( + batch_size, + float16=True, + optimized=optimized, + hidden_size=hidden_size, + device=device, + formula=formula, + enable_assert=formula in self.formula_must_pass, + fp32_gelu_op=True, + ) + + self.run_test( + batch_size, + float16=True, + optimized=optimized, + hidden_size=hidden_size, + device=device, + formula=formula, + enable_assert=formula in self.formula_must_pass, + fp32_gelu_op=False, + ) def test_cpu(self): - cpu = torch.device('cpu') + cpu = torch.device("cpu") for i in self.formula_to_test: self.run_one(self.optimized, cpu, hidden_size=self.hidden_size, formula=i) def test_cuda(self): if not torch.cuda.is_available(): import pytest - pytest.skip('test requires GPU and torch+cuda') + + pytest.skip("test requires GPU and torch+cuda") else: - gpu = torch.device('cuda') + gpu = torch.device("cuda") for i in self.formula_to_test: self.run_one(self.optimized, gpu, hidden_size=self.hidden_size, formula=i) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/onnxruntime/test/python/transformers/test_parity_huggingface_gpt_attention.py b/onnxruntime/test/python/transformers/test_parity_huggingface_gpt_attention.py index 16172792cbf38..c29cf969734c4 100644 --- a/onnxruntime/test/python/transformers/test_parity_huggingface_gpt_attention.py +++ b/onnxruntime/test/python/transformers/test_parity_huggingface_gpt_attention.py @@ -10,17 +10,18 @@ # license information. # ------------------------------------------------------------------------- +import os +import random import unittest + +import numpy +import onnx import pytest import torch -from torch import nn -import random from onnx import helper -import onnx -import numpy -import os +from parity_utilities import compare_outputs, create_ort_session, diff_outputs +from torch import nn from transformers.modeling_utils import Conv1D -from parity_utilities import diff_outputs, create_ort_session, compare_outputs DEBUG_OUTPUTS = ["qk", "norm_qk", "softmax", "attn_weights"] @@ -30,19 +31,23 @@ class MyGPT2Attention(nn.Module): This module is modifed from Gpt2Attention of huggingface transformers v4.9.1. Code related to crosss attention, c_proj, attn_dropout and head_mask etc are removed. """ - def __init__(self, - max_position_embeddings=1024, - hidden_size=768, - num_attention_heads=12, - use_cache=True, - debug=False, - fix_onnx_export=True): + + def __init__( + self, + max_position_embeddings=1024, + hidden_size=768, + num_attention_heads=12, + use_cache=True, + debug=False, + fix_onnx_export=True, + ): super().__init__() max_positions = max_position_embeddings self.register_buffer( "bias", - torch.tril(torch.ones((max_positions, max_positions), - dtype=torch.uint8)).view(1, 1, max_positions, max_positions), + torch.tril(torch.ones((max_positions, max_positions), dtype=torch.uint8)).view( + 1, 1, max_positions, max_positions + ), ) self.register_buffer("masked_bias", torch.tensor(-1e4)) self.embed_dim = hidden_size @@ -53,7 +58,7 @@ def __init__(self, self.c_attn = Conv1D(3 * self.embed_dim, self.embed_dim) # Use random bias instead of zeros for parity test. - self.c_attn.bias = nn.Parameter(torch.normal(0.0, 0.1, (3 * self.embed_dim, ))) + self.c_attn.bias = nn.Parameter(torch.normal(0.0, 0.1, (3 * self.embed_dim,))) self.use_cache = use_cache self.debug = debug @@ -69,15 +74,15 @@ def _attn(self, query, key, value, attention_mask=None): # This walkaround is not needed when attention fusion will be applied later since the subgraph will be replaced by an Attention node. if self.fix_onnx_export and torch.onnx.is_in_onnx_export(): if qk.dtype == torch.float16: - norm_qk = qk.to(torch.float32) * (1.0 / (float(value.size(-1))**0.5)) + norm_qk = qk.to(torch.float32) * (1.0 / (float(value.size(-1)) ** 0.5)) norm_qk = norm_qk.to(torch.float16) else: - norm_qk = qk * (1.0 / (float(value.size(-1))**0.5)) + norm_qk = qk * (1.0 / (float(value.size(-1)) ** 0.5)) else: - norm_qk = qk / (float(value.size(-1))**0.5) + norm_qk = qk / (float(value.size(-1)) ** 0.5) query_length, key_length = query.size(-2), key.size(-2) - causal_mask = self.bias[:, :, key_length - query_length:key_length, :key_length].bool() + causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length].bool() attn_weights = torch.where(causal_mask, norm_qk, self.masked_bias.to(norm_qk.dtype)) if attention_mask is not None: @@ -99,7 +104,7 @@ def _split_heads(self, tensor, num_heads, attn_head_size): def _merge_heads(self, tensor, num_heads, attn_head_size): tensor = tensor.permute(0, 2, 1, 3).contiguous() - new_shape = tensor.size()[:-2] + (num_heads * attn_head_size, ) + new_shape = tensor.size()[:-2] + (num_heads * attn_head_size,) return tensor.view(new_shape) @staticmethod @@ -144,28 +149,35 @@ def forward(self, hidden_states, attention_mask=None, layer_past=None): outputs = (attn_output, present) if self.debug: if "qk" in DEBUG_OUTPUTS: - outputs += (qk, ) + outputs += (qk,) if "norm_qk" in DEBUG_OUTPUTS: - outputs += (norm_qk, ) + outputs += (norm_qk,) if "softmax" in DEBUG_OUTPUTS: - outputs += (softmax, ) + outputs += (softmax,) if "attn_weights" in DEBUG_OUTPUTS: - outputs += (attn_weights, ) + outputs += (attn_weights,) return outputs -def create_inputs(batch_size=1, - hidden_size=768, - num_attention_heads=12, - sequence_length=1, - past_sequence_length=5, - float16=False, - device=torch.device('cuda'), - padding_length=0): +def create_inputs( + batch_size=1, + hidden_size=768, + num_attention_heads=12, + sequence_length=1, + past_sequence_length=5, + float16=False, + device=torch.device("cuda"), + padding_length=0, +): float_type = torch.float16 if float16 else torch.float32 - past_shape = [batch_size, num_attention_heads, past_sequence_length, int(hidden_size / num_attention_heads)] + past_shape = [ + batch_size, + num_attention_heads, + past_sequence_length, + int(hidden_size / num_attention_heads), + ] past_key = torch.rand(past_shape, dtype=float_type, device=device) past_value = torch.rand(past_shape, dtype=float_type, device=device) layer_past = MyGPT2Attention.concat_key_value(past_key, past_value) @@ -181,8 +193,9 @@ def create_inputs(batch_size=1, padding_position = random.randint(0, total_sequence_length - 1) attention_mask[i, padding_position] = 0 - input_hidden_states = torch.normal(mean=0.0, std=0.1, - size=(batch_size, sequence_length, hidden_size)).to(float_type).to(device) + input_hidden_states = ( + torch.normal(mean=0.0, std=0.1, size=(batch_size, sequence_length, hidden_size)).to(float_type).to(device) + ) return input_hidden_states, attention_mask, layer_past @@ -195,88 +208,75 @@ def get_output_names(debug=False): def export_onnx(model, onnx_model_path, float16, hidden_size, num_attention_heads, debug, device): from pathlib import Path + Path(onnx_model_path).parent.mkdir(parents=True, exist_ok=True) - input_hidden_states, attention_mask, layer_past = create_inputs(float16=float16, - hidden_size=hidden_size, - num_attention_heads=num_attention_heads, - device=device) + input_hidden_states, attention_mask, layer_past = create_inputs( + float16=float16, + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + device=device, + ) with torch.no_grad(): outputs = model(input_hidden_states, attention_mask=attention_mask, layer_past=layer_past) dynamic_axes = { - 'input_hidden_states': { - 0: 'batch_size', - 1: 'seq_len' - }, - "attn_output": { - 0: 'batch_size', - 1: 'seq_len' - }, - "past": { - 1: 'batch_size', - 3: 'past_seq_len' - }, - "present": { - 1: 'batch_size', - 3: 'total_seq_len' - }, - "attention_mask": { - 0: 'batch_size', - 1: 'total_seq_len' - } + "input_hidden_states": {0: "batch_size", 1: "seq_len"}, + "attn_output": {0: "batch_size", 1: "seq_len"}, + "past": {1: "batch_size", 3: "past_seq_len"}, + "present": {1: "batch_size", 3: "total_seq_len"}, + "attention_mask": {0: "batch_size", 1: "total_seq_len"}, } if debug: debug_dynamic_axes = { - "qk": { - 0: 'batch_size', - 1: 'seq_len' - }, - "norm_qk": { - 0: 'batch_size', - 1: 'seq_len' - }, - "softmax": { - 0: 'batch_size', - 1: 'seq_len' - }, - "attn_weights": { - 0: 'batch_size', - 1: 'seq_len' - } + "qk": {0: "batch_size", 1: "seq_len"}, + "norm_qk": {0: "batch_size", 1: "seq_len"}, + "softmax": {0: "batch_size", 1: "seq_len"}, + "attn_weights": {0: "batch_size", 1: "seq_len"}, } for name in DEBUG_OUTPUTS: dynamic_axes[name] = debug_dynamic_axes[name] - torch.onnx.export(model, - args=(input_hidden_states, { - 'attention_mask': attention_mask, - 'layer_past': layer_past - }), - f=onnx_model_path, - input_names=['input_hidden_states', 'attention_mask', 'past'], - output_names=get_output_names(debug), - dynamic_axes=dynamic_axes, - opset_version=11, - do_constant_folding=True) + torch.onnx.export( + model, + args=( + input_hidden_states, + {"attention_mask": attention_mask, "layer_past": layer_past}, + ), + f=onnx_model_path, + input_names=["input_hidden_states", "attention_mask", "past"], + output_names=get_output_names(debug), + dynamic_axes=dynamic_axes, + opset_version=11, + do_constant_folding=True, + ) print("exported:", onnx_model_path) def optimize_onnx(input_onnx_path, optimized_onnx_path, num_heads, debug): from onnxruntime.transformers.onnx_model import OnnxModel + m = onnx.load(input_onnx_path) onnx_model = OnnxModel(m) nodes_to_remove = onnx_model.nodes() output_names = ["attn_output", "present"] + DEBUG_OUTPUTS if debug else ["attn_output", "present"] - node_to_add = helper.make_node("Attention", - ["input_hidden_states", "c_attn.weight", "c_attn.bias", "attention_mask", "past"], - output_names, - "gpt2_attention", - num_heads=num_heads, - unidirectional=1, - domain="com.microsoft") + node_to_add = helper.make_node( + "Attention", + [ + "input_hidden_states", + "c_attn.weight", + "c_attn.bias", + "attention_mask", + "past", + ], + output_names, + "gpt2_attention", + num_heads=num_heads, + unidirectional=1, + domain="com.microsoft", + ) onnx_model.remove_nodes(nodes_to_remove) onnx_model.add_node(node_to_add) @@ -286,41 +286,54 @@ def optimize_onnx(input_onnx_path, optimized_onnx_path, num_heads, debug): def onnxruntime_inference(ort_session, input_hidden_states, attention_mask, past): ort_inputs = { - 'past': numpy.ascontiguousarray(past.cpu().numpy()), - 'attention_mask': numpy.ascontiguousarray(attention_mask.cpu().numpy()), - 'input_hidden_states': numpy.ascontiguousarray(input_hidden_states.cpu().numpy()), + "past": numpy.ascontiguousarray(past.cpu().numpy()), + "attention_mask": numpy.ascontiguousarray(attention_mask.cpu().numpy()), + "input_hidden_states": numpy.ascontiguousarray(input_hidden_states.cpu().numpy()), } ort_outputs = ort_session.run(None, ort_inputs) return ort_outputs -def verify_attention(model, - onnx_model_path, - batch_size, - hidden_size, - num_attention_heads, - sequence_length, - past_sequence_length, - float16, - device, - padding_length, - optimized, - test_cases=100): +def verify_attention( + model, + onnx_model_path, + batch_size, + hidden_size, + num_attention_heads, + sequence_length, + past_sequence_length, + float16, + device, + padding_length, + optimized, + test_cases=100, +): print( f"optimized={optimized}, batch_size={batch_size}, hidden_size={hidden_size}, num_attention_heads={num_attention_heads}, sequence_length={sequence_length}, past_sequence_length={past_sequence_length}, float16={float16}, padding_length={padding_length}, device={device}" ) passed_cases = 0 max_diffs = [] - ort_session = create_ort_session(onnx_model_path, device.type == 'cuda') + ort_session = create_ort_session(onnx_model_path, device.type == "cuda") for i in range(test_cases): - input_hidden_states, attention_mask, layer_past = create_inputs(batch_size, hidden_size, num_attention_heads, - sequence_length, past_sequence_length, float16, - device, padding_length) + input_hidden_states, attention_mask, layer_past = create_inputs( + batch_size, + hidden_size, + num_attention_heads, + sequence_length, + past_sequence_length, + float16, + device, + padding_length, + ) with torch.no_grad(): - torch_outputs = model(input_hidden_states, layer_past=layer_past, attention_mask=attention_mask) + torch_outputs = model( + input_hidden_states, + layer_past=layer_past, + attention_mask=attention_mask, + ) ort_outputs = onnxruntime_inference(ort_session, input_hidden_states, attention_mask, layer_past) @@ -341,8 +354,9 @@ def run(batch_size, float16, optimized, hidden_size, num_attention_heads, device test_name = f"batch_size={batch_size}, float16={float16}, optimized={optimized}, hidden_size={hidden_size}, num_attention_heads={num_attention_heads}" print(f"\nTesting ONNX parity: {test_name}") - debug = (not optimized - ) # or DEBUG_OUTPUTS==["softmax"] when you add an extra output for softmax result in Attention operator + debug = ( + not optimized + ) # or DEBUG_OUTPUTS==["softmax"] when you add an extra output for softmax result in Attention operator model = MyGPT2Attention(hidden_size=hidden_size, num_attention_heads=num_attention_heads, debug=debug) model.eval() model.to(device) @@ -350,11 +364,11 @@ def run(batch_size, float16, optimized, hidden_size, num_attention_heads, device model.half() # Do not re-use onnx file from previous test since weights of model are random. - onnx_model_path = './temp/gpt_attention_{}.onnx'.format("fp16" if float16 else "fp32") + onnx_model_path = "./temp/gpt_attention_{}.onnx".format("fp16" if float16 else "fp32") export_onnx(model, onnx_model_path, float16, hidden_size, num_attention_heads, debug, device) if optimized: - optimized_onnx_path = './temp/gpt_attention_opt_{}.onnx'.format("fp16" if float16 else "fp32") + optimized_onnx_path = "./temp/gpt_attention_opt_{}.onnx".format("fp16" if float16 else "fp32") optimize_onnx(onnx_model_path, optimized_onnx_path, num_attention_heads, debug) onnx_path = optimized_onnx_path else: @@ -365,22 +379,58 @@ def run(batch_size, float16, optimized, hidden_size, num_attention_heads, device past_sequence_length = 0 padding_length = 0 num_failure = 0 - num_failure += verify_attention(model, onnx_path, batch_size, hidden_size, num_attention_heads, sequence_length, - past_sequence_length, float16, device, padding_length, optimized, test_cases) + num_failure += verify_attention( + model, + onnx_path, + batch_size, + hidden_size, + num_attention_heads, + sequence_length, + past_sequence_length, + float16, + device, + padding_length, + optimized, + test_cases, + ) # Test Case: with past state and padding last 2 words sequence_length = 3 past_sequence_length = 5 padding_length = 2 - num_failure += verify_attention(model, onnx_path, batch_size, hidden_size, num_attention_heads, sequence_length, - past_sequence_length, float16, device, padding_length, optimized, test_cases) + num_failure += verify_attention( + model, + onnx_path, + batch_size, + hidden_size, + num_attention_heads, + sequence_length, + past_sequence_length, + float16, + device, + padding_length, + optimized, + test_cases, + ) # Test Case: random mask one word sequence_length = 1 past_sequence_length = 128 padding_length = -1 - num_failure += verify_attention(model, onnx_path, batch_size, hidden_size, num_attention_heads, sequence_length, - past_sequence_length, float16, device, padding_length, optimized, test_cases) + num_failure += verify_attention( + model, + onnx_path, + batch_size, + hidden_size, + num_attention_heads, + sequence_length, + past_sequence_length, + float16, + device, + padding_length, + optimized, + test_cases, + ) # clean up onnx file os.remove(onnx_model_path) @@ -396,63 +446,80 @@ def setUp(self): self.test_cases = 10 # Number of test cases per test run def run_test(self, batch_size, float16, optimized, hidden_size, num_attention_heads, device): - if float16 and device.type == 'cpu': # CPU does not support FP16 + if float16 and device.type == "cpu": # CPU does not support FP16 return - num_failure, test_name = run(batch_size, float16, optimized, hidden_size, num_attention_heads, device, - self.test_cases) + num_failure, test_name = run( + batch_size, + float16, + optimized, + hidden_size, + num_attention_heads, + device, + self.test_cases, + ) self.assertTrue(num_failure == 0, test_name) def run_small(self, optimized, device): for batch_size in [64]: - self.run_test(batch_size, - float16=False, - optimized=optimized, - hidden_size=768, - num_attention_heads=12, - device=device) - self.run_test(batch_size, - float16=True, - optimized=optimized, - hidden_size=768, - num_attention_heads=12, - device=device) + self.run_test( + batch_size, + float16=False, + optimized=optimized, + hidden_size=768, + num_attention_heads=12, + device=device, + ) + self.run_test( + batch_size, + float16=True, + optimized=optimized, + hidden_size=768, + num_attention_heads=12, + device=device, + ) def run_large(self, optimized, device): for batch_size in [2]: - self.run_test(batch_size, - float16=False, - optimized=optimized, - hidden_size=4096, - num_attention_heads=32, - device=device) - self.run_test(batch_size, - float16=True, - optimized=optimized, - hidden_size=4096, - num_attention_heads=32, - device=device) + self.run_test( + batch_size, + float16=False, + optimized=optimized, + hidden_size=4096, + num_attention_heads=32, + device=device, + ) + self.run_test( + batch_size, + float16=True, + optimized=optimized, + hidden_size=4096, + num_attention_heads=32, + device=device, + ) def test_cpu(self): - cpu = torch.device('cpu') + cpu = torch.device("cpu") self.run_small(self.optimized, cpu) def test_cuda(self): if not torch.cuda.is_available(): import pytest - pytest.skip('test requires GPU and torch+cuda') + + pytest.skip("test requires GPU and torch+cuda") else: - gpu = torch.device('cuda') + gpu = torch.device("cuda") self.run_small(self.optimized, gpu) @pytest.mark.slow def test_large_cuda(self): if not torch.cuda.is_available(): import pytest - pytest.skip('test requires GPU and torch+cuda') + + pytest.skip("test requires GPU and torch+cuda") else: - gpu = torch.device('cuda') + gpu = torch.device("cuda") self.run_large(self.optimized, gpu) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/onnxruntime/test/python/transformers/test_parity_layernorm.py b/onnxruntime/test/python/transformers/test_parity_layernorm.py index 0245cb45abec2..01122b4830bfa 100644 --- a/onnxruntime/test/python/transformers/test_parity_layernorm.py +++ b/onnxruntime/test/python/transformers/test_parity_layernorm.py @@ -4,12 +4,13 @@ # license information. # -------------------------------------------------------------------------- -import unittest -import torch -from torch import nn import os +import unittest + import onnx +import torch from parity_utilities import * +from torch import nn if find_transformers_source(): from onnx_model import OnnxModel @@ -50,11 +51,11 @@ def forward(self, x): else: y = self.my_layer_norm(x) - return (y, ) + return (y,) def get_weight(onnx_model): - last_mul_node = onnx_model.get_nodes_by_op_type('Mul')[-1] + last_mul_node = onnx_model.get_nodes_by_op_type("Mul")[-1] i, value = onnx_model.get_constant_input(last_mul_node) assert value is not None weight_name = last_mul_node.input[i] @@ -62,7 +63,7 @@ def get_weight(onnx_model): def get_bias(onnx_model): - last_add_node = onnx_model.get_nodes_by_op_type('Add')[-1] + last_add_node = onnx_model.get_nodes_by_op_type("Add")[-1] i, value = onnx_model.get_constant_input(last_add_node) assert value is not None bias_name = last_add_node.input[i] @@ -79,11 +80,14 @@ def optimize_fp16_onnx_with_cast(input_onnx_path, optimized_onnx_path, epsilon): onnx.helper.make_node("Cast", ["input"], ["fp32_input"], "cast_input", to=1), onnx.helper.make_node("Cast", [weight_name], ["fp32_layer_norm.weight"], "cast_weight", to=1), onnx.helper.make_node("Cast", [bias_name], ["fp32_layer_norm.bias"], "cast_bias", to=1), - onnx.helper.make_node("LayerNormalization", ["fp32_input", "fp32_layer_norm.weight", "fp32_layer_norm.bias"], - ["fp32_output"], - "layer_norm", - epsilon=epsilon), # use fp32 epsilon - onnx.helper.make_node("Cast", ["fp32_output"], ["output"], "cast_output", to=10) + onnx.helper.make_node( + "LayerNormalization", + ["fp32_input", "fp32_layer_norm.weight", "fp32_layer_norm.bias"], + ["fp32_output"], + "layer_norm", + epsilon=epsilon, + ), # use fp32 epsilon + onnx.helper.make_node("Cast", ["fp32_output"], ["output"], "cast_output", to=10), ] onnx_model.remove_nodes(nodes_to_remove) @@ -101,9 +105,13 @@ def optimize_fp16_onnx_no_cast(input_onnx_path, optimized_onnx_path, epsilon): nodes_to_remove = [n for n in onnx_model.nodes() if n.output[0] != weight_name and n.output[0] != bias_name] nodes_to_remove = onnx_model.nodes() - node_to_add = onnx.helper.make_node("LayerNormalization", ["input", weight_name, bias_name], ["output"], - "layer_norm", - epsilon=epsilon) + node_to_add = onnx.helper.make_node( + "LayerNormalization", + ["input", weight_name, bias_name], + ["output"], + "layer_norm", + epsilon=epsilon, + ) onnx_model.remove_nodes(nodes_to_remove) onnx_model.add_node(node_to_add) @@ -116,18 +124,20 @@ def get_output_names(): return outputs -def run(batch_size, - float16, - optimized, - hidden_size, - device, - test_cases, - sequence_length=2, - epsilon=0.00001, - cast_fp16=True, - cast_onnx_only=False, - formula=0, - verbose=False): +def run( + batch_size, + float16, + optimized, + hidden_size, + device, + test_cases, + sequence_length=2, + epsilon=0.00001, + cast_fp16=True, + cast_onnx_only=False, + formula=0, + verbose=False, +): test_name = f"device={device}, float16={float16}, optimized={optimized}, batch_size={batch_size}, sequence_length={sequence_length}, hidden_size={hidden_size}, epsilon={epsilon}, cast_fp16={cast_fp16}, cast_onnx_only={cast_onnx_only}, formula={formula}" print(f"\nTesting: {test_name}") @@ -139,13 +149,17 @@ def run(batch_size, model.half() # Do not re-use onnx file from previous test since weights of model are random. - onnx_model_path = './temp/layer_norm_{}_formula{}.onnx'.format("fp16" if float16 else "fp32", formula) + onnx_model_path = "./temp/layer_norm_{}_formula{}.onnx".format("fp16" if float16 else "fp32", formula) export_onnx(model, onnx_model_path, float16, hidden_size, device) if optimized: - optimized_onnx_path = './temp/layer_norm_{}_formula{}_opt.onnx'.format("fp16" if float16 else "fp32", formula) + optimized_onnx_path = "./temp/layer_norm_{}_formula{}_opt.onnx".format("fp16" if float16 else "fp32", formula) if (not float16) or cast_fp16: - optimize_onnx(onnx_model_path, optimized_onnx_path, expected_op=LayerNorm.get_fused_op()) + optimize_onnx( + onnx_model_path, + optimized_onnx_path, + expected_op=LayerNorm.get_fused_op(), + ) else: if cast_onnx_only: optimize_fp16_onnx_with_cast(onnx_model_path, optimized_onnx_path, epsilon=epsilon) @@ -156,16 +170,18 @@ def run(batch_size, else: onnx_path = onnx_model_path - num_failure = run_parity(model, - onnx_path, - batch_size, - hidden_size, - sequence_length, - float16, - device, - optimized, - test_cases, - verbose=verbose) + num_failure = run_parity( + model, + onnx_path, + batch_size, + hidden_size, + sequence_length, + float16, + device, + optimized, + test_cases, + verbose=verbose, + ) # clean up onnx file os.remove(onnx_model_path) @@ -183,46 +199,52 @@ def setUp(self): self.hidden_size = 768 self.verbose = False - def run_test(self, - batch_size, - float16, - optimized, - hidden_size, - device, - cast_fp16=True, - cast_onnx_only=False, - formula=0, - epsilon=0.00001, - enable_assert=True): - if float16 and device.type == 'cpu': # CPU does not support FP16 + def run_test( + self, + batch_size, + float16, + optimized, + hidden_size, + device, + cast_fp16=True, + cast_onnx_only=False, + formula=0, + epsilon=0.00001, + enable_assert=True, + ): + if float16 and device.type == "cpu": # CPU does not support FP16 return - num_failure, test_name = run(batch_size, - float16, - optimized, - hidden_size, - device, - self.test_cases, - self.sequence_length, - epsilon, - cast_fp16, - cast_onnx_only, - formula, - verbose=self.verbose) + num_failure, test_name = run( + batch_size, + float16, + optimized, + hidden_size, + device, + self.test_cases, + self.sequence_length, + epsilon, + cast_fp16, + cast_onnx_only, + formula, + verbose=self.verbose, + ) if enable_assert: self.assertTrue(num_failure == 0, "Failed: " + test_name) def run_one(self, optimized, device, hidden_size=768, run_extra_tests=False): for batch_size in [4]: for formula in [0, 1]: - for epsilon in [1e-5]: #[1e-5, 1e-12] - self.run_test(batch_size, - float16=False, - optimized=optimized, - hidden_size=hidden_size, - device=device, - formula=formula, - epsilon=epsilon) + for epsilon in [1e-5]: # [1e-5, 1e-12] + self.run_test( + batch_size, + float16=False, + optimized=optimized, + hidden_size=hidden_size, + device=device, + formula=formula, + epsilon=epsilon, + ) self.run_test( batch_size, @@ -234,13 +256,13 @@ def run_one(self, optimized, device, hidden_size=768, run_extra_tests=False): cast_onnx_only=False, formula=formula, epsilon=epsilon, - enable_assert=False # This setting has small chance to exceed tollerance threshold 0.001 + enable_assert=False, # This setting has small chance to exceed tollerance threshold 0.001 ) if not run_extra_tests: continue - if device.type != 'cuda' or formula != 1: + if device.type != "cuda" or formula != 1: self.run_test( batch_size, float16=True, @@ -251,7 +273,7 @@ def run_one(self, optimized, device, hidden_size=768, run_extra_tests=False): cast_onnx_only=False, formula=formula, epsilon=epsilon, - enable_assert=False # This setting cannot pass tollerance threshold + enable_assert=False, # This setting cannot pass tollerance threshold ) self.run_test( @@ -264,21 +286,22 @@ def run_one(self, optimized, device, hidden_size=768, run_extra_tests=False): cast_onnx_only=True, formula=formula, epsilon=epsilon, - enable_assert=False # This setting cannot pass tollerance threshold + enable_assert=False, # This setting cannot pass tollerance threshold ) def test_cpu(self): - cpu = torch.device('cpu') + cpu = torch.device("cpu") self.run_one(self.optimized, cpu, hidden_size=self.hidden_size) def test_cuda(self): if not torch.cuda.is_available(): import pytest - pytest.skip('test requires GPU and torch+cuda') + + pytest.skip("test requires GPU and torch+cuda") else: - gpu = torch.device('cuda') + gpu = torch.device("cuda") self.run_one(self.optimized, gpu, hidden_size=self.hidden_size, run_extra_tests=True) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/onnxruntime/test/python/transformers/test_profiler.py b/onnxruntime/test/python/transformers/test_profiler.py index 35e9b85d4c729..c8397bc3f7bb8 100644 --- a/onnxruntime/test/python/transformers/test_profiler.py +++ b/onnxruntime/test/python/transformers/test_profiler.py @@ -8,38 +8,41 @@ # For live logging, use the command: pytest -o log_cli=true --log-cli-level=DEBUG -import unittest import os -import pytest +import unittest +import pytest from test_optimizer import _get_test_model_path class TestBertProfiler(unittest.TestCase): def setUp(self): from onnxruntime import get_available_providers - self.test_cuda = 'CUDAExecutionProvider' in get_available_providers() + + self.test_cuda = "CUDAExecutionProvider" in get_available_providers() def run_profile(self, arguments: str): from onnxruntime.transformers.profiler import parse_arguments, run + args = parse_arguments(arguments.split()) results = run(args) self.assertTrue(len(results) > 1) @pytest.mark.slow def test_profiler_gpu(self): - input_model_path = _get_test_model_path('bert_keras_squad') + input_model_path = _get_test_model_path("bert_keras_squad") if self.test_cuda: - self.run_profile(f'--model {input_model_path} --batch_size 1 --sequence_length 7 --use_gpu') + self.run_profile(f"--model {input_model_path} --batch_size 1 --sequence_length 7 --use_gpu") @pytest.mark.slow def test_profiler_cpu(self): - input_model_path = _get_test_model_path('bert_keras_squad') - self.run_profile(f'--model {input_model_path} --batch_size 1 --sequence_length 7 --dummy_inputs default') + input_model_path = _get_test_model_path("bert_keras_squad") + self.run_profile(f"--model {input_model_path} --batch_size 1 --sequence_length 7 --dummy_inputs default") -if __name__ == '__main__': +if __name__ == "__main__": import sys - sys.path.append(os.path.join(os.path.dirname(__file__), '..')) + + sys.path.append(os.path.join(os.path.dirname(__file__), "..")) unittest.main() diff --git a/onnxruntime/test/python/transformers/test_shape_infer_helper.py b/onnxruntime/test/python/transformers/test_shape_infer_helper.py index 1a9f22d1477cd..6429b4cee5843 100644 --- a/onnxruntime/test/python/transformers/test_shape_infer_helper.py +++ b/onnxruntime/test/python/transformers/test_shape_infer_helper.py @@ -1,16 +1,17 @@ import unittest -import pytest +import pytest from parity_utilities import find_transformers_source + if find_transformers_source(): - from onnx_exporter import export_onnx_model_from_pt + from benchmark_helper import OptimizerInfo, Precision from huggingface_models import MODELS - from benchmark_helper import Precision, OptimizerInfo + from onnx_exporter import export_onnx_model_from_pt from shape_infer_helper import SymbolicShapeInferenceHelper else: - from onnxruntime.transformers.onnx_exporter import export_onnx_model_from_pt + from onnxruntime.transformers.benchmark_helper import OptimizerInfo, Precision from onnxruntime.transformers.huggingface_models import MODELS - from onnxruntime.transformers.benchmark_helper import Precision, OptimizerInfo + from onnxruntime.transformers.onnx_exporter import export_onnx_model_from_pt from onnxruntime.transformers.shape_infer_helper import SymbolicShapeInferenceHelper @@ -19,15 +20,31 @@ def _load_onnx(self, model_name): input_names = MODELS[model_name][0] base_path = "../onnx_models/" import torch + with torch.no_grad(): - export_onnx_model_from_pt(model_name, MODELS[model_name][1], MODELS[model_name][2], MODELS[model_name][3], - None, '../cache_models', base_path, input_names[:1], False, Precision.FLOAT32, - OptimizerInfo.BYSCRIPT, True, True, False, {}) - model_path = base_path + model_name.replace('-', '_') + "_1.onnx" + export_onnx_model_from_pt( + model_name, + MODELS[model_name][1], + MODELS[model_name][2], + MODELS[model_name][3], + None, + "../cache_models", + base_path, + input_names[:1], + False, + Precision.FLOAT32, + OptimizerInfo.BYSCRIPT, + True, + True, + False, + {}, + ) + model_path = base_path + model_name.replace("-", "_") + "_1.onnx" import onnx + return onnx.load_model(model_path) - #TODO: use a static lightweight model for test + # TODO: use a static lightweight model for test @pytest.mark.slow def test_bert_shape_infer_helper(self): model = self._load_onnx("bert-base-cased") @@ -36,9 +53,15 @@ def test_bert_shape_infer_helper(self): self.assertEqual(shape_infer_helper.get_edge_shape("802"), []) self.assertEqual(shape_infer_helper.get_edge_shape("804"), [4, 16, 3072]) self.assertEqual(shape_infer_helper.get_edge_shape("1748"), [1]) - self.assertEqual(shape_infer_helper.get_edge_shape("encoder.layer.4.attention.output.LayerNorm.weight"), [768]) + self.assertEqual( + shape_infer_helper.get_edge_shape("encoder.layer.4.attention.output.LayerNorm.weight"), + [768], + ) self.assertEqual(shape_infer_helper.get_edge_shape("817"), [4, 16, 1]) - self.assertEqual(shape_infer_helper.get_edge_shape("encoder.layer.4.intermediate.dense.bias"), [3072]) + self.assertEqual( + shape_infer_helper.get_edge_shape("encoder.layer.4.intermediate.dense.bias"), + [3072], + ) self.assertEqual(shape_infer_helper.get_edge_shape("880"), [4, 12, 16, 16]) self.assertEqual(shape_infer_helper.compare_shape("329", "253"), False) @@ -47,5 +70,5 @@ def test_bert_shape_infer_helper(self): self.assertEqual(shape_infer_helper.compare_shape("447", "853"), False) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/onnxruntime/test/testdata/CNTK/gen.py b/onnxruntime/test/testdata/CNTK/gen.py index 246f45c81051b..db9022d3d50f8 100644 --- a/onnxruntime/test/testdata/CNTK/gen.py +++ b/onnxruntime/test/testdata/CNTK/gen.py @@ -1,23 +1,27 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. +import os + import cntk as C import numpy as np import onnx -import os from onnx import numpy_helper -model_file = 'model.onnx' -data_dir = 'test_data_set_0' +model_file = "model.onnx" +data_dir = "test_data_set_0" def SaveTensorProto(file_path, variable, data, name): # ONNX input shape always has sequence axis as the first dimension, if sequence axis exists if len(variable.dynamic_axes) == 2: - data = data.transpose(( - 1, - 0, - ) + tuple(range(2, len(data.shape)))) + data = data.transpose( + ( + 1, + 0, + ) + + tuple(range(2, len(data.shape))) + ) tp = numpy_helper.from_array(data, name if name else variable.uid) onnx.save_tensor(tp, file_path) @@ -26,8 +30,12 @@ def SaveData(test_data_dir, prefix, variables, data_list, name_replacements=None if isinstance(data_list, np.ndarray): data_list = [data_list] for (i, d), v in zip(enumerate(data_list), variables): - SaveTensorProto(os.path.join(test_data_dir, '{0}_{1}.pb'.format(prefix, i)), v, d, - name_replacements[v.uid] if name_replacements else None) + SaveTensorProto( + os.path.join(test_data_dir, "{0}_{1}.pb".format(prefix, i)), + v, + d, + name_replacements[v.uid] if name_replacements else None, + ) def Save(dir, func, feed, outputs): @@ -56,31 +64,41 @@ def Save(dir, func, feed, outputs): if not os.path.exists(test_data_dir): os.makedirs(test_data_dir) - SaveData(test_data_dir, 'input', func.arguments, [feed[var] for var in func.arguments], cntk_to_actual_names) - SaveData(test_data_dir, 'output', func.outputs, [outputs[var] for var in func.outputs]) + SaveData( + test_data_dir, + "input", + func.arguments, + [feed[var] for var in func.arguments], + cntk_to_actual_names, + ) + SaveData(test_data_dir, "output", func.outputs, [outputs[var] for var in func.outputs]) def GenSimple(): - x = C.input_variable(( - 1, - 3, - )) # TODO: fix CNTK exporter bug with shape (3,) + x = C.input_variable( + ( + 1, + 3, + ) + ) # TODO: fix CNTK exporter bug with shape (3,) y = C.layers.Embedding(2)(x) + C.parameter((-1,)) data_x = np.random.rand(1, *x.shape).astype(np.float32) data_y = y.eval(data_x) - Save('test_simple', y, data_x, data_y) + Save("test_simple", y, data_x, data_y) def GenSharedWeights(): - x = C.input_variable(( - 1, - 3, - )) + x = C.input_variable( + ( + 1, + 3, + ) + ) y = C.layers.Embedding(2)(x) y = y + y.parameters[0] data_x = np.random.rand(1, *x.shape).astype(np.float32) data_y = y.eval(data_x) - Save('test_shared_weights', y, data_x, data_y) + Save("test_shared_weights", y, data_x, data_y) def GenSimpleMNIST(): @@ -93,28 +111,36 @@ def GenSimpleMNIST(): scaled_input = C.element_times(C.constant(0.00390625, shape=(input_dim,)), feature) - z = C.layers.Sequential([ - C.layers.For(range(num_hidden_layers), lambda i: C.layers.Dense(hidden_layers_dim, activation=C.relu)), - C.layers.Dense(num_output_classes) - ])(scaled_input) + z = C.layers.Sequential( + [ + C.layers.For( + range(num_hidden_layers), + lambda i: C.layers.Dense(hidden_layers_dim, activation=C.relu), + ), + C.layers.Dense(num_output_classes), + ] + )(scaled_input) model = C.softmax(z) data_feature = np.random.rand(1, *feature.shape).astype(np.float32) data_output = model.eval(data_feature) - Save('test_simpleMNIST', model, data_feature, data_output) + Save("test_simpleMNIST", model, data_feature, data_output) def GenMatMul_1k(): - feature = C.input_variable(( - 1024, - 1024, - ), np.float32) + feature = C.input_variable( + ( + 1024, + 1024, + ), + np.float32, + ) model = C.times(feature, C.parameter((1024, 1024), init=C.glorot_uniform())) data_feature = np.random.rand(1, *feature.shape).astype(np.float32) data_output = model.eval(data_feature) - Save('test_MatMul_1k', model, data_feature, data_output) + Save("test_MatMul_1k", model, data_feature, data_output) def LSTM(cell_dim, use_scan=True): @@ -145,11 +171,11 @@ def GenLSTMx4(use_scan): lstm4 = C.layers.Recurrence(LSTM(512, use_scan))(lstm3) model = lstm4 - postfix = 'Scan' if use_scan else 'LSTM' + postfix = "Scan" if use_scan else "LSTM" data_feature = np.random.rand(1, 64, 128).astype(np.float32) data_output = np.asarray(model.eval(data_feature)) - Save('test_LSTMx4_' + postfix, model, data_feature, data_output) + Save("test_LSTMx4_" + postfix, model, data_feature, data_output) def GenScan(): @@ -160,13 +186,13 @@ def GenScan(): data_feature = np.random.rand(2, 5, 3).astype(np.float32) data_output = np.asarray(model.eval(data_feature)) - Save('test_Scan', model, data_feature, data_output) + Save("test_Scan", model, data_feature, data_output) # Currently CNTK only outputs batch == 1, do some editing - in_mp = onnx.load('test_Scan/model.onnx') + in_mp = onnx.load("test_Scan/model.onnx") out_mp = onnx.ModelProto() out_mp.CopyFrom(in_mp) - out_mp.graph.ClearField('initializer') + out_mp.graph.ClearField("initializer") # change LSTM init_c/h into inputs to support truncated sequence # as batch dimension is unknown on those data when building model @@ -179,7 +205,7 @@ def GenScan(): shape[0] = 2 aa = np.zeros(shape, dtype=np.float32) tp = numpy_helper.from_array(aa, i.name) - with open('test_Scan/test_data_set_0/input_' + str(num_inputs) + '.pb', 'wb') as ff: + with open("test_Scan/test_data_set_0/input_" + str(num_inputs) + ".pb", "wb") as ff: ff.write(tp.SerializeToString()) num_inputs = num_inputs + 1 else: @@ -187,16 +213,16 @@ def GenScan(): for vi in list(out_mp.graph.input) + list(out_mp.graph.output) + list(out_mp.graph.value_info): dim = vi.type.tensor_type.shape.dim - dim[len(dim) - 2].dim_param = 'batch' + dim[len(dim) - 2].dim_param = "batch" for n in out_mp.graph.node: - if n.op_type == 'Scan': - body = [attr for attr in n.attribute if attr.name == 'body'][0] + if n.op_type == "Scan": + body = [attr for attr in n.attribute if attr.name == "body"][0] for vi in list(body.g.input) + list(body.g.output) + list(body.g.value_info): dim = vi.type.tensor_type.shape.dim - dim[0].dim_param = 'batch' + dim[0].dim_param = "batch" - onnx.save(out_mp, 'test_Scan/model.onnx', 'wb') + onnx.save(out_mp, "test_Scan/model.onnx", "wb") def GenSimpleScan(): @@ -206,7 +232,7 @@ def GenSimpleScan(): model = C.sequence.reduce_sum(scan) data_feature = np.random.rand(1, 64, 128).astype(np.float32) data_output = np.asarray(model.eval(data_feature), dtype=np.float32) - Save('test_SimpleScan', model, data_feature, data_output) + Save("test_SimpleScan", model, data_feature, data_output) def GenGRU(): @@ -216,21 +242,31 @@ def GenGRU(): model = C.splice(gru_fw, gru_bw, axis=0) data_feature = np.random.rand(1, 16, 64).astype(np.float32) data_output = np.asarray(model.eval(data_feature)) - Save('test_GRU', model, data_feature, data_output) + Save("test_GRU", model, data_feature, data_output) def GenRNN(): feature = C.sequence.input_variable((64,), np.float32) - model = C.optimized_rnnstack(feature, C.parameter(( - C.InferredDimension, - 64, - ), init=C.glorot_uniform()), 128, 2, True, 'rnnReLU') + model = C.optimized_rnnstack( + feature, + C.parameter( + ( + C.InferredDimension, + 64, + ), + init=C.glorot_uniform(), + ), + 128, + 2, + True, + "rnnReLU", + ) data_feature = np.random.rand(1, 16, 64).astype(np.float32) data_output = np.asarray(model.eval(data_feature)) - Save('test_RNN', model, data_feature, data_output) + Save("test_RNN", model, data_feature, data_output) -if __name__ == '__main__': +if __name__ == "__main__": np.random.seed(0) GenSimple() GenSharedWeights() diff --git a/onnxruntime/test/testdata/capi_symbolic_dims.py b/onnxruntime/test/testdata/capi_symbolic_dims.py index 62f2bcbe422a6..69024717cf1ba 100644 --- a/onnxruntime/test/testdata/capi_symbolic_dims.py +++ b/onnxruntime/test/testdata/capi_symbolic_dims.py @@ -1,26 +1,25 @@ import onnx -from onnx import helper -from onnx import TensorProto -from onnx import shape_inference +from onnx import TensorProto, helper, shape_inference # create output with rank but unnamed symbolic dim -output = helper.make_tensor_value_info('C', TensorProto.FLOAT, [1]) +output = helper.make_tensor_value_info("C", TensorProto.FLOAT, [1]) output.type.tensor_type.shape.Clear() dim = output.type.tensor_type.shape.dim.add() print(dim) graph_def = helper.make_graph( nodes=[ - helper.make_node(op_type="Reshape", inputs=['A', 'B'], outputs=['C'], name='reshape'), + helper.make_node(op_type="Reshape", inputs=["A", "B"], outputs=["C"], name="reshape"), ], - name='test-model', + name="test-model", inputs=[ # create inputs with symbolic dims - helper.make_tensor_value_info("A", TensorProto.FLOAT, ['n', 2]), - helper.make_tensor_value_info("B", TensorProto.INT64, ['m']), + helper.make_tensor_value_info("A", TensorProto.FLOAT, ["n", 2]), + helper.make_tensor_value_info("B", TensorProto.INT64, ["m"]), ], outputs=[output], - initializer=[]) + initializer=[], +) model = helper.make_model(graph_def, opset_imports=[helper.make_operatorsetid("", 11)]) onnx.checker.check_model(model) @@ -31,6 +30,7 @@ onnx.save_model(model, "capi_symbolic_dims.onnx") import onnxruntime as rt + sess = rt.InferenceSession("capi_symbolic_dims.onnx") print([i.shape for i in sess.get_inputs()]) print([i.shape for i in sess.get_outputs()]) diff --git a/onnxruntime/test/testdata/coreml_argmax_cast_test.py b/onnxruntime/test/testdata/coreml_argmax_cast_test.py index 9656609becf0b..37e99d9b88ef6 100644 --- a/onnxruntime/test/testdata/coreml_argmax_cast_test.py +++ b/onnxruntime/test/testdata/coreml_argmax_cast_test.py @@ -1,6 +1,5 @@ import onnx -from onnx import helper -from onnx import TensorProto +from onnx import TensorProto, helper # CoreML EP currently handles a special case for supporting ArgMax op # Please see in /onnxruntime/core/providers/coreml/builders/impl/argmax_op_builder.cc and @@ -10,21 +9,19 @@ def GenerateModel(model_name): nodes = [ - helper.make_node("ArgMax", ["X"], [ - "argmax_output_int64"], "argmax", axis=1, keepdims=1), - helper.make_node("Cast", ["argmax_output_int64"], [ - "Y"], "cast", to=6), # cast to int32 type + helper.make_node("ArgMax", ["X"], ["argmax_output_int64"], "argmax", axis=1, keepdims=1), + helper.make_node("Cast", ["argmax_output_int64"], ["Y"], "cast", to=6), # cast to int32 type ] graph = helper.make_graph( nodes, "CoreML_ArgMax_Cast_Test", [ # input - helper.make_tensor_value_info('X', TensorProto.FLOAT, [3, 2, 2]), + helper.make_tensor_value_info("X", TensorProto.FLOAT, [3, 2, 2]), ], [ # output - helper.make_tensor_value_info('Y', TensorProto.INT32, [3, 1, 2]), - ] + helper.make_tensor_value_info("Y", TensorProto.INT32, [3, 1, 2]), + ], ) model = helper.make_model(graph) @@ -32,4 +29,4 @@ def GenerateModel(model_name): if __name__ == "__main__": - GenerateModel('coreml_argmax_cast_test.onnx') + GenerateModel("coreml_argmax_cast_test.onnx") diff --git a/onnxruntime/test/testdata/dynamic_quantize_matmul_test.py b/onnxruntime/test/testdata/dynamic_quantize_matmul_test.py index 79c999b648fb9..d681723810e65 100644 --- a/onnxruntime/test/testdata/dynamic_quantize_matmul_test.py +++ b/onnxruntime/test/testdata/dynamic_quantize_matmul_test.py @@ -1,52 +1,70 @@ -import onnx -from onnx import helper -from onnx import TensorProto from enum import Enum -def GenerateModel(model_name, sign, b_zp = True, bias = False): - nodes = [ # DynamicQuantizeMatMul subgraph - helper.make_node("DynamicQuantizeLinear", ["A"], ["a_quantized", "a_scale", "a_zp"], "DynamicQuantizeLinear"), +import onnx +from onnx import TensorProto, helper + +def GenerateModel(model_name, sign, b_zp=True, bias=False): + nodes = [ # DynamicQuantizeMatMul subgraph + helper.make_node( + "DynamicQuantizeLinear", + ["A"], + ["a_quantized", "a_scale", "a_zp"], + "DynamicQuantizeLinear", + ), helper.make_node( "MatMulInteger", ["a_quantized", "B", "a_zp", "b_zero_point"] if b_zp else ["a_quantized", "B", "a_zp"], ["matmul_output_int32"], - "MatMulInteger"), - + "MatMulInteger", + ), helper.make_node("Mul", ["a_scale", "b_scale"], ["multiplier"], "mul_right"), - helper.make_node("Cast", ["matmul_output_int32"], ["matmul_output_float"], "cast", to=1), - - helper.make_node("Mul", ["matmul_output_float", "multiplier"], ["mul_bottom_output" if bias else "Y"], "mul_bottom"), + helper.make_node( + "Mul", + ["matmul_output_float", "multiplier"], + ["mul_bottom_output" if bias else "Y"], + "mul_bottom", + ), ] inputs = [ - helper.make_tensor_value_info('A', TensorProto.FLOAT, ['M', 'K']), - helper.make_tensor_value_info('B', TensorProto.INT8 if sign else TensorProto.UINT8, ['K', 'N']), - helper.make_tensor_value_info('b_scale', TensorProto.FLOAT, ['C']), - ] + helper.make_tensor_value_info("A", TensorProto.FLOAT, ["M", "K"]), + helper.make_tensor_value_info("B", TensorProto.INT8 if sign else TensorProto.UINT8, ["K", "N"]), + helper.make_tensor_value_info("b_scale", TensorProto.FLOAT, ["C"]), + ] if b_zp: - inputs.extend([helper.make_tensor_value_info('b_zero_point', TensorProto.INT8 if sign else TensorProto.UINT8, ['C'])]) + inputs.extend( + [ + helper.make_tensor_value_info( + "b_zero_point", + TensorProto.INT8 if sign else TensorProto.UINT8, + ["C"], + ) + ] + ) if bias: nodes.extend([helper.make_node("Add", ["mul_bottom_output", "bias"], ["Y"], "add")]) - inputs.extend([helper.make_tensor_value_info('bias', TensorProto.FLOAT, ['N'])]) + inputs.extend([helper.make_tensor_value_info("bias", TensorProto.FLOAT, ["N"])]) graph = helper.make_graph( nodes, - "DynamicQuantizeMatMul_fusion", #name + "DynamicQuantizeMatMul_fusion", # name inputs, [ # outputs - helper.make_tensor_value_info('Y', TensorProto.FLOAT, ['M', 'N']), - ]) + helper.make_tensor_value_info("Y", TensorProto.FLOAT, ["M", "N"]), + ], + ) model = helper.make_model(graph) onnx.save(model, model_name) + if __name__ == "__main__": - GenerateModel('dynamic_quantize_matmul_int8.onnx', True) - GenerateModel('dynamic_quantize_matmul_uint8.onnx', False) - GenerateModel('dynamic_quantize_matmul_int8_bias.onnx', True, False, True) - GenerateModel('dynamic_quantize_matmul_uint8_bias.onnx', False, False, True) \ No newline at end of file + GenerateModel("dynamic_quantize_matmul_int8.onnx", True) + GenerateModel("dynamic_quantize_matmul_uint8.onnx", False) + GenerateModel("dynamic_quantize_matmul_int8_bias.onnx", True, False, True) + GenerateModel("dynamic_quantize_matmul_uint8_bias.onnx", False, False, True) diff --git a/onnxruntime/test/testdata/ep_dynamic_graph_input_test.py b/onnxruntime/test/testdata/ep_dynamic_graph_input_test.py index d04f8a8884d3d..b6ef5f78330f1 100644 --- a/onnxruntime/test/testdata/ep_dynamic_graph_input_test.py +++ b/onnxruntime/test/testdata/ep_dynamic_graph_input_test.py @@ -1,6 +1,5 @@ import onnx -from onnx import helper -from onnx import TensorProto +from onnx import TensorProto, helper # Since NNAPI EP does not support dynamic shape input and we now switch from the approach of immediately rejecting @@ -9,39 +8,39 @@ # Please see BaseOpBuilder::HasSupportedInputs in /onnxruntime/core/providers/nnapi/nnapi_builtin/builders/op_support_checker.cc def GenerateModel(model_name): nodes = [ - helper.make_node("Resize", ["X", "", "", "Resize_1_sizes"], [ - "Resize_1_output"], "resize_1", mode="cubic"), helper.make_node( - "Add", ["Resize_1_output", "Add_2_input"], ["Y"], "add"), + "Resize", + ["X", "", "", "Resize_1_sizes"], + ["Resize_1_output"], + "resize_1", + mode="cubic", + ), + helper.make_node("Add", ["Resize_1_output", "Add_2_input"], ["Y"], "add"), ] initializers = [ - helper.make_tensor('Resize_1_sizes', TensorProto.INT64, [ - 4], [1, 1, 3, 3]), - helper.make_tensor('Add_2_input', TensorProto.FLOAT, [1, 1, 3, 3], [ - 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0]) + helper.make_tensor("Resize_1_sizes", TensorProto.INT64, [4], [1, 1, 3, 3]), + helper.make_tensor( + "Add_2_input", + TensorProto.FLOAT, + [1, 1, 3, 3], + [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0], + ), ] inputs = [ - helper.make_tensor_value_info( - 'X', TensorProto.FLOAT, ["1", "1", "N", "N"]), # used dim_param here + helper.make_tensor_value_info("X", TensorProto.FLOAT, ["1", "1", "N", "N"]), # used dim_param here ] outputs = [ - helper.make_tensor_value_info('Y', TensorProto.FLOAT, [1, 1, 3, 3]), + helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 1, 3, 3]), ] - graph = helper.make_graph( - nodes, - "EP_Dynamic_Graph_Input_Test", - inputs, - outputs, - initializers - ) + graph = helper.make_graph(nodes, "EP_Dynamic_Graph_Input_Test", inputs, outputs, initializers) model = helper.make_model(graph) onnx.save(model, model_name) if __name__ == "__main__": - GenerateModel('ep_dynamic_graph_input_test.onnx') + GenerateModel("ep_dynamic_graph_input_test.onnx") diff --git a/onnxruntime/test/testdata/ep_partitioning_tests.py b/onnxruntime/test/testdata/ep_partitioning_tests.py index fe68f6cc5e303..a85b9bda6c187 100644 --- a/onnxruntime/test/testdata/ep_partitioning_tests.py +++ b/onnxruntime/test/testdata/ep_partitioning_tests.py @@ -1,7 +1,6 @@ import numpy as np import onnx -from onnx import helper -from onnx import TensorProto +from onnx import TensorProto, helper # Create graph with Add and Sub nodes that can be used to test partitioning when one of the operators @@ -32,29 +31,26 @@ def create_model_1(): # So if this model is loaded in a partitioning test, there should only be one partition running on the EP regardless # of whether Add or Sub is supported by it. graph = helper.make_graph( - nodes= - [ - helper.make_node("Add", ['input0', 'input1'], ['1'], "A1"), - helper.make_node("Sub", ['input0', 'input1'], ["2"], "S1"), - helper.make_node("Add", ['2', 'input1'], ['3_out'], "A2"), - helper.make_node("Sub", ['2', 'input1'], ['4'], "S2"), - helper.make_node("Add", ['1', '4'], ['5_out'], "A3"), - helper.make_node("Add", ['input1', 'input2'], ['6_out'], "A4"), + nodes=[ + helper.make_node("Add", ["input0", "input1"], ["1"], "A1"), + helper.make_node("Sub", ["input0", "input1"], ["2"], "S1"), + helper.make_node("Add", ["2", "input1"], ["3_out"], "A2"), + helper.make_node("Sub", ["2", "input1"], ["4"], "S2"), + helper.make_node("Add", ["1", "4"], ["5_out"], "A3"), + helper.make_node("Add", ["input1", "input2"], ["6_out"], "A4"), ], name="graph", - inputs= - [ - helper.make_tensor_value_info('input0', TensorProto.INT64, [1]), - helper.make_tensor_value_info('input1', TensorProto.INT64, [1]), - helper.make_tensor_value_info('input2', TensorProto.INT64, [1]), + inputs=[ + helper.make_tensor_value_info("input0", TensorProto.INT64, [1]), + helper.make_tensor_value_info("input1", TensorProto.INT64, [1]), + helper.make_tensor_value_info("input2", TensorProto.INT64, [1]), ], - outputs= - [ - helper.make_tensor_value_info('3_out', TensorProto.INT64, [1]), - helper.make_tensor_value_info('5_out', TensorProto.INT64, [1]), - helper.make_tensor_value_info('6_out', TensorProto.INT64, [1]), + outputs=[ + helper.make_tensor_value_info("3_out", TensorProto.INT64, [1]), + helper.make_tensor_value_info("5_out", TensorProto.INT64, [1]), + helper.make_tensor_value_info("6_out", TensorProto.INT64, [1]), ], - initializer=[] + initializer=[], ) model = helper.make_model(graph) @@ -79,38 +75,35 @@ def create_model_2(): # \ / # a7 graph = helper.make_graph( - nodes= - [ - helper.make_node("Sub", ['input0', 'input1'], ['s1_out'], "S1"), - helper.make_node("Add", ['s1_out', 'input2'], ['a1_out'], "A1"), - helper.make_node("Add", ['a1_out', 'input0'], ['a2_out'], "A2"), - helper.make_node("Sub", ['a1_out', 'input1'], ['s2_out'], "S2"), - helper.make_node("Add", ['a2_out', 'input2'], ['a3_out'], "A3"), - helper.make_node("Add", ['a2_out', 's2_out'], ['a4_out'], "A4"), - helper.make_node("Add", ['a3_out', 'input0'], ['a5_out'], "A5"), - helper.make_node("Add", ['a4_out', 'input1'], ['a6_out'], "A6"), - helper.make_node("Add", ['a5_out', 'a6_out'], ['a7_out'], "A7"), + nodes=[ + helper.make_node("Sub", ["input0", "input1"], ["s1_out"], "S1"), + helper.make_node("Add", ["s1_out", "input2"], ["a1_out"], "A1"), + helper.make_node("Add", ["a1_out", "input0"], ["a2_out"], "A2"), + helper.make_node("Sub", ["a1_out", "input1"], ["s2_out"], "S2"), + helper.make_node("Add", ["a2_out", "input2"], ["a3_out"], "A3"), + helper.make_node("Add", ["a2_out", "s2_out"], ["a4_out"], "A4"), + helper.make_node("Add", ["a3_out", "input0"], ["a5_out"], "A5"), + helper.make_node("Add", ["a4_out", "input1"], ["a6_out"], "A6"), + helper.make_node("Add", ["a5_out", "a6_out"], ["a7_out"], "A7"), ], name="graph", - inputs= - [ - helper.make_tensor_value_info('input0', TensorProto.INT64, [1]), - helper.make_tensor_value_info('input1', TensorProto.INT64, [1]), - helper.make_tensor_value_info('input2', TensorProto.INT64, [1]), + inputs=[ + helper.make_tensor_value_info("input0", TensorProto.INT64, [1]), + helper.make_tensor_value_info("input1", TensorProto.INT64, [1]), + helper.make_tensor_value_info("input2", TensorProto.INT64, [1]), ], - outputs= - [ - helper.make_tensor_value_info('a7_out', TensorProto.INT64, [1]), + outputs=[ + helper.make_tensor_value_info("a7_out", TensorProto.INT64, [1]), ], - initializer=[] + initializer=[], ) model = helper.make_model(graph) return model -if __name__ == '__main__': +if __name__ == "__main__": model = create_model_1() - onnx.save(model, 'ep_partitioning_test_1.onnx') + onnx.save(model, "ep_partitioning_test_1.onnx") model = create_model_2() - onnx.save(model, 'ep_partitioning_test_2.onnx') + onnx.save(model, "ep_partitioning_test_2.onnx") diff --git a/onnxruntime/test/testdata/matmul_integer_to_float.py b/onnxruntime/test/testdata/matmul_integer_to_float.py index 13b38bb11dc2a..6b126fb3a2a1f 100644 --- a/onnxruntime/test/testdata/matmul_integer_to_float.py +++ b/onnxruntime/test/testdata/matmul_integer_to_float.py @@ -1,58 +1,73 @@ -import onnx -from onnx import helper -from onnx import TensorProto from enum import Enum -def GenerateModel(model_name, sign_i, sign_w, has_zp = True, bias = False): +import onnx +from onnx import TensorProto, helper + + +def GenerateModel(model_name, sign_i, sign_w, has_zp=True, bias=False): nodes = [ # subgraph helper.make_node( "MatMulInteger", ["A", "B", "a_zero_point", "b_zero_point"] if has_zp else ["A", "B"], ["matmul_output_int32"], - "MatMulInteger"), - + "MatMulInteger", + ), helper.make_node("Mul", ["a_scale", "b_scale"], ["multiplier"], "mul_right"), - helper.make_node("Cast", ["matmul_output_int32"], ["matmul_output_float"], "cast", to=1), - - helper.make_node("Mul", ["matmul_output_float", "multiplier"], ["mul_bottom_output" if bias else "Y"], "mul_bottom"), + helper.make_node( + "Mul", + ["matmul_output_float", "multiplier"], + ["mul_bottom_output" if bias else "Y"], + "mul_bottom", + ), ] inputs = [ # inputs - helper.make_tensor_value_info('A', TensorProto.INT8 if sign_i else TensorProto.UINT8, ['M', 'K']), - helper.make_tensor_value_info('B', TensorProto.INT8 if sign_w else TensorProto.UINT8, ['K', 'N']), - helper.make_tensor_value_info('a_scale', TensorProto.FLOAT, [1]), - helper.make_tensor_value_info('b_scale', TensorProto.FLOAT, ['C']), - - ] + helper.make_tensor_value_info("A", TensorProto.INT8 if sign_i else TensorProto.UINT8, ["M", "K"]), + helper.make_tensor_value_info("B", TensorProto.INT8 if sign_w else TensorProto.UINT8, ["K", "N"]), + helper.make_tensor_value_info("a_scale", TensorProto.FLOAT, [1]), + helper.make_tensor_value_info("b_scale", TensorProto.FLOAT, ["C"]), + ] if has_zp: - inputs.extend([ - helper.make_tensor_value_info('a_zero_point', TensorProto.INT8 if sign_i else TensorProto.UINT8, [1]), - helper.make_tensor_value_info('b_zero_point', TensorProto.INT8 if sign_w else TensorProto.UINT8, ['C']), - ]) + inputs.extend( + [ + helper.make_tensor_value_info( + "a_zero_point", + TensorProto.INT8 if sign_i else TensorProto.UINT8, + [1], + ), + helper.make_tensor_value_info( + "b_zero_point", + TensorProto.INT8 if sign_w else TensorProto.UINT8, + ["C"], + ), + ] + ) if bias: nodes.extend([helper.make_node("Add", ["mul_bottom_output", "bias"], ["Y"], "add")]) - inputs.extend([helper.make_tensor_value_info('bias', TensorProto.FLOAT, ['N'])]) + inputs.extend([helper.make_tensor_value_info("bias", TensorProto.FLOAT, ["N"])]) graph = helper.make_graph( nodes, - "DynamicQuantizeMatMul_fusion", #name + "DynamicQuantizeMatMul_fusion", # name inputs, [ # outputs - helper.make_tensor_value_info('Y', TensorProto.FLOAT, ['M', 'N']), - ]) + helper.make_tensor_value_info("Y", TensorProto.FLOAT, ["M", "N"]), + ], + ) model = helper.make_model(graph) onnx.save(model, model_name) + if __name__ == "__main__": - GenerateModel('matmul_integer_to_float_int8.onnx', False, True) - GenerateModel('matmul_integer_to_float_uint8.onnx', False, False) - GenerateModel('matmul_integer_to_float_int8_bias.onnx', False, True, False, True) - GenerateModel('matmul_integer_to_float_uint8_bias.onnx', False, False, False, True) + GenerateModel("matmul_integer_to_float_int8.onnx", False, True) + GenerateModel("matmul_integer_to_float_uint8.onnx", False, False) + GenerateModel("matmul_integer_to_float_int8_bias.onnx", False, True, False, True) + GenerateModel("matmul_integer_to_float_uint8_bias.onnx", False, False, False, True) - GenerateModel('matmul_integer_to_float_int8_int8.onnx', True, True) - GenerateModel('matmul_integer_to_float_int8_int8_bias.onnx', True, True, False, True) \ No newline at end of file + GenerateModel("matmul_integer_to_float_int8_int8.onnx", True, True) + GenerateModel("matmul_integer_to_float_int8_int8_bias.onnx", True, True, False, True) diff --git a/onnxruntime/test/testdata/model_with_external_initializer_come_from_user.py b/onnxruntime/test/testdata/model_with_external_initializer_come_from_user.py index baad806b7e85a..b6b622e20e248 100644 --- a/onnxruntime/test/testdata/model_with_external_initializer_come_from_user.py +++ b/onnxruntime/test/testdata/model_with_external_initializer_come_from_user.py @@ -1,12 +1,11 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. -import onnx import numpy as np -from onnx import helper -from onnx import TensorProto -from onnx.numpy_helper import from_array +import onnx +from onnx import TensorProto, helper from onnx.external_data_helper import set_external_data +from onnx.numpy_helper import from_array def create_external_data_tensor(value, tensor_name): # type: (List[Any], Text) -> TensorProto @@ -14,27 +13,27 @@ def create_external_data_tensor(value, tensor_name): # type: (List[Any], Text) tensor.name = tensor_name tensor_filename = "{}.bin".format(tensor_name) set_external_data(tensor, location=tensor_filename) - tensor.ClearField('raw_data') + tensor.ClearField("raw_data") tensor.data_location = onnx.TensorProto.EXTERNAL return tensor def GenerateModel(model_name): # Create one input (ValueInfoProto) - X = helper.make_tensor_value_info('X', TensorProto.FLOAT, [1, 2]) + X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [1, 2]) # Create second input (ValueInfoProto) - Pads = helper.make_tensor_value_info('Pads_not_on_disk', TensorProto.INT64, [4]) + Pads = helper.make_tensor_value_info("Pads_not_on_disk", TensorProto.INT64, [4]) # Create one output (ValueInfoProto) - Y = helper.make_tensor_value_info('Y', TensorProto.FLOAT, [1, 4]) + Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 4]) # Create a node (NodeProto) node_def = helper.make_node( - 'Pad', # node name - ['X', 'Pads_not_on_disk'], # inputs - ['Y'], # outputs - mode='constant', # Attributes + "Pad", # node name + ["X", "Pads_not_on_disk"], # inputs + ["Y"], # outputs + mode="constant", # Attributes ) # Create the graph (GraphProto) @@ -44,19 +43,18 @@ def GenerateModel(model_name): "test-model", [X, Pads], [Y], - [create_external_data_tensor(initializer_data, "Pads_not_on_disk")] + [create_external_data_tensor(initializer_data, "Pads_not_on_disk")], ) # Create the model (ModelProto) - model_def = helper.make_model(graph_def, - producer_name='onnx-example') + model_def = helper.make_model(graph_def, producer_name="onnx-example") - print('The ir_version in model: {}\n'.format(model_def.ir_version)) - print('The producer_name in model: {}\n'.format(model_def.producer_name)) - print('The graph in model:\n{}'.format(model_def.graph)) + print("The ir_version in model: {}\n".format(model_def.ir_version)) + print("The producer_name in model: {}\n".format(model_def.producer_name)) + print("The graph in model:\n{}".format(model_def.graph)) with open(model_name, "wb") as model_file: model_file.write(model_def.SerializeToString()) if __name__ == "__main__": - GenerateModel('model_with_external_initializer_come_from_user.onnx') + GenerateModel("model_with_external_initializer_come_from_user.onnx") diff --git a/onnxruntime/test/testdata/model_with_external_initializers.py b/onnxruntime/test/testdata/model_with_external_initializers.py index a5ca89fda7864..8b591549963fd 100644 --- a/onnxruntime/test/testdata/model_with_external_initializers.py +++ b/onnxruntime/test/testdata/model_with_external_initializers.py @@ -1,12 +1,11 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. -import onnx import numpy as np -from onnx import helper -from onnx import TensorProto -from onnx.numpy_helper import from_array +import onnx +from onnx import TensorProto, helper from onnx.external_data_helper import set_external_data +from onnx.numpy_helper import from_array def create_external_data_tensor(value, tensor_name): # type: (List[Any], Text) -> TensorProto @@ -15,29 +14,29 @@ def create_external_data_tensor(value, tensor_name): # type: (List[Any], Text) tensor_filename = "{}.bin".format(tensor_name) set_external_data(tensor, location=tensor_filename) - with open(os.path.join(tensor_filename), 'wb') as data_file: + with open(os.path.join(tensor_filename), "wb") as data_file: data_file.write(tensor.raw_data) - tensor.ClearField('raw_data') + tensor.ClearField("raw_data") tensor.data_location = onnx.TensorProto.EXTERNAL return tensor def GenerateModel(model_name): # Create one input (ValueInfoProto) - X = helper.make_tensor_value_info('X', TensorProto.FLOAT, [1, 2]) + X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [1, 2]) # Create second input (ValueInfoProto) - Pads = helper.make_tensor_value_info('Pads', TensorProto.INT64, [4]) + Pads = helper.make_tensor_value_info("Pads", TensorProto.INT64, [4]) # Create one output (ValueInfoProto) - Y = helper.make_tensor_value_info('Y', TensorProto.FLOAT, [1, 4]) + Y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 4]) # Create a node (NodeProto) node_def = helper.make_node( - 'Pad', # node name - ['X', 'Pads'], # inputs - ['Y'], # outputs - mode='constant', # Attributes + "Pad", # node name + ["X", "Pads"], # inputs + ["Y"], # outputs + mode="constant", # Attributes ) # Create the graph (GraphProto) @@ -46,21 +45,30 @@ def GenerateModel(model_name): "test-model", [X, Pads], [Y], - [create_external_data_tensor([0, 0, 1, 1, ], "Pads")] + [ + create_external_data_tensor( + [ + 0, + 0, + 1, + 1, + ], + "Pads", + ) + ], ) # Create the model (ModelProto) - model_def = helper.make_model(graph_def, - producer_name='onnx-example') + model_def = helper.make_model(graph_def, producer_name="onnx-example") - print('The ir_version in model: {}\n'.format(model_def.ir_version)) - print('The producer_name in model: {}\n'.format(model_def.producer_name)) - print('The graph in model:\n{}'.format(model_def.graph)) + print("The ir_version in model: {}\n".format(model_def.ir_version)) + print("The producer_name in model: {}\n".format(model_def.producer_name)) + print("The graph in model:\n{}".format(model_def.graph)) onnx.checker.check_model(model_def) - print('The model is checked!') + print("The model is checked!") with open(model_name, "wb") as model_file: model_file.write(model_def.SerializeToString()) if __name__ == "__main__": - GenerateModel('model_with_external_initializers.onnx') + GenerateModel("model_with_external_initializers.onnx") diff --git a/onnxruntime/test/testdata/model_with_metadata.py b/onnxruntime/test/testdata/model_with_metadata.py index 0d085d6deb34d..7e01b6a38f54c 100644 --- a/onnxruntime/test/testdata/model_with_metadata.py +++ b/onnxruntime/test/testdata/model_with_metadata.py @@ -1,6 +1,5 @@ import onnx -from onnx import helper -from onnx import TensorProto +from onnx import TensorProto, helper # Create a model with metadata to test ORT conversion @@ -12,27 +11,27 @@ def GenerateModel(model_name): graph = helper.make_graph( nodes, "NNAPI_Internal_uint8_Test", - [helper.make_tensor_value_info('X', TensorProto.FLOAT, [1, 3])], - [helper.make_tensor_value_info('Y', TensorProto.FLOAT, [1, 3])], + [helper.make_tensor_value_info("X", TensorProto.FLOAT, [1, 3])], + [helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 3])], ) model = helper.make_model(graph) # Add meta data - model.doc_string = 'This is doc_string' - model.producer_name = 'TensorTorch' + model.doc_string = "This is doc_string" + model.producer_name = "TensorTorch" model.model_version = 12345 - model.domain = 'ai.onnx.ml' + model.domain = "ai.onnx.ml" helper.set_model_props( model, { - 'I am key 1!': 'I am value 1!', - '': 'Value for empty key!', - 'Key for empty value!': '', - } + "I am key 1!": "I am value 1!", + "": "Value for empty key!", + "Key for empty value!": "", + }, ) onnx.save(model, model_name) if __name__ == "__main__": - GenerateModel('model_with_metadata.onnx') + GenerateModel("model_with_metadata.onnx") diff --git a/onnxruntime/test/testdata/nnapi_internal_uint8_support.py b/onnxruntime/test/testdata/nnapi_internal_uint8_support.py index d7d2c18cf57e6..0956ba3bd13fa 100644 --- a/onnxruntime/test/testdata/nnapi_internal_uint8_support.py +++ b/onnxruntime/test/testdata/nnapi_internal_uint8_support.py @@ -1,6 +1,5 @@ import onnx -from onnx import helper -from onnx import TensorProto +from onnx import TensorProto, helper # This is to test the operators without "Qlinear" support but still support uint8 input @@ -8,28 +7,56 @@ # def GenerateModel(model_name): def GenerateModel(model_name): nodes = [ - helper.make_node("QuantizeLinear", ["X", "Scale", "Zero_point"], ["X_quantized"], "quantize_0"), - helper.make_node("Concat", ["X_quantized", "X_quantized"], ["X_concat"], axis=-2, name="concat_0"), - helper.make_node("MaxPool", ["X_concat"], ["X_maxpool"], kernel_shape=[2, 2], name="maxpool_0"), - helper.make_node("Transpose", ["X_maxpool"], ["X_transposed"], perm=[0, 1, 3, 2], name="transpose_0"), - helper.make_node("DequantizeLinear", ["X_transposed", "Scale", "Zero_point"], ["Y"], "dequantize_0"), + helper.make_node( + "QuantizeLinear", + ["X", "Scale", "Zero_point"], + ["X_quantized"], + "quantize_0", + ), + helper.make_node( + "Concat", + ["X_quantized", "X_quantized"], + ["X_concat"], + axis=-2, + name="concat_0", + ), + helper.make_node( + "MaxPool", + ["X_concat"], + ["X_maxpool"], + kernel_shape=[2, 2], + name="maxpool_0", + ), + helper.make_node( + "Transpose", + ["X_maxpool"], + ["X_transposed"], + perm=[0, 1, 3, 2], + name="transpose_0", + ), + helper.make_node( + "DequantizeLinear", + ["X_transposed", "Scale", "Zero_point"], + ["Y"], + "dequantize_0", + ), ] initializers = [ - helper.make_tensor('Scale', TensorProto.FLOAT, [1], [256.0]), - helper.make_tensor('Zero_point', TensorProto.UINT8, [1], [0]), + helper.make_tensor("Scale", TensorProto.FLOAT, [1], [256.0]), + helper.make_tensor("Zero_point", TensorProto.UINT8, [1], [0]), ] inputs = [ - helper.make_tensor_value_info('X', TensorProto.FLOAT, [1, 1, 1, 3]), + helper.make_tensor_value_info("X", TensorProto.FLOAT, [1, 1, 1, 3]), ] graph = helper.make_graph( nodes, "NNAPI_Internal_uint8_Test", inputs, - [helper.make_tensor_value_info('Y', TensorProto.FLOAT, [1, 1, 2, 1])], - initializers + [helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 1, 2, 1])], + initializers, ) model = helper.make_model(graph) @@ -37,4 +64,4 @@ def GenerateModel(model_name): if __name__ == "__main__": - GenerateModel('nnapi_internal_uint8_support.onnx') + GenerateModel("nnapi_internal_uint8_support.onnx") diff --git a/onnxruntime/test/testdata/nnapi_reshape_flatten_test.py b/onnxruntime/test/testdata/nnapi_reshape_flatten_test.py index 40050a04efc5b..27cfd1304f392 100644 --- a/onnxruntime/test/testdata/nnapi_reshape_flatten_test.py +++ b/onnxruntime/test/testdata/nnapi_reshape_flatten_test.py @@ -1,6 +1,5 @@ import onnx -from onnx import helper -from onnx import TensorProto +from onnx import TensorProto, helper # Since NNAPI EP handles Reshape and Flatten differently, @@ -18,23 +17,28 @@ def GenerateModel(model_name): ] initializers = [ - helper.make_tensor('Reshape_1_shape', TensorProto.INT64, [2], [3, 4]), - helper.make_tensor('Reshape_2_shape', TensorProto.INT64, [2], [1, 6]), - helper.make_tensor('Gemm_B', TensorProto.FLOAT, [4, 2], [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]), - helper.make_tensor('MatMul_B', TensorProto.FLOAT, [2, 3], [1.0, 2.0, 3.0, 4.0, 5.0, 6.0]), + helper.make_tensor("Reshape_1_shape", TensorProto.INT64, [2], [3, 4]), + helper.make_tensor("Reshape_2_shape", TensorProto.INT64, [2], [1, 6]), + helper.make_tensor( + "Gemm_B", + TensorProto.FLOAT, + [4, 2], + [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], + ), + helper.make_tensor("MatMul_B", TensorProto.FLOAT, [2, 3], [1.0, 2.0, 3.0, 4.0, 5.0, 6.0]), ] inputs = [ - helper.make_tensor_value_info('X', TensorProto.FLOAT, [2, 1, 2]), - helper.make_tensor_value_info('Y', TensorProto.FLOAT, [3, 2, 2]) + helper.make_tensor_value_info("X", TensorProto.FLOAT, [2, 1, 2]), + helper.make_tensor_value_info("Y", TensorProto.FLOAT, [3, 2, 2]), ] graph = helper.make_graph( nodes, "NNAPI_Reshape_Flatten_Test", inputs, - [helper.make_tensor_value_info('Z', TensorProto.FLOAT, [1, 6])], - initializers + [helper.make_tensor_value_info("Z", TensorProto.FLOAT, [1, 6])], + initializers, ) model = helper.make_model(graph) @@ -42,4 +46,4 @@ def GenerateModel(model_name): if __name__ == "__main__": - GenerateModel('nnapi_reshape_flatten_test.onnx') + GenerateModel("nnapi_reshape_flatten_test.onnx") diff --git a/onnxruntime/test/testdata/ort_github_issue_10305.py b/onnxruntime/test/testdata/ort_github_issue_10305.py index e8d4829448b7a..92f90d9bd5edb 100644 --- a/onnxruntime/test/testdata/ort_github_issue_10305.py +++ b/onnxruntime/test/testdata/ort_github_issue_10305.py @@ -1,6 +1,5 @@ import onnx -from onnx import helper -from onnx import TensorProto +from onnx import TensorProto, helper # Loop is so the Tranpose output is used in a subgraph loop_body = helper.make_graph( @@ -9,50 +8,65 @@ ], "Loop_body", [ - helper.make_tensor_value_info('iteration_num', TensorProto.INT64, [1]), - helper.make_tensor_value_info('subgraph_keep_going_in', TensorProto.BOOL, [1]), - helper.make_tensor_value_info('loop_state_in', TensorProto.FLOAT, [1]) + helper.make_tensor_value_info("iteration_num", TensorProto.INT64, [1]), + helper.make_tensor_value_info("subgraph_keep_going_in", TensorProto.BOOL, [1]), + helper.make_tensor_value_info("loop_state_in", TensorProto.FLOAT, [1]), ], [ - helper.make_tensor_value_info('subgraph_keep_going_in', TensorProto.BOOL, [1]), - helper.make_tensor_value_info('loop_state_out', TensorProto.FLOAT, [2, 2, 2]), + helper.make_tensor_value_info("subgraph_keep_going_in", TensorProto.BOOL, [1]), + helper.make_tensor_value_info("loop_state_out", TensorProto.FLOAT, [2, 2, 2]), ], - [ - ] + [], ) # Create the main graph graph_proto = helper.make_graph( [ # add a Transpose that can be moved past the Slice - helper.make_node('Transpose', inputs=['input:0'], outputs=['transpose:0'], name='transpose0', perm=[1, 0, 2]), - helper.make_node('Slice', - inputs=['transpose:0', 'start', 'end'], - outputs=['strided_slice:0'], name='slice0'), - helper.make_node('Squeeze', - inputs=['strided_slice:0', 'start'], - outputs=['out:0'], - name='squeeze0'), - helper.make_node("Loop", ["max_trip_count", "subgraph_keep_going_in", "state_var_in"], ["out:1"], "Loop1", - body=loop_body) + helper.make_node( + "Transpose", + inputs=["input:0"], + outputs=["transpose:0"], + name="transpose0", + perm=[1, 0, 2], + ), + helper.make_node( + "Slice", + inputs=["transpose:0", "start", "end"], + outputs=["strided_slice:0"], + name="slice0", + ), + helper.make_node( + "Squeeze", + inputs=["strided_slice:0", "start"], + outputs=["out:0"], + name="squeeze0", + ), + helper.make_node( + "Loop", + ["max_trip_count", "subgraph_keep_going_in", "state_var_in"], + ["out:1"], + "Loop1", + body=loop_body, + ), ], "Main_graph", [ - helper.make_tensor_value_info('input:0', TensorProto.FLOAT, [2, 2, 2]), - helper.make_tensor_value_info('state_var_in', TensorProto.FLOAT, [1]), + helper.make_tensor_value_info("input:0", TensorProto.FLOAT, [2, 2, 2]), + helper.make_tensor_value_info("state_var_in", TensorProto.FLOAT, [1]), ], [ - helper.make_tensor_value_info('out:0', TensorProto.FLOAT, [2, 2]), - helper.make_tensor_value_info('out:1', TensorProto.FLOAT, [2, 2, 2]), + helper.make_tensor_value_info("out:0", TensorProto.FLOAT, [2, 2]), + helper.make_tensor_value_info("out:1", TensorProto.FLOAT, [2, 2, 2]), ], [ - helper.make_tensor('start', TensorProto.INT64, [1], [0]), - helper.make_tensor('end', TensorProto.INT64, [1], [1]), - helper.make_tensor('max_trip_count', TensorProto.INT64, [1], [1]), - helper.make_tensor('subgraph_keep_going_in', TensorProto.BOOL, [1], [1]), - ] + helper.make_tensor("start", TensorProto.INT64, [1], [0]), + helper.make_tensor("end", TensorProto.INT64, [1], [1]), + helper.make_tensor("max_trip_count", TensorProto.INT64, [1], [1]), + helper.make_tensor("subgraph_keep_going_in", TensorProto.BOOL, [1], [1]), + ], ) model = helper.make_model(graph_proto) onnx.checker.check_model(model, True) -onnx.save(model, 'ort_github_issue_10305.onnx') \ No newline at end of file +onnx.save(model, "ort_github_issue_10305.onnx") diff --git a/onnxruntime/test/testdata/ort_github_issue_4031.py b/onnxruntime/test/testdata/ort_github_issue_4031.py index f991b410242bb..0f5da031f3638 100644 --- a/onnxruntime/test/testdata/ort_github_issue_4031.py +++ b/onnxruntime/test/testdata/ort_github_issue_4031.py @@ -1,6 +1,5 @@ import onnx -from onnx import helper -from onnx import TensorProto +from onnx import TensorProto, helper if_body = helper.make_graph( [ @@ -13,54 +12,69 @@ # no explicit inputs ], [ - helper.make_tensor_value_info('output', TensorProto.BOOL, [1]), # how is this getting a type of float? - ]) + helper.make_tensor_value_info("output", TensorProto.BOOL, [1]), # how is this getting a type of float? + ], +) # Loop body graph with If node and usage of main_graph_initializer on this level body = helper.make_graph( [ # Add node that can be constant folded. Creates NodeArg when created but that implicit usage of an outer scope # value main_graph_initializer goes away after constant folding - helper.make_node("Add", ["sub_graph_initializer", "main_graph_initializer"], ["initializer_sum"], "Add1"), + helper.make_node( + "Add", + ["sub_graph_initializer", "main_graph_initializer"], + ["initializer_sum"], + "Add1", + ), helper.make_node("Add", ["initializer_sum", "loop_state_in"], ["loop_state_out"], "Add2"), # If node to create usage of main_graph_initializer another level down - helper.make_node("If", ["subgraph_keep_going_in"], ["subgraph_keep_going_out"], "If1", - then_branch=if_body, else_branch=if_body), + helper.make_node( + "If", + ["subgraph_keep_going_in"], + ["subgraph_keep_going_out"], + "If1", + then_branch=if_body, + else_branch=if_body, + ), ], "Loop_body", [ - helper.make_tensor_value_info('iteration_num', TensorProto.INT64, [1]), - helper.make_tensor_value_info('subgraph_keep_going_in', TensorProto.BOOL, [1]), - helper.make_tensor_value_info('loop_state_in', TensorProto.FLOAT, [1]) + helper.make_tensor_value_info("iteration_num", TensorProto.INT64, [1]), + helper.make_tensor_value_info("subgraph_keep_going_in", TensorProto.BOOL, [1]), + helper.make_tensor_value_info("loop_state_in", TensorProto.FLOAT, [1]), ], [ - helper.make_tensor_value_info('subgraph_keep_going_out', TensorProto.BOOL, [1]), - helper.make_tensor_value_info('loop_state_out', TensorProto.FLOAT, [1]), + helper.make_tensor_value_info("subgraph_keep_going_out", TensorProto.BOOL, [1]), + helper.make_tensor_value_info("loop_state_out", TensorProto.FLOAT, [1]), ], - [ - helper.make_tensor('sub_graph_initializer', TensorProto.FLOAT, [1], [1.]) - ] + [helper.make_tensor("sub_graph_initializer", TensorProto.FLOAT, [1], [1.0])], ) # Create the main graph graph_proto = helper.make_graph( [ - helper.make_node("Loop", ["max_trip_count", "keep_going", "state_var_in"], - ["state_var_out"], "Loop1", body=body) + helper.make_node( + "Loop", + ["max_trip_count", "keep_going", "state_var_in"], + ["state_var_out"], + "Loop1", + body=body, + ) ], "Main_graph", [ - helper.make_tensor_value_info('state_var_in', TensorProto.FLOAT, [1]), + helper.make_tensor_value_info("state_var_in", TensorProto.FLOAT, [1]), ], [ - helper.make_tensor_value_info('state_var_out', TensorProto.FLOAT, [1]), + helper.make_tensor_value_info("state_var_out", TensorProto.FLOAT, [1]), ], [ - helper.make_tensor('max_trip_count', TensorProto.INT64, [1], [1]), - helper.make_tensor('main_graph_initializer', TensorProto.FLOAT, [1], [1.]), - helper.make_tensor('keep_going', TensorProto.BOOL, [1], [True]), - ] + helper.make_tensor("max_trip_count", TensorProto.INT64, [1], [1]), + helper.make_tensor("main_graph_initializer", TensorProto.FLOAT, [1], [1.0]), + helper.make_tensor("keep_going", TensorProto.BOOL, [1], [True]), + ], ) model = helper.make_model(graph_proto) -onnx.save(model, 'ort_github_issue_4031.onnx') \ No newline at end of file +onnx.save(model, "ort_github_issue_4031.onnx") diff --git a/onnxruntime/test/testdata/sparse_initializer_as_output.py b/onnxruntime/test/testdata/sparse_initializer_as_output.py index 8b6c964b44f00..741ed6439e815 100644 --- a/onnxruntime/test/testdata/sparse_initializer_as_output.py +++ b/onnxruntime/test/testdata/sparse_initializer_as_output.py @@ -1,31 +1,41 @@ -import onnx -import numpy as np +import argparse import os import sys -import argparse -from onnx import helper, numpy_helper, mapping, utils -from onnx.helper import make_opsetid -from onnx import AttributeProto, SparseTensorProto, TensorProto, GraphProto, ValueInfoProto import traceback +from typing import Any, Callable, Dict, List, Optional, Sequence, Text, Tuple, TypeVar, Union, cast -from typing import Text, Sequence, Any, Optional, Dict, Union, TypeVar, Callable, Tuple, List, cast +import numpy as np +import onnx +from onnx import ( + AttributeProto, + GraphProto, + SparseTensorProto, + TensorProto, + ValueInfoProto, + helper, + mapping, + numpy_helper, + utils, +) +from onnx.helper import make_opsetid def parse_arguments(): parser = argparse.ArgumentParser() - parser.add_argument('--node_name', required=True, type=str, help='Constant Node name') + parser.add_argument("--node_name", required=True, type=str, help="Constant Node name") parser.add_argument("--output_file", required=True, type=str, help="Model file name to save") - + args = parser.parse_args() return args + # This is availabl in ONNX at a later commit def make_sparse_tensor_value_info( - name, # type: Text - elem_type, # type: int - shape, # type: Optional[Sequence[Union[Text, int]]] - doc_string="", # type: Text - shape_denotation=None, # type: Optional[List[Text]] + name, # type: Text + elem_type, # type: int + shape, # type: Optional[Sequence[Union[Text, int]]] + doc_string="", # type: Text + shape_denotation=None, # type: Optional[List[Text]] ): # type: (...) -> ValueInfoProto """Makes a ValueInfoProto based on the data type and shape.""" value_info_proto = ValueInfoProto() @@ -50,9 +60,7 @@ def make_sparse_tensor_value_info( if shape_denotation: if len(shape_denotation) != len(shape): - raise ValueError( - 'Invalid shape_denotation. ' - 'Must be of the same length as shape.') + raise ValueError("Invalid shape_denotation. " "Must be of the same length as shape.") for i, d in enumerate(shape): dim = sparse_tensor_shape_proto.dim.add() @@ -63,9 +71,7 @@ def make_sparse_tensor_value_info( elif isinstance(d, str): dim.dim_param = d else: - raise ValueError( - 'Invalid item in shape: {}. ' - 'Needs to be one of `int` or `str`.'.format(d)) + raise ValueError("Invalid item in shape: {}. " "Needs to be one of `int` or `str`.".format(d)) if shape_denotation: dim.denotation = shape_denotation[i] @@ -74,40 +80,56 @@ def make_sparse_tensor_value_info( def create_model(constant_node_name, output_file_name): - dense_shape = [3,3] + dense_shape = [3, 3] sparse_values = [1.764052391052246, 0.40015721321105957, 0.978738009929657] - values_tensor = helper.make_tensor(name='Constant', data_type=TensorProto.FLOAT, - dims=[len(sparse_values)], - vals=np.array(sparse_values).astype(np.float32), raw=False) + values_tensor = helper.make_tensor( + name="Constant", + data_type=TensorProto.FLOAT, + dims=[len(sparse_values)], + vals=np.array(sparse_values).astype(np.float32), + raw=False, + ) linear_indicies = [2, 3, 5] - indicies_tensor = helper.make_tensor(name='indicies', data_type=TensorProto.INT64, - dims=[len(linear_indicies)], - vals=np.array(linear_indicies).astype(np.int64), raw=False) + indicies_tensor = helper.make_tensor( + name="indicies", + data_type=TensorProto.INT64, + dims=[len(linear_indicies)], + vals=np.array(linear_indicies).astype(np.int64), + raw=False, + ) sparse_tensor = helper.make_sparse_tensor(values_tensor, indicies_tensor, dense_shape) # Nodes - #sparse_attribute = helper.make_attribute('value', sparse_tensor) - constant_node = helper.make_node(constant_node_name, inputs=[], outputs=['values'], - name='Constant', domain='', value=sparse_tensor) - - # Outputs, a square matrix - Values_info = make_sparse_tensor_value_info('values', TensorProto.FLOAT, dense_shape) - - graph_def = helper.make_graph(nodes=[constant_node], - name='ConstantNodeOutput', - inputs=[], - outputs=[Values_info]) - - model_def = helper.make_model(graph_def, producer_name='dmitrism', - opset_imports=[make_opsetid('', 12)]) + # sparse_attribute = helper.make_attribute('value', sparse_tensor) + constant_node = helper.make_node( + constant_node_name, + inputs=[], + outputs=["values"], + name="Constant", + domain="", + value=sparse_tensor, + ) + + # Outputs, a square matrix + Values_info = make_sparse_tensor_value_info("values", TensorProto.FLOAT, dense_shape) + + graph_def = helper.make_graph( + nodes=[constant_node], + name="ConstantNodeOutput", + inputs=[], + outputs=[Values_info], + ) + + model_def = helper.make_model(graph_def, producer_name="dmitrism", opset_imports=[make_opsetid("", 12)]) onnx.save(model_def, output_file_name) + if __name__ == "__main__": - try: - args = parse_arguments() - sys.exit(create_model(args.node_name, args.output_file)) - except Exception as inst : + try: + args = parse_arguments() + sys.exit(create_model(args.node_name, args.output_file)) + except Exception as inst: print("Exception thrown: ", str(inst)) print(traceback.format_exc()) diff --git a/onnxruntime/test/testdata/sparse_to_dense_matmul.py b/onnxruntime/test/testdata/sparse_to_dense_matmul.py index 68c24651dddee..26fb426968c39 100644 --- a/onnxruntime/test/testdata/sparse_to_dense_matmul.py +++ b/onnxruntime/test/testdata/sparse_to_dense_matmul.py @@ -1,28 +1,39 @@ -import onnx -import numpy as np +import argparse import os import sys -import argparse -from onnx import helper, numpy_helper, mapping, utils -from onnx.helper import make_opsetid -from onnx import AttributeProto, SparseTensorProto, TensorProto, GraphProto, ValueInfoProto import traceback +from typing import Any, Callable, Dict, List, Optional, Sequence, Text, Tuple, TypeVar, Union, cast + +import numpy as np +import onnx +from onnx import ( + AttributeProto, + GraphProto, + SparseTensorProto, + TensorProto, + ValueInfoProto, + helper, + mapping, + numpy_helper, + utils, +) +from onnx.helper import make_opsetid -from typing import Text, Sequence, Any, Optional, Dict, Union, TypeVar, Callable, Tuple, List, cast def parse_arguments(): parser = argparse.ArgumentParser() parser.add_argument("--output_file", required=True, type=str, help="Model file name to save") return parser.parse_args() + # This function is now available in ONNX def make_sparse_tensor_value_info( - name, # type: Text - elem_type, # type: int - shape, # type: Optional[Sequence[Union[Text, int]]] - doc_string="", # type: Text - shape_denotation=None, # type: Optional[List[Text]] - ): # type: (...) -> ValueInfoProto + name, # type: Text + elem_type, # type: int + shape, # type: Optional[Sequence[Union[Text, int]]] + doc_string="", # type: Text + shape_denotation=None, # type: Optional[List[Text]] +): # type: (...) -> ValueInfoProto """Makes a ValueInfoProto based on the data type and shape.""" value_info_proto = ValueInfoProto() value_info_proto.name = name @@ -46,9 +57,7 @@ def make_sparse_tensor_value_info( if shape_denotation: if len(shape_denotation) != len(shape): - raise ValueError( - 'Invalid shape_denotation. ' - 'Must be of the same length as shape.') + raise ValueError("Invalid shape_denotation. " "Must be of the same length as shape.") for i, d in enumerate(shape): dim = sparse_tensor_shape_proto.dim.add() @@ -59,38 +68,47 @@ def make_sparse_tensor_value_info( elif isinstance(d, str): dim.dim_param = d else: - raise ValueError( - 'Invalid item in shape: {}. ' - 'Needs to be one of `int` or `text`.'.format(d)) + raise ValueError("Invalid item in shape: {}. " "Needs to be one of `int` or `text`.".format(d)) if shape_denotation: dim.denotation = shape_denotation[i] return value_info_proto -def create_model(output_file_name): - matmul_node = helper.make_node("SparseToDenseMatMul", inputs=['sparse_A', 'dense_B'], outputs=['dense_Y'], - name='SpMM', domain='com.microsoft') - - value_info_A = make_sparse_tensor_value_info('sparse_A', TensorProto.FLOAT, [9, 9]) - value_info_B = helper.make_tensor_value_info('dense_B', TensorProto.FLOAT, [9, 9]) - value_info_Y = helper.make_tensor_value_info('dense_Y', TensorProto.FLOAT, [9, 9]) - graph_def = helper.make_graph(nodes=[matmul_node], - name='SpMM', - inputs=[value_info_A, value_info_B], - outputs=[value_info_Y]) +def create_model(output_file_name): + matmul_node = helper.make_node( + "SparseToDenseMatMul", + inputs=["sparse_A", "dense_B"], + outputs=["dense_Y"], + name="SpMM", + domain="com.microsoft", + ) + + value_info_A = make_sparse_tensor_value_info("sparse_A", TensorProto.FLOAT, [9, 9]) + value_info_B = helper.make_tensor_value_info("dense_B", TensorProto.FLOAT, [9, 9]) + value_info_Y = helper.make_tensor_value_info("dense_Y", TensorProto.FLOAT, [9, 9]) + + graph_def = helper.make_graph( + nodes=[matmul_node], + name="SpMM", + inputs=[value_info_A, value_info_B], + outputs=[value_info_Y], + ) + + model_def = helper.make_model( + graph_def, + producer_name="dmitrism", + opset_imports=[make_opsetid("com.microsoft", 1)], + ) - model_def = helper.make_model(graph_def, producer_name='dmitrism', - opset_imports=[make_opsetid('com.microsoft', 1)]) - onnx.save(model_def, output_file_name) + if __name__ == "__main__": - try: - args = parse_arguments() - sys.exit(create_model(args.output_file)) - except Exception as inst : + try: + args = parse_arguments() + sys.exit(create_model(args.output_file)) + except Exception as inst: print("Exception thrown: ", str(inst)) print(traceback.format_exc()) - diff --git a/onnxruntime/test/testdata/transform/approximation/gelu_approximation_gen.py b/onnxruntime/test/testdata/transform/approximation/gelu_approximation_gen.py index b3189e7437677..4dd60291ddf27 100644 --- a/onnxruntime/test/testdata/transform/approximation/gelu_approximation_gen.py +++ b/onnxruntime/test/testdata/transform/approximation/gelu_approximation_gen.py @@ -1,43 +1,42 @@ import onnx -from onnx import helper -from onnx import TensorProto +from onnx import TensorProto, helper graph = helper.make_graph( [ # nodes # Add node before Gelu helper.make_node("Gelu", ["A"], ["C"], "Gelu_1", domain="com.microsoft"), ], - "Gelu_NoBias", #name + "Gelu_NoBias", # name [ # inputs - helper.make_tensor_value_info('A', TensorProto.FLOAT, ['batch', 'seq_len', 3072]), + helper.make_tensor_value_info("A", TensorProto.FLOAT, ["batch", "seq_len", 3072]), ], [ # outputs - helper.make_tensor_value_info('C', TensorProto.FLOAT, ['batch', 'seq_len', 3072]), + helper.make_tensor_value_info("C", TensorProto.FLOAT, ["batch", "seq_len", 3072]), ], - [ # initializers - ]) + [], # initializers +) model = helper.make_model(graph) -onnx.save(model, r'gelu.onnx') +onnx.save(model, r"gelu.onnx") graph = helper.make_graph( [ # nodes # Add node before Gelu helper.make_node("BiasGelu", ["A", "B"], ["C"], "AddGeluFusion_1", domain="com.microsoft"), ], - "Gelu_AddBias", #name + "Gelu_AddBias", # name [ # inputs - helper.make_tensor_value_info('A', TensorProto.FLOAT, ['batch', 'seq_len', 3072]), - helper.make_tensor_value_info('B', TensorProto.FLOAT, [3072]), + helper.make_tensor_value_info("A", TensorProto.FLOAT, ["batch", "seq_len", 3072]), + helper.make_tensor_value_info("B", TensorProto.FLOAT, [3072]), ], [ # outputs - helper.make_tensor_value_info('C', TensorProto.FLOAT, ['batch', 'seq_len', 3072]), + helper.make_tensor_value_info("C", TensorProto.FLOAT, ["batch", "seq_len", 3072]), ], - [ # initializers - ]) + [], # initializers +) model = helper.make_model(graph) -onnx.save(model, r'gelu_add_bias.onnx') +onnx.save(model, r"gelu_add_bias.onnx") graph = helper.make_graph( [ # nodes @@ -45,17 +44,17 @@ helper.make_node("MatMul", ["A", "B"], ["C"], "MatMul_1"), helper.make_node("BiasGelu", ["C", "D"], ["E"], "AddGeluFusion_1", domain="com.microsoft"), ], - "MatMul_AddGeluFusion", #name + "MatMul_AddGeluFusion", # name [ # inputs - helper.make_tensor_value_info('A', TensorProto.FLOAT, ['batch', 'seq_len', 'x']), - helper.make_tensor_value_info('B', TensorProto.FLOAT, [128, 3072]), - helper.make_tensor_value_info('D', TensorProto.FLOAT, [3072]), + helper.make_tensor_value_info("A", TensorProto.FLOAT, ["batch", "seq_len", "x"]), + helper.make_tensor_value_info("B", TensorProto.FLOAT, [128, 3072]), + helper.make_tensor_value_info("D", TensorProto.FLOAT, [3072]), ], [ # outputs - helper.make_tensor_value_info('E', TensorProto.FLOAT, ['batch', 'seq_len', 3072]), + helper.make_tensor_value_info("E", TensorProto.FLOAT, ["batch", "seq_len", 3072]), ], - [ # initializers - ]) + [], # initializers +) model = helper.make_model(graph) -onnx.save(model, r'gelu_add_matmul.onnx') +onnx.save(model, r"gelu_add_matmul.onnx") diff --git a/onnxruntime/test/testdata/transform/cast_elimination.py b/onnxruntime/test/testdata/transform/cast_elimination.py index 3d546c85e9b66..fbf0932dcaa0d 100644 --- a/onnxruntime/test/testdata/transform/cast_elimination.py +++ b/onnxruntime/test/testdata/transform/cast_elimination.py @@ -1,48 +1,46 @@ -import onnx -from onnx import helper -from onnx import TensorProto, GraphProto, OperatorSetIdProto -from onnx import numpy_helper import numpy as np +import onnx +from onnx import GraphProto, OperatorSetIdProto, TensorProto, helper, numpy_helper -X1 = helper.make_tensor_value_info('x1', TensorProto.INT64, [4, 4]) -X2 = helper.make_tensor_value_info('x2', TensorProto.INT64, [4, 1]) -X3 = helper.make_tensor_value_info('x3', TensorProto.INT64, [4, 1]) -Y = helper.make_tensor_value_info('output', TensorProto.INT64, [4, 4]) +X1 = helper.make_tensor_value_info("x1", TensorProto.INT64, [4, 4]) +X2 = helper.make_tensor_value_info("x2", TensorProto.INT64, [4, 1]) +X3 = helper.make_tensor_value_info("x3", TensorProto.INT64, [4, 1]) +Y = helper.make_tensor_value_info("output", TensorProto.INT64, [4, 4]) -less1 = helper.make_node('Less', ['x1', 'x2'], ['less1'], name='less1') -less2 = helper.make_node('Less', ['x1', 'x3'], ['less2'], name='less2') -cast1 = helper.make_node('Cast', ['less1'], ['cast1'], name='cast1', to=9) -and_node = helper.make_node('And', ['cast1', 'less2'], ['and_node'], name='and_node') -cast2 = helper.make_node('Cast', ['and_node'], ['cast2'], name='cast2', to=9) -cast3 = helper.make_node('Cast', ['cast2'], ['cast3'], name='cast3', to=1) -cast4 = helper.make_node('Cast', ['x1'], ['cast4'], name='cast4', to=7) -cast5 = helper.make_node('Cast', ['cast4'], ['cast5'], name='cast5', to=1) -matmul = helper.make_node('MatMul', ['cast3', 'cast5'], ['matmul'], name='matmul') -cast6 = helper.make_node('Cast', ['matmul'], ['cast6'], name='cast6', to=7) -cast7 = helper.make_node('Cast', ['cast6'], ['output'], name='cast7', to=7) +less1 = helper.make_node("Less", ["x1", "x2"], ["less1"], name="less1") +less2 = helper.make_node("Less", ["x1", "x3"], ["less2"], name="less2") +cast1 = helper.make_node("Cast", ["less1"], ["cast1"], name="cast1", to=9) +and_node = helper.make_node("And", ["cast1", "less2"], ["and_node"], name="and_node") +cast2 = helper.make_node("Cast", ["and_node"], ["cast2"], name="cast2", to=9) +cast3 = helper.make_node("Cast", ["cast2"], ["cast3"], name="cast3", to=1) +cast4 = helper.make_node("Cast", ["x1"], ["cast4"], name="cast4", to=7) +cast5 = helper.make_node("Cast", ["cast4"], ["cast5"], name="cast5", to=1) +matmul = helper.make_node("MatMul", ["cast3", "cast5"], ["matmul"], name="matmul") +cast6 = helper.make_node("Cast", ["matmul"], ["cast6"], name="cast6", to=7) +cast7 = helper.make_node("Cast", ["cast6"], ["output"], name="cast7", to=7) # Create the graph (GraphProto) graph_def = helper.make_graph( [less1, less2, cast1, and_node, cast2, cast3, cast4, cast5, matmul, cast6, cast7], - 'cast_elimination_model', + "cast_elimination_model", [X1, X2, X3], - [Y] + [Y], ) opsets = [] onnxdomain = OperatorSetIdProto() onnxdomain.version = 12 -onnxdomain.domain = "" # The empty string ("") or absence of this field implies the operator set that is defined as part of the ONNX specification. +onnxdomain.domain = "" # The empty string ("") or absence of this field implies the operator set that is defined as part of the ONNX specification. opsets.append(onnxdomain) msdomain = OperatorSetIdProto() msdomain.version = 1 -msdomain.domain = 'com.microsoft' +msdomain.domain = "com.microsoft" opsets.append(msdomain) -kwargs={} -kwargs['opset_imports'] = opsets +kwargs = {} +kwargs["opset_imports"] = opsets # Create the model (ModelProto) -model_def = helper.make_model(graph_def, producer_name='onnx-example', **kwargs) -onnx.save(model_def, 'cast_elimination.onnx') +model_def = helper.make_model(graph_def, producer_name="onnx-example", **kwargs) +onnx.save(model_def, "cast_elimination.onnx") diff --git a/onnxruntime/test/testdata/transform/computation_reduction.py b/onnxruntime/test/testdata/transform/computation_reduction.py index 897c5a58c600f..7d33c9cc66c89 100644 --- a/onnxruntime/test/testdata/transform/computation_reduction.py +++ b/onnxruntime/test/testdata/transform/computation_reduction.py @@ -1,22 +1,27 @@ -import onnx -from onnx import helper -from onnx import AttributeProto, TensorProto, GraphProto, OperatorSetIdProto -from onnx import numpy_helper import numpy as np +import onnx +from onnx import AttributeProto, GraphProto, OperatorSetIdProto, TensorProto, helper, numpy_helper -vocab_size=256 #30258 +vocab_size = 256 # 30258 -X = helper.make_tensor_value_info('input', TensorProto.FLOAT, ["batch", "seqlen", 128]) -unsqueezed_masked_lm_positions = helper.make_tensor_value_info('unsqueezed_masked_lm_positions', - TensorProto.INT64, ["batch", "dynamic_prediction_count", 1]) -Y = helper.make_tensor_value_info('output', TensorProto.FLOAT, ["batch", "dynamic_prediction_count", vocab_size]) -Gather_Y = helper.make_tensor_value_info('gather_output', TensorProto.FLOAT, ["batch", 128]) +X = helper.make_tensor_value_info("input", TensorProto.FLOAT, ["batch", "seqlen", 128]) +unsqueezed_masked_lm_positions = helper.make_tensor_value_info( + "unsqueezed_masked_lm_positions", + TensorProto.INT64, + ["batch", "dynamic_prediction_count", 1], +) +Y = helper.make_tensor_value_info("output", TensorProto.FLOAT, ["batch", "dynamic_prediction_count", vocab_size]) +Gather_Y = helper.make_tensor_value_info("gather_output", TensorProto.FLOAT, ["batch", 128]) layer_norm1_weight_np_vals = np.random.uniform(0.0, 1.0, (128)).astype(np.float32).reshape((128)) -layer_norm1_weight_initializer = numpy_helper.from_array(layer_norm1_weight_np_vals, "bert.encoder.layer.2.output.LayerNorm.weight") +layer_norm1_weight_initializer = numpy_helper.from_array( + layer_norm1_weight_np_vals, "bert.encoder.layer.2.output.LayerNorm.weight" +) layer_norm1_bias_np_vals = np.random.uniform(0.0, 1.0, (128)).astype(np.float32).reshape((128)) -layer_norm1_bias_initializer = numpy_helper.from_array(layer_norm1_bias_np_vals, "bert.encoder.layer.2.output.LayerNorm.bias") +layer_norm1_bias_initializer = numpy_helper.from_array( + layer_norm1_bias_np_vals, "bert.encoder.layer.2.output.LayerNorm.bias" +) matmul1_np_vals = np.random.uniform(0.0, 1.0, (128, 128)).astype(np.float32).reshape((128, 128)) matmul1_initializer = numpy_helper.from_array(matmul1_np_vals, "matmul1_initializer") @@ -25,10 +30,14 @@ add1_initializer = numpy_helper.from_array(add1_np_vals, "add1_initializerr") layer_norm2_weight_np_vals = np.random.uniform(0.0, 1.0, (128)).astype(np.float32).reshape((128)) -layer_norm2_weight_initializer = numpy_helper.from_array(layer_norm2_weight_np_vals, "cls.predictions.transform.LayerNorm.weight") +layer_norm2_weight_initializer = numpy_helper.from_array( + layer_norm2_weight_np_vals, "cls.predictions.transform.LayerNorm.weight" +) layer_norm2_bias_np_vals = np.random.uniform(0.0, 1.0, (128)).astype(np.float32).reshape((128)) -layer_norm2_bias_initializer = numpy_helper.from_array(layer_norm2_bias_np_vals, "cls.predictions.transform.LayerNorm.bias") +layer_norm2_bias_initializer = numpy_helper.from_array( + layer_norm2_bias_np_vals, "cls.predictions.transform.LayerNorm.bias" +) matmul2_np_vals = np.random.uniform(0.0, 1.0, (128, vocab_size)).astype(np.float32).reshape((128, vocab_size)) matmul2_initializer = numpy_helper.from_array(matmul2_np_vals, "bert.embeddings.word_embeddings.weight_transposed") @@ -39,48 +48,82 @@ gather_indice_np_vals = np.asarray([0]).astype(np.int64).reshape(()) gather_indice_initializer = numpy_helper.from_array(gather_indice_np_vals, "gather_indice_initializer") -nodes=[] -layer_norm1 = helper.make_node('LayerNormalization', - ['input', layer_norm1_weight_initializer.name, layer_norm1_bias_initializer.name], - ['layer_norm1', 'saved_mean1', 'saved_inv_std_var1'], - name='layer_norm_1', epsilon=9.999999960041972e-13, axis=-1) +nodes = [] +layer_norm1 = helper.make_node( + "LayerNormalization", + ["input", layer_norm1_weight_initializer.name, layer_norm1_bias_initializer.name], + ["layer_norm1", "saved_mean1", "saved_inv_std_var1"], + name="layer_norm_1", + epsilon=9.999999960041972e-13, + axis=-1, +) nodes.append(layer_norm1) -gather1 = helper.make_node('Gather', ['layer_norm1', gather_indice_initializer.name], ['gather_output'], name="gather_output", axis=1) +gather1 = helper.make_node( + "Gather", + ["layer_norm1", gather_indice_initializer.name], + ["gather_output"], + name="gather_output", + axis=1, +) nodes.append(gather1) -matmul1 = helper.make_node('MatMul', ['layer_norm1', matmul1_initializer.name], ['matmul1'], name="matmul_1") +matmul1 = helper.make_node("MatMul", ["layer_norm1", matmul1_initializer.name], ["matmul1"], name="matmul_1") nodes.append(matmul1) -add1 = helper.make_node('Add', [add1_initializer.name, 'matmul1'], ['add1'], name="add_1") +add1 = helper.make_node("Add", [add1_initializer.name, "matmul1"], ["add1"], name="add_1") nodes.append(add1) -gelu1 = helper.make_node('Gelu', ['add1'], ['gelu1'], name='gelu_1', domain='com.microsoft') +gelu1 = helper.make_node("Gelu", ["add1"], ["gelu1"], name="gelu_1", domain="com.microsoft") nodes.append(gelu1) -layer_norm2 = helper.make_node('LayerNormalization', - ['gelu1', layer_norm2_weight_initializer.name, layer_norm2_bias_initializer.name], - ['layer_norm2', 'saved_mean2', 'saved_inv_std_var2'], - name='layer_norm_2', epsilon=9.999999960041972e-13, axis=-1) +layer_norm2 = helper.make_node( + "LayerNormalization", + ["gelu1", layer_norm2_weight_initializer.name, layer_norm2_bias_initializer.name], + ["layer_norm2", "saved_mean2", "saved_inv_std_var2"], + name="layer_norm_2", + epsilon=9.999999960041972e-13, + axis=-1, +) nodes.append(layer_norm2) -matmul2 = helper.make_node('MatMul', ['layer_norm2', matmul2_initializer.name], ['matmul2'], name="matmul_2") +matmul2 = helper.make_node("MatMul", ["layer_norm2", matmul2_initializer.name], ["matmul2"], name="matmul_2") nodes.append(matmul2) -add2 = helper.make_node('Add', ['matmul2', add2_initializer.name], ['add2'], name="add_2") +add2 = helper.make_node("Add", ["matmul2", add2_initializer.name], ["add2"], name="add_2") nodes.append(add2) -gathernd1 = helper.make_node('GatherND', ['add2', 'unsqueezed_masked_lm_positions'], ['gathernd1'], name="gathernd_1", batch_dims=1) +gathernd1 = helper.make_node( + "GatherND", + ["add2", "unsqueezed_masked_lm_positions"], + ["gathernd1"], + name="gathernd_1", + batch_dims=1, +) nodes.append(gathernd1) -identity1 = helper.make_node('Identity', ['gathernd1'], ['output'], name="output") +identity1 = helper.make_node("Identity", ["gathernd1"], ["output"], name="output") nodes.append(identity1) -initializers=[layer_norm1_weight_initializer, layer_norm1_bias_initializer, matmul1_initializer, add1_initializer, - layer_norm2_weight_initializer, layer_norm2_bias_initializer, matmul2_initializer, add2_initializer, - gather_indice_initializer] +initializers = [ + layer_norm1_weight_initializer, + layer_norm1_bias_initializer, + matmul1_initializer, + add1_initializer, + layer_norm2_weight_initializer, + layer_norm2_bias_initializer, + matmul2_initializer, + add2_initializer, + gather_indice_initializer, +] # Create the graph (GraphProto) -graph_def = helper.make_graph(nodes, 'test-model', [X, unsqueezed_masked_lm_positions], [Y, Gather_Y], initializers) +graph_def = helper.make_graph( + nodes, + "test-model", + [X, unsqueezed_masked_lm_positions], + [Y, Gather_Y], + initializers, +) opsets = [] onnxdomain = OperatorSetIdProto() @@ -96,6 +139,6 @@ kwargs = {} kwargs["opset_imports"] = opsets -model_def = helper.make_model(graph_def, producer_name='onnx-example', **kwargs) +model_def = helper.make_model(graph_def, producer_name="onnx-example", **kwargs) onnx.save(model_def, "computation_reduction_transformer.onnx") diff --git a/onnxruntime/test/testdata/transform/computation_reduction/e2e.py b/onnxruntime/test/testdata/transform/computation_reduction/e2e.py index 502afd524577e..9f7450bb6a6f3 100755 --- a/onnxruntime/test/testdata/transform/computation_reduction/e2e.py +++ b/onnxruntime/test/testdata/transform/computation_reduction/e2e.py @@ -1,22 +1,27 @@ -import onnx -from onnx import helper -from onnx import AttributeProto, TensorProto, GraphProto, OperatorSetIdProto -from onnx import numpy_helper import numpy as np +import onnx +from onnx import AttributeProto, GraphProto, OperatorSetIdProto, TensorProto, helper, numpy_helper -vocab_size=256 #30258 +vocab_size = 256 # 30258 -X = helper.make_tensor_value_info('input', TensorProto.FLOAT, ["batch", "seqlen", 128]) -unsqueezed_masked_lm_positions = helper.make_tensor_value_info('unsqueezed_masked_lm_positions', - TensorProto.INT64, ["batch", "dynamic_prediction_count", 1]) -Y = helper.make_tensor_value_info('output', TensorProto.FLOAT, ["batch", "dynamic_prediction_count", vocab_size]) -Gather_Y = helper.make_tensor_value_info('gather_output', TensorProto.FLOAT, ["batch", 128]) +X = helper.make_tensor_value_info("input", TensorProto.FLOAT, ["batch", "seqlen", 128]) +unsqueezed_masked_lm_positions = helper.make_tensor_value_info( + "unsqueezed_masked_lm_positions", + TensorProto.INT64, + ["batch", "dynamic_prediction_count", 1], +) +Y = helper.make_tensor_value_info("output", TensorProto.FLOAT, ["batch", "dynamic_prediction_count", vocab_size]) +Gather_Y = helper.make_tensor_value_info("gather_output", TensorProto.FLOAT, ["batch", 128]) layer_norm1_weight_np_vals = np.random.uniform(0.0, 1.0, (128)).astype(np.float32).reshape((128)) -layer_norm1_weight_initializer = numpy_helper.from_array(layer_norm1_weight_np_vals, "bert.encoder.layer.2.output.LayerNorm.weight") +layer_norm1_weight_initializer = numpy_helper.from_array( + layer_norm1_weight_np_vals, "bert.encoder.layer.2.output.LayerNorm.weight" +) layer_norm1_bias_np_vals = np.random.uniform(0.0, 1.0, (128)).astype(np.float32).reshape((128)) -layer_norm1_bias_initializer = numpy_helper.from_array(layer_norm1_bias_np_vals, "bert.encoder.layer.2.output.LayerNorm.bias") +layer_norm1_bias_initializer = numpy_helper.from_array( + layer_norm1_bias_np_vals, "bert.encoder.layer.2.output.LayerNorm.bias" +) matmul1_np_vals = np.random.uniform(0.0, 1.0, (128, 128)).astype(np.float32).reshape((128, 128)) matmul1_initializer = numpy_helper.from_array(matmul1_np_vals, "matmul1_initializer") @@ -25,10 +30,14 @@ add1_initializer = numpy_helper.from_array(add1_np_vals, "add1_initializerr") layer_norm2_weight_np_vals = np.random.uniform(0.0, 1.0, (128)).astype(np.float32).reshape((128)) -layer_norm2_weight_initializer = numpy_helper.from_array(layer_norm2_weight_np_vals, "cls.predictions.transform.LayerNorm.weight") +layer_norm2_weight_initializer = numpy_helper.from_array( + layer_norm2_weight_np_vals, "cls.predictions.transform.LayerNorm.weight" +) layer_norm2_bias_np_vals = np.random.uniform(0.0, 1.0, (128)).astype(np.float32).reshape((128)) -layer_norm2_bias_initializer = numpy_helper.from_array(layer_norm2_bias_np_vals, "cls.predictions.transform.LayerNorm.bias") +layer_norm2_bias_initializer = numpy_helper.from_array( + layer_norm2_bias_np_vals, "cls.predictions.transform.LayerNorm.bias" +) matmul2_np_vals = np.random.uniform(0.0, 1.0, (128, vocab_size)).astype(np.float32).reshape((128, vocab_size)) matmul2_initializer = numpy_helper.from_array(matmul2_np_vals, "bert.embeddings.word_embeddings.weight_transposed") @@ -39,48 +48,82 @@ gather_indice_np_vals = np.asarray([0]).astype(np.int64).reshape(()) gather_indice_initializer = numpy_helper.from_array(gather_indice_np_vals, "gather_indice_initializer") -nodes=[] -layer_norm1 = helper.make_node('LayerNormalization', - ['input', layer_norm1_weight_initializer.name, layer_norm1_bias_initializer.name], - ['layer_norm1', 'saved_mean1', 'saved_inv_std_var1'], - name='layer_norm_1', epsilon=9.999999960041972e-13, axis=-1) +nodes = [] +layer_norm1 = helper.make_node( + "LayerNormalization", + ["input", layer_norm1_weight_initializer.name, layer_norm1_bias_initializer.name], + ["layer_norm1", "saved_mean1", "saved_inv_std_var1"], + name="layer_norm_1", + epsilon=9.999999960041972e-13, + axis=-1, +) nodes.append(layer_norm1) -gather1 = helper.make_node('Gather', ['layer_norm1', gather_indice_initializer.name], ['gather_output'], name="gather_output", axis=1) +gather1 = helper.make_node( + "Gather", + ["layer_norm1", gather_indice_initializer.name], + ["gather_output"], + name="gather_output", + axis=1, +) nodes.append(gather1) -matmul1 = helper.make_node('MatMul', ['layer_norm1', matmul1_initializer.name], ['matmul1'], name="matmul_1") +matmul1 = helper.make_node("MatMul", ["layer_norm1", matmul1_initializer.name], ["matmul1"], name="matmul_1") nodes.append(matmul1) -add1 = helper.make_node('Add', [add1_initializer.name, 'matmul1'], ['add1'], name="add_1") +add1 = helper.make_node("Add", [add1_initializer.name, "matmul1"], ["add1"], name="add_1") nodes.append(add1) -gelu1 = helper.make_node('Gelu', ['add1'], ['gelu1'], name='gelu_1', domain='com.microsoft') +gelu1 = helper.make_node("Gelu", ["add1"], ["gelu1"], name="gelu_1", domain="com.microsoft") nodes.append(gelu1) -layer_norm2 = helper.make_node('LayerNormalization', - ['gelu1', layer_norm2_weight_initializer.name, layer_norm2_bias_initializer.name], - ['layer_norm2', 'saved_mean2', 'saved_inv_std_var2'], - name='layer_norm_2', epsilon=9.999999960041972e-13, axis=-1) +layer_norm2 = helper.make_node( + "LayerNormalization", + ["gelu1", layer_norm2_weight_initializer.name, layer_norm2_bias_initializer.name], + ["layer_norm2", "saved_mean2", "saved_inv_std_var2"], + name="layer_norm_2", + epsilon=9.999999960041972e-13, + axis=-1, +) nodes.append(layer_norm2) -matmul2 = helper.make_node('MatMul', ['layer_norm2', matmul2_initializer.name], ['matmul2'], name="matmul_2") +matmul2 = helper.make_node("MatMul", ["layer_norm2", matmul2_initializer.name], ["matmul2"], name="matmul_2") nodes.append(matmul2) -add2 = helper.make_node('Add', ['matmul2', add2_initializer.name], ['add2'], name="add_2") +add2 = helper.make_node("Add", ["matmul2", add2_initializer.name], ["add2"], name="add_2") nodes.append(add2) -gathernd1 = helper.make_node('GatherND', ['add2', 'unsqueezed_masked_lm_positions'], ['gathernd1'], name="gathernd_1", batch_dims=1) +gathernd1 = helper.make_node( + "GatherND", + ["add2", "unsqueezed_masked_lm_positions"], + ["gathernd1"], + name="gathernd_1", + batch_dims=1, +) nodes.append(gathernd1) -identity1 = helper.make_node('Identity', ['gathernd1'], ['output'], name="output") +identity1 = helper.make_node("Identity", ["gathernd1"], ["output"], name="output") nodes.append(identity1) -initializers=[layer_norm1_weight_initializer, layer_norm1_bias_initializer, matmul1_initializer, add1_initializer, - layer_norm2_weight_initializer, layer_norm2_bias_initializer, matmul2_initializer, add2_initializer, - gather_indice_initializer] +initializers = [ + layer_norm1_weight_initializer, + layer_norm1_bias_initializer, + matmul1_initializer, + add1_initializer, + layer_norm2_weight_initializer, + layer_norm2_bias_initializer, + matmul2_initializer, + add2_initializer, + gather_indice_initializer, +] # Create the graph (GraphProto) -graph_def = helper.make_graph(nodes, 'test-model', [X, unsqueezed_masked_lm_positions], [Y, Gather_Y], initializers) +graph_def = helper.make_graph( + nodes, + "test-model", + [X, unsqueezed_masked_lm_positions], + [Y, Gather_Y], + initializers, +) opsets = [] onnxdomain = OperatorSetIdProto() @@ -96,6 +139,6 @@ kwargs = {} kwargs["opset_imports"] = opsets -model_def = helper.make_model(graph_def, producer_name='onnx-example', **kwargs) +model_def = helper.make_model(graph_def, producer_name="onnx-example", **kwargs) onnx.save(model_def, "e2e.onnx") diff --git a/onnxruntime/test/testdata/transform/computation_reduction/gathernd_add.py b/onnxruntime/test/testdata/transform/computation_reduction/gathernd_add.py index f3e6190300cdb..0d32f081a97e6 100755 --- a/onnxruntime/test/testdata/transform/computation_reduction/gathernd_add.py +++ b/onnxruntime/test/testdata/transform/computation_reduction/gathernd_add.py @@ -1,35 +1,54 @@ -import onnx -from onnx import helper -from onnx import AttributeProto, TensorProto, GraphProto, OperatorSetIdProto -from onnx import numpy_helper import numpy as np +import onnx +from onnx import AttributeProto, GraphProto, OperatorSetIdProto, TensorProto, helper, numpy_helper -X = helper.make_tensor_value_info('input', TensorProto.FLOAT, ["batch", "seqlen", 128]) -unsqueezed_masked_lm_positions = helper.make_tensor_value_info('unsqueezed_masked_lm_positions', - TensorProto.INT64, ["batch", "dynamic_prediction_count", 1]) -Y = helper.make_tensor_value_info('output', TensorProto.FLOAT, ["batch", "dynamic_prediction_count", 128]) -Y2 = helper.make_tensor_value_info('output2', TensorProto.FLOAT, ["batch", "dynamic_prediction_count", 128]) +X = helper.make_tensor_value_info("input", TensorProto.FLOAT, ["batch", "seqlen", 128]) +unsqueezed_masked_lm_positions = helper.make_tensor_value_info( + "unsqueezed_masked_lm_positions", + TensorProto.INT64, + ["batch", "dynamic_prediction_count", 1], +) +Y = helper.make_tensor_value_info("output", TensorProto.FLOAT, ["batch", "dynamic_prediction_count", 128]) +Y2 = helper.make_tensor_value_info("output2", TensorProto.FLOAT, ["batch", "dynamic_prediction_count", 128]) nodes = [] # case 1 bias_np_val = np.random.uniform(0.0, 1.0, (128)).astype(np.float32).reshape((128)) bias_initializer = numpy_helper.from_array(bias_np_val, "bias") -add1 = helper.make_node('Add', ['input', 'bias'], ['add_1'], name="add_1") +add1 = helper.make_node("Add", ["input", "bias"], ["add_1"], name="add_1") nodes.append(add1) -gathernd1 = helper.make_node('GatherND', ['add_1', 'unsqueezed_masked_lm_positions'], ['output'], name="gathernd_1", batch_dims=1) +gathernd1 = helper.make_node( + "GatherND", + ["add_1", "unsqueezed_masked_lm_positions"], + ["output"], + name="gathernd_1", + batch_dims=1, +) nodes.append(gathernd1) # case 2 bias2_np_val = np.random.uniform(0.0, 1.0, (128)).astype(np.float32).reshape((128)) bias2_initializer = numpy_helper.from_array(bias2_np_val, "bias2") -add2 = helper.make_node('Add', ['bias2', 'input'], ['add_2'], name="add_2") +add2 = helper.make_node("Add", ["bias2", "input"], ["add_2"], name="add_2") nodes.append(add2) -gathernd2 = helper.make_node('GatherND', ['add_2', 'unsqueezed_masked_lm_positions'], ['output2'], name="gathernd_2", batch_dims=1) +gathernd2 = helper.make_node( + "GatherND", + ["add_2", "unsqueezed_masked_lm_positions"], + ["output2"], + name="gathernd_2", + batch_dims=1, +) nodes.append(gathernd2) -graph_def = helper.make_graph(nodes, 'test-model', [X, unsqueezed_masked_lm_positions], [Y, Y2], [bias_initializer, bias2_initializer]) +graph_def = helper.make_graph( + nodes, + "test-model", + [X, unsqueezed_masked_lm_positions], + [Y, Y2], + [bias_initializer, bias2_initializer], +) opsets = [] onnxdomain = OperatorSetIdProto() @@ -45,6 +64,6 @@ kwargs = {} kwargs["opset_imports"] = opsets -model_def = helper.make_model(graph_def, producer_name='onnx-example', **kwargs) +model_def = helper.make_model(graph_def, producer_name="onnx-example", **kwargs) onnx.save(model_def, "gathernd_add.onnx") diff --git a/onnxruntime/test/testdata/transform/computation_reduction/gathernd_div.py b/onnxruntime/test/testdata/transform/computation_reduction/gathernd_div.py index e08b017cad434..c5814dd4d81f2 100755 --- a/onnxruntime/test/testdata/transform/computation_reduction/gathernd_div.py +++ b/onnxruntime/test/testdata/transform/computation_reduction/gathernd_div.py @@ -1,35 +1,54 @@ -import onnx -from onnx import helper -from onnx import AttributeProto, TensorProto, GraphProto, OperatorSetIdProto -from onnx import numpy_helper import numpy as np +import onnx +from onnx import AttributeProto, GraphProto, OperatorSetIdProto, TensorProto, helper, numpy_helper -X = helper.make_tensor_value_info('input', TensorProto.FLOAT, ["batch", "seqlen", 128]) -unsqueezed_masked_lm_positions = helper.make_tensor_value_info('unsqueezed_masked_lm_positions', - TensorProto.INT64, ["batch", "dynamic_prediction_count", 1]) -Y = helper.make_tensor_value_info('output', TensorProto.FLOAT, ["batch", "dynamic_prediction_count", 128]) -Y2 = helper.make_tensor_value_info('output2', TensorProto.FLOAT, ["batch", "dynamic_prediction_count", 128]) +X = helper.make_tensor_value_info("input", TensorProto.FLOAT, ["batch", "seqlen", 128]) +unsqueezed_masked_lm_positions = helper.make_tensor_value_info( + "unsqueezed_masked_lm_positions", + TensorProto.INT64, + ["batch", "dynamic_prediction_count", 1], +) +Y = helper.make_tensor_value_info("output", TensorProto.FLOAT, ["batch", "dynamic_prediction_count", 128]) +Y2 = helper.make_tensor_value_info("output2", TensorProto.FLOAT, ["batch", "dynamic_prediction_count", 128]) nodes = [] # case 1 divisor_np_val = np.random.uniform(0.0, 1.0, (128)).astype(np.float32).reshape((128)) divisor_initializer = numpy_helper.from_array(divisor_np_val, "divisor") -div1 = helper.make_node('Div', ['input', 'divisor'], ['div_1'], name="div_1") +div1 = helper.make_node("Div", ["input", "divisor"], ["div_1"], name="div_1") nodes.append(div1) -gathernd1 = helper.make_node('GatherND', ['div_1', 'unsqueezed_masked_lm_positions'], ['output'], name="gathernd_1", batch_dims=1) +gathernd1 = helper.make_node( + "GatherND", + ["div_1", "unsqueezed_masked_lm_positions"], + ["output"], + name="gathernd_1", + batch_dims=1, +) nodes.append(gathernd1) # case 2 divisor2_np_val = np.random.uniform(0.0, 1.0, (128)).astype(np.float32).reshape((128)) divisor2_initializer = numpy_helper.from_array(divisor2_np_val, "divisor2") -div2 = helper.make_node('Div', ['divisor2', 'input'], ['div_2'], name="div_2") +div2 = helper.make_node("Div", ["divisor2", "input"], ["div_2"], name="div_2") nodes.append(div2) -gathernd2 = helper.make_node('GatherND', ['div_2', 'unsqueezed_masked_lm_positions'], ['output2'], name="gathernd_2", batch_dims=1) +gathernd2 = helper.make_node( + "GatherND", + ["div_2", "unsqueezed_masked_lm_positions"], + ["output2"], + name="gathernd_2", + batch_dims=1, +) nodes.append(gathernd2) -graph_def = helper.make_graph(nodes, 'test-model', [X, unsqueezed_masked_lm_positions], [Y, Y2], [divisor_initializer, divisor2_initializer]) +graph_def = helper.make_graph( + nodes, + "test-model", + [X, unsqueezed_masked_lm_positions], + [Y, Y2], + [divisor_initializer, divisor2_initializer], +) opsets = [] onnxdomain = OperatorSetIdProto() @@ -45,6 +64,6 @@ kwargs = {} kwargs["opset_imports"] = opsets -model_def = helper.make_model(graph_def, producer_name='onnx-example', **kwargs) +model_def = helper.make_model(graph_def, producer_name="onnx-example", **kwargs) onnx.save(model_def, "gathernd_div.onnx") diff --git a/onnxruntime/test/testdata/transform/computation_reduction/gathernd_gelu.py b/onnxruntime/test/testdata/transform/computation_reduction/gathernd_gelu.py index c286bec6e0dda..ea97a62886844 100755 --- a/onnxruntime/test/testdata/transform/computation_reduction/gathernd_gelu.py +++ b/onnxruntime/test/testdata/transform/computation_reduction/gathernd_gelu.py @@ -1,23 +1,30 @@ -import onnx -from onnx import helper -from onnx import AttributeProto, TensorProto, GraphProto, OperatorSetIdProto -from onnx import numpy_helper import numpy as np - -X = helper.make_tensor_value_info('input', TensorProto.FLOAT, ["batch", "seqlen", 128]) -unsqueezed_masked_lm_positions = helper.make_tensor_value_info('unsqueezed_masked_lm_positions', - TensorProto.INT64, ["batch", "dynamic_prediction_count", 1]) -Y = helper.make_tensor_value_info('output', TensorProto.FLOAT, ["batch", "dynamic_prediction_count", 128]) +import onnx +from onnx import AttributeProto, GraphProto, OperatorSetIdProto, TensorProto, helper, numpy_helper + +X = helper.make_tensor_value_info("input", TensorProto.FLOAT, ["batch", "seqlen", 128]) +unsqueezed_masked_lm_positions = helper.make_tensor_value_info( + "unsqueezed_masked_lm_positions", + TensorProto.INT64, + ["batch", "dynamic_prediction_count", 1], +) +Y = helper.make_tensor_value_info("output", TensorProto.FLOAT, ["batch", "dynamic_prediction_count", 128]) nodes = [] # case 1 -gelu1 = helper.make_node('Gelu', ['input'], ['gelu_1'], name="gelu_1", domain='com.microsoft') +gelu1 = helper.make_node("Gelu", ["input"], ["gelu_1"], name="gelu_1", domain="com.microsoft") nodes.append(gelu1) -gathernd1 = helper.make_node('GatherND', ['gelu_1', 'unsqueezed_masked_lm_positions'], ['output'], name="gathernd_1", batch_dims=1) +gathernd1 = helper.make_node( + "GatherND", + ["gelu_1", "unsqueezed_masked_lm_positions"], + ["output"], + name="gathernd_1", + batch_dims=1, +) nodes.append(gathernd1) -graph_def = helper.make_graph(nodes, 'test-model', [X, unsqueezed_masked_lm_positions], [Y]) +graph_def = helper.make_graph(nodes, "test-model", [X, unsqueezed_masked_lm_positions], [Y]) opsets = [] onnxdomain = OperatorSetIdProto() @@ -33,6 +40,6 @@ kwargs = {} kwargs["opset_imports"] = opsets -model_def = helper.make_model(graph_def, producer_name='onnx-example', **kwargs) +model_def = helper.make_model(graph_def, producer_name="onnx-example", **kwargs) onnx.save(model_def, "gathernd_gelu.onnx") diff --git a/onnxruntime/test/testdata/transform/computation_reduction/gathernd_layernormalization.py b/onnxruntime/test/testdata/transform/computation_reduction/gathernd_layernormalization.py index 8a645011474e9..eb63c769025fe 100755 --- a/onnxruntime/test/testdata/transform/computation_reduction/gathernd_layernormalization.py +++ b/onnxruntime/test/testdata/transform/computation_reduction/gathernd_layernormalization.py @@ -1,31 +1,52 @@ -import onnx -from onnx import helper -from onnx import AttributeProto, TensorProto, GraphProto, OperatorSetIdProto -from onnx import numpy_helper import numpy as np +import onnx +from onnx import AttributeProto, GraphProto, OperatorSetIdProto, TensorProto, helper, numpy_helper -X = helper.make_tensor_value_info('input', TensorProto.FLOAT, ["batch", "seqlen", 128]) -unsqueezed_masked_lm_positions = helper.make_tensor_value_info('unsqueezed_masked_lm_positions', - TensorProto.INT64, ["batch", "dynamic_prediction_count", 1]) -Y = helper.make_tensor_value_info('output', TensorProto.FLOAT, ["batch", "dynamic_prediction_count", 128]) +X = helper.make_tensor_value_info("input", TensorProto.FLOAT, ["batch", "seqlen", 128]) +unsqueezed_masked_lm_positions = helper.make_tensor_value_info( + "unsqueezed_masked_lm_positions", + TensorProto.INT64, + ["batch", "dynamic_prediction_count", 1], +) +Y = helper.make_tensor_value_info("output", TensorProto.FLOAT, ["batch", "dynamic_prediction_count", 128]) layer_norm1_weight_np_vals = np.random.uniform(0.0, 1.0, (128)).astype(np.float32).reshape((128)) -layer_norm1_weight_initializer = numpy_helper.from_array(layer_norm1_weight_np_vals, "bert.encoder.layer.2.output.LayerNorm.weight") +layer_norm1_weight_initializer = numpy_helper.from_array( + layer_norm1_weight_np_vals, "bert.encoder.layer.2.output.LayerNorm.weight" +) layer_norm1_bias_np_vals = np.random.uniform(0.0, 1.0, (128)).astype(np.float32).reshape((128)) -layer_norm1_bias_initializer = numpy_helper.from_array(layer_norm1_bias_np_vals, "bert.encoder.layer.2.output.LayerNorm.bias") +layer_norm1_bias_initializer = numpy_helper.from_array( + layer_norm1_bias_np_vals, "bert.encoder.layer.2.output.LayerNorm.bias" +) -nodes=[] -layer_norm1 = helper.make_node('LayerNormalization', - ['input', layer_norm1_weight_initializer.name, layer_norm1_bias_initializer.name], - ['layer_norm1', 'saved_mean1', 'saved_inv_std_var1'], - name='layer_norm_1', epsilon=9.999999960041972e-13, axis=-1) +nodes = [] +layer_norm1 = helper.make_node( + "LayerNormalization", + ["input", layer_norm1_weight_initializer.name, layer_norm1_bias_initializer.name], + ["layer_norm1", "saved_mean1", "saved_inv_std_var1"], + name="layer_norm_1", + epsilon=9.999999960041972e-13, + axis=-1, +) nodes.append(layer_norm1) -gathernd1 = helper.make_node('GatherND', ['layer_norm1', 'unsqueezed_masked_lm_positions'], ['output'], name="gathernd_1", batch_dims=1) +gathernd1 = helper.make_node( + "GatherND", + ["layer_norm1", "unsqueezed_masked_lm_positions"], + ["output"], + name="gathernd_1", + batch_dims=1, +) nodes.append(gathernd1) -graph_def = helper.make_graph(nodes, 'test-model', [X, unsqueezed_masked_lm_positions], [Y], [layer_norm1_weight_initializer, layer_norm1_bias_initializer]) +graph_def = helper.make_graph( + nodes, + "test-model", + [X, unsqueezed_masked_lm_positions], + [Y], + [layer_norm1_weight_initializer, layer_norm1_bias_initializer], +) opsets = [] onnxdomain = OperatorSetIdProto() @@ -41,6 +62,6 @@ kwargs = {} kwargs["opset_imports"] = opsets -model_def = helper.make_model(graph_def, producer_name='onnx-example', **kwargs) +model_def = helper.make_model(graph_def, producer_name="onnx-example", **kwargs) onnx.save(model_def, "gathernd_layernormalization.onnx") diff --git a/onnxruntime/test/testdata/transform/computation_reduction/gathernd_matmul.py b/onnxruntime/test/testdata/transform/computation_reduction/gathernd_matmul.py index 776871ec1111f..2b7ea7127dd0f 100755 --- a/onnxruntime/test/testdata/transform/computation_reduction/gathernd_matmul.py +++ b/onnxruntime/test/testdata/transform/computation_reduction/gathernd_matmul.py @@ -1,27 +1,34 @@ -import onnx -from onnx import helper -from onnx import AttributeProto, TensorProto, GraphProto, OperatorSetIdProto -from onnx import numpy_helper import numpy as np +import onnx +from onnx import AttributeProto, GraphProto, OperatorSetIdProto, TensorProto, helper, numpy_helper -X = helper.make_tensor_value_info('input', TensorProto.FLOAT, ["batch", "seqlen", 128]) -unsqueezed_masked_lm_positions = helper.make_tensor_value_info('unsqueezed_masked_lm_positions', - TensorProto.INT64, ["batch", "dynamic_prediction_count", 1]) -Y = helper.make_tensor_value_info('output', TensorProto.FLOAT, ["batch", "dynamic_prediction_count", 128]) +X = helper.make_tensor_value_info("input", TensorProto.FLOAT, ["batch", "seqlen", 128]) +unsqueezed_masked_lm_positions = helper.make_tensor_value_info( + "unsqueezed_masked_lm_positions", + TensorProto.INT64, + ["batch", "dynamic_prediction_count", 1], +) +Y = helper.make_tensor_value_info("output", TensorProto.FLOAT, ["batch", "dynamic_prediction_count", 128]) matmul1_np_vals = np.random.uniform(0.0, 1.0, (128, 128)).astype(np.float32).reshape((128, 128)) matmul1_initializer = numpy_helper.from_array(matmul1_np_vals, "matmul1_initializer") -nodes=[] -matmul1 = helper.make_node('MatMul', ['input', matmul1_initializer.name], ['matmul1'], name="matmul_1") +nodes = [] +matmul1 = helper.make_node("MatMul", ["input", matmul1_initializer.name], ["matmul1"], name="matmul_1") nodes.append(matmul1) -gathernd1 = helper.make_node('GatherND', ['matmul1', 'unsqueezed_masked_lm_positions'], ['output'], name="gathernd_1", batch_dims=1) +gathernd1 = helper.make_node( + "GatherND", + ["matmul1", "unsqueezed_masked_lm_positions"], + ["output"], + name="gathernd_1", + batch_dims=1, +) nodes.append(gathernd1) -initializers=[matmul1_initializer] +initializers = [matmul1_initializer] -graph_def = helper.make_graph(nodes, 'test-model', [X, unsqueezed_masked_lm_positions], [Y], initializers) +graph_def = helper.make_graph(nodes, "test-model", [X, unsqueezed_masked_lm_positions], [Y], initializers) opsets = [] onnxdomain = OperatorSetIdProto() @@ -37,6 +44,6 @@ kwargs = {} kwargs["opset_imports"] = opsets -model_def = helper.make_model(graph_def, producer_name='onnx-example', **kwargs) +model_def = helper.make_model(graph_def, producer_name="onnx-example", **kwargs) onnx.save(model_def, "gathernd_matmul.onnx") diff --git a/onnxruntime/test/testdata/transform/concat_graph_gen.py b/onnxruntime/test/testdata/transform/concat_graph_gen.py index 599eeec221e13..85817154e9014 100644 --- a/onnxruntime/test/testdata/transform/concat_graph_gen.py +++ b/onnxruntime/test/testdata/transform/concat_graph_gen.py @@ -1,55 +1,87 @@ -import onnx -from onnx import helper -from onnx import TensorProto import numpy as np +import onnx +from onnx import TensorProto, helper + def GenerateModel(model_name): - nodes = [ - helper.make_node("Gather", ["embed_weights","input_1"], ["gather_out"], "gather"), - - helper.make_node("Add", ["gather_out", "add_q_weight"], ["add_q_out"], "add_q"), - helper.make_node("Add", ["gather_out", "add_k_weight"], ["add_k_out"], "add_k"), - helper.make_node("Add", ["gather_out", "add_v_weight"], ["add_v_out"], "add_v"), - - helper.make_node("Concat", ["add_q_out", "add_k_out", "add_v_out"], - ["concat_out"], "concat", axis=0), - - helper.make_node("Add", ["add_qkv_weight", "concat_out"], ["add_out"], "add"), - helper.make_node("ReduceSum",["add_out"],["predictions"],"reduce_sum_1", axes=[0], keepdims=1), - ] - - embed_weights = np.random.uniform(-1,1,8000).tolist() - - add_q_weight = [-0.23681640625, -0.16552734375, 0.2191162109375, -0.1756591796875, - -0.03460693359375, -0.05316162109375, -0.336181640625, -0.253662109375] - - add_k_weight = [0.0246734619140625, 0.011993408203125, 0.0178375244140625, 0.00998687744140625, - 0.0255126953125, 0.076416015625, -0.040771484375, 0.0107879638671875] - - add_v_weight = [-0.005893707275390625, -0.00916290283203125, 0.04541015625, 0.0159454345703125, - -0.0029163360595703125, -0.03472900390625, 0.0535888671875, 0.0091094970703125] - - initializers = [ # initializers - helper.make_tensor('embed_weights', TensorProto.FLOAT, [1000, 8], embed_weights), - helper.make_tensor('add_q_weight', TensorProto.FLOAT, [8], add_q_weight), - helper.make_tensor('add_k_weight', TensorProto.FLOAT, [8], add_k_weight), - helper.make_tensor('add_v_weight', TensorProto.FLOAT, [8], add_v_weight), - helper.make_tensor('add_qkv_weight', TensorProto.FLOAT, [1], [1.0]), - ] - - graph = helper.make_graph( - nodes, - "ConcatThreeInputs", #name - [ # inputs - helper.make_tensor_value_info('input_1', TensorProto.INT64, ['batch', 'seq_len']) - ], - [ # outputs - helper.make_tensor_value_info('predictions', TensorProto.FLOAT, [1,1,8]), - ], - initializers) - - model = helper.make_model(graph) - onnx.save(model, model_name) - -GenerateModel('concat_trainable.onnx') + nodes = [ + helper.make_node("Gather", ["embed_weights", "input_1"], ["gather_out"], "gather"), + helper.make_node("Add", ["gather_out", "add_q_weight"], ["add_q_out"], "add_q"), + helper.make_node("Add", ["gather_out", "add_k_weight"], ["add_k_out"], "add_k"), + helper.make_node("Add", ["gather_out", "add_v_weight"], ["add_v_out"], "add_v"), + helper.make_node( + "Concat", + ["add_q_out", "add_k_out", "add_v_out"], + ["concat_out"], + "concat", + axis=0, + ), + helper.make_node("Add", ["add_qkv_weight", "concat_out"], ["add_out"], "add"), + helper.make_node( + "ReduceSum", + ["add_out"], + ["predictions"], + "reduce_sum_1", + axes=[0], + keepdims=1, + ), + ] + + embed_weights = np.random.uniform(-1, 1, 8000).tolist() + + add_q_weight = [ + -0.23681640625, + -0.16552734375, + 0.2191162109375, + -0.1756591796875, + -0.03460693359375, + -0.05316162109375, + -0.336181640625, + -0.253662109375, + ] + + add_k_weight = [ + 0.0246734619140625, + 0.011993408203125, + 0.0178375244140625, + 0.00998687744140625, + 0.0255126953125, + 0.076416015625, + -0.040771484375, + 0.0107879638671875, + ] + + add_v_weight = [ + -0.005893707275390625, + -0.00916290283203125, + 0.04541015625, + 0.0159454345703125, + -0.0029163360595703125, + -0.03472900390625, + 0.0535888671875, + 0.0091094970703125, + ] + + initializers = [ # initializers + helper.make_tensor("embed_weights", TensorProto.FLOAT, [1000, 8], embed_weights), + helper.make_tensor("add_q_weight", TensorProto.FLOAT, [8], add_q_weight), + helper.make_tensor("add_k_weight", TensorProto.FLOAT, [8], add_k_weight), + helper.make_tensor("add_v_weight", TensorProto.FLOAT, [8], add_v_weight), + helper.make_tensor("add_qkv_weight", TensorProto.FLOAT, [1], [1.0]), + ] + + graph = helper.make_graph( + nodes, + "ConcatThreeInputs", # name + [helper.make_tensor_value_info("input_1", TensorProto.INT64, ["batch", "seq_len"])], # inputs + [ # outputs + helper.make_tensor_value_info("predictions", TensorProto.FLOAT, [1, 1, 8]), + ], + initializers, + ) + + model = helper.make_model(graph) + onnx.save(model, model_name) + +GenerateModel("concat_trainable.onnx") diff --git a/onnxruntime/test/testdata/transform/concat_slice_elimination.py b/onnxruntime/test/testdata/transform/concat_slice_elimination.py index c09c8bfddc651..88a1236922a19 100644 --- a/onnxruntime/test/testdata/transform/concat_slice_elimination.py +++ b/onnxruntime/test/testdata/transform/concat_slice_elimination.py @@ -1,117 +1,187 @@ -import onnx -from onnx import helper -from onnx import TensorProto, GraphProto, OperatorSetIdProto -from onnx import numpy_helper -import numpy as np import random +import numpy as np +import onnx +from onnx import GraphProto, OperatorSetIdProto, TensorProto, helper, numpy_helper + batch = 3 hidden_size = 4 attention_head = 2 hidden_per_attention = 2 -relative_attention_num_buckets=32 -input_len=8 -output_len=8 +relative_attention_num_buckets = 32 +input_len = 8 +output_len = 8 -X = helper.make_tensor_value_info('input', TensorProto.FLOAT, [batch, input_len, hidden_size]) -Y = helper.make_tensor_value_info('output', TensorProto.FLOAT, [batch, output_len, hidden_size]) +X = helper.make_tensor_value_info("input", TensorProto.FLOAT, [batch, input_len, hidden_size]) +Y = helper.make_tensor_value_info("output", TensorProto.FLOAT, [batch, output_len, hidden_size]) q_weight_np_vals = (0.01 * np.arange(hidden_size * hidden_size, dtype=np.float32)).reshape((hidden_size, hidden_size)) -q_weight_initializer = numpy_helper.from_array(q_weight_np_vals, 'encoder.layer.0.SelfAttention.q.weight') +q_weight_initializer = numpy_helper.from_array(q_weight_np_vals, "encoder.layer.0.SelfAttention.q.weight") k_weight_np_vals = (0.01 * np.arange(hidden_size * hidden_size, dtype=np.float32)).reshape((hidden_size, hidden_size)) -k_weight_initializer = numpy_helper.from_array(k_weight_np_vals, 'encoder.layer.0.SelfAttention.k.weight') +k_weight_initializer = numpy_helper.from_array(k_weight_np_vals, "encoder.layer.0.SelfAttention.k.weight") v_weight_np_vals = (0.01 * np.arange(hidden_size * hidden_size, dtype=np.float32)).reshape((hidden_size, hidden_size)) -v_weight_initializer = numpy_helper.from_array(v_weight_np_vals, 'encoder.layer.0.SelfAttention.v.weight') +v_weight_initializer = numpy_helper.from_array(v_weight_np_vals, "encoder.layer.0.SelfAttention.v.weight") -q_bias_np_vals = (0.01 * np.arange(hidden_size, dtype=np.float32)) -q_bias_initializer = numpy_helper.from_array(q_bias_np_vals, 'encoder.layer.0.SelfAttention.q.bias') +q_bias_np_vals = 0.01 * np.arange(hidden_size, dtype=np.float32) +q_bias_initializer = numpy_helper.from_array(q_bias_np_vals, "encoder.layer.0.SelfAttention.q.bias") -k_bias_np_vals = (0.01 * np.arange(hidden_size, dtype=np.float32)) -k_bias_initializer = numpy_helper.from_array(k_bias_np_vals, 'encoder.layer.0.SelfAttention.k.bias') +k_bias_np_vals = 0.01 * np.arange(hidden_size, dtype=np.float32) +k_bias_initializer = numpy_helper.from_array(k_bias_np_vals, "encoder.layer.0.SelfAttention.k.bias") -v_bias_np_vals = (0.01 * np.arange(hidden_size, dtype=np.float32)) -v_bias_initializer = numpy_helper.from_array(v_bias_np_vals, 'encoder.layer.0.SelfAttention.v.bias') +v_bias_np_vals = 0.01 * np.arange(hidden_size, dtype=np.float32) +v_bias_initializer = numpy_helper.from_array(v_bias_np_vals, "encoder.layer.0.SelfAttention.v.bias") -q_starts_initializer = numpy_helper.from_array(np.asarray([0], dtype=np.int64), 'q_starts') -k_starts_initializer = numpy_helper.from_array(np.asarray([hidden_size], dtype=np.int64), 'k_starts') -v_starts_initializer = numpy_helper.from_array(np.asarray([2*hidden_size], dtype=np.int64), 'v_starts') +q_starts_initializer = numpy_helper.from_array(np.asarray([0], dtype=np.int64), "q_starts") +k_starts_initializer = numpy_helper.from_array(np.asarray([hidden_size], dtype=np.int64), "k_starts") +v_starts_initializer = numpy_helper.from_array(np.asarray([2 * hidden_size], dtype=np.int64), "v_starts") -q_ends_initializer = numpy_helper.from_array(np.asarray([hidden_size], dtype=np.int64), 'q_ends') -k_ends_initializer = numpy_helper.from_array(np.asarray([2*hidden_size], dtype=np.int64), 'k_ends') -v_ends_initializer = numpy_helper.from_array(np.asarray([9223372036854775807], dtype=np.int64), 'v_ends') +q_ends_initializer = numpy_helper.from_array(np.asarray([hidden_size], dtype=np.int64), "q_ends") +k_ends_initializer = numpy_helper.from_array(np.asarray([2 * hidden_size], dtype=np.int64), "k_ends") +v_ends_initializer = numpy_helper.from_array(np.asarray([9223372036854775807], dtype=np.int64), "v_ends") -slice_axes_initializer = numpy_helper.from_array(np.asarray([0], dtype=np.int64), 'slice_axes') -slice_steps_initializer = numpy_helper.from_array(np.asarray([1], dtype=np.int64), 'slice_steps') +slice_axes_initializer = numpy_helper.from_array(np.asarray([0], dtype=np.int64), "slice_axes") +slice_steps_initializer = numpy_helper.from_array(np.asarray([1], dtype=np.int64), "slice_steps") -transpose_q = helper.make_node('Transpose', [q_weight_initializer.name], ['transpose_q'], name='transpose_q', perm=[1,0]) -transpose_k = helper.make_node('Transpose', [k_weight_initializer.name], ['transpose_k'], name='transpose_k', perm=[1,0]) -transpose_v = helper.make_node('Transpose', [v_weight_initializer.name], ['transpose_v'], name='transpose_v', perm=[1,0]) +transpose_q = helper.make_node( + "Transpose", + [q_weight_initializer.name], + ["transpose_q"], + name="transpose_q", + perm=[1, 0], +) +transpose_k = helper.make_node( + "Transpose", + [k_weight_initializer.name], + ["transpose_k"], + name="transpose_k", + perm=[1, 0], +) +transpose_v = helper.make_node( + "Transpose", + [v_weight_initializer.name], + ["transpose_v"], + name="transpose_v", + perm=[1, 0], +) -matmul_q = helper.make_node('MatMul', ['input', 'transpose_q'], ['matmul_q'], name='matmul_q') -matmul_k = helper.make_node('MatMul', ['input', 'transpose_k'], ['matmul_k'], name='matmul_k') -matmul_v = helper.make_node('MatMul', ['input', 'transpose_v'], ['matmul_v'], name='matmul_v') +matmul_q = helper.make_node("MatMul", ["input", "transpose_q"], ["matmul_q"], name="matmul_q") +matmul_k = helper.make_node("MatMul", ["input", "transpose_k"], ["matmul_k"], name="matmul_k") +matmul_v = helper.make_node("MatMul", ["input", "transpose_v"], ["matmul_v"], name="matmul_v") -concat_bias = helper.make_node("Concat", - [q_bias_initializer.name,k_bias_initializer.name,v_bias_initializer.name], - ['concat_bias'], - axis=0, - name='concat_bias') +concat_bias = helper.make_node( + "Concat", + [q_bias_initializer.name, k_bias_initializer.name, v_bias_initializer.name], + ["concat_bias"], + axis=0, + name="concat_bias", +) -slice_q = helper.make_node('Slice', - ['concat_bias', q_starts_initializer.name, q_ends_initializer.name, - slice_axes_initializer.name,slice_steps_initializer.name], - ['slice_q'], name = 'slice_q' ) +slice_q = helper.make_node( + "Slice", + [ + "concat_bias", + q_starts_initializer.name, + q_ends_initializer.name, + slice_axes_initializer.name, + slice_steps_initializer.name, + ], + ["slice_q"], + name="slice_q", +) -slice_k = helper.make_node('Slice', - ['concat_bias', k_starts_initializer.name, k_ends_initializer.name, - slice_axes_initializer.name,slice_steps_initializer.name], - ['slice_k'], name = 'slice_k' ) +slice_k = helper.make_node( + "Slice", + [ + "concat_bias", + k_starts_initializer.name, + k_ends_initializer.name, + slice_axes_initializer.name, + slice_steps_initializer.name, + ], + ["slice_k"], + name="slice_k", +) -slice_v = helper.make_node('Slice', - ['concat_bias', v_starts_initializer.name, v_ends_initializer.name, - slice_axes_initializer.name,slice_steps_initializer.name], - ['slice_v'], name = 'slice_v' ) +slice_v = helper.make_node( + "Slice", + [ + "concat_bias", + v_starts_initializer.name, + v_ends_initializer.name, + slice_axes_initializer.name, + slice_steps_initializer.name, + ], + ["slice_v"], + name="slice_v", +) -add_q = helper.make_node('Add', ['matmul_q', 'slice_q'], ['add_q'], name='add_q') -add_k = helper.make_node('Add', ['matmul_k', 'slice_k'], ['add_k'], name='add_k') -add_v = helper.make_node('Add', ['matmul_v', 'slice_v'], ['add_v'], name='add_v') +add_q = helper.make_node("Add", ["matmul_q", "slice_q"], ["add_q"], name="add_q") +add_k = helper.make_node("Add", ["matmul_k", "slice_k"], ["add_k"], name="add_k") +add_v = helper.make_node("Add", ["matmul_v", "slice_v"], ["add_v"], name="add_v") -add_1 = helper.make_node('Add', ['add_q','add_k'],['add_1'], name='add_1') -add_2 = helper.make_node('Add', ['add_1','add_v'],['add_2'], name='add_2') -identity = helper.make_node('Identity', ['add_2'], ['output'], name='identity') +add_1 = helper.make_node("Add", ["add_q", "add_k"], ["add_1"], name="add_1") +add_2 = helper.make_node("Add", ["add_1", "add_v"], ["add_2"], name="add_2") +identity = helper.make_node("Identity", ["add_2"], ["output"], name="identity") # Create the graph (GraphProto) graph_def = helper.make_graph( - [transpose_q,transpose_k,transpose_v,matmul_q,matmul_k,matmul_v,add_q,add_k,add_v, - concat_bias,slice_q, slice_k, slice_v, add_1, add_2, identity], - 'concat-slice-test-model', + [ + transpose_q, + transpose_k, + transpose_v, + matmul_q, + matmul_k, + matmul_v, + add_q, + add_k, + add_v, + concat_bias, + slice_q, + slice_k, + slice_v, + add_1, + add_2, + identity, + ], + "concat-slice-test-model", [X], [Y], - [q_weight_initializer,k_weight_initializer,v_weight_initializer,q_bias_initializer,k_bias_initializer, - v_bias_initializer,q_starts_initializer,k_starts_initializer,v_starts_initializer,q_ends_initializer, - k_ends_initializer,v_ends_initializer,slice_axes_initializer,slice_steps_initializer] + [ + q_weight_initializer, + k_weight_initializer, + v_weight_initializer, + q_bias_initializer, + k_bias_initializer, + v_bias_initializer, + q_starts_initializer, + k_starts_initializer, + v_starts_initializer, + q_ends_initializer, + k_ends_initializer, + v_ends_initializer, + slice_axes_initializer, + slice_steps_initializer, + ], ) opsets = [] onnxdomain = OperatorSetIdProto() onnxdomain.version = 12 -onnxdomain.domain = "" # The empty string ("") or absence of this field implies the operator set that is defined as part of the ONNX specification. +onnxdomain.domain = "" # The empty string ("") or absence of this field implies the operator set that is defined as part of the ONNX specification. opsets.append(onnxdomain) msdomain = OperatorSetIdProto() msdomain.version = 1 -msdomain.domain = 'com.microsoft' +msdomain.domain = "com.microsoft" opsets.append(msdomain) -kwargs={} -kwargs['opset_imports'] = opsets +kwargs = {} +kwargs["opset_imports"] = opsets # Create the model (ModelProto) -model_def = helper.make_model(graph_def, producer_name='onnx-example', **kwargs) -onnx.save(model_def, 'concat_slice_basic_test.onnx') - - +model_def = helper.make_model(graph_def, producer_name="onnx-example", **kwargs) +onnx.save(model_def, "concat_slice_basic_test.onnx") diff --git a/onnxruntime/test/testdata/transform/cse/generate.py b/onnxruntime/test/testdata/transform/cse/generate.py index 7a676f13df780..1cd1b54b09a53 100644 --- a/onnxruntime/test/testdata/transform/cse/generate.py +++ b/onnxruntime/test/testdata/transform/cse/generate.py @@ -1,233 +1,362 @@ -import onnx -from onnx import helper, shape_inference -from onnx import AttributeProto, TensorProto, GraphProto import os +import onnx +from onnx import AttributeProto, GraphProto, TensorProto, helper, shape_inference + _this_dir = os.path.abspath(os.path.dirname(__file__)) + def _onnx_export(graph_def, relative_path, verbose=False): - model = helper.make_model(graph_def, producer_name='makalini', opset_imports=[helper.make_operatorsetid("", 11)]) - onnx.checker.check_model(model) - inferred_model = shape_inference.infer_shapes(model) - onnx.checker.check_model(inferred_model) - model_path = os.path.join(_this_dir, relative_path) - os.makedirs(os.path.dirname(model_path), exist_ok=True) - onnx.save_model(model, model_path) - if verbose: - print() - print(inferred_model) - import onnxruntime as rt - rt.InferenceSession(model_path) + model = helper.make_model( + graph_def, + producer_name="makalini", + opset_imports=[helper.make_operatorsetid("", 11)], + ) + onnx.checker.check_model(model) + inferred_model = shape_inference.infer_shapes(model) + onnx.checker.check_model(inferred_model) + model_path = os.path.join(_this_dir, relative_path) + os.makedirs(os.path.dirname(model_path), exist_ok=True) + onnx.save_model(model, model_path) + if verbose: + print() + print(inferred_model) + import onnxruntime as rt + + rt.InferenceSession(model_path) + def cse1(): - graph_def = helper.make_graph( - nodes = [ - helper.make_node(op_type = "MatMul", inputs = ['w', 'x'], outputs = ['MatMul1'], name = 'matmul_1'), - helper.make_node(op_type = "MatMul", inputs = ['w', 'x'], outputs = ['MatMul2'], name = 'matmul_2'), - helper.make_node(op_type = "Add", inputs = ['MatMul1', 'b'], outputs = ['Add1'], name = 'add_1'), - helper.make_node(op_type = "Add", inputs = ['MatMul2', 'b'], outputs = ['Add2'], name = 'add_2'), - helper.make_node(op_type = "Relu", inputs = ['Add1'], outputs = ['Relu1'], name = 'relu_1'), - helper.make_node(op_type = "Relu", inputs = ['Add2'], outputs = ['Relu2'], name = 'relu_2'), - helper.make_node(op_type = "Add", inputs = ['Relu1', 'Relu2'], outputs = ['Result'], name = 'result') - ], - name = 'cse1', - inputs = [ - helper.make_tensor_value_info("x", TensorProto.FLOAT, [5]) - ], - outputs = [ - helper.make_tensor_value_info('Result', TensorProto.FLOAT, [2]) - ], - initializer = [ - helper.make_tensor('w', TensorProto.FLOAT, [2, 5], list(range(2*5))), - helper.make_tensor('b', TensorProto.FLOAT, [2], list(range(2))), - ] - ) - _onnx_export(graph_def, 'cse1.onnx') + graph_def = helper.make_graph( + nodes=[ + helper.make_node( + op_type="MatMul", + inputs=["w", "x"], + outputs=["MatMul1"], + name="matmul_1", + ), + helper.make_node( + op_type="MatMul", + inputs=["w", "x"], + outputs=["MatMul2"], + name="matmul_2", + ), + helper.make_node(op_type="Add", inputs=["MatMul1", "b"], outputs=["Add1"], name="add_1"), + helper.make_node(op_type="Add", inputs=["MatMul2", "b"], outputs=["Add2"], name="add_2"), + helper.make_node(op_type="Relu", inputs=["Add1"], outputs=["Relu1"], name="relu_1"), + helper.make_node(op_type="Relu", inputs=["Add2"], outputs=["Relu2"], name="relu_2"), + helper.make_node( + op_type="Add", + inputs=["Relu1", "Relu2"], + outputs=["Result"], + name="result", + ), + ], + name="cse1", + inputs=[helper.make_tensor_value_info("x", TensorProto.FLOAT, [5])], + outputs=[helper.make_tensor_value_info("Result", TensorProto.FLOAT, [2])], + initializer=[ + helper.make_tensor("w", TensorProto.FLOAT, [2, 5], list(range(2 * 5))), + helper.make_tensor("b", TensorProto.FLOAT, [2], list(range(2))), + ], + ) + _onnx_export(graph_def, "cse1.onnx") + def cse_graph_output(): - graph_def = helper.make_graph( - nodes = [ - helper.make_node(op_type = "Add", inputs = ['x', 'b'], outputs = ['res1'], name = 'add_1'), - helper.make_node(op_type = "Add", inputs = ['x', 'b'], outputs = ['res2'], name = 'add_2'), - ], - name = 'cse_graph_output', - inputs = [ - helper.make_tensor_value_info("x", TensorProto.FLOAT, [5]) - ], - outputs = [ - helper.make_tensor_value_info('res1', TensorProto.FLOAT, [5]), - helper.make_tensor_value_info('res2', TensorProto.FLOAT, [5]) - ], - initializer = [ - helper.make_tensor('b', TensorProto.FLOAT, [5], list(range(5))), - ] - ) - - _onnx_export(graph_def, 'cse_graph_output.onnx') + graph_def = helper.make_graph( + nodes=[ + helper.make_node(op_type="Add", inputs=["x", "b"], outputs=["res1"], name="add_1"), + helper.make_node(op_type="Add", inputs=["x", "b"], outputs=["res2"], name="add_2"), + ], + name="cse_graph_output", + inputs=[helper.make_tensor_value_info("x", TensorProto.FLOAT, [5])], + outputs=[ + helper.make_tensor_value_info("res1", TensorProto.FLOAT, [5]), + helper.make_tensor_value_info("res2", TensorProto.FLOAT, [5]), + ], + initializer=[ + helper.make_tensor("b", TensorProto.FLOAT, [5], list(range(5))), + ], + ) + + _onnx_export(graph_def, "cse_graph_output.onnx") + def cse_optional_args(): - n = 5 - graph_def = helper.make_graph( - nodes = [ - helper.make_node(op_type = "Clip", inputs = ['x'], outputs = ['Clipped0'], name = 'clip_0'), - helper.make_node(op_type = "Clip", inputs = ['x', ''], outputs = ['Clipped1'], name = 'clip_1'), - helper.make_node(op_type = "Clip", inputs = ['x', '', ''], outputs = ['Clipped2'], name = 'clip_2'), - helper.make_node(op_type = "Clip", inputs = ['x', '', 'c'], outputs = ['Clipped3'], name = 'clip_3'), - helper.make_node(op_type = "Clip", inputs = ['x', 'c', ''], outputs = ['Clipped4'], name = 'clip_4'), - helper.make_node(op_type = "Sum", inputs = ['Clipped0', 'Clipped1', 'Clipped2', 'Clipped3', 'Clipped4'], outputs = ['Result'], name = 'sum_1') - ], - name = 'cse_optional_args', - inputs = [ - helper.make_tensor_value_info("x", TensorProto.FLOAT, [n]) - ], - outputs = [ - helper.make_tensor_value_info('Result', TensorProto.FLOAT, [n]) - ], - initializer = [ - helper.make_tensor('c', TensorProto.FLOAT, [], [0.5]) - ] - ) - _onnx_export(graph_def, 'cse_optional_args.onnx') + n = 5 + graph_def = helper.make_graph( + nodes=[ + helper.make_node(op_type="Clip", inputs=["x"], outputs=["Clipped0"], name="clip_0"), + helper.make_node(op_type="Clip", inputs=["x", ""], outputs=["Clipped1"], name="clip_1"), + helper.make_node( + op_type="Clip", + inputs=["x", "", ""], + outputs=["Clipped2"], + name="clip_2", + ), + helper.make_node( + op_type="Clip", + inputs=["x", "", "c"], + outputs=["Clipped3"], + name="clip_3", + ), + helper.make_node( + op_type="Clip", + inputs=["x", "c", ""], + outputs=["Clipped4"], + name="clip_4", + ), + helper.make_node( + op_type="Sum", + inputs=["Clipped0", "Clipped1", "Clipped2", "Clipped3", "Clipped4"], + outputs=["Result"], + name="sum_1", + ), + ], + name="cse_optional_args", + inputs=[helper.make_tensor_value_info("x", TensorProto.FLOAT, [n])], + outputs=[helper.make_tensor_value_info("Result", TensorProto.FLOAT, [n])], + initializer=[helper.make_tensor("c", TensorProto.FLOAT, [], [0.5])], + ) + _onnx_export(graph_def, "cse_optional_args.onnx") + def cse_subgraph(): - if_true_graph = helper.make_graph( - nodes = [ - helper.make_node(op_type = "Sum", inputs = ['x', 'x'], outputs = ['Result1'], name = 'iftrue_res_1'), - helper.make_node(op_type = "Sum", inputs = ['x', 'x'], outputs = ['Result2'], name = 'iftrue_res_2'), - helper.make_node(op_type = "Mul", inputs = ['x', 'x'], outputs = ['Intermediate1'], name = 'iftrue_intermediate_1'), - helper.make_node(op_type = "Mul", inputs = ['x', 'x'], outputs = ['Intermediate2'], name = 'iftrue_intermediate_2'), - helper.make_node(op_type = "Sum", inputs = ['Intermediate1', 'Intermediate2'], outputs = ['Result3'], name = 'iftrue_res_3'), - ], - name = 'if_true_graph', - inputs = [ - ], - outputs = [ - helper.make_tensor_value_info('Result1', TensorProto.FLOAT, [2]), - helper.make_tensor_value_info('Result2', TensorProto.FLOAT, [2]), - helper.make_tensor_value_info('Result3', TensorProto.FLOAT, [2]), - ], - initializer = [ - ] - ) - - if_false_graph = helper.make_graph( - nodes = [ - helper.make_node(op_type = "Mul", inputs = ['x', 'x'], outputs = ['Result1'], name = 'iffalse_res_1'), - helper.make_node(op_type = "Mul", inputs = ['x', 'x'], outputs = ['Result2'], name = 'iffalse_res_2'), - helper.make_node(op_type = "Sum", inputs = ['x', 'x'], outputs = ['Intermediate1'], name = 'iffalse_intermediate_1'), - helper.make_node(op_type = "Sum", inputs = ['x', 'x'], outputs = ['Intermediate2'], name = 'iffalse_intermediate_2'), - helper.make_node(op_type = "Mul", inputs = ['Intermediate1', 'Intermediate2'], outputs = ['Result3'], name = 'iffalse_res_3'), - ], - name = 'if_false_graph', - inputs = [ - ], - outputs = [ - helper.make_tensor_value_info('Result1', TensorProto.FLOAT, [2]), - helper.make_tensor_value_info('Result2', TensorProto.FLOAT, [2]), - helper.make_tensor_value_info('Result3', TensorProto.FLOAT, [2]), - ], - initializer = [ - ] - ) - - graph_def = helper.make_graph( - nodes = [ - helper.make_node(op_type = "If", inputs = ['b'], outputs = ['Result1', 'Result2', 'Result3'], name = 'if_0', then_branch=if_true_graph, else_branch=if_false_graph), - ], - name = 'cse_subgraph', - inputs = [ - helper.make_tensor_value_info("b", TensorProto.BOOL, [1]), - helper.make_tensor_value_info("x", TensorProto.FLOAT, [2]) - ], - outputs = [ - helper.make_tensor_value_info('Result1', TensorProto.FLOAT, [2]), - helper.make_tensor_value_info('Result2', TensorProto.FLOAT, [2]), - helper.make_tensor_value_info('Result3', TensorProto.FLOAT, [2]), - ], - initializer = [ - ] - ) - _onnx_export(graph_def, 'cse_subgraph.onnx') + if_true_graph = helper.make_graph( + nodes=[ + helper.make_node( + op_type="Sum", + inputs=["x", "x"], + outputs=["Result1"], + name="iftrue_res_1", + ), + helper.make_node( + op_type="Sum", + inputs=["x", "x"], + outputs=["Result2"], + name="iftrue_res_2", + ), + helper.make_node( + op_type="Mul", + inputs=["x", "x"], + outputs=["Intermediate1"], + name="iftrue_intermediate_1", + ), + helper.make_node( + op_type="Mul", + inputs=["x", "x"], + outputs=["Intermediate2"], + name="iftrue_intermediate_2", + ), + helper.make_node( + op_type="Sum", + inputs=["Intermediate1", "Intermediate2"], + outputs=["Result3"], + name="iftrue_res_3", + ), + ], + name="if_true_graph", + inputs=[], + outputs=[ + helper.make_tensor_value_info("Result1", TensorProto.FLOAT, [2]), + helper.make_tensor_value_info("Result2", TensorProto.FLOAT, [2]), + helper.make_tensor_value_info("Result3", TensorProto.FLOAT, [2]), + ], + initializer=[], + ) + + if_false_graph = helper.make_graph( + nodes=[ + helper.make_node( + op_type="Mul", + inputs=["x", "x"], + outputs=["Result1"], + name="iffalse_res_1", + ), + helper.make_node( + op_type="Mul", + inputs=["x", "x"], + outputs=["Result2"], + name="iffalse_res_2", + ), + helper.make_node( + op_type="Sum", + inputs=["x", "x"], + outputs=["Intermediate1"], + name="iffalse_intermediate_1", + ), + helper.make_node( + op_type="Sum", + inputs=["x", "x"], + outputs=["Intermediate2"], + name="iffalse_intermediate_2", + ), + helper.make_node( + op_type="Mul", + inputs=["Intermediate1", "Intermediate2"], + outputs=["Result3"], + name="iffalse_res_3", + ), + ], + name="if_false_graph", + inputs=[], + outputs=[ + helper.make_tensor_value_info("Result1", TensorProto.FLOAT, [2]), + helper.make_tensor_value_info("Result2", TensorProto.FLOAT, [2]), + helper.make_tensor_value_info("Result3", TensorProto.FLOAT, [2]), + ], + initializer=[], + ) + + graph_def = helper.make_graph( + nodes=[ + helper.make_node( + op_type="If", + inputs=["b"], + outputs=["Result1", "Result2", "Result3"], + name="if_0", + then_branch=if_true_graph, + else_branch=if_false_graph, + ), + ], + name="cse_subgraph", + inputs=[ + helper.make_tensor_value_info("b", TensorProto.BOOL, [1]), + helper.make_tensor_value_info("x", TensorProto.FLOAT, [2]), + ], + outputs=[ + helper.make_tensor_value_info("Result1", TensorProto.FLOAT, [2]), + helper.make_tensor_value_info("Result2", TensorProto.FLOAT, [2]), + helper.make_tensor_value_info("Result3", TensorProto.FLOAT, [2]), + ], + initializer=[], + ) + _onnx_export(graph_def, "cse_subgraph.onnx") + def cse_random(): - n = 5 - graph_def = helper.make_graph( - nodes = [ - helper.make_node(op_type = "RandomUniform", inputs = [], outputs = ['Random1'], name = 'random_uniform_1', shape=[n]), - helper.make_node(op_type = "RandomUniform", inputs = [], outputs = ['Random2'], name = 'random_uniform_2', shape=[n]), - helper.make_node(op_type = "RandomUniform", inputs = [], outputs = ['Random3'], name = 'random_uniform_3', shape=[n], seed=1.0), - helper.make_node(op_type = "RandomUniform", inputs = [], outputs = ['Random4'], name = 'random_uniform_4', shape=[n], seed=1.0), - helper.make_node(op_type = "Sum", inputs = ['x', 'Random1', 'Random2', 'Random3', 'Random4'], outputs = ['Result'], name = 'sum_1') - ], - name = 'cse_random', - inputs = [ - helper.make_tensor_value_info("x", TensorProto.FLOAT, [n]) - ], - outputs = [ - helper.make_tensor_value_info('Result', TensorProto.FLOAT, [n]) - ], - initializer = [ - ] - ) - _onnx_export(graph_def, 'cse_random.onnx') + n = 5 + graph_def = helper.make_graph( + nodes=[ + helper.make_node( + op_type="RandomUniform", + inputs=[], + outputs=["Random1"], + name="random_uniform_1", + shape=[n], + ), + helper.make_node( + op_type="RandomUniform", + inputs=[], + outputs=["Random2"], + name="random_uniform_2", + shape=[n], + ), + helper.make_node( + op_type="RandomUniform", + inputs=[], + outputs=["Random3"], + name="random_uniform_3", + shape=[n], + seed=1.0, + ), + helper.make_node( + op_type="RandomUniform", + inputs=[], + outputs=["Random4"], + name="random_uniform_4", + shape=[n], + seed=1.0, + ), + helper.make_node( + op_type="Sum", + inputs=["x", "Random1", "Random2", "Random3", "Random4"], + outputs=["Result"], + name="sum_1", + ), + ], + name="cse_random", + inputs=[helper.make_tensor_value_info("x", TensorProto.FLOAT, [n])], + outputs=[helper.make_tensor_value_info("Result", TensorProto.FLOAT, [n])], + initializer=[], + ) + _onnx_export(graph_def, "cse_random.onnx") + def cse_merge_constants(): - n = 3 - graph_def = helper.make_graph( - nodes = [ - helper.make_node(op_type = "Add", inputs = ['c', 'c'], outputs = ['Add1'], name = 'add_1'), - helper.make_node(op_type = "Add", inputs = ['c', 'c'], outputs = ['Add2'], name = 'add_2'), - helper.make_node(op_type = "Add", inputs = ['Add1', 'x'], outputs = ['Add3'], name = 'add_3'), - helper.make_node(op_type = "Add", inputs = ['Add2', 'x'], outputs = ['Add4'], name = 'add_4'), - helper.make_node(op_type = "Add", inputs = ['Add3', 'Add4'], outputs = ['Result'], name = 'add_5'), - ], - name = 'cse_merge_constants', - inputs = [ - helper.make_tensor_value_info("x", TensorProto.FLOAT, [n]) - ], - outputs = [ - helper.make_tensor_value_info('Result', TensorProto.FLOAT, [n]) - ], - initializer = [ - helper.make_tensor('c', TensorProto.FLOAT, [n], list(range(n))) - ] - ) - _onnx_export(graph_def, 'cse_merge_constants.onnx') + n = 3 + graph_def = helper.make_graph( + nodes=[ + helper.make_node(op_type="Add", inputs=["c", "c"], outputs=["Add1"], name="add_1"), + helper.make_node(op_type="Add", inputs=["c", "c"], outputs=["Add2"], name="add_2"), + helper.make_node(op_type="Add", inputs=["Add1", "x"], outputs=["Add3"], name="add_3"), + helper.make_node(op_type="Add", inputs=["Add2", "x"], outputs=["Add4"], name="add_4"), + helper.make_node(op_type="Add", inputs=["Add3", "Add4"], outputs=["Result"], name="add_5"), + ], + name="cse_merge_constants", + inputs=[helper.make_tensor_value_info("x", TensorProto.FLOAT, [n])], + outputs=[helper.make_tensor_value_info("Result", TensorProto.FLOAT, [n])], + initializer=[helper.make_tensor("c", TensorProto.FLOAT, [n], list(range(n)))], + ) + _onnx_export(graph_def, "cse_merge_constants.onnx") + def cse_only_one_graph_output(): - graph_def = helper.make_graph( - nodes = [ - helper.make_node(op_type = "Split", inputs = ['x'], outputs = ['Split1Output1', 'Split1Output2'], name = 'split_1'), - helper.make_node(op_type = "Split", inputs = ['x'], outputs = ['Split2Output1', 'Split2Output2'], name = 'split_2'), - helper.make_node(op_type = "ReduceSum", inputs = ['Split1Output1'], outputs = ['ReduceSum1'], name = 'reducesum_1'), - helper.make_node(op_type = "ReduceSum", inputs = ['Split2Output1'], outputs = ['ReduceSum2'], name = 'reducesum_2'), - helper.make_node(op_type = "Add", inputs = ['ReduceSum1', 'ReduceSum2'], outputs = ['Add'], name = 'add_1'), - ], - name = 'cse_only_one_graph_output', - inputs = [ - helper.make_tensor_value_info("x", TensorProto.FLOAT, [4, 4]) - ], - outputs = [ - helper.make_tensor_value_info('Split1Output2', TensorProto.FLOAT, [2, 4]), - helper.make_tensor_value_info('Split2Output2', TensorProto.FLOAT, [2, 4]), - helper.make_tensor_value_info('Add', TensorProto.FLOAT, [1, 1]), - ], - initializer = [ - ] - ) - _onnx_export(graph_def, 'cse_only_one_graph_output.onnx') + graph_def = helper.make_graph( + nodes=[ + helper.make_node( + op_type="Split", + inputs=["x"], + outputs=["Split1Output1", "Split1Output2"], + name="split_1", + ), + helper.make_node( + op_type="Split", + inputs=["x"], + outputs=["Split2Output1", "Split2Output2"], + name="split_2", + ), + helper.make_node( + op_type="ReduceSum", + inputs=["Split1Output1"], + outputs=["ReduceSum1"], + name="reducesum_1", + ), + helper.make_node( + op_type="ReduceSum", + inputs=["Split2Output1"], + outputs=["ReduceSum2"], + name="reducesum_2", + ), + helper.make_node( + op_type="Add", + inputs=["ReduceSum1", "ReduceSum2"], + outputs=["Add"], + name="add_1", + ), + ], + name="cse_only_one_graph_output", + inputs=[helper.make_tensor_value_info("x", TensorProto.FLOAT, [4, 4])], + outputs=[ + helper.make_tensor_value_info("Split1Output2", TensorProto.FLOAT, [2, 4]), + helper.make_tensor_value_info("Split2Output2", TensorProto.FLOAT, [2, 4]), + helper.make_tensor_value_info("Add", TensorProto.FLOAT, [1, 1]), + ], + initializer=[], + ) + _onnx_export(graph_def, "cse_only_one_graph_output.onnx") def generate_all(): - cse1() - cse_graph_output() - cse_optional_args() - cse_subgraph() - cse_random() - cse_merge_constants() - cse_only_one_graph_output() - -if __name__ == '__main__': - generate_all() + cse1() + cse_graph_output() + cse_optional_args() + cse_subgraph() + cse_random() + cse_merge_constants() + cse_only_one_graph_output() +if __name__ == "__main__": + generate_all() diff --git a/onnxruntime/test/testdata/transform/dropout_zeroratio_elimination.py b/onnxruntime/test/testdata/transform/dropout_zeroratio_elimination.py index 3c3d22a235f42..9518b78e61832 100644 --- a/onnxruntime/test/testdata/transform/dropout_zeroratio_elimination.py +++ b/onnxruntime/test/testdata/transform/dropout_zeroratio_elimination.py @@ -1,60 +1,60 @@ import onnx -from onnx import helper -from onnx import TensorProto, OperatorSetIdProto +from onnx import OperatorSetIdProto, TensorProto, helper # inputs/outputs -X = helper.make_tensor_value_info('X', TensorProto.FLOAT, [2, 1]) -O1 = helper.make_tensor_value_info('O1', TensorProto.FLOAT, [2, 1]) -O2 = helper.make_tensor_value_info('O2', TensorProto.FLOAT, [2, 1]) -O3 = helper.make_tensor_value_info('O3', TensorProto.FLOAT, [2, 1]) -O4 = helper.make_tensor_value_info('O4', TensorProto.FLOAT, [2, 1]) -O5 = helper.make_tensor_value_info('O5', TensorProto.FLOAT, [2, 1]) +X = helper.make_tensor_value_info("X", TensorProto.FLOAT, [2, 1]) +O1 = helper.make_tensor_value_info("O1", TensorProto.FLOAT, [2, 1]) +O2 = helper.make_tensor_value_info("O2", TensorProto.FLOAT, [2, 1]) +O3 = helper.make_tensor_value_info("O3", TensorProto.FLOAT, [2, 1]) +O4 = helper.make_tensor_value_info("O4", TensorProto.FLOAT, [2, 1]) +O5 = helper.make_tensor_value_info("O5", TensorProto.FLOAT, [2, 1]) -X2 = helper.make_tensor_value_info('X2', TensorProto.FLOAT, []) +X2 = helper.make_tensor_value_info("X2", TensorProto.FLOAT, []) # initializers -zeroratio_float = helper.make_tensor('ratio_zero_float', TensorProto.FLOAT, [], [0.0]) -zeroratio_double = helper.make_tensor('ratio_zero_double', TensorProto.DOUBLE, [], [0.0]) -zeroratio_float16 = helper.make_tensor('ratio_zero_float16', TensorProto.FLOAT16, [], [0]) -nonzeroratio = helper.make_tensor('ratio_nonzero', TensorProto.FLOAT, [], [0.1]) -training_mode = helper.make_tensor('training_mode', TensorProto.BOOL, [], [1]) +zeroratio_float = helper.make_tensor("ratio_zero_float", TensorProto.FLOAT, [], [0.0]) +zeroratio_double = helper.make_tensor("ratio_zero_double", TensorProto.DOUBLE, [], [0.0]) +zeroratio_float16 = helper.make_tensor("ratio_zero_float16", TensorProto.FLOAT16, [], [0]) +nonzeroratio = helper.make_tensor("ratio_nonzero", TensorProto.FLOAT, [], [0.1]) +training_mode = helper.make_tensor("training_mode", TensorProto.BOOL, [], [1]) opsets = [] onnxdomain = OperatorSetIdProto() onnxdomain.version = 12 -onnxdomain.domain = "" # The empty string ("") or absence of this field implies the operator set that is defined as part of the ONNX specification. +onnxdomain.domain = "" # The empty string ("") or absence of this field implies the operator set that is defined as part of the ONNX specification. opsets.append(onnxdomain) -kwargs={} -kwargs['opset_imports'] = opsets +kwargs = {} +kwargs["opset_imports"] = opsets # Create the model (ModelProto) -I1 = helper.make_node('Identity', ['X'], ['I1_out'], name='I1') +I1 = helper.make_node("Identity", ["X"], ["I1_out"], name="I1") D1 = helper.make_node("Dropout", ["I1_out", "ratio_zero_float", "training_mode"], ["D1_out"], "D1") -I2 = helper.make_node('Identity', ['D1_out'], ['O1'], name='I2') +I2 = helper.make_node("Identity", ["D1_out"], ["O1"], name="I2") -I3 = helper.make_node('Identity', ['X'], ['I3_out'], name='I3') +I3 = helper.make_node("Identity", ["X"], ["I3_out"], name="I3") D2 = helper.make_node("Dropout", ["I3_out", "ratio_nonzero", "training_mode"], ["D2_out"], "D2") -I4 = helper.make_node('Identity', ['D2_out'], ['O2'], name='I4') +I4 = helper.make_node("Identity", ["D2_out"], ["O2"], name="I4") -I5 = helper.make_node('Identity', ['X'], ['I5_out'], name='I5') +I5 = helper.make_node("Identity", ["X"], ["I5_out"], name="I5") D3 = helper.make_node("Dropout", ["I5_out", "X2", "training_mode"], ["D3_out"], "D3") -I6 = helper.make_node('Identity', ['D3_out'], ['O3'], name='I6') +I6 = helper.make_node("Identity", ["D3_out"], ["O3"], name="I6") -I7 = helper.make_node('Identity', ['X'], ['I7_out'], name='I7') +I7 = helper.make_node("Identity", ["X"], ["I7_out"], name="I7") D4 = helper.make_node("Dropout", ["I7_out", "ratio_zero_double", "training_mode"], ["D4_out"], "D4") -I8 = helper.make_node('Identity', ['D4_out'], ['O4'], name='I8') +I8 = helper.make_node("Identity", ["D4_out"], ["O4"], name="I8") -I9 = helper.make_node('Identity', ['X'], ['I9_out'], name='I9') +I9 = helper.make_node("Identity", ["X"], ["I9_out"], name="I9") D5 = helper.make_node("Dropout", ["I9_out", "ratio_zero_float16", "training_mode"], ["D5_out"], "D5") -I10 = helper.make_node('Identity', ['D5_out'], ['O5'], name='I10') +I10 = helper.make_node("Identity", ["D5_out"], ["O5"], name="I10") graph = helper.make_graph( [I1, D1, I2, I3, D2, I4, I5, D3, I6, I7, D4, I8, I9, D5, I10], - "Dropout_Elimination", #name + "Dropout_Elimination", # name [X, X2], [O1, O2, O3, O4, O5], - [zeroratio_float, zeroratio_double, zeroratio_float16, nonzeroratio, training_mode]) + [zeroratio_float, zeroratio_double, zeroratio_float16, nonzeroratio, training_mode], +) -model = helper.make_model(graph, producer_name='onnx-example', **kwargs) -onnx.save(model, 'dropout_ratio.onnx') \ No newline at end of file +model = helper.make_model(graph, producer_name="onnx-example", **kwargs) +onnx.save(model, "dropout_ratio.onnx") diff --git a/onnxruntime/test/testdata/transform/expand_elimination.py b/onnxruntime/test/testdata/transform/expand_elimination.py index debe0a9908766..da1530876348e 100644 --- a/onnxruntime/test/testdata/transform/expand_elimination.py +++ b/onnxruntime/test/testdata/transform/expand_elimination.py @@ -1,56 +1,74 @@ -import onnx -from onnx import helper -from onnx import TensorProto, GraphProto, OperatorSetIdProto -from onnx import numpy_helper import numpy as np +import onnx +from onnx import GraphProto, OperatorSetIdProto, TensorProto, helper, numpy_helper + +X1 = helper.make_tensor_value_info("input1", TensorProto.FLOAT, [2, 1]) +X2 = helper.make_tensor_value_info("input2", TensorProto.FLOAT, ["dynamic", 4]) +Y = helper.make_tensor_value_info("output", TensorProto.FLOAT, [1, 2, 4]) + +shape_constant1 = numpy_helper.from_array(np.array([1, 4], dtype=np.int64), name="shape_constant1") +shape_constant2 = numpy_helper.from_array(np.array([1, 1], dtype=np.int64), name="shape_constant2") +shape_constant3 = numpy_helper.from_array(np.array([2, 1], dtype=np.int64), name="shape_constant3") +shape_constant4 = numpy_helper.from_array(np.array([1, 1, 1], dtype=np.int64), name="shape_constant4") +shape_constant5 = numpy_helper.from_array(np.array([1, 4], dtype=np.int64), name="shape_constant5") +shape_constant6 = numpy_helper.from_array(np.array([2, 1], dtype=np.int64), name="shape_constant6") -X1 = helper.make_tensor_value_info('input1', TensorProto.FLOAT, [2, 1]) -X2 = helper.make_tensor_value_info('input2', TensorProto.FLOAT, ['dynamic', 4]) -Y = helper.make_tensor_value_info('output', TensorProto.FLOAT, [1, 2, 4]) - -shape_constant1 = numpy_helper.from_array(np.array([1, 4], dtype=np.int64), name='shape_constant1') -shape_constant2 = numpy_helper.from_array(np.array([1, 1], dtype=np.int64), name='shape_constant2') -shape_constant3 = numpy_helper.from_array(np.array([2, 1], dtype=np.int64), name='shape_constant3') -shape_constant4 = numpy_helper.from_array(np.array([1, 1, 1], dtype=np.int64), name='shape_constant4') -shape_constant5 = numpy_helper.from_array(np.array([1, 4], dtype=np.int64), name='shape_constant5') -shape_constant6 = numpy_helper.from_array(np.array([2, 1], dtype=np.int64), name='shape_constant6') - -identity1 = helper.make_node('Identity', ['input1'], ['identity1'], name='identity1') -expand1 = helper.make_node('Expand', ['identity1', shape_constant1.name], ['expand1'], name='expand1') -expand2 = helper.make_node('Expand', ['identity1', shape_constant2.name], ['expand2'], name='expand2') -mul1 = helper.make_node('Mul', ['expand1', 'expand2'], ['mul1'], name='mul1') # (2, 4) -expand3 = helper.make_node('Expand', ['mul1', shape_constant3.name], ['expand3'], name='expand3') -expand4 = helper.make_node('Expand', ['identity1', shape_constant4.name], ['expand4'], name='expand4') -mul2 = helper.make_node('Mul', ['expand3', 'expand4'], ['mul2'], name='mul2') # (1, 2, 4) -identity2 = helper.make_node('Identity', ['input2'], ['identity2'], name='identity2') -expand5 = helper.make_node('Expand', ['identity2', shape_constant5.name], ['expand5'], name='expand5') -expand6 = helper.make_node('Expand', ['identity2', shape_constant6.name], ['expand6'], name='expand6') -mul3 = helper.make_node('Mul', ['expand5', 'expand6'], ['mul3'], name='mul3') # (dynamic=2, 4) -mul4 = helper.make_node('Mul', ['mul2', 'mul3'], ['output'], name='mul4') +identity1 = helper.make_node("Identity", ["input1"], ["identity1"], name="identity1") +expand1 = helper.make_node("Expand", ["identity1", shape_constant1.name], ["expand1"], name="expand1") +expand2 = helper.make_node("Expand", ["identity1", shape_constant2.name], ["expand2"], name="expand2") +mul1 = helper.make_node("Mul", ["expand1", "expand2"], ["mul1"], name="mul1") # (2, 4) +expand3 = helper.make_node("Expand", ["mul1", shape_constant3.name], ["expand3"], name="expand3") +expand4 = helper.make_node("Expand", ["identity1", shape_constant4.name], ["expand4"], name="expand4") +mul2 = helper.make_node("Mul", ["expand3", "expand4"], ["mul2"], name="mul2") # (1, 2, 4) +identity2 = helper.make_node("Identity", ["input2"], ["identity2"], name="identity2") +expand5 = helper.make_node("Expand", ["identity2", shape_constant5.name], ["expand5"], name="expand5") +expand6 = helper.make_node("Expand", ["identity2", shape_constant6.name], ["expand6"], name="expand6") +mul3 = helper.make_node("Mul", ["expand5", "expand6"], ["mul3"], name="mul3") # (dynamic=2, 4) +mul4 = helper.make_node("Mul", ["mul2", "mul3"], ["output"], name="mul4") # Create the graph (GraphProto) graph_def = helper.make_graph( - [identity1, expand1, expand2, mul1, expand3, expand4, mul2, identity2, expand5, expand6, mul3, mul4], - 'expand_elimination_model', + [ + identity1, + expand1, + expand2, + mul1, + expand3, + expand4, + mul2, + identity2, + expand5, + expand6, + mul3, + mul4, + ], + "expand_elimination_model", [X1, X2], [Y], - [shape_constant1, shape_constant2, shape_constant3, shape_constant4, shape_constant5, shape_constant6] + [ + shape_constant1, + shape_constant2, + shape_constant3, + shape_constant4, + shape_constant5, + shape_constant6, + ], ) opsets = [] onnxdomain = OperatorSetIdProto() onnxdomain.version = 12 -onnxdomain.domain = "" # The empty string ("") or absence of this field implies the operator set that is defined as part of the ONNX specification. +onnxdomain.domain = "" # The empty string ("") or absence of this field implies the operator set that is defined as part of the ONNX specification. opsets.append(onnxdomain) msdomain = OperatorSetIdProto() msdomain.version = 1 -msdomain.domain = 'com.microsoft' +msdomain.domain = "com.microsoft" opsets.append(msdomain) -kwargs={} -kwargs['opset_imports'] = opsets +kwargs = {} +kwargs["opset_imports"] = opsets # Create the model (ModelProto) -model_def = helper.make_model(graph_def, producer_name='onnx-example', **kwargs) -onnx.save(model_def, 'expand_elimination.onnx') +model_def = helper.make_model(graph_def, producer_name="onnx-example", **kwargs) +onnx.save(model_def, "expand_elimination.onnx") diff --git a/onnxruntime/test/testdata/transform/fusion/attention_gen.py b/onnxruntime/test/testdata/transform/fusion/attention_gen.py index f229a665561b0..cd1569ae5cd2a 100644 --- a/onnxruntime/test/testdata/transform/fusion/attention_gen.py +++ b/onnxruntime/test/testdata/transform/fusion/attention_gen.py @@ -1,290 +1,635 @@ import sys -import onnx -from onnx import helper -from onnx import TensorProto from enum import Enum +import onnx +from onnx import TensorProto, helper + matmul_q_weights = [ - -0.10791015625, -0.04193115234375, 0.09051513671875, 0.025787353515625, - -0.11572265625, -0.126953125, -0.043304443359375, -0.02984619140625, - 0.033538818359375, -0.05755615234375, -0.04986572265625, -0.01558685302734375, - -0.0352783203125, 0.03546142578125, 0.05218505859375, 0.005565643310546875, - -0.05950927734375, 0.0172119140625, 0.06646728515625, 0.046630859375, - 0.031524658203125, 0.048614501953125, -0.11102294921875, -0.018463134765625, - -0.0352783203125, 0.037200927734375, 0.082763671875, 0.1260986328125, - -0.1087646484375, 0.00566864013671875, -0.027191162109375, -0.0027103424072265625, - -0.1256103515625, -0.0245361328125, 0.04437255859375, -0.05267333984375, - -0.0606689453125, 0.009735107421875, 0.01100921630859375, 0.045928955078125, - -0.036834716796875, 0.005405426025390625, 0.04571533203125, 0.11767578125, - 0.0286102294921875, -0.01071929931640625, -0.006378173828125, 0.0213470458984375, - -0.1434326171875, -0.0975341796875, 0.031402587890625, 0.02880859375, - 0.048004150390625, -0.028289794921875, 0.018157958984375, 0.061981201171875, - -0.126953125, -0.03350830078125, 0.1297607421875, -0.0093841552734375, - -0.0258026123046875, -0.000560760498046875, 0.1123046875, -0.0560302734375 + -0.10791015625, + -0.04193115234375, + 0.09051513671875, + 0.025787353515625, + -0.11572265625, + -0.126953125, + -0.043304443359375, + -0.02984619140625, + 0.033538818359375, + -0.05755615234375, + -0.04986572265625, + -0.01558685302734375, + -0.0352783203125, + 0.03546142578125, + 0.05218505859375, + 0.005565643310546875, + -0.05950927734375, + 0.0172119140625, + 0.06646728515625, + 0.046630859375, + 0.031524658203125, + 0.048614501953125, + -0.11102294921875, + -0.018463134765625, + -0.0352783203125, + 0.037200927734375, + 0.082763671875, + 0.1260986328125, + -0.1087646484375, + 0.00566864013671875, + -0.027191162109375, + -0.0027103424072265625, + -0.1256103515625, + -0.0245361328125, + 0.04437255859375, + -0.05267333984375, + -0.0606689453125, + 0.009735107421875, + 0.01100921630859375, + 0.045928955078125, + -0.036834716796875, + 0.005405426025390625, + 0.04571533203125, + 0.11767578125, + 0.0286102294921875, + -0.01071929931640625, + -0.006378173828125, + 0.0213470458984375, + -0.1434326171875, + -0.0975341796875, + 0.031402587890625, + 0.02880859375, + 0.048004150390625, + -0.028289794921875, + 0.018157958984375, + 0.061981201171875, + -0.126953125, + -0.03350830078125, + 0.1297607421875, + -0.0093841552734375, + -0.0258026123046875, + -0.000560760498046875, + 0.1123046875, + -0.0560302734375, ] matmul_k_weights = [ - 0.022125244140625, -0.017730712890625, -0.03265380859375, -0.05108642578125, - 0.0423583984375, 0.112060546875, 0.080810546875, 0.09375, - -0.043182373046875, -0.05010986328125, -0.063720703125, -0.00824737548828125, - 0.1492919921875, 0.048431396484375, -0.0482177734375, -0.1123046875, - -0.00719451904296875, -0.0229949951171875, -0.03424072265625, 0.0152435302734375, - 0.023468017578125, 0.0301513671875, -0.04656982421875, -0.043701171875, - 0.040313720703125, 0.00644683837890625, -0.0186614990234375, 0.0261383056640625, - 0.09063720703125, -0.078369140625, -0.05841064453125, -0.0743408203125, - 0.040130615234375, -0.0782470703125, 0.03729248046875, -0.07537841796875, - -0.0006098747253417969, 0.0285186767578125, -0.0518798828125, -0.01404571533203125, - -0.08001708984375, 0.015960693359375, -0.0357666015625, -0.048065185546875, - 0.01461029052734375, 0.06365966796875, 0.10125732421875, -0.00481414794921875, - 0.056182861328125, 0.072998046875, -0.06591796875, -0.035064697265625, - -0.1356201171875, -0.055877685546875, 0.06793212890625, -0.1292724609375, - 0.054901123046875, -0.0021762847900390625, 0.059783935546875, -0.035430908203125, - 0.0528564453125, 0.035125732421875, -0.0186767578125, -0.062286376953125 + 0.022125244140625, + -0.017730712890625, + -0.03265380859375, + -0.05108642578125, + 0.0423583984375, + 0.112060546875, + 0.080810546875, + 0.09375, + -0.043182373046875, + -0.05010986328125, + -0.063720703125, + -0.00824737548828125, + 0.1492919921875, + 0.048431396484375, + -0.0482177734375, + -0.1123046875, + -0.00719451904296875, + -0.0229949951171875, + -0.03424072265625, + 0.0152435302734375, + 0.023468017578125, + 0.0301513671875, + -0.04656982421875, + -0.043701171875, + 0.040313720703125, + 0.00644683837890625, + -0.0186614990234375, + 0.0261383056640625, + 0.09063720703125, + -0.078369140625, + -0.05841064453125, + -0.0743408203125, + 0.040130615234375, + -0.0782470703125, + 0.03729248046875, + -0.07537841796875, + -0.0006098747253417969, + 0.0285186767578125, + -0.0518798828125, + -0.01404571533203125, + -0.08001708984375, + 0.015960693359375, + -0.0357666015625, + -0.048065185546875, + 0.01461029052734375, + 0.06365966796875, + 0.10125732421875, + -0.00481414794921875, + 0.056182861328125, + 0.072998046875, + -0.06591796875, + -0.035064697265625, + -0.1356201171875, + -0.055877685546875, + 0.06793212890625, + -0.1292724609375, + 0.054901123046875, + -0.0021762847900390625, + 0.059783935546875, + -0.035430908203125, + 0.0528564453125, + 0.035125732421875, + -0.0186767578125, + -0.062286376953125, ] matmul_v_weights = [ - -0.03643798828125, 0.02862548828125, 0.039764404296875, 0.06097412109375, - -0.002288818359375, -0.10797119140625, -0.01171875, 0.041717529296875, - 0.032196044921875, 0.0135650634765625, 0.020233154296875, -0.05084228515625, - -0.011260986328125, -0.1241455078125, -0.0101165771484375, -0.00490570068359375, - -0.01361083984375, -0.01454925537109375, -0.000637054443359375, -0.01534271240234375, - -0.0438232421875, 0.034332275390625, 0.011962890625, -0.0139617919921875, - 0.03363037109375, 0.0265350341796875, 0.039947509765625, -0.0268707275390625, - 0.03900146484375, 0.08172607421875, 0.015625, 0.010986328125, - 0.0240325927734375, -0.029022216796875, 0.01403045654296875, 0.0135650634765625, - -0.0174102783203125, 0.07305908203125, -0.0231170654296875, 0.011444091796875, - 0.006130218505859375, 0.06268310546875, -0.05902099609375, -0.0109100341796875, - 0.0185089111328125, 0.0161590576171875, 0.0185546875, 0.032440185546875, - 0.0011491775512695312, 0.01153564453125, 0.005832672119140625, -0.0538330078125, - -0.008056640625, 0.01096343994140625, 0.037811279296875, 0.05902099609375, - 0.0394287109375, 0.00004678964614868164, -0.03778076171875, 0.004573822021484375, - -0.0237274169921875, -0.0124969482421875, -0.045013427734375, -0.04217529296875 - ] + -0.03643798828125, + 0.02862548828125, + 0.039764404296875, + 0.06097412109375, + -0.002288818359375, + -0.10797119140625, + -0.01171875, + 0.041717529296875, + 0.032196044921875, + 0.0135650634765625, + 0.020233154296875, + -0.05084228515625, + -0.011260986328125, + -0.1241455078125, + -0.0101165771484375, + -0.00490570068359375, + -0.01361083984375, + -0.01454925537109375, + -0.000637054443359375, + -0.01534271240234375, + -0.0438232421875, + 0.034332275390625, + 0.011962890625, + -0.0139617919921875, + 0.03363037109375, + 0.0265350341796875, + 0.039947509765625, + -0.0268707275390625, + 0.03900146484375, + 0.08172607421875, + 0.015625, + 0.010986328125, + 0.0240325927734375, + -0.029022216796875, + 0.01403045654296875, + 0.0135650634765625, + -0.0174102783203125, + 0.07305908203125, + -0.0231170654296875, + 0.011444091796875, + 0.006130218505859375, + 0.06268310546875, + -0.05902099609375, + -0.0109100341796875, + 0.0185089111328125, + 0.0161590576171875, + 0.0185546875, + 0.032440185546875, + 0.0011491775512695312, + 0.01153564453125, + 0.005832672119140625, + -0.0538330078125, + -0.008056640625, + 0.01096343994140625, + 0.037811279296875, + 0.05902099609375, + 0.0394287109375, + 0.00004678964614868164, + -0.03778076171875, + 0.004573822021484375, + -0.0237274169921875, + -0.0124969482421875, + -0.045013427734375, + -0.04217529296875, +] matmul_qkv_weights = [ - -0.04888916015625, 0.0143280029296875, 0.066650390625,-0.0343017578125, - -0.0010356903076171875, -0.00048232078552246094, 0.07470703125, -0.04736328125, - 0.01454925537109375, -0.0086669921875, -0.051971435546875, -0.0201568603515625, - 0.040435791015625, -0.019256591796875, 0.0205078125, 0.0111541748046875, - 0.0071868896484375, -0.0298309326171875, -0.0306549072265625, -0.0225372314453125, - -0.04193115234375, 0.07073974609375, -0.048065185546875, 0.0198822021484375, - -0.035552978515625, -0.022796630859375, 0.03839111328125, 0.007099151611328125, - -0.0080108642578125, -0.0017957687377929688, 0.0266265869140625,-0.028289794921875, - 0.0032901763916015625, 0.0208740234375, -0.01529693603515625, -0.046600341796875, - -0.034637451171875, 0.011322021484375, -0.026458740234375, 0.04656982421875, - -0.0091705322265625, 0.017913818359375, -0.019256591796875, -0.001216888427734375, - -0.08245849609375, -0.023162841796875, -0.04132080078125, -0.03363037109375, - 0.0029315948486328125, 0.03173828125, -0.004024505615234375, 0.04534912109375, - -0.0036163330078125, -0.03912353515625, -0.00800323486328125, 0.058197021484375, - 0.05572509765625, 0.01165771484375, 0.06756591796875, 0.05816650390625, - -0.0654296875, -0.0241851806640625, 0.0205535888671875, -0.031707763671875 + -0.04888916015625, + 0.0143280029296875, + 0.066650390625, + -0.0343017578125, + -0.0010356903076171875, + -0.00048232078552246094, + 0.07470703125, + -0.04736328125, + 0.01454925537109375, + -0.0086669921875, + -0.051971435546875, + -0.0201568603515625, + 0.040435791015625, + -0.019256591796875, + 0.0205078125, + 0.0111541748046875, + 0.0071868896484375, + -0.0298309326171875, + -0.0306549072265625, + -0.0225372314453125, + -0.04193115234375, + 0.07073974609375, + -0.048065185546875, + 0.0198822021484375, + -0.035552978515625, + -0.022796630859375, + 0.03839111328125, + 0.007099151611328125, + -0.0080108642578125, + -0.0017957687377929688, + 0.0266265869140625, + -0.028289794921875, + 0.0032901763916015625, + 0.0208740234375, + -0.01529693603515625, + -0.046600341796875, + -0.034637451171875, + 0.011322021484375, + -0.026458740234375, + 0.04656982421875, + -0.0091705322265625, + 0.017913818359375, + -0.019256591796875, + -0.001216888427734375, + -0.08245849609375, + -0.023162841796875, + -0.04132080078125, + -0.03363037109375, + 0.0029315948486328125, + 0.03173828125, + -0.004024505615234375, + 0.04534912109375, + -0.0036163330078125, + -0.03912353515625, + -0.00800323486328125, + 0.058197021484375, + 0.05572509765625, + 0.01165771484375, + 0.06756591796875, + 0.05816650390625, + -0.0654296875, + -0.0241851806640625, + 0.0205535888671875, + -0.031707763671875, +] + +add_q_weight = [ + -0.23681640625, + -0.16552734375, + 0.2191162109375, + -0.1756591796875, + -0.03460693359375, + -0.05316162109375, + -0.336181640625, + -0.253662109375, ] -add_q_weight = [-0.23681640625, -0.16552734375, 0.2191162109375, -0.1756591796875, - -0.03460693359375, -0.05316162109375, -0.336181640625, -0.253662109375] +add_k_weight = [ + 0.0246734619140625, + 0.011993408203125, + 0.0178375244140625, + 0.00998687744140625, + 0.0255126953125, + 0.076416015625, + -0.040771484375, + 0.0107879638671875, +] -add_k_weight = [0.0246734619140625, 0.011993408203125, 0.0178375244140625, 0.00998687744140625, - 0.0255126953125, 0.076416015625, -0.040771484375, 0.0107879638671875] +add_v_weight = [ + -0.005893707275390625, + -0.00916290283203125, + 0.04541015625, + 0.0159454345703125, + -0.0029163360595703125, + -0.03472900390625, + 0.0535888671875, + 0.0091094970703125, +] -add_v_weight = [-0.005893707275390625, -0.00916290283203125, 0.04541015625, 0.0159454345703125, - -0.0029163360595703125, -0.03472900390625, 0.0535888671875, 0.0091094970703125] +add_qkv_weight = [ + -0.1146240234375, + -0.06768798828125, + -0.10040283203125, + -0.07012939453125, + -0.08624267578125, + 0.1507568359375, + -0.06634521484375, + -0.0194549560546875, +] -add_qkv_weight = [-0.1146240234375, -0.06768798828125, -0.10040283203125, -0.07012939453125, - -0.08624267578125, 0.1507568359375, -0.06634521484375, -0.0194549560546875] def GenerateModel(model_name): nodes = [ # Attention subgraph - helper.make_node("LayerNormalization", ["input_1", "layer_norm_weight", "layer_norm_bias"], - ["layernorm_out"], - "layernorm", - axis=-1, - epsion=0.000009999999747378752), - + helper.make_node( + "LayerNormalization", + ["input_1", "layer_norm_weight", "layer_norm_bias"], + ["layernorm_out"], + "layernorm", + axis=-1, + epsion=0.000009999999747378752, + ), # q nodes helper.make_node("MatMul", ["layernorm_out", "matmul_q_weight"], ["matmul_q_out"], "matmul_q"), helper.make_node("Add", ["matmul_q_out", "add_q_weight"], ["add_q_out"], "add_q"), helper.make_node("Reshape", ["add_q_out", "reshape_weight_1"], ["reshape_q_out"], "reshape_q"), - helper.make_node("Transpose", ["reshape_q_out"], ["transpose_q_out"], "transpose_q", - perm=[0,2,1,3]), - + helper.make_node( + "Transpose", + ["reshape_q_out"], + ["transpose_q_out"], + "transpose_q", + perm=[0, 2, 1, 3], + ), # k nodes helper.make_node("MatMul", ["layernorm_out", "matmul_k_weight"], ["matmul_k_out"], "matmul_k"), helper.make_node("Add", ["matmul_k_out", "add_k_weight"], ["add_k_out"], "add_k"), helper.make_node("Reshape", ["add_k_out", "reshape_weight_1"], ["reshape_k_out"], "reshape_k"), - helper.make_node("Transpose", ["reshape_k_out"], ["transpose_k_out"], "transpose_k", - perm=[0,2,3,1]), - + helper.make_node( + "Transpose", + ["reshape_k_out"], + ["transpose_k_out"], + "transpose_k", + perm=[0, 2, 3, 1], + ), # mask nodes - helper.make_node("Constant", [], ["mask_input"], "constant", - value=helper.make_tensor('mask', TensorProto.FLOAT, - [1, 3], [0.0, 0.0, 0.0])), + helper.make_node( + "Constant", + [], + ["mask_input"], + "constant", + value=helper.make_tensor("mask", TensorProto.FLOAT, [1, 3], [0.0, 0.0, 0.0]), + ), helper.make_node("Unsqueeze", ["mask_input"], ["unsqueeze0_out"], "unsqueeze0", axes=[1]), helper.make_node("Unsqueeze", ["unsqueeze0_out"], ["unsqueeze1_out"], "unsqueeze1", axes=[2]), helper.make_node("Sub", ["sub_weight", "unsqueeze1_out"], ["sub_out"], "sub"), helper.make_node("Mul", ["sub_out", "mul_weight"], ["mul_mask_out"], "mul_mask"), - # qk nodes - helper.make_node("MatMul", ["transpose_q_out", "transpose_k_out"], ["matmul_qk_out"], "matmul_qk"), + helper.make_node( + "MatMul", + ["transpose_q_out", "transpose_k_out"], + ["matmul_qk_out"], + "matmul_qk", + ), helper.make_node("Div", ["matmul_qk_out", "div_weight"], ["div_qk_out"], "div_qk"), helper.make_node("Add", ["div_qk_out", "mul_mask_out"], ["add_qk_out"], "add_qk"), helper.make_node("Softmax", ["add_qk_out"], ["softmax_qk_out"], "softmax_qk", axis=3), - # v nodes helper.make_node("MatMul", ["layernorm_out", "matmul_v_weight"], ["matmul_v_out"], "matmul_v"), helper.make_node("Add", ["matmul_v_out", "add_v_weight"], ["add_v_out"], "add_v"), helper.make_node("Reshape", ["add_v_out", "reshape_weight_1"], ["reshape_v_out"], "reshape_v"), - helper.make_node("Transpose", ["reshape_v_out"], ["transpose_v_out"], "transpose_v", - perm=[0,2,1,3]), - + helper.make_node( + "Transpose", + ["reshape_v_out"], + ["transpose_v_out"], + "transpose_v", + perm=[0, 2, 1, 3], + ), # qkv nodes - helper.make_node("MatMul", ["softmax_qk_out", "transpose_v_out"], ["matmul_qkv_1_out"], "matmul_qkv_1"), - helper.make_node("Transpose", ["matmul_qkv_1_out"], ["transpose_qkv_out"], "transpose_qkv", - perm=[0,2,1,3] + helper.make_node( + "MatMul", + ["softmax_qk_out", "transpose_v_out"], + ["matmul_qkv_1_out"], + "matmul_qkv_1", + ), + helper.make_node( + "Transpose", + ["matmul_qkv_1_out"], + ["transpose_qkv_out"], + "transpose_qkv", + perm=[0, 2, 1, 3], + ), + helper.make_node( + "Reshape", + ["transpose_qkv_out", "reshape_weight_2"], + ["reshape_qkv_out"], + "reshape_qkv", + ), + helper.make_node( + "MatMul", + ["reshape_qkv_out", "matmul_qkv_weight"], + ["matmul_qkv_2_out"], + "matmul_qkv_2", ), - helper.make_node("Reshape", ["transpose_qkv_out", "reshape_weight_2"], ["reshape_qkv_out"], "reshape_qkv"), - helper.make_node("MatMul", ["reshape_qkv_out", "matmul_qkv_weight"], ["matmul_qkv_2_out"], "matmul_qkv_2"), helper.make_node("Add", ["matmul_qkv_2_out", "add_qkv_weight"], ["add_qkv_out"], "add_qkv"), - helper.make_node("Add", ["add_qkv_out", "layernorm_out"], ["output"], "add"), ] initializers = [ # initializers - helper.make_tensor('layer_norm_weight', TensorProto.FLOAT, [8], [1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0]), - helper.make_tensor('layer_norm_bias', TensorProto.FLOAT, [8], [0.1, 0.2, 0.3, 0.4, 0.1, 0.2, 0.3, 0.4]), - helper.make_tensor('matmul_q_weight', TensorProto.FLOAT, [8, 8], matmul_q_weights), - helper.make_tensor('matmul_k_weight', TensorProto.FLOAT, [8, 8], matmul_k_weights), - helper.make_tensor('matmul_v_weight', TensorProto.FLOAT, [8, 8], matmul_v_weights), - helper.make_tensor('matmul_qkv_weight', TensorProto.FLOAT, [8, 8], matmul_qkv_weights), - helper.make_tensor('div_weight', TensorProto.FLOAT, [1], [2]), - helper.make_tensor('sub_weight', TensorProto.FLOAT, [1], [1.0]), - helper.make_tensor('mul_weight', TensorProto.FLOAT, [1], [-10000]), - helper.make_tensor('add_q_weight', TensorProto.FLOAT, [8], add_q_weight), - helper.make_tensor('add_k_weight', TensorProto.FLOAT, [8], add_k_weight), - helper.make_tensor('add_v_weight', TensorProto.FLOAT, [8], add_v_weight), - helper.make_tensor('add_qkv_weight', TensorProto.FLOAT, [8], add_qkv_weight), - helper.make_tensor('reshape_weight_1', TensorProto.INT64, [4], [0, 0, 2, 4]), - helper.make_tensor('reshape_weight_2', TensorProto.INT64, [3], [0, 0, 8]), + helper.make_tensor( + "layer_norm_weight", + TensorProto.FLOAT, + [8], + [1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0], + ), + helper.make_tensor( + "layer_norm_bias", + TensorProto.FLOAT, + [8], + [0.1, 0.2, 0.3, 0.4, 0.1, 0.2, 0.3, 0.4], + ), + helper.make_tensor("matmul_q_weight", TensorProto.FLOAT, [8, 8], matmul_q_weights), + helper.make_tensor("matmul_k_weight", TensorProto.FLOAT, [8, 8], matmul_k_weights), + helper.make_tensor("matmul_v_weight", TensorProto.FLOAT, [8, 8], matmul_v_weights), + helper.make_tensor("matmul_qkv_weight", TensorProto.FLOAT, [8, 8], matmul_qkv_weights), + helper.make_tensor("div_weight", TensorProto.FLOAT, [1], [2]), + helper.make_tensor("sub_weight", TensorProto.FLOAT, [1], [1.0]), + helper.make_tensor("mul_weight", TensorProto.FLOAT, [1], [-10000]), + helper.make_tensor("add_q_weight", TensorProto.FLOAT, [8], add_q_weight), + helper.make_tensor("add_k_weight", TensorProto.FLOAT, [8], add_k_weight), + helper.make_tensor("add_v_weight", TensorProto.FLOAT, [8], add_v_weight), + helper.make_tensor("add_qkv_weight", TensorProto.FLOAT, [8], add_qkv_weight), + helper.make_tensor("reshape_weight_1", TensorProto.INT64, [4], [0, 0, 2, 4]), + helper.make_tensor("reshape_weight_2", TensorProto.INT64, [3], [0, 0, 8]), ] graph = helper.make_graph( nodes, - "AttentionFusionOneInput", #name - [ # inputs - helper.make_tensor_value_info('input_1', TensorProto.FLOAT, [1, 3, 8]) - ], + "AttentionFusionOneInput", # name + [helper.make_tensor_value_info("input_1", TensorProto.FLOAT, [1, 3, 8])], # inputs [ # outputs - helper.make_tensor_value_info('output', TensorProto.FLOAT, [1, 3, 8]), + helper.make_tensor_value_info("output", TensorProto.FLOAT, [1, 3, 8]), ], - initializers) + initializers, + ) model = helper.make_model(graph) onnx.save(model, model_name) + def GenerateModel2(model_name): nodes = [ # Attention subgraph - helper.make_node("LayerNormalization", ["input_1", "layer_norm_weight", "layer_norm_bias"], - ["layernorm_out"], - "layernorm", - axis=-1, - epsion=0.000009999999960041972), + helper.make_node( + "LayerNormalization", + ["input_1", "layer_norm_weight", "layer_norm_bias"], + ["layernorm_out"], + "layernorm", + axis=-1, + epsion=0.000009999999960041972, + ), # shape path helper.make_node("Shape", ["layernorm_out"], ["shape0_out"], "shape0"), helper.make_node("Gather", ["shape0_out", "indices_0"], ["gather0_out"], "gather0", axis=0), helper.make_node("Shape", ["layernorm_out"], ["shape1_out"], "shape1"), helper.make_node("Gather", ["shape1_out", "indices_1"], ["gather1_out"], "gather1", axis=0), - # v nodes helper.make_node("MatMul", ["layernorm_out", "matmul_v_weight"], ["matmul_v_out"], "matmul_v"), helper.make_node("Add", ["matmul_v_out", "add_v_weight"], ["add_v_out"], "add_v"), helper.make_node("Reshape", ["add_v_out", "reshape_weight_1"], ["reshape_v_out"], "reshape_v"), - helper.make_node("Transpose", ["reshape_v_out"], ["transpose_v_out"], "transpose_v", - perm=[0,2,1,3]), - + helper.make_node( + "Transpose", + ["reshape_v_out"], + ["transpose_v_out"], + "transpose_v", + perm=[0, 2, 1, 3], + ), # q nodes helper.make_node("MatMul", ["layernorm_out", "matmul_q_weight"], ["matmul_q_out"], "matmul_q"), helper.make_node("Add", ["matmul_q_out", "add_q_weight"], ["add_q_out"], "add_q"), helper.make_node("Reshape", ["add_q_out", "reshape_weight_1"], ["reshape_q_out"], "reshape_q"), - helper.make_node("Transpose", ["reshape_q_out"], ["transpose_q_out"], "transpose_q", - perm=[0,2,1,3]), + helper.make_node( + "Transpose", + ["reshape_q_out"], + ["transpose_q_out"], + "transpose_q", + perm=[0, 2, 1, 3], + ), helper.make_node("Div", ["transpose_q_out", "div_weight"], ["div_q_out"], "div_q"), - # k nodes helper.make_node("MatMul", ["layernorm_out", "matmul_k_weight"], ["matmul_k_out"], "matmul_k"), helper.make_node("Add", ["matmul_k_out", "add_k_weight"], ["add_k_out"], "add_k"), helper.make_node("Reshape", ["add_k_out", "reshape_weight_1"], ["reshape_k_out"], "reshape_k"), - helper.make_node("Transpose", ["reshape_k_out"], ["transpose_k_out"], "transpose_k", - perm=[0,2,3,1]), - + helper.make_node( + "Transpose", + ["reshape_k_out"], + ["transpose_k_out"], + "transpose_k", + perm=[0, 2, 3, 1], + ), # path x - helper.make_node("Unsqueeze", ["gather0_out"], ["unsqueeze_x_0_out"], "unsqueeze_x_0", axes=[0]), - helper.make_node("Unsqueeze", ["gather1_out"], ["unsqueeze_x_1_out"], "unsqueeze_x_1", axes=[0]), - helper.make_node("Concat", ["unsqueeze_x_0_out", "dim_1", "dim_1", "unsqueeze_x_1_out"], ["concat3_out"], "concat3", axis=0), - helper.make_node("Concat", ["unsqueeze_x_0_out", "dim_-1", "dim_8"], ["concat4_out"], "concat4", axis=0), - + helper.make_node( + "Unsqueeze", + ["gather0_out"], + ["unsqueeze_x_0_out"], + "unsqueeze_x_0", + axes=[0], + ), + helper.make_node( + "Unsqueeze", + ["gather1_out"], + ["unsqueeze_x_1_out"], + "unsqueeze_x_1", + axes=[0], + ), + helper.make_node( + "Concat", + ["unsqueeze_x_0_out", "dim_1", "dim_1", "unsqueeze_x_1_out"], + ["concat3_out"], + "concat3", + axis=0, + ), + helper.make_node( + "Concat", + ["unsqueeze_x_0_out", "dim_-1", "dim_8"], + ["concat4_out"], + "concat4", + axis=0, + ), # mask nodes - helper.make_node("Constant", [], ["mask_input"], "constant", - value=helper.make_tensor('mask', TensorProto.FLOAT, - [1, 3], [1.0, 1.0, 1.0])), + helper.make_node( + "Constant", + [], + ["mask_input"], + "constant", + value=helper.make_tensor("mask", TensorProto.FLOAT, [1, 3], [1.0, 1.0, 1.0]), + ), helper.make_node("Equal", ["mask_input", "equal_weight"], ["equal_out"], "equal"), - - #qkx paths + # qkx paths helper.make_node("MatMul", ["div_q_out", "transpose_k_out"], ["matmul_qk_out"], "matmul_qk"), helper.make_node("Reshape", ["equal_out", "concat3_out"], ["reshape_x_out"], "reshape_x"), helper.make_node("Shape", ["matmul_qk_out"], ["shape_x_out"], "shape_x"), helper.make_node("Expand", ["reshape_x_out", "shape_x_out"], ["expand_out"], "expand"), - helper.make_node("Where", ["expand_out", "where_weight", "matmul_qk_out"], ["where_out"], "where"), #bugbug + helper.make_node( + "Where", + ["expand_out", "where_weight", "matmul_qk_out"], + ["where_out"], + "where", + ), # bugbug helper.make_node("Softmax", ["where_out"], ["softmax_qk_out"], "softmax_qk", axis=3), - # qkv nodes - helper.make_node("MatMul", ["softmax_qk_out", "transpose_v_out"], ["matmul_qkv_1_out"], "matmul_qkv_1"), - helper.make_node("Transpose", ["matmul_qkv_1_out"], ["transpose_qkv_out"], "transpose_qkv", - perm=[0,2,1,3] + helper.make_node( + "MatMul", + ["softmax_qk_out", "transpose_v_out"], + ["matmul_qkv_1_out"], + "matmul_qkv_1", + ), + helper.make_node( + "Transpose", + ["matmul_qkv_1_out"], + ["transpose_qkv_out"], + "transpose_qkv", + perm=[0, 2, 1, 3], + ), + helper.make_node( + "Reshape", + ["transpose_qkv_out", "concat4_out"], + ["reshape_qkv_out"], + "reshape_qkv", + ), + helper.make_node( + "MatMul", + ["reshape_qkv_out", "matmul_qkv_weight"], + ["matmul_qkv_2_out"], + "matmul_qkv_2", ), - helper.make_node("Reshape", ["transpose_qkv_out", "concat4_out"], ["reshape_qkv_out"], "reshape_qkv"), - helper.make_node("MatMul", ["reshape_qkv_out", "matmul_qkv_weight"], ["matmul_qkv_2_out"], "matmul_qkv_2"), helper.make_node("Add", ["matmul_qkv_2_out", "add_qkv_weight"], ["add_qkv_out"], "add_qkv"), - helper.make_node("Add", ["add_qkv_out", "layernorm_out"], ["output"], "add"), ] initializers = [ # initializers - helper.make_tensor('layer_norm_weight', TensorProto.FLOAT, [8], [1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0]), - helper.make_tensor('layer_norm_bias', TensorProto.FLOAT, [8], [0.1, 0.2, 0.3, 0.4, 0.1, 0.2, 0.3, 0.4]), - helper.make_tensor('matmul_q_weight', TensorProto.FLOAT, [8, 8], matmul_q_weights), - helper.make_tensor('matmul_k_weight', TensorProto.FLOAT, [8, 8], matmul_k_weights), - helper.make_tensor('matmul_v_weight', TensorProto.FLOAT, [8, 8], matmul_v_weights), - helper.make_tensor('matmul_qkv_weight', TensorProto.FLOAT, [8, 8], matmul_qkv_weights), - helper.make_tensor('div_weight', TensorProto.FLOAT, [1], [2]), - helper.make_tensor('add_q_weight', TensorProto.FLOAT, [8], add_q_weight), - helper.make_tensor('add_k_weight', TensorProto.FLOAT, [8], add_k_weight), - helper.make_tensor('add_v_weight', TensorProto.FLOAT, [8], add_v_weight), - helper.make_tensor('add_qkv_weight', TensorProto.FLOAT, [8], add_qkv_weight), - helper.make_tensor('equal_weight', TensorProto.FLOAT, [], [0.0]), - helper.make_tensor('where_weight', TensorProto.FLOAT, [], [sys.float_info.min]), - helper.make_tensor('reshape_weight_1', TensorProto.INT64, [4], [0, -1, 2, 4]), - helper.make_tensor('reshape_weight_2', TensorProto.INT64, [4], [0, 1, 1, -1]), - helper.make_tensor('reshape_weight_3', TensorProto.INT64, [3], [0, -1, 8]), - helper.make_tensor('indices_0', TensorProto.INT64, [], [0]), - helper.make_tensor('indices_1', TensorProto.INT64, [], [1]), - helper.make_tensor('dim_-1', TensorProto.INT64, [1], [-1]), - helper.make_tensor('dim_1', TensorProto.INT64, [1], [1]), - helper.make_tensor('dim_8', TensorProto.INT64, [1], [8]), + helper.make_tensor( + "layer_norm_weight", + TensorProto.FLOAT, + [8], + [1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0], + ), + helper.make_tensor( + "layer_norm_bias", + TensorProto.FLOAT, + [8], + [0.1, 0.2, 0.3, 0.4, 0.1, 0.2, 0.3, 0.4], + ), + helper.make_tensor("matmul_q_weight", TensorProto.FLOAT, [8, 8], matmul_q_weights), + helper.make_tensor("matmul_k_weight", TensorProto.FLOAT, [8, 8], matmul_k_weights), + helper.make_tensor("matmul_v_weight", TensorProto.FLOAT, [8, 8], matmul_v_weights), + helper.make_tensor("matmul_qkv_weight", TensorProto.FLOAT, [8, 8], matmul_qkv_weights), + helper.make_tensor("div_weight", TensorProto.FLOAT, [1], [2]), + helper.make_tensor("add_q_weight", TensorProto.FLOAT, [8], add_q_weight), + helper.make_tensor("add_k_weight", TensorProto.FLOAT, [8], add_k_weight), + helper.make_tensor("add_v_weight", TensorProto.FLOAT, [8], add_v_weight), + helper.make_tensor("add_qkv_weight", TensorProto.FLOAT, [8], add_qkv_weight), + helper.make_tensor("equal_weight", TensorProto.FLOAT, [], [0.0]), + helper.make_tensor("where_weight", TensorProto.FLOAT, [], [sys.float_info.min]), + helper.make_tensor("reshape_weight_1", TensorProto.INT64, [4], [0, -1, 2, 4]), + helper.make_tensor("reshape_weight_2", TensorProto.INT64, [4], [0, 1, 1, -1]), + helper.make_tensor("reshape_weight_3", TensorProto.INT64, [3], [0, -1, 8]), + helper.make_tensor("indices_0", TensorProto.INT64, [], [0]), + helper.make_tensor("indices_1", TensorProto.INT64, [], [1]), + helper.make_tensor("dim_-1", TensorProto.INT64, [1], [-1]), + helper.make_tensor("dim_1", TensorProto.INT64, [1], [1]), + helper.make_tensor("dim_8", TensorProto.INT64, [1], [8]), ] graph = helper.make_graph( nodes, - "AttentionFusion_DistilBert", #name - [ # inputs - helper.make_tensor_value_info('input_1', TensorProto.FLOAT, [1, 1, 8]) - ], + "AttentionFusion_DistilBert", # name + [helper.make_tensor_value_info("input_1", TensorProto.FLOAT, [1, 1, 8])], # inputs [ # outputs - helper.make_tensor_value_info('output', TensorProto.FLOAT, [1, 1, 8]), + helper.make_tensor_value_info("output", TensorProto.FLOAT, [1, 1, 8]), ], - initializers) + initializers, + ) model = helper.make_model(graph) onnx.save(model, model_name) -GenerateModel('attention_mask_no_cast.onnx') -GenerateModel2('attention_distilbert.onnx') +GenerateModel("attention_mask_no_cast.onnx") +GenerateModel2("attention_distilbert.onnx") diff --git a/onnxruntime/test/testdata/transform/fusion/bias_dropout_residual_gen.py b/onnxruntime/test/testdata/transform/fusion/bias_dropout_residual_gen.py index aa010fb4090a3..f62f3210f19a5 100644 --- a/onnxruntime/test/testdata/transform/fusion/bias_dropout_residual_gen.py +++ b/onnxruntime/test/testdata/transform/fusion/bias_dropout_residual_gen.py @@ -1,26 +1,25 @@ import onnx -from onnx import helper -from onnx import TensorProto, OperatorSetIdProto +from onnx import OperatorSetIdProto, TensorProto, helper # inputs/outputs -A = helper.make_tensor_value_info('A', TensorProto.FLOAT, ['unk_1', 'unk_2', 3072]) -B = helper.make_tensor_value_info('B', TensorProto.FLOAT, [3072]) -R = helper.make_tensor_value_info('R', TensorProto.FLOAT, ['unk_1', 'unk_2', 3072]) -C = helper.make_tensor_value_info('C', TensorProto.FLOAT, ['unk_1', 'unk_2', 3072]) -mask = helper.make_tensor_value_info('mask', TensorProto.BOOL, ['unk_1', 'unk_2', 3072]) +A = helper.make_tensor_value_info("A", TensorProto.FLOAT, ["unk_1", "unk_2", 3072]) +B = helper.make_tensor_value_info("B", TensorProto.FLOAT, [3072]) +R = helper.make_tensor_value_info("R", TensorProto.FLOAT, ["unk_1", "unk_2", 3072]) +C = helper.make_tensor_value_info("C", TensorProto.FLOAT, ["unk_1", "unk_2", 3072]) +mask = helper.make_tensor_value_info("mask", TensorProto.BOOL, ["unk_1", "unk_2", 3072]) # initializers -ratio = helper.make_tensor('ratio_const', TensorProto.FLOAT, [], [0.8]) -training_mode = helper.make_tensor('training_mode', TensorProto.BOOL, [], [1]) +ratio = helper.make_tensor("ratio_const", TensorProto.FLOAT, [], [0.8]) +training_mode = helper.make_tensor("training_mode", TensorProto.BOOL, [], [1]) opsets = [] onnxdomain = OperatorSetIdProto() onnxdomain.version = 12 -onnxdomain.domain = "" # The empty string ("") or absence of this field implies the operator set that is defined as part of the ONNX specification. +onnxdomain.domain = "" # The empty string ("") or absence of this field implies the operator set that is defined as part of the ONNX specification. opsets.append(onnxdomain) -kwargs={} -kwargs['opset_imports'] = opsets +kwargs = {} +kwargs["opset_imports"] = opsets # Create the model (ModelProto) bias = helper.make_node("Add", ["A", "B"], ["add0_out"], "add0") @@ -28,13 +27,14 @@ graph = helper.make_graph( [bias, dropout_12], - "Bias_Dropout_Fusion", #name + "Bias_Dropout_Fusion", # name [A, B], [C], - [ratio, training_mode]) + [ratio, training_mode], +) -model = helper.make_model(graph, producer_name='onnx-example', **kwargs) -onnx.save(model, 'bias_dropout_fusion1.onnx') +model = helper.make_model(graph, producer_name="onnx-example", **kwargs) +onnx.save(model, "bias_dropout_fusion1.onnx") # Create the model (ModelProto) bias = helper.make_node("Add", ["B", "A"], ["add0_out"], "add0") @@ -42,193 +42,255 @@ graph = helper.make_graph( [bias, dropout_12], - "Bias_Dropout_Fusion", #name + "Bias_Dropout_Fusion", # name [A, B], [C], - [ratio, training_mode]) + [ratio, training_mode], +) -model = helper.make_model(graph, producer_name='onnx-example', **kwargs) -onnx.save(model, 'bias_dropout_fusion2.onnx') +model = helper.make_model(graph, producer_name="onnx-example", **kwargs) +onnx.save(model, "bias_dropout_fusion2.onnx") # Create the model (ModelProto) bias = helper.make_node("Add", ["A", "B"], ["add0_out"], "add0") -dropout_12 = helper.make_node("Dropout", ["add0_out", "ratio_const", "training_mode"], ["dropout_out", "mask"], "dropout0") +dropout_12 = helper.make_node( + "Dropout", + ["add0_out", "ratio_const", "training_mode"], + ["dropout_out", "mask"], + "dropout0", +) residual = helper.make_node("Add", ["dropout_out", "R"], ["C"], "add1") graph = helper.make_graph( [bias, dropout_12, residual], - "Bias_Dropout_Fusion", #name + "Bias_Dropout_Fusion", # name [A, B, R], [C], - [ratio, training_mode]) + [ratio, training_mode], +) -model = helper.make_model(graph, producer_name='onnx-example', **kwargs) -onnx.save(model, 'bias_dropout_residual_fusion1.onnx') +model = helper.make_model(graph, producer_name="onnx-example", **kwargs) +onnx.save(model, "bias_dropout_residual_fusion1.onnx") # Create the model (ModelProto) bias = helper.make_node("Add", ["B", "A"], ["add0_out"], "add0") -dropout_12 = helper.make_node("Dropout", ["add0_out", "ratio_const", "training_mode"], ["dropout_out", "mask"], "dropout0") +dropout_12 = helper.make_node( + "Dropout", + ["add0_out", "ratio_const", "training_mode"], + ["dropout_out", "mask"], + "dropout0", +) residual = helper.make_node("Add", ["R", "dropout_out"], ["C"], "add1") graph = helper.make_graph( [bias, dropout_12, residual], - "Bias_Dropout_Fusion", #name + "Bias_Dropout_Fusion", # name [A, B, R], [C], - [ratio, training_mode]) + [ratio, training_mode], +) -model = helper.make_model(graph, producer_name='onnx-example', **kwargs) -onnx.save(model, 'bias_dropout_residual_fusion2.onnx') +model = helper.make_model(graph, producer_name="onnx-example", **kwargs) +onnx.save(model, "bias_dropout_residual_fusion2.onnx") # Create the model (ModelProto) -R_mismatch = helper.make_tensor_value_info('R', TensorProto.FLOAT, [3072]) +R_mismatch = helper.make_tensor_value_info("R", TensorProto.FLOAT, [3072]) bias = helper.make_node("Add", ["B", "A"], ["add0_out"], "add0") -dropout_12 = helper.make_node("Dropout", ["add0_out", "ratio_const", "training_mode"], ["dropout_out", "mask"], "dropout0") +dropout_12 = helper.make_node( + "Dropout", + ["add0_out", "ratio_const", "training_mode"], + ["dropout_out", "mask"], + "dropout0", +) residual = helper.make_node("Add", ["R", "dropout_out"], ["C"], "add1") graph = helper.make_graph( [bias, dropout_12, residual], - "Bias_Dropout_Fusion", #name + "Bias_Dropout_Fusion", # name [A, B, R_mismatch], [C], - [ratio, training_mode]) + [ratio, training_mode], +) -model = helper.make_model(graph, producer_name='onnx-example', **kwargs) -onnx.save(model, 'bias_dropout_residual_fusion_mismatch.onnx') +model = helper.make_model(graph, producer_name="onnx-example", **kwargs) +onnx.save(model, "bias_dropout_residual_fusion_mismatch.onnx") # If the Dropout output 0 is also a graph output, the residual Add shouldn't be fused. # Create the model (ModelProto) bias = helper.make_node("Add", ["B", "A"], ["add0_out"], "add0") -dropout_12 = helper.make_node("Dropout", ["add0_out", "ratio_const", "training_mode"], ["dropout_out", "mask"], "dropout0") +dropout_12 = helper.make_node( + "Dropout", + ["add0_out", "ratio_const", "training_mode"], + ["dropout_out", "mask"], + "dropout0", +) residual = helper.make_node("Add", ["R", "dropout_out"], ["C"], "add1") -D = helper.make_tensor_value_info('dropout_out', TensorProto.FLOAT, ['unk_1', 'unk_2', 3072]) +D = helper.make_tensor_value_info("dropout_out", TensorProto.FLOAT, ["unk_1", "unk_2", 3072]) graph = helper.make_graph( [bias, dropout_12, residual], - "Bias_Dropout_Fusion", #name + "Bias_Dropout_Fusion", # name [A, B, R], [C, D], - [ratio, training_mode]) + [ratio, training_mode], +) -model = helper.make_model(graph, producer_name='onnx-example', **kwargs) -onnx.save(model, 'bias_dropout_residual_fusion_multiple_consumers1.onnx') +model = helper.make_model(graph, producer_name="onnx-example", **kwargs) +onnx.save(model, "bias_dropout_residual_fusion_multiple_consumers1.onnx") # If the Dropout has multiple consumers of output 0, the residual Add shouldn't be fused. # Create the model (ModelProto) -D = helper.make_tensor_value_info('D', TensorProto.FLOAT, ['unk_1', 'unk_2', 3072]) +D = helper.make_tensor_value_info("D", TensorProto.FLOAT, ["unk_1", "unk_2", 3072]) bias = helper.make_node("Add", ["B", "A"], ["add0_out"], "add0") -dropout_12 = helper.make_node("Dropout", ["add0_out", "ratio_const", "training_mode"], ["dropout_out", "mask"], "dropout0") +dropout_12 = helper.make_node( + "Dropout", + ["add0_out", "ratio_const", "training_mode"], + ["dropout_out", "mask"], + "dropout0", +) residual = helper.make_node("Add", ["R", "dropout_out"], ["C"], "add1") identity = helper.make_node("Identity", ["dropout_out"], ["D"], "identity") graph = helper.make_graph( [bias, dropout_12, residual, identity], - "Bias_Dropout_Fusion", #name + "Bias_Dropout_Fusion", # name [A, B, R], [C, D], - [ratio, training_mode]) + [ratio, training_mode], +) -model = helper.make_model(graph, producer_name='onnx-example', **kwargs) -onnx.save(model, 'bias_dropout_residual_fusion_multiple_consumers2.onnx') +model = helper.make_model(graph, producer_name="onnx-example", **kwargs) +onnx.save(model, "bias_dropout_residual_fusion_multiple_consumers2.onnx") # Create the model (ModelProto) -A2 = helper.make_tensor_value_info('A2', TensorProto.FLOAT, ['unk_1', 'unk_2', 3072]) +A2 = helper.make_tensor_value_info("A2", TensorProto.FLOAT, ["unk_1", "unk_2", 3072]) bias = helper.make_node("Add", ["A", "A2"], ["add0_out"], "add0") dropout_12 = helper.make_node("Dropout", ["add0_out", "ratio_const", "training_mode"], ["C", "mask"], "dropout0") graph = helper.make_graph( [bias, dropout_12], - "Bias_Dropout_Fusion", #name + "Bias_Dropout_Fusion", # name [A, A2], [C], - [ratio, training_mode]) + [ratio, training_mode], +) -model = helper.make_model(graph, producer_name='onnx-example', **kwargs) -onnx.save(model, 'bias_dropout_same_shape_fusion.onnx') +model = helper.make_model(graph, producer_name="onnx-example", **kwargs) +onnx.save(model, "bias_dropout_same_shape_fusion.onnx") # Create the model (ModelProto) bias = helper.make_node("Add", ["A", "A2"], ["add0_out"], "add0") -dropout_12 = helper.make_node("Dropout", ["add0_out", "ratio_const", "training_mode"], ["dropout_out", "mask"], "dropout0") +dropout_12 = helper.make_node( + "Dropout", + ["add0_out", "ratio_const", "training_mode"], + ["dropout_out", "mask"], + "dropout0", +) residual = helper.make_node("Add", ["dropout_out", "R"], ["C"], "add1") graph = helper.make_graph( [bias, dropout_12, residual], - "Bias_Dropout_Fusion", #name + "Bias_Dropout_Fusion", # name [A, A2, R], [C], - [ratio, training_mode]) + [ratio, training_mode], +) -model = helper.make_model(graph, producer_name='onnx-example', **kwargs) -onnx.save(model, 'bias_dropout_residual_same_shape_fusion.onnx') +model = helper.make_model(graph, producer_name="onnx-example", **kwargs) +onnx.save(model, "bias_dropout_residual_same_shape_fusion.onnx") # Create the model (ModelProto) -A_unk = helper.make_tensor_value_info('A_unk', TensorProto.FLOAT, ['unk_1', 'unk_2', 'unk_3']) -B_unk = helper.make_tensor_value_info('B_unk', TensorProto.FLOAT, ['unk_3']) -C_unk = helper.make_tensor_value_info('C_unk', TensorProto.FLOAT, ['unk_1', 'unk_2', 'unk_3']) +A_unk = helper.make_tensor_value_info("A_unk", TensorProto.FLOAT, ["unk_1", "unk_2", "unk_3"]) +B_unk = helper.make_tensor_value_info("B_unk", TensorProto.FLOAT, ["unk_3"]) +C_unk = helper.make_tensor_value_info("C_unk", TensorProto.FLOAT, ["unk_1", "unk_2", "unk_3"]) bias = helper.make_node("Add", ["A_unk", "B_unk"], ["add0_out"], "add0") -dropout_12 = helper.make_node("Dropout", ["add0_out", "ratio_const", "training_mode"], ["C_unk", "mask"], "dropout0") +dropout_12 = helper.make_node( + "Dropout", + ["add0_out", "ratio_const", "training_mode"], + ["C_unk", "mask"], + "dropout0", +) graph = helper.make_graph( [bias, dropout_12], - "Bias_Dropout_Fusion", #name + "Bias_Dropout_Fusion", # name [A_unk, B_unk], [C_unk], - [ratio, training_mode]) + [ratio, training_mode], +) -model = helper.make_model(graph, producer_name='onnx-example', **kwargs) -onnx.save(model, 'bias_dropout_fusion_dim_is_param.onnx') +model = helper.make_model(graph, producer_name="onnx-example", **kwargs) +onnx.save(model, "bias_dropout_fusion_dim_is_param.onnx") # Create the model (ModelProto) -R_unk = helper.make_tensor_value_info('R_unk', TensorProto.FLOAT, ['unk_1', 'unk_2', 'unk_3']) +R_unk = helper.make_tensor_value_info("R_unk", TensorProto.FLOAT, ["unk_1", "unk_2", "unk_3"]) bias = helper.make_node("Add", ["A_unk", "B_unk"], ["add0_out"], "add0") -dropout_12 = helper.make_node("Dropout", ["add0_out", "ratio_const", "training_mode"], ["dropout_out", "mask"], "dropout0") +dropout_12 = helper.make_node( + "Dropout", + ["add0_out", "ratio_const", "training_mode"], + ["dropout_out", "mask"], + "dropout0", +) residual = helper.make_node("Add", ["dropout_out", "R_unk"], ["C_unk"], "add1") graph = helper.make_graph( [bias, dropout_12, residual], - "Bias_Dropout_Fusion", #name + "Bias_Dropout_Fusion", # name [A_unk, B_unk, R_unk], [C_unk], - [ratio, training_mode]) + [ratio, training_mode], +) -model = helper.make_model(graph, producer_name='onnx-example', **kwargs) -onnx.save(model, 'bias_dropout_residual_fusion_dim_is_param.onnx') +model = helper.make_model(graph, producer_name="onnx-example", **kwargs) +onnx.save(model, "bias_dropout_residual_fusion_dim_is_param.onnx") # Create the model (ModelProto) -A_unk2 = helper.make_tensor_value_info('A_unk2', TensorProto.FLOAT, ['unk_1', 'unk_2', 'unk_3']) +A_unk2 = helper.make_tensor_value_info("A_unk2", TensorProto.FLOAT, ["unk_1", "unk_2", "unk_3"]) bias = helper.make_node("Add", ["A_unk", "A_unk2"], ["add0_out"], "add0") -dropout_12 = helper.make_node("Dropout", ["add0_out", "ratio_const", "training_mode"], ["C_unk", "mask"], "dropout0") +dropout_12 = helper.make_node( + "Dropout", + ["add0_out", "ratio_const", "training_mode"], + ["C_unk", "mask"], + "dropout0", +) graph = helper.make_graph( [bias, dropout_12], - "Bias_Dropout_Fusion", #name + "Bias_Dropout_Fusion", # name [A_unk, A_unk2], [C_unk], - [ratio, training_mode]) + [ratio, training_mode], +) -model = helper.make_model(graph, producer_name='onnx-example', **kwargs) -onnx.save(model, 'bias_dropout_same_shape_fusion_dim_is_param.onnx') +model = helper.make_model(graph, producer_name="onnx-example", **kwargs) +onnx.save(model, "bias_dropout_same_shape_fusion_dim_is_param.onnx") # Create the model (ModelProto) bias = helper.make_node("Add", ["A_unk", "A_unk2"], ["add0_out"], "add0") -dropout_12 = helper.make_node("Dropout", ["add0_out", "ratio_const", "training_mode"], ["dropout_out", "mask"], "dropout0") +dropout_12 = helper.make_node( + "Dropout", + ["add0_out", "ratio_const", "training_mode"], + ["dropout_out", "mask"], + "dropout0", +) residual = helper.make_node("Add", ["dropout_out", "R_unk"], ["C_unk"], "add1") graph = helper.make_graph( [bias, dropout_12, residual], - "Bias_Dropout_Fusion", #name + "Bias_Dropout_Fusion", # name [A_unk, A_unk2, R_unk], [C_unk], - [ratio, training_mode]) + [ratio, training_mode], +) -model = helper.make_model(graph, producer_name='onnx-example', **kwargs) -onnx.save(model, 'bias_dropout_residual_same_shape_fusion_dim_is_param.onnx') +model = helper.make_model(graph, producer_name="onnx-example", **kwargs) +onnx.save(model, "bias_dropout_residual_same_shape_fusion_dim_is_param.onnx") diff --git a/onnxruntime/test/testdata/transform/fusion/bias_gelu_gen.py b/onnxruntime/test/testdata/transform/fusion/bias_gelu_gen.py index 7b51369be3473..96a52c73f40e6 100644 --- a/onnxruntime/test/testdata/transform/fusion/bias_gelu_gen.py +++ b/onnxruntime/test/testdata/transform/fusion/bias_gelu_gen.py @@ -1,12 +1,10 @@ import onnx -from onnx import helper -from onnx import TensorProto +from onnx import TensorProto, helper graph = helper.make_graph( [ # nodes # Add node before Gelu helper.make_node("Add", ["A", "B"], ["add0_out"], "add0"), - # Gelu subgraph helper.make_node("Div", ["add0_out", "div_const"], ["div_out"], "div"), helper.make_node("Mul", ["add0_out", "mul_const"], ["mul_out"], "mul0"), @@ -14,19 +12,20 @@ helper.make_node("Add", ["erf_out", "add_const"], ["add1_out"], "add1"), helper.make_node("Mul", ["mul_out", "add1_out"], ["C"], "mul1"), ], - "Gelu_Add_Fusion", #name + "Gelu_Add_Fusion", # name [ # inputs - helper.make_tensor_value_info('A', TensorProto.FLOAT, ['unk_1', 'unk_2', 3072]), - helper.make_tensor_value_info('B', TensorProto.FLOAT, [3072]), + helper.make_tensor_value_info("A", TensorProto.FLOAT, ["unk_1", "unk_2", 3072]), + helper.make_tensor_value_info("B", TensorProto.FLOAT, [3072]), ], [ # outputs - helper.make_tensor_value_info('C', TensorProto.FLOAT, ['unk_3', 'unk_4', 3072]), + helper.make_tensor_value_info("C", TensorProto.FLOAT, ["unk_3", "unk_4", 3072]), ], [ # initializers - helper.make_tensor('div_const', TensorProto.FLOAT, [], [1.4142135381698608]), - helper.make_tensor('mul_const', TensorProto.FLOAT, [], [0.5]), - helper.make_tensor('add_const', TensorProto.FLOAT, [], [1]), - ]) + helper.make_tensor("div_const", TensorProto.FLOAT, [], [1.4142135381698608]), + helper.make_tensor("mul_const", TensorProto.FLOAT, [], [0.5]), + helper.make_tensor("add_const", TensorProto.FLOAT, [], [1]), + ], +) model = helper.make_model(graph) -onnx.save(model, r'bias_gelu_fusion.onnx') +onnx.save(model, r"bias_gelu_fusion.onnx") diff --git a/onnxruntime/test/testdata/transform/fusion/bias_gelu_matmul_gen.py b/onnxruntime/test/testdata/transform/fusion/bias_gelu_matmul_gen.py index 9d78e622f9417..8f6e37fcb95c1 100644 --- a/onnxruntime/test/testdata/transform/fusion/bias_gelu_matmul_gen.py +++ b/onnxruntime/test/testdata/transform/fusion/bias_gelu_matmul_gen.py @@ -1,36 +1,34 @@ import onnx -from onnx import helper -from onnx import TensorProto +from onnx import TensorProto, helper graph = helper.make_graph( [ # nodes # Add node before Gelu helper.make_node("Add", ["A", "B"], ["add0_out"], "add0"), - # Gelu subgraph helper.make_node("Div", ["add0_out", "div_const"], ["div_out"], "div"), helper.make_node("Mul", ["add0_out", "mul_const"], ["mul_out"], "mul0"), helper.make_node("Erf", ["div_out"], ["erf_out"], "erf"), helper.make_node("Add", ["erf_out", "add_const"], ["add1_out"], "add1"), helper.make_node("Mul", ["mul_out", "add1_out"], ["C"], "mul1"), - # MatMul node after Gelu for recompute helper.make_node("MatMul", ["X", "C"], ["D"], "matmul"), ], - "Gelu_Add_Fusion_Recompute", #name + "Gelu_Add_Fusion_Recompute", # name [ # inputs - helper.make_tensor_value_info('A', TensorProto.FLOAT, ['unk_1', 'unk_2', 3072]), - helper.make_tensor_value_info('B', TensorProto.FLOAT, [3072]), - helper.make_tensor_value_info('X', TensorProto.FLOAT, ['unk_5', 'unk_6', 3072]), + helper.make_tensor_value_info("A", TensorProto.FLOAT, ["unk_1", "unk_2", 3072]), + helper.make_tensor_value_info("B", TensorProto.FLOAT, [3072]), + helper.make_tensor_value_info("X", TensorProto.FLOAT, ["unk_5", "unk_6", 3072]), ], [ # outputs - helper.make_tensor_value_info('D', TensorProto.FLOAT, ['unk_3', 'unk_4', 'unk_5']), + helper.make_tensor_value_info("D", TensorProto.FLOAT, ["unk_3", "unk_4", "unk_5"]), ], [ # initializers - helper.make_tensor('div_const', TensorProto.FLOAT, [], [1.4142135381698608]), - helper.make_tensor('mul_const', TensorProto.FLOAT, [], [0.5]), - helper.make_tensor('add_const', TensorProto.FLOAT, [], [1]), - ]) + helper.make_tensor("div_const", TensorProto.FLOAT, [], [1.4142135381698608]), + helper.make_tensor("mul_const", TensorProto.FLOAT, [], [0.5]), + helper.make_tensor("add_const", TensorProto.FLOAT, [], [1]), + ], +) model = helper.make_model(graph) -onnx.save(model, r'bias_gelu_fusion_recompute.onnx') +onnx.save(model, r"bias_gelu_fusion_recompute.onnx") diff --git a/onnxruntime/test/testdata/transform/fusion/bias_softmax_gen.py b/onnxruntime/test/testdata/transform/fusion/bias_softmax_gen.py index a7954dace01d1..26002476c6999 100644 --- a/onnxruntime/test/testdata/transform/fusion/bias_softmax_gen.py +++ b/onnxruntime/test/testdata/transform/fusion/bias_softmax_gen.py @@ -1,101 +1,195 @@ import onnx -from onnx import helper -from onnx import TensorProto +from onnx import TensorProto, helper add = helper.make_node("Add", ["input", "bias"], ["add_out"], "add") reverseadd = helper.make_node("Add", ["bias", "input"], ["add_out"], "add") -softmax1 = helper.make_node("Softmax", ["add_out"], ["output"], "softmax", axis=1) -softmax3 = helper.make_node("Softmax", ["add_out"], ["output"], "softmax", axis=3) -softmax6 = helper.make_node("Softmax", ["add_out"], ["output"], "softmax", axis=6) +softmax1 = helper.make_node("Softmax", ["add_out"], ["output"], "softmax", axis=1) +softmax3 = helper.make_node("Softmax", ["add_out"], ["output"], "softmax", axis=3) +softmax6 = helper.make_node("Softmax", ["add_out"], ["output"], "softmax", axis=6) onnx.save( helper.make_model( - helper.make_graph( - [add, softmax1], "Add_Softmax_Fusion", - [ - helper.make_tensor_value_info('input', TensorProto.FLOAT, ['d_1', 'd_2']), - helper.make_tensor_value_info('bias', TensorProto.FLOAT, ['d_1', 'd_2']), - ], - [ - helper.make_tensor_value_info('output', TensorProto.FLOAT, ['d_1', 'd_2']), - ], - [])), r'bias_softmax_fusion_simple.onnx') + helper.make_graph( + [add, softmax1], + "Add_Softmax_Fusion", + [ + helper.make_tensor_value_info("input", TensorProto.FLOAT, ["d_1", "d_2"]), + helper.make_tensor_value_info("bias", TensorProto.FLOAT, ["d_1", "d_2"]), + ], + [ + helper.make_tensor_value_info("output", TensorProto.FLOAT, ["d_1", "d_2"]), + ], + [], + ) + ), + r"bias_softmax_fusion_simple.onnx", +) onnx.save( helper.make_model( - helper.make_graph( - [add, softmax6], "Add_Softmax_Fusion", - [ - helper.make_tensor_value_info('input', TensorProto.FLOAT, ['d_0', 'd_1', 'd_2', 'd_3', 'd_4', 'd_5', 'd_6', 'd_7', 'd_8']), - helper.make_tensor_value_info('bias', TensorProto.FLOAT, ['d_0', 'd_1', 'd_2', 1, 1, 1, 'd_6', 'd_7', 'd_8']), - ], - [ - helper.make_tensor_value_info('output', TensorProto.FLOAT, ['d_0', 'd_1', 'd_2', 'd_3', 'd_4', 'd_5', 'd_6', 'd_7', 'd_8']), - ], - [])), r'bias_softmax_fusion_middleones.onnx') + helper.make_graph( + [add, softmax6], + "Add_Softmax_Fusion", + [ + helper.make_tensor_value_info( + "input", + TensorProto.FLOAT, + ["d_0", "d_1", "d_2", "d_3", "d_4", "d_5", "d_6", "d_7", "d_8"], + ), + helper.make_tensor_value_info( + "bias", + TensorProto.FLOAT, + ["d_0", "d_1", "d_2", 1, 1, 1, "d_6", "d_7", "d_8"], + ), + ], + [ + helper.make_tensor_value_info( + "output", + TensorProto.FLOAT, + ["d_0", "d_1", "d_2", "d_3", "d_4", "d_5", "d_6", "d_7", "d_8"], + ), + ], + [], + ) + ), + r"bias_softmax_fusion_middleones.onnx", +) onnx.save( helper.make_model( - helper.make_graph( - [reverseadd, softmax6], "Add_Softmax_Fusion", - [ - helper.make_tensor_value_info('input', TensorProto.FLOAT, ['d_0', 'd_1', 'd_2', 'd_3', 'd_4', 'd_5', 'd_6', 'd_7', 'd_8']), - helper.make_tensor_value_info('bias', TensorProto.FLOAT, ['d_0', 'd_1', 'd_2', 1, 1, 1, 'd_6', 'd_7', 'd_8']), - ], - [ - helper.make_tensor_value_info('output', TensorProto.FLOAT, ['d_0', 'd_1', 'd_2', 'd_3', 'd_4', 'd_5', 'd_6', 'd_7', 'd_8']), - ], - [])), r'bias_softmax_fusion_middleones_reversed.onnx') + helper.make_graph( + [reverseadd, softmax6], + "Add_Softmax_Fusion", + [ + helper.make_tensor_value_info( + "input", + TensorProto.FLOAT, + ["d_0", "d_1", "d_2", "d_3", "d_4", "d_5", "d_6", "d_7", "d_8"], + ), + helper.make_tensor_value_info( + "bias", + TensorProto.FLOAT, + ["d_0", "d_1", "d_2", 1, 1, 1, "d_6", "d_7", "d_8"], + ), + ], + [ + helper.make_tensor_value_info( + "output", + TensorProto.FLOAT, + ["d_0", "d_1", "d_2", "d_3", "d_4", "d_5", "d_6", "d_7", "d_8"], + ), + ], + [], + ) + ), + r"bias_softmax_fusion_middleones_reversed.onnx", +) # should NOT fuse onnx.save( helper.make_model( - helper.make_graph( - [add, softmax3], "Add_Softmax_Fusion", - [ - helper.make_tensor_value_info('input', TensorProto.FLOAT, ['d_0', 'd_1', 'd_2', 'd_3', 'd_4', 'd_5', 'd_6', 'd_7', 'd_8']), - helper.make_tensor_value_info('bias', TensorProto.FLOAT, ['d_0', 'd_1', 'd_2', 1, 1, 1, 'd_6', 'd_7', 'd_8']), - ], - [ - helper.make_tensor_value_info('output', TensorProto.FLOAT, ['d_0', 'd_1', 'd_2', 'd_3', 'd_4', 'd_5', 'd_6', 'd_7', 'd_8']), - ], - [])), r'bias_softmax_fusion_middleones_badaxis.onnx') + helper.make_graph( + [add, softmax3], + "Add_Softmax_Fusion", + [ + helper.make_tensor_value_info( + "input", + TensorProto.FLOAT, + ["d_0", "d_1", "d_2", "d_3", "d_4", "d_5", "d_6", "d_7", "d_8"], + ), + helper.make_tensor_value_info( + "bias", + TensorProto.FLOAT, + ["d_0", "d_1", "d_2", 1, 1, 1, "d_6", "d_7", "d_8"], + ), + ], + [ + helper.make_tensor_value_info( + "output", + TensorProto.FLOAT, + ["d_0", "d_1", "d_2", "d_3", "d_4", "d_5", "d_6", "d_7", "d_8"], + ), + ], + [], + ) + ), + r"bias_softmax_fusion_middleones_badaxis.onnx", +) onnx.save( helper.make_model( - helper.make_graph( - [add, softmax6], "Add_Softmax_Fusion", - [ - helper.make_tensor_value_info('input', TensorProto.FLOAT, ['d_0', 'd_1', 'd_2', 'd_3', 'd_4', 'd_5', 'd_6', 'd_7', 'd_8']), - helper.make_tensor_value_info('bias', TensorProto.FLOAT, [1, 1, 1, 1, 1, 1, 'd_6', 'd_7', 'd_8']), - ], - [ - helper.make_tensor_value_info('output', TensorProto.FLOAT, ['d_0', 'd_1', 'd_2', 'd_3', 'd_4', 'd_5', 'd_6', 'd_7', 'd_8']), - ], - [])), r'bias_softmax_fusion_allleadingones.onnx') + helper.make_graph( + [add, softmax6], + "Add_Softmax_Fusion", + [ + helper.make_tensor_value_info( + "input", + TensorProto.FLOAT, + ["d_0", "d_1", "d_2", "d_3", "d_4", "d_5", "d_6", "d_7", "d_8"], + ), + helper.make_tensor_value_info("bias", TensorProto.FLOAT, [1, 1, 1, 1, 1, 1, "d_6", "d_7", "d_8"]), + ], + [ + helper.make_tensor_value_info( + "output", + TensorProto.FLOAT, + ["d_0", "d_1", "d_2", "d_3", "d_4", "d_5", "d_6", "d_7", "d_8"], + ), + ], + [], + ) + ), + r"bias_softmax_fusion_allleadingones.onnx", +) onnx.save( helper.make_model( - helper.make_graph( - [add, softmax6], "Add_Softmax_Fusion", - [ - helper.make_tensor_value_info('input', TensorProto.FLOAT, ['d_0', 'd_1', 'd_2', 'd_3', 'd_4', 'd_5', 'd_6', 'd_7', 'd_8']), - helper.make_tensor_value_info('bias', TensorProto.FLOAT, [1, 1, 'd_6', 'd_7', 'd_8']), - ], - [ - helper.make_tensor_value_info('output', TensorProto.FLOAT, ['d_0', 'd_1', 'd_2', 'd_3', 'd_4', 'd_5', 'd_6', 'd_7', 'd_8']), - ], - [])), r'bias_softmax_fusion_someleadingones.onnx') + helper.make_graph( + [add, softmax6], + "Add_Softmax_Fusion", + [ + helper.make_tensor_value_info( + "input", + TensorProto.FLOAT, + ["d_0", "d_1", "d_2", "d_3", "d_4", "d_5", "d_6", "d_7", "d_8"], + ), + helper.make_tensor_value_info("bias", TensorProto.FLOAT, [1, 1, "d_6", "d_7", "d_8"]), + ], + [ + helper.make_tensor_value_info( + "output", + TensorProto.FLOAT, + ["d_0", "d_1", "d_2", "d_3", "d_4", "d_5", "d_6", "d_7", "d_8"], + ), + ], + [], + ) + ), + r"bias_softmax_fusion_someleadingones.onnx", +) onnx.save( helper.make_model( - helper.make_graph( - [add, softmax6], "Add_Softmax_Fusion", - [ - helper.make_tensor_value_info('input', TensorProto.FLOAT, ['d_0', 'd_1', 'd_2', 'd_3', 'd_4', 'd_5', 'd_6', 'd_7', 'd_8']), - helper.make_tensor_value_info('bias', TensorProto.FLOAT, ['d_6', 'd_7', 'd_8']), - ], - [ - helper.make_tensor_value_info('output', TensorProto.FLOAT, ['d_0', 'd_1', 'd_2', 'd_3', 'd_4', 'd_5', 'd_6', 'd_7', 'd_8']), - ], - [])), r'bias_softmax_fusion_noleadingones.onnx') \ No newline at end of file + helper.make_graph( + [add, softmax6], + "Add_Softmax_Fusion", + [ + helper.make_tensor_value_info( + "input", + TensorProto.FLOAT, + ["d_0", "d_1", "d_2", "d_3", "d_4", "d_5", "d_6", "d_7", "d_8"], + ), + helper.make_tensor_value_info("bias", TensorProto.FLOAT, ["d_6", "d_7", "d_8"]), + ], + [ + helper.make_tensor_value_info( + "output", + TensorProto.FLOAT, + ["d_0", "d_1", "d_2", "d_3", "d_4", "d_5", "d_6", "d_7", "d_8"], + ), + ], + [], + ) + ), + r"bias_softmax_fusion_noleadingones.onnx", +) diff --git a/onnxruntime/test/testdata/transform/fusion/constant_folding_with_shape_to_initializer.py b/onnxruntime/test/testdata/transform/fusion/constant_folding_with_shape_to_initializer.py index ab1b9d00b439c..6cc5cdeb79f4a 100644 --- a/onnxruntime/test/testdata/transform/fusion/constant_folding_with_shape_to_initializer.py +++ b/onnxruntime/test/testdata/transform/fusion/constant_folding_with_shape_to_initializer.py @@ -1,80 +1,110 @@ -import onnx -from onnx import helper -from onnx import TensorProto, GraphProto, OperatorSetIdProto -from onnx import numpy_helper import numpy as np +import onnx +from onnx import GraphProto, OperatorSetIdProto, TensorProto, helper, numpy_helper -X = helper.make_tensor_value_info('input', TensorProto.FLOAT, [2, 4, 8]) -Y = helper.make_tensor_value_info('output', TensorProto.FLOAT, [2, 4, 16]) +X = helper.make_tensor_value_info("input", TensorProto.FLOAT, [2, 4, 8]) +Y = helper.make_tensor_value_info("output", TensorProto.FLOAT, [2, 4, 16]) matmul_weight_vals = (0.01 * np.arange(2 * 4 * 4, dtype=np.float32)).reshape((2, 4, 4)) -matmul_weight_initializer = numpy_helper.from_array(matmul_weight_vals, 'matmul_weight') -gather_constant_zero = numpy_helper.from_array(np.int64(0), name='gather_constant_zero') -gather_constant_one = numpy_helper.from_array(np.int64(1), name='gather_constant_one') -div_constant_two = numpy_helper.from_array(np.int64(2), name='div_constant_two') -unsqueeze_constant_16 = numpy_helper.from_array(np.int64(16), name='unsqueeze_constant_16') +matmul_weight_initializer = numpy_helper.from_array(matmul_weight_vals, "matmul_weight") +gather_constant_zero = numpy_helper.from_array(np.int64(0), name="gather_constant_zero") +gather_constant_one = numpy_helper.from_array(np.int64(1), name="gather_constant_one") +div_constant_two = numpy_helper.from_array(np.int64(2), name="div_constant_two") +unsqueeze_constant_16 = numpy_helper.from_array(np.int64(16), name="unsqueeze_constant_16") -shape1 = helper.make_node('Shape', ['input'], ['shape1'], name='shape1') -constant_of_shape = helper.make_node('ConstantOfShape', ['shape1'], ['constant_of_shape'], name='constant_of_shape') -transpose = helper.make_node('Transpose', ['constant_of_shape'], ['transpose'], name='transpose', perm=[0,2,1]) -matmul1 = helper.make_node('MatMul', ['transpose', matmul_weight_initializer.name], ['matmul1'], name='matmul1') -matmul2 = helper.make_node('MatMul', ['matmul1', 'input'], ['matmul2'], name='matmul2') -shape2 = helper.make_node('Shape', ['matmul2'], ['shape2'], name='shape2') -gather1 = helper.make_node('Gather', ['shape2', gather_constant_zero.name], ['gather1'], name='gather1', axis=0) -gather2 = helper.make_node('Gather', ['shape2', gather_constant_one.name], ['gather2'], name='gather2', axis=0) -div = helper.make_node('Div', ['gather2', div_constant_two.name], ['div'], name='div') -unsqueeze1 = helper.make_node('Unsqueeze', ['gather1'], ['unsqueeze1'], name='unsqueeze1', axes=[0]) -unsqueeze2 = helper.make_node('Unsqueeze', ['div'], ['unsqueeze2'], name='unsqueeze2', axes=[0]) -unsqueeze3 = helper.make_node('Unsqueeze', [unsqueeze_constant_16.name], ['unsqueeze3'], name='unsqueeze3', axes=[0]) -concat = helper.make_node('Concat', ['unsqueeze1', 'unsqueeze2', 'unsqueeze3'], ['concat'], name='concat', axis=0) -reshape = helper.make_node('Reshape', ['matmul2', 'concat'], ['output'], name='reshape') +shape1 = helper.make_node("Shape", ["input"], ["shape1"], name="shape1") +constant_of_shape = helper.make_node("ConstantOfShape", ["shape1"], ["constant_of_shape"], name="constant_of_shape") +transpose = helper.make_node("Transpose", ["constant_of_shape"], ["transpose"], name="transpose", perm=[0, 2, 1]) +matmul1 = helper.make_node("MatMul", ["transpose", matmul_weight_initializer.name], ["matmul1"], name="matmul1") +matmul2 = helper.make_node("MatMul", ["matmul1", "input"], ["matmul2"], name="matmul2") +shape2 = helper.make_node("Shape", ["matmul2"], ["shape2"], name="shape2") +gather1 = helper.make_node("Gather", ["shape2", gather_constant_zero.name], ["gather1"], name="gather1", axis=0) +gather2 = helper.make_node("Gather", ["shape2", gather_constant_one.name], ["gather2"], name="gather2", axis=0) +div = helper.make_node("Div", ["gather2", div_constant_two.name], ["div"], name="div") +unsqueeze1 = helper.make_node("Unsqueeze", ["gather1"], ["unsqueeze1"], name="unsqueeze1", axes=[0]) +unsqueeze2 = helper.make_node("Unsqueeze", ["div"], ["unsqueeze2"], name="unsqueeze2", axes=[0]) +unsqueeze3 = helper.make_node( + "Unsqueeze", + [unsqueeze_constant_16.name], + ["unsqueeze3"], + name="unsqueeze3", + axes=[0], +) +concat = helper.make_node( + "Concat", + ["unsqueeze1", "unsqueeze2", "unsqueeze3"], + ["concat"], + name="concat", + axis=0, +) +reshape = helper.make_node("Reshape", ["matmul2", "concat"], ["output"], name="reshape") # Create the graph (GraphProto) graph_def = helper.make_graph( - [shape1, constant_of_shape, transpose, matmul1, matmul2, shape2, gather1, gather2, div, unsqueeze1, unsqueeze2, unsqueeze3, concat, reshape], - 'constant_folding_with_shape_to_initializer_model', + [ + shape1, + constant_of_shape, + transpose, + matmul1, + matmul2, + shape2, + gather1, + gather2, + div, + unsqueeze1, + unsqueeze2, + unsqueeze3, + concat, + reshape, + ], + "constant_folding_with_shape_to_initializer_model", [X], [Y], - [matmul_weight_initializer, gather_constant_zero, gather_constant_one, div_constant_two, unsqueeze_constant_16] + [ + matmul_weight_initializer, + gather_constant_zero, + gather_constant_one, + div_constant_two, + unsqueeze_constant_16, + ], ) opsets = [] onnxdomain = OperatorSetIdProto() onnxdomain.version = 12 -onnxdomain.domain = "" # The empty string ("") or absence of this field implies the operator set that is defined as part of the ONNX specification. +onnxdomain.domain = "" # The empty string ("") or absence of this field implies the operator set that is defined as part of the ONNX specification. opsets.append(onnxdomain) msdomain = OperatorSetIdProto() msdomain.version = 1 -msdomain.domain = 'com.microsoft' +msdomain.domain = "com.microsoft" opsets.append(msdomain) -kwargs={} -kwargs['opset_imports'] = opsets +kwargs = {} +kwargs["opset_imports"] = opsets # Create the model (ModelProto) -model_def = helper.make_model(graph_def, producer_name='onnx-example', **kwargs) -onnx.save(model_def, 'constant_folding_with_shape_to_initializer.onnx') - +model_def = helper.make_model(graph_def, producer_name="onnx-example", **kwargs) +onnx.save(model_def, "constant_folding_with_shape_to_initializer.onnx") -X = helper.make_tensor_value_info('input', TensorProto.FLOAT, [1]) -Y = helper.make_tensor_value_info('output', TensorProto.FLOAT, [1]) +X = helper.make_tensor_value_info("input", TensorProto.FLOAT, [1]) +Y = helper.make_tensor_value_info("output", TensorProto.FLOAT, [1]) -squeeze = helper.make_node('Squeeze', ['input'], ['squeeze'], name='squeeze', axes=[0]) -shape = helper.make_node('Shape', ['squeeze'], ['shape'], name='shape') -constant_of_shape = helper.make_node('ConstantOfShape', ['shape'], ['constant_of_shape'], name='constant_of_shape') -add = helper.make_node('Add', ['squeeze', 'constant_of_shape'], ['add'], name='add') -unsqueeze = helper.make_node('Unsqueeze', ['add'], ['output'], name='unsqueeze', axes=[0]) +squeeze = helper.make_node("Squeeze", ["input"], ["squeeze"], name="squeeze", axes=[0]) +shape = helper.make_node("Shape", ["squeeze"], ["shape"], name="shape") +constant_of_shape = helper.make_node("ConstantOfShape", ["shape"], ["constant_of_shape"], name="constant_of_shape") +add = helper.make_node("Add", ["squeeze", "constant_of_shape"], ["add"], name="add") +unsqueeze = helper.make_node("Unsqueeze", ["add"], ["output"], name="unsqueeze", axes=[0]) # Create the graph (GraphProto) graph_def = helper.make_graph( [squeeze, shape, constant_of_shape, add, unsqueeze], - 'constant_folding_with_scalar_shape_to_initializer_model', + "constant_folding_with_scalar_shape_to_initializer_model", [X], - [Y] + [Y], ) # Create the model (ModelProto) -model_def = helper.make_model(graph_def, producer_name='onnx-example', **kwargs) -onnx.save(model_def, 'constant_folding_with_scalar_shape_to_initializer.onnx') +model_def = helper.make_model(graph_def, producer_name="onnx-example", **kwargs) +onnx.save(model_def, "constant_folding_with_scalar_shape_to_initializer.onnx") diff --git a/onnxruntime/test/testdata/transform/fusion/create_conv_clip11.py b/onnxruntime/test/testdata/transform/fusion/create_conv_clip11.py index ebd2849515de2..a9934e4625730 100644 --- a/onnxruntime/test/testdata/transform/fusion/create_conv_clip11.py +++ b/onnxruntime/test/testdata/transform/fusion/create_conv_clip11.py @@ -1,17 +1,14 @@ import onnx -from onnx import helper -from onnx import TensorProto +from onnx import TensorProto, helper graph = helper.make_graph( [ # nodes # fusable, min/max from constant inputs. helper.make_node("Conv", ["X0", "W"], ["conv0_out"], "Conv0"), helper.make_node("Clip", ["conv0_out", "const_min", "const_max"], ["clip0_out"], "Clip0"), - # mutable input. no fusion. helper.make_node("Conv", ["X1", "W"], ["conv1_out"], "Conv1"), helper.make_node("Clip", ["conv1_out", "mutable_min", "const_max"], ["clip1_out"], "Clip1"), - # fusable. default min/max. helper.make_node("Conv", ["X2", "W"], ["conv2_out"], "Conv2"), helper.make_node("Clip", ["conv2_out"], ["clip2_out"], "Clip2"), @@ -19,21 +16,22 @@ "ConvClipFusion", # name [ # inputs # each Conv has a distinct X input so that the common subexpression elimination does not combine them - helper.make_tensor_value_info('X0', TensorProto.FLOAT, [1, 1, 7]), - helper.make_tensor_value_info('X1', TensorProto.FLOAT, [1, 1, 7]), - helper.make_tensor_value_info('X2', TensorProto.FLOAT, [1, 1, 7]), - helper.make_tensor_value_info('W', TensorProto.FLOAT, [1, 1, 1]), - helper.make_tensor_value_info('mutable_min', TensorProto.FLOAT, [1]), + helper.make_tensor_value_info("X0", TensorProto.FLOAT, [1, 1, 7]), + helper.make_tensor_value_info("X1", TensorProto.FLOAT, [1, 1, 7]), + helper.make_tensor_value_info("X2", TensorProto.FLOAT, [1, 1, 7]), + helper.make_tensor_value_info("W", TensorProto.FLOAT, [1, 1, 1]), + helper.make_tensor_value_info("mutable_min", TensorProto.FLOAT, [1]), ], [ # outputs - helper.make_tensor_value_info('clip0_out', TensorProto.FLOAT, None), - helper.make_tensor_value_info('clip1_out', TensorProto.FLOAT, None), - helper.make_tensor_value_info('clip2_out', TensorProto.FLOAT, None), + helper.make_tensor_value_info("clip0_out", TensorProto.FLOAT, None), + helper.make_tensor_value_info("clip1_out", TensorProto.FLOAT, None), + helper.make_tensor_value_info("clip2_out", TensorProto.FLOAT, None), ], [ # initializers - helper.make_tensor('const_min', TensorProto.FLOAT, [1], [-1.0]), - helper.make_tensor('const_max', TensorProto.FLOAT, [1], [10.0]) - ]) + helper.make_tensor("const_min", TensorProto.FLOAT, [1], [-1.0]), + helper.make_tensor("const_max", TensorProto.FLOAT, [1], [10.0]), + ], +) -model = helper.make_model(graph, opset_imports=[helper.make_operatorsetid('', 11)]) -onnx.save(model, r'conv_clip11.onnx') +model = helper.make_model(graph, opset_imports=[helper.make_operatorsetid("", 11)]) +onnx.save(model, r"conv_clip11.onnx") diff --git a/onnxruntime/test/testdata/transform/fusion/create_conv_hardsigmoid.py b/onnxruntime/test/testdata/transform/fusion/create_conv_hardsigmoid.py index 5e7c05b50fa13..71328e34b7e49 100644 --- a/onnxruntime/test/testdata/transform/fusion/create_conv_hardsigmoid.py +++ b/onnxruntime/test/testdata/transform/fusion/create_conv_hardsigmoid.py @@ -1,22 +1,21 @@ import onnx -from onnx import helper -from onnx import TensorProto +from onnx import TensorProto, helper graph = helper.make_graph( - [ # nodes + [ # nodes # fusable, const_min_negative should be replaced helper.make_node("Conv", ["X", "W"], ["conv0_out"], "Conv0"), helper.make_node("HardSigmoid", ["conv0_out"], ["hardsigmoid0_out"], "HardSigmoid0"), ], - "ConvClipFusion", #name + "ConvClipFusion", # name [ # inputs - helper.make_tensor_value_info('X', TensorProto.FLOAT, [1, 1, 10, 10]), - helper.make_tensor_value_info('W', TensorProto.FLOAT, [1, 1, 3, 3]), + helper.make_tensor_value_info("X", TensorProto.FLOAT, [1, 1, 10, 10]), + helper.make_tensor_value_info("W", TensorProto.FLOAT, [1, 1, 3, 3]), ], [ # outputs - helper.make_tensor_value_info('hardsigmoid0_out', TensorProto.FLOAT, None), + helper.make_tensor_value_info("hardsigmoid0_out", TensorProto.FLOAT, None), ], ) model = helper.make_model(graph) -onnx.save(model, r'conv_hardsigmoid.onnx') +onnx.save(model, r"conv_hardsigmoid.onnx") diff --git a/onnxruntime/test/testdata/transform/fusion/create_conv_with_padding_relu.py b/onnxruntime/test/testdata/transform/fusion/create_conv_with_padding_relu.py index e403fe50d0c52..ee9892567ae87 100644 --- a/onnxruntime/test/testdata/transform/fusion/create_conv_with_padding_relu.py +++ b/onnxruntime/test/testdata/transform/fusion/create_conv_with_padding_relu.py @@ -1,6 +1,5 @@ import onnx -from onnx import helper -from onnx import TensorProto +from onnx import TensorProto, helper # copied and adapted from here: # https://github.com/onnx/onnx/blob/c940fa3fea84948e46603cab2f86467291443beb/docs/Operators.md?plain=1#L3494-L3502 @@ -8,20 +7,18 @@ graph = helper.make_graph( [ # nodes # Convolution with padding - helper.make_node('Conv', ['x', 'W'], ['y'], - kernel_shape=[3, 3], - pads=[1, 1, 1, 1]), - helper.make_node('Relu', ['y'], ['relu_out']), + helper.make_node("Conv", ["x", "W"], ["y"], kernel_shape=[3, 3], pads=[1, 1, 1, 1]), + helper.make_node("Relu", ["y"], ["relu_out"]), ], "ConvWithPaddingReluFusion", - [ # inputs - helper.make_tensor_value_info('x', TensorProto.FLOAT, [1, 1, 5, 5]), - helper.make_tensor_value_info('W', TensorProto.FLOAT, [1, 1, 3, 3]), + [ # inputs + helper.make_tensor_value_info("x", TensorProto.FLOAT, [1, 1, 5, 5]), + helper.make_tensor_value_info("W", TensorProto.FLOAT, [1, 1, 3, 3]), + ], + [ # outputs + helper.make_tensor_value_info("relu_out", TensorProto.FLOAT, [1, 1, 5, 5]), ], - [ # outputs - helper.make_tensor_value_info('relu_out', TensorProto.FLOAT, [1, 1, 5, 5]), - ] ) model = helper.make_model(graph) -onnx.save(model, r'conv_with_padding_relu.onnx') +onnx.save(model, r"conv_with_padding_relu.onnx") diff --git a/onnxruntime/test/testdata/transform/fusion/div_mul.py b/onnxruntime/test/testdata/transform/fusion/div_mul.py index db05d5b84a22b..7263a986d40ca 100644 --- a/onnxruntime/test/testdata/transform/fusion/div_mul.py +++ b/onnxruntime/test/testdata/transform/fusion/div_mul.py @@ -1,21 +1,22 @@ -import onnx -from onnx import helper -from onnx import TensorProto, OperatorSetIdProto from enum import Enum +import onnx +from onnx import OperatorSetIdProto, TensorProto, helper + opsets = [] onnxdomain = OperatorSetIdProto() onnxdomain.version = 12 -onnxdomain.domain = "" # The empty string ("") or absence of this field implies the operator set that is defined as part of the ONNX specification. +onnxdomain.domain = "" # The empty string ("") or absence of this field implies the operator set that is defined as part of the ONNX specification. opsets.append(onnxdomain) msdomain = OperatorSetIdProto() msdomain.version = 1 -msdomain.domain = 'com.microsoft' +msdomain.domain = "com.microsoft" opsets.append(msdomain) -kwargs={} -kwargs['opset_imports'] = opsets +kwargs = {} +kwargs["opset_imports"] = opsets + def GenerateModel(model_name): nodes = [ # subgraph @@ -40,30 +41,32 @@ def GenerateModel(model_name): ] inputs = [ # inputs - helper.make_tensor_value_info('A', TensorProto.FLOAT, ['M', 'K']), - helper.make_tensor_value_info('B', TensorProto.FLOAT, ['M', 'K']), - helper.make_tensor_value_info('C', TensorProto.FLOAT16, ['M', 'K']), - helper.make_tensor_value_info('D', TensorProto.INT64, ['M', 'K']), - ] + helper.make_tensor_value_info("A", TensorProto.FLOAT, ["M", "K"]), + helper.make_tensor_value_info("B", TensorProto.FLOAT, ["M", "K"]), + helper.make_tensor_value_info("C", TensorProto.FLOAT16, ["M", "K"]), + helper.make_tensor_value_info("D", TensorProto.INT64, ["M", "K"]), + ] initializers = [ - helper.make_tensor('float_1', TensorProto.FLOAT, [1], [1.0]), - helper.make_tensor('float16_1', TensorProto.FLOAT16, [1], [15360]), # 15360 is the fp16 representation of 1.f - helper.make_tensor('int64_1', TensorProto.INT64, [1], [1]), - ] + helper.make_tensor("float_1", TensorProto.FLOAT, [1], [1.0]), + helper.make_tensor("float16_1", TensorProto.FLOAT16, [1], [15360]), # 15360 is the fp16 representation of 1.f + helper.make_tensor("int64_1", TensorProto.INT64, [1], [1]), + ] graph = helper.make_graph( nodes, - "DivMul", #name + "DivMul", # name inputs, [ # outputs - helper.make_tensor_value_info('Y', TensorProto.INT64, ['M', 'K']), - helper.make_tensor_value_info('div_5', TensorProto.FLOAT, ['M', 'K']), + helper.make_tensor_value_info("Y", TensorProto.INT64, ["M", "K"]), + helper.make_tensor_value_info("div_5", TensorProto.FLOAT, ["M", "K"]), ], - initializers) + initializers, + ) model = helper.make_model(graph, **kwargs) onnx.save(model, model_name) + if __name__ == "__main__": - GenerateModel('div_mul.onnx') \ No newline at end of file + GenerateModel("div_mul.onnx") diff --git a/onnxruntime/test/testdata/transform/fusion/dynamic_quantize_matmul.py b/onnxruntime/test/testdata/transform/fusion/dynamic_quantize_matmul.py index cec03dcbdf94f..6eff2e01ec8bf 100644 --- a/onnxruntime/test/testdata/transform/fusion/dynamic_quantize_matmul.py +++ b/onnxruntime/test/testdata/transform/fusion/dynamic_quantize_matmul.py @@ -1,17 +1,24 @@ -import onnx -from onnx import helper -from onnx import TensorProto from enum import Enum -def GenerateModel(model_name, b_has_zp = True, has_bias = False, bias_ND = False): +import onnx +from onnx import TensorProto, helper + + +def GenerateModel(model_name, b_has_zp=True, has_bias=False, bias_ND=False): mul_output = "Mul_output" if has_bias else "output" - nodes = [ # construct graph - helper.make_node("DynamicQuantizeLinear", ["input"], ["a_quantized", "a_scale", "a_zp"], "DynamicQuantizeLinear"), + nodes = [ # construct graph + helper.make_node( + "DynamicQuantizeLinear", + ["input"], + ["a_quantized", "a_scale", "a_zp"], + "DynamicQuantizeLinear", + ), helper.make_node( "MatMulInteger", - ["a_quantized", "b_quantized", "a_zp", "b_zp"] if b_has_zp else ["a_quantized", "b_quantized", "a_zp"], + ["a_quantized", "b_quantized", "a_zp", "b_zp"] if b_has_zp else ["a_quantized", "b_quantized", "a_zp"], ["matmul_output_int32"], - "MatMulInteger"), + "MatMulInteger", + ), helper.make_node("Mul", ["a_scale", "b_scale"], ["multiplier"], "mul_right"), helper.make_node("Cast", ["matmul_output_int32"], ["matmul_output_float"], "cast", to=1), helper.make_node("Mul", ["matmul_output_float", "multiplier"], [mul_output], "mul_bottom"), @@ -24,37 +31,48 @@ def GenerateModel(model_name, b_has_zp = True, has_bias = False, bias_ND = False if has_bias: if bias_ND: - initializers.extend([ # initializers - helper.make_tensor('bias', TensorProto.FLOAT, [3, 3], [3.0, 4.0, 6.0, 3.0, 4.0, 6.0, 3.0, 4.0, 5.0]), - ]) + initializers.extend( + [ # initializers + helper.make_tensor( + "bias", + TensorProto.FLOAT, + [3, 3], + [3.0, 4.0, 6.0, 3.0, 4.0, 6.0, 3.0, 4.0, 5.0], + ), + ] + ) else: - initializers.extend([ # initializers - helper.make_tensor('bias', TensorProto.FLOAT, [3], [3.0, 4.0, 5.0]), - ]) + initializers.extend( + [ # initializers + helper.make_tensor("bias", TensorProto.FLOAT, [3], [3.0, 4.0, 5.0]), + ] + ) inputs = [ # inputs - helper.make_tensor_value_info('input', TensorProto.FLOAT, [3, 2]), - helper.make_tensor_value_info('b_quantized', TensorProto.UINT8, [2,3]), - helper.make_tensor_value_info('b_scale', TensorProto.FLOAT, [1]), - ] + helper.make_tensor_value_info("input", TensorProto.FLOAT, [3, 2]), + helper.make_tensor_value_info("b_quantized", TensorProto.UINT8, [2, 3]), + helper.make_tensor_value_info("b_scale", TensorProto.FLOAT, [1]), + ] if b_has_zp: - inputs.extend([helper.make_tensor_value_info('b_zp', TensorProto.UINT8, [1])]) + inputs.extend([helper.make_tensor_value_info("b_zp", TensorProto.UINT8, [1])]) graph = helper.make_graph( nodes, - "DynamicQuantizeLinear_fusion", #name + "DynamicQuantizeLinear_fusion", # name inputs, [ # outputs - helper.make_tensor_value_info('output', TensorProto.FLOAT, [3, 3]), + helper.make_tensor_value_info("output", TensorProto.FLOAT, [3, 3]), ], - initializers) + initializers, + ) model = helper.make_model(graph) onnx.save(model, model_name) + if __name__ == "__main__": - GenerateModel('dynamic_quantize_matmul.onnx') - GenerateModel('dynamic_quantize_matmul_bias.onnx', True, True) - GenerateModel('dynamic_quantize_matmul_bias_b_no_zp.onnx', False, True) - GenerateModel('dynamic_quantize_matmul_bias_ND.onnx', False, True, True) \ No newline at end of file + GenerateModel("dynamic_quantize_matmul.onnx") + GenerateModel("dynamic_quantize_matmul_bias.onnx", True, True) + GenerateModel("dynamic_quantize_matmul_bias_b_no_zp.onnx", False, True) + GenerateModel("dynamic_quantize_matmul_bias_ND.onnx", False, True, True) diff --git a/onnxruntime/test/testdata/transform/fusion/embed_layer_norm_gen.py b/onnxruntime/test/testdata/transform/fusion/embed_layer_norm_gen.py index 63254f0fd0960..cc1058c37e31f 100644 --- a/onnxruntime/test/testdata/transform/fusion/embed_layer_norm_gen.py +++ b/onnxruntime/test/testdata/transform/fusion/embed_layer_norm_gen.py @@ -1,61 +1,197 @@ -import onnx -from onnx import helper -from onnx import TensorProto from enum import Enum + +import onnx +from onnx import TensorProto, helper from packaging import version -if version.parse(onnx.__version__) == version.parse('1.8.0'): +if version.parse(onnx.__version__) == version.parse("1.8.0"): opset_version = 13 -elif version.parse(onnx.__version__) == version.parse('1.6.0'): +elif version.parse(onnx.__version__) == version.parse("1.6.0"): opset_version = 11 else: raise RuntimeError("Please pip install onnx==1.8.0 or 1.6.0 before running this script") -def GenerateNodes(model_name, has_cast, suffix=''): + +def GenerateNodes(model_name, has_cast, suffix=""): nodes = [ # LayerNorm subgraph helper.make_node("Shape", ["input_ids" + suffix], ["shape1_out" + suffix], "shape1" + suffix), - helper.make_node("Gather", ["shape1_out" + suffix, "indices_0"], ["gather0_out" + suffix], "gather0" + suffix), - helper.make_node("Unsqueeze", ["gather0_out" + suffix, "axes_0"], ["unsqueeze0_out" + suffix], "unsqueeze0" + suffix) if opset_version == 13 \ - else helper.make_node("Unsqueeze", ["gather0_out" + suffix], ["unsqueeze0_out" + suffix], "unsqueeze0" + suffix, axes=[0]), + helper.make_node( + "Gather", + ["shape1_out" + suffix, "indices_0"], + ["gather0_out" + suffix], + "gather0" + suffix, + ), + helper.make_node( + "Unsqueeze", + ["gather0_out" + suffix, "axes_0"], + ["unsqueeze0_out" + suffix], + "unsqueeze0" + suffix, + ) + if opset_version == 13 + else helper.make_node( + "Unsqueeze", + ["gather0_out" + suffix], + ["unsqueeze0_out" + suffix], + "unsqueeze0" + suffix, + axes=[0], + ), helper.make_node("Shape", ["input_ids" + suffix], ["shape2_out" + suffix], "shape2" + suffix), - helper.make_node("Gather", ["shape2_out" + suffix, "indices_1"], ["gather1_out" + suffix], "gather1" + suffix), - helper.make_node("Unsqueeze", ["gather1_out" + suffix, "axes_0"], ["unsqueeze1_out" + suffix], "unsqueeze1" + suffix) if opset_version == 13 \ - else helper.make_node("Unsqueeze", ["gather1_out" + suffix], ["unsqueeze1_out" + suffix], "unsqueeze1" + suffix, axes=[0]), - helper.make_node("Concat", ["unsqueeze0_out" + suffix, "unsqueeze1_out" + suffix], ["concat_out" + suffix], - "concat" + suffix, axis=0), - helper.make_node("Cast", ["gather1_out" + suffix], ["cast_out" + suffix], "cast" + suffix, to=7), - helper.make_node("Range", ["start_0", "cast_out" + suffix if has_cast else "gather1_out" + suffix, "delta_1"], - ["range_out" + suffix], "range" + suffix), - helper.make_node("Unsqueeze", ["range_out" + suffix, "axes_0"], ["unsqueeze2_out" + suffix], "unsqueeze2" + suffix) if opset_version == 13 \ - else helper.make_node("Unsqueeze", ["range_out" + suffix], ["unsqueeze2_out" + suffix], "unsqueeze2" + suffix, axes=[0]), - helper.make_node("Expand", ["unsqueeze2_out" + suffix, "concat_out" + suffix], ["expand_out" + suffix], - "expand" + suffix), - helper.make_node("Gather", ["pos_embed", "expand_out" + suffix], ["pos_gather_out" + suffix], - "pos_gather" + suffix), - helper.make_node("Gather", ["word_embed", "input_ids" + suffix], ["word_gather_out" + suffix], - "word_gather" + suffix), - helper.make_node("Add", ["word_gather_out" + suffix, "pos_gather_out" + suffix], ["word_add_pos_out" + suffix], - "word_add_pos" + suffix), - helper.make_node("Gather", ["seg_embed", "segment_ids" + suffix], ["seg_gather_out" + suffix], - "seg_gather" + suffix), - helper.make_node("Add", ["word_add_pos_out" + suffix, "seg_gather_out" + suffix], ["add3_out" + suffix], - "add3" + suffix), - helper.make_node("LayerNormalization", ["add3_out" + suffix, "layer_norm_weight", "layer_norm_bias"], - ["layernorm_out" + suffix], - "layernorm" + suffix, - axis=-1, - epsion=0.000009999999747378752), - helper.make_node("Cast", ["input_mask" + suffix], ["mask_cast_out" + suffix], "mask_cast" + suffix, to=6), - helper.make_node("ReduceSum", ["mask_cast_out" + suffix, "axes_1"], ["mask_index_out" + suffix], "mask_index" + suffix, keepdims=0) if opset_version == 13 \ - else helper.make_node("ReduceSum", ["mask_cast_out" + suffix], ["mask_index_out" + suffix], "mask_index" + suffix, axes=[1], keepdims=0), - helper.make_node("Attention", ["layernorm_out" + suffix, "qkv_weights", "qkv_bias", "mask_index_out" + suffix], - ["att_out" + suffix], - "att" + suffix, - domain="com.microsoft", - num_heads=2), - helper.make_node("MatMul", ["att_out" + suffix, "matmul_weight"], ["matmul_out" + suffix], "matmul" + suffix), - helper.make_node("Add", ["matmul_out" + suffix, "add_bias"], ["add_out" + suffix], "add" + suffix), - helper.make_node("Add", ["add_out" + suffix, "layernorm_out" + suffix], ["add2_out" + suffix], "add2" + suffix) + helper.make_node( + "Gather", + ["shape2_out" + suffix, "indices_1"], + ["gather1_out" + suffix], + "gather1" + suffix, + ), + helper.make_node( + "Unsqueeze", + ["gather1_out" + suffix, "axes_0"], + ["unsqueeze1_out" + suffix], + "unsqueeze1" + suffix, + ) + if opset_version == 13 + else helper.make_node( + "Unsqueeze", + ["gather1_out" + suffix], + ["unsqueeze1_out" + suffix], + "unsqueeze1" + suffix, + axes=[0], + ), + helper.make_node( + "Concat", + ["unsqueeze0_out" + suffix, "unsqueeze1_out" + suffix], + ["concat_out" + suffix], + "concat" + suffix, + axis=0, + ), + helper.make_node( + "Cast", + ["gather1_out" + suffix], + ["cast_out" + suffix], + "cast" + suffix, + to=7, + ), + helper.make_node( + "Range", + [ + "start_0", + "cast_out" + suffix if has_cast else "gather1_out" + suffix, + "delta_1", + ], + ["range_out" + suffix], + "range" + suffix, + ), + helper.make_node( + "Unsqueeze", + ["range_out" + suffix, "axes_0"], + ["unsqueeze2_out" + suffix], + "unsqueeze2" + suffix, + ) + if opset_version == 13 + else helper.make_node( + "Unsqueeze", + ["range_out" + suffix], + ["unsqueeze2_out" + suffix], + "unsqueeze2" + suffix, + axes=[0], + ), + helper.make_node( + "Expand", + ["unsqueeze2_out" + suffix, "concat_out" + suffix], + ["expand_out" + suffix], + "expand" + suffix, + ), + helper.make_node( + "Gather", + ["pos_embed", "expand_out" + suffix], + ["pos_gather_out" + suffix], + "pos_gather" + suffix, + ), + helper.make_node( + "Gather", + ["word_embed", "input_ids" + suffix], + ["word_gather_out" + suffix], + "word_gather" + suffix, + ), + helper.make_node( + "Add", + ["word_gather_out" + suffix, "pos_gather_out" + suffix], + ["word_add_pos_out" + suffix], + "word_add_pos" + suffix, + ), + helper.make_node( + "Gather", + ["seg_embed", "segment_ids" + suffix], + ["seg_gather_out" + suffix], + "seg_gather" + suffix, + ), + helper.make_node( + "Add", + ["word_add_pos_out" + suffix, "seg_gather_out" + suffix], + ["add3_out" + suffix], + "add3" + suffix, + ), + helper.make_node( + "LayerNormalization", + ["add3_out" + suffix, "layer_norm_weight", "layer_norm_bias"], + ["layernorm_out" + suffix], + "layernorm" + suffix, + axis=-1, + epsion=0.000009999999747378752, + ), + helper.make_node( + "Cast", + ["input_mask" + suffix], + ["mask_cast_out" + suffix], + "mask_cast" + suffix, + to=6, + ), + helper.make_node( + "ReduceSum", + ["mask_cast_out" + suffix, "axes_1"], + ["mask_index_out" + suffix], + "mask_index" + suffix, + keepdims=0, + ) + if opset_version == 13 + else helper.make_node( + "ReduceSum", + ["mask_cast_out" + suffix], + ["mask_index_out" + suffix], + "mask_index" + suffix, + axes=[1], + keepdims=0, + ), + helper.make_node( + "Attention", + [ + "layernorm_out" + suffix, + "qkv_weights", + "qkv_bias", + "mask_index_out" + suffix, + ], + ["att_out" + suffix], + "att" + suffix, + domain="com.microsoft", + num_heads=2, + ), + helper.make_node( + "MatMul", + ["att_out" + suffix, "matmul_weight"], + ["matmul_out" + suffix], + "matmul" + suffix, + ), + helper.make_node( + "Add", + ["matmul_out" + suffix, "add_bias"], + ["add_out" + suffix], + "add" + suffix, + ), + helper.make_node( + "Add", + ["add_out" + suffix, "layernorm_out" + suffix], + ["add2_out" + suffix], + "add2" + suffix, + ), ] if not has_cast: @@ -66,32 +202,88 @@ def GenerateNodes(model_name, has_cast, suffix=''): def GenerateInitializers(): # hidden_size=4, num_heads=2 initializers = [ # initializers - helper.make_tensor('indices_0', TensorProto.INT64, [], [0]), - helper.make_tensor('indices_1', TensorProto.INT64, [], [1]), - helper.make_tensor('start_0', TensorProto.INT64, [], [0]), - helper.make_tensor('delta_1', TensorProto.INT64, [], [1]), - helper.make_tensor('word_embed', TensorProto.FLOAT, [2, 4], [1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0]), - helper.make_tensor('pos_embed', TensorProto.FLOAT, [4, 4], - [1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0]), - helper.make_tensor('seg_embed', TensorProto.FLOAT, [2, 4], [1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0]), - helper.make_tensor('layer_norm_weight', TensorProto.FLOAT, [4], [1.0, 2.0, 3.0, 4.0]), - helper.make_tensor('layer_norm_bias', TensorProto.FLOAT, [4], [0.1, 0.2, 0.3, 0.4]), - helper.make_tensor('qkv_weights', TensorProto.FLOAT, [4, 12], [0.1] * 4 * 12), - helper.make_tensor('qkv_bias', TensorProto.FLOAT, [12], - [0.1, 0.2, 0.3, 0.4, 0.1, 0.2, 0.3, 0.4, 0.1, 0.2, 0.3, 0.4]), - helper.make_tensor('matmul_weight', TensorProto.FLOAT, [4, 4], - [1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0]), - helper.make_tensor('add_bias', TensorProto.FLOAT, [4], [0.1, 0.2, 0.3, 0.4]), - helper.make_tensor('axes_0', TensorProto.INT64, [1], [0]), - helper.make_tensor('axes_1', TensorProto.INT64, [1], [1]), + helper.make_tensor("indices_0", TensorProto.INT64, [], [0]), + helper.make_tensor("indices_1", TensorProto.INT64, [], [1]), + helper.make_tensor("start_0", TensorProto.INT64, [], [0]), + helper.make_tensor("delta_1", TensorProto.INT64, [], [1]), + helper.make_tensor( + "word_embed", + TensorProto.FLOAT, + [2, 4], + [1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0], + ), + helper.make_tensor( + "pos_embed", + TensorProto.FLOAT, + [4, 4], + [ + 1.0, + 2.0, + 3.0, + 4.0, + 1.0, + 2.0, + 3.0, + 4.0, + 1.0, + 2.0, + 3.0, + 4.0, + 1.0, + 2.0, + 3.0, + 4.0, + ], + ), + helper.make_tensor( + "seg_embed", + TensorProto.FLOAT, + [2, 4], + [1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0], + ), + helper.make_tensor("layer_norm_weight", TensorProto.FLOAT, [4], [1.0, 2.0, 3.0, 4.0]), + helper.make_tensor("layer_norm_bias", TensorProto.FLOAT, [4], [0.1, 0.2, 0.3, 0.4]), + helper.make_tensor("qkv_weights", TensorProto.FLOAT, [4, 12], [0.1] * 4 * 12), + helper.make_tensor( + "qkv_bias", + TensorProto.FLOAT, + [12], + [0.1, 0.2, 0.3, 0.4, 0.1, 0.2, 0.3, 0.4, 0.1, 0.2, 0.3, 0.4], + ), + helper.make_tensor( + "matmul_weight", + TensorProto.FLOAT, + [4, 4], + [ + 1.0, + 2.0, + 3.0, + 4.0, + 1.0, + 2.0, + 3.0, + 4.0, + 1.0, + 2.0, + 3.0, + 4.0, + 1.0, + 2.0, + 3.0, + 4.0, + ], + ), + helper.make_tensor("add_bias", TensorProto.FLOAT, [4], [0.1, 0.2, 0.3, 0.4]), + helper.make_tensor("axes_0", TensorProto.INT64, [1], [0]), + helper.make_tensor("axes_1", TensorProto.INT64, [1], [1]), ] return initializers def GenerateMultipleEmbedModel(model_name): - nodes_1 = GenerateNodes(model_name, False, '_1') - nodes_2 = GenerateNodes(model_name, False, '_2') + nodes_1 = GenerateNodes(model_name, False, "_1") + nodes_2 = GenerateNodes(model_name, False, "_2") nodes = nodes_1 + nodes_2 nodes.append(helper.make_node("Add", ["add2_out_1", "add2_out_2"], ["add3_out"], "add3")) @@ -100,19 +292,20 @@ def GenerateMultipleEmbedModel(model_name): graph = helper.make_graph( nodes, - "EmbedLayerNorm_format3", #name + "EmbedLayerNorm_format3", # name [ # inputs - helper.make_tensor_value_info('input_ids_1', TensorProto.INT64, ['batch', 3]), - helper.make_tensor_value_info('segment_ids_1', TensorProto.INT64, ['batch', 3]), - helper.make_tensor_value_info('input_mask_1', TensorProto.INT64, ['batch', 3]), - helper.make_tensor_value_info('input_ids_2', TensorProto.INT64, ['batch', 3]), - helper.make_tensor_value_info('segment_ids_2', TensorProto.INT64, ['batch', 3]), - helper.make_tensor_value_info('input_mask_2', TensorProto.INT64, ['batch', 3]), + helper.make_tensor_value_info("input_ids_1", TensorProto.INT64, ["batch", 3]), + helper.make_tensor_value_info("segment_ids_1", TensorProto.INT64, ["batch", 3]), + helper.make_tensor_value_info("input_mask_1", TensorProto.INT64, ["batch", 3]), + helper.make_tensor_value_info("input_ids_2", TensorProto.INT64, ["batch", 3]), + helper.make_tensor_value_info("segment_ids_2", TensorProto.INT64, ["batch", 3]), + helper.make_tensor_value_info("input_mask_2", TensorProto.INT64, ["batch", 3]), ], [ # outputs - helper.make_tensor_value_info('add3_out', TensorProto.FLOAT, ['batch', 3, 4]), + helper.make_tensor_value_info("add3_out", TensorProto.FLOAT, ["batch", 3, 4]), ], - initializers) + initializers, + ) model = helper.make_model(graph) onnx.save(model, model_name) @@ -126,16 +319,17 @@ def GenerateModel3(model_name, has_cast): graph = helper.make_graph( nodes, - "EmbedLayerNorm_format3", #name + "EmbedLayerNorm_format3", # name [ # inputs - helper.make_tensor_value_info('input_ids', TensorProto.INT64, ['batch', 3]), - helper.make_tensor_value_info('segment_ids', TensorProto.INT64, ['batch', 3]), - helper.make_tensor_value_info('input_mask', TensorProto.INT64, ['batch', 3]), + helper.make_tensor_value_info("input_ids", TensorProto.INT64, ["batch", 3]), + helper.make_tensor_value_info("segment_ids", TensorProto.INT64, ["batch", 3]), + helper.make_tensor_value_info("input_mask", TensorProto.INT64, ["batch", 3]), ], [ # outputs - helper.make_tensor_value_info('add2_out', TensorProto.FLOAT, ['batch', 3, 4]), + helper.make_tensor_value_info("add2_out", TensorProto.FLOAT, ["batch", 3, 4]), ], - initializers) + initializers, + ) model = helper.make_model(graph) onnx.save(model, model_name) @@ -148,59 +342,169 @@ def GenerateModel5(model_name): sequence_length = 3 nodes = [ - helper.make_node("Gather", ["word_embed", "input_ids"], ["word_gather_out"], "word_gather", axis=0), - helper.make_node("Add", ["word_gather_out", "pos_gather_out"], ["word_add_pos_out"], "word_add_pos"), - helper.make_node("Gather", ["seg_embed", "segment_ids"], ["seg_gather_out"], "seg_gather", axis=0), + helper.make_node( + "Gather", + ["word_embed", "input_ids"], + ["word_gather_out"], + "word_gather", + axis=0, + ), + helper.make_node( + "Add", + ["word_gather_out", "pos_gather_out"], + ["word_add_pos_out"], + "word_add_pos", + ), + helper.make_node( + "Gather", + ["seg_embed", "segment_ids"], + ["seg_gather_out"], + "seg_gather", + axis=0, + ), helper.make_node("Add", ["word_add_pos_out", "seg_gather_out"], ["add3_out"], "add3"), - helper.make_node("LayerNormalization", ["add3_out", "layer_norm_weight", "layer_norm_bias"], ["layernorm_out"], - "layernorm", - axis=-1, - epsion=0.000009999999747378752), + helper.make_node( + "LayerNormalization", + ["add3_out", "layer_norm_weight", "layer_norm_bias"], + ["layernorm_out"], + "layernorm", + axis=-1, + epsion=0.000009999999747378752, + ), helper.make_node("Cast", ["input_mask"], ["mask_cast_out"], "mask_cast", to=6), - helper.make_node("ReduceSum", ["mask_cast_out", "axes_1"], ["mask_index_out"], "mask_index", keepdims=0) if opset_version == 13 \ - else helper.make_node("ReduceSum", ["mask_cast_out"], ["mask_index_out"], "mask_index", axes=[1], keepdims=0), - - helper.make_node("Attention", ["layernorm_out", "qkv_weights", "qkv_bias", "mask_index_out"], ["att_out"], - "att", - domain="com.microsoft", - num_heads=attention_heads), + helper.make_node( + "ReduceSum", + ["mask_cast_out", "axes_1"], + ["mask_index_out"], + "mask_index", + keepdims=0, + ) + if opset_version == 13 + else helper.make_node( + "ReduceSum", + ["mask_cast_out"], + ["mask_index_out"], + "mask_index", + axes=[1], + keepdims=0, + ), + helper.make_node( + "Attention", + ["layernorm_out", "qkv_weights", "qkv_bias", "mask_index_out"], + ["att_out"], + "att", + domain="com.microsoft", + num_heads=attention_heads, + ), helper.make_node("MatMul", ["att_out", "matmul_weight"], ["matmul_out"], "matmul"), helper.make_node("Add", ["matmul_out", "add_bias"], ["add_out"], "add"), - helper.make_node("Add", ["add_out", "layernorm_out"], ["add2_out"], "add2") + helper.make_node("Add", ["add_out", "layernorm_out"], ["add2_out"], "add2"), ] qkv_weights = [1.0] * hidden_size * (3 * hidden_size) initializers = [ # initializers - helper.make_tensor('word_embed', TensorProto.FLOAT, [2, hidden_size], [1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0]), - helper.make_tensor('pos_gather_out', TensorProto.FLOAT, [batch_size, sequence_length, hidden_size], [ - 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 8.0, 7.0, 6.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, - 8.0, 7.0, 6.0 - ]), - helper.make_tensor('seg_embed', TensorProto.FLOAT, [2, hidden_size], [1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0]), - helper.make_tensor('layer_norm_weight', TensorProto.FLOAT, [hidden_size], [1.0, 2.0, 3.0, 4.0]), - helper.make_tensor('layer_norm_bias', TensorProto.FLOAT, [hidden_size], [0.1, 0.2, 0.3, 0.4]), - helper.make_tensor('qkv_weights', TensorProto.FLOAT, [hidden_size, 3 * hidden_size], qkv_weights), - helper.make_tensor('qkv_bias', TensorProto.FLOAT, [3 * hidden_size], - [0.1, 0.2, 0.3, 0.4, 0.1, 0.2, 0.3, 0.4, 0.1, 0.2, 0.3, 0.4]), - helper.make_tensor('matmul_weight', TensorProto.FLOAT, [hidden_size, hidden_size], - [1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0]), - helper.make_tensor('add_bias', TensorProto.FLOAT, [hidden_size], [0.1, 0.2, 0.3, 0.4]), - helper.make_tensor('axes_1', TensorProto.INT64, [1], [1]), + helper.make_tensor( + "word_embed", + TensorProto.FLOAT, + [2, hidden_size], + [1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0], + ), + helper.make_tensor( + "pos_gather_out", + TensorProto.FLOAT, + [batch_size, sequence_length, hidden_size], + [ + 1.0, + 2.0, + 3.0, + 4.0, + 5.0, + 6.0, + 7.0, + 8.0, + 9.0, + 8.0, + 7.0, + 6.0, + 1.0, + 2.0, + 3.0, + 4.0, + 5.0, + 6.0, + 7.0, + 8.0, + 9.0, + 8.0, + 7.0, + 6.0, + ], + ), + helper.make_tensor( + "seg_embed", + TensorProto.FLOAT, + [2, hidden_size], + [1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0], + ), + helper.make_tensor("layer_norm_weight", TensorProto.FLOAT, [hidden_size], [1.0, 2.0, 3.0, 4.0]), + helper.make_tensor("layer_norm_bias", TensorProto.FLOAT, [hidden_size], [0.1, 0.2, 0.3, 0.4]), + helper.make_tensor( + "qkv_weights", + TensorProto.FLOAT, + [hidden_size, 3 * hidden_size], + qkv_weights, + ), + helper.make_tensor( + "qkv_bias", + TensorProto.FLOAT, + [3 * hidden_size], + [0.1, 0.2, 0.3, 0.4, 0.1, 0.2, 0.3, 0.4, 0.1, 0.2, 0.3, 0.4], + ), + helper.make_tensor( + "matmul_weight", + TensorProto.FLOAT, + [hidden_size, hidden_size], + [ + 1.0, + 2.0, + 3.0, + 4.0, + 1.0, + 2.0, + 3.0, + 4.0, + 1.0, + 2.0, + 3.0, + 4.0, + 1.0, + 2.0, + 3.0, + 4.0, + ], + ), + helper.make_tensor("add_bias", TensorProto.FLOAT, [hidden_size], [0.1, 0.2, 0.3, 0.4]), + helper.make_tensor("axes_1", TensorProto.INT64, [1], [1]), ] graph = helper.make_graph( nodes, - "EmbedLayerNorm_format5", #name + "EmbedLayerNorm_format5", # name [ # inputs - helper.make_tensor_value_info('input_ids', TensorProto.INT64, [batch_size, sequence_length]), - helper.make_tensor_value_info('segment_ids', TensorProto.INT64, [batch_size, sequence_length]), - helper.make_tensor_value_info('input_mask', TensorProto.INT64, [batch_size, sequence_length]), + helper.make_tensor_value_info("input_ids", TensorProto.INT64, [batch_size, sequence_length]), + helper.make_tensor_value_info("segment_ids", TensorProto.INT64, [batch_size, sequence_length]), + helper.make_tensor_value_info("input_mask", TensorProto.INT64, [batch_size, sequence_length]), ], [ # outputs - helper.make_tensor_value_info('add2_out', TensorProto.FLOAT, [batch_size, sequence_length, hidden_size]), + helper.make_tensor_value_info( + "add2_out", + TensorProto.FLOAT, + [batch_size, sequence_length, hidden_size], + ), ], - initializers) + initializers, + ) model = helper.make_model(graph) onnx.save(model, model_name) @@ -210,77 +514,166 @@ def GenerateModel6(model_name): nodes = [ # LayerNorm subgraph helper.make_node("Shape", ["input_ids"], ["shape1_out"], "shape1"), helper.make_node("Gather", ["shape1_out", "indices_0"], ["gather0_out"], "gather0"), - helper.make_node("Unsqueeze", ["gather0_out", "axes_0"], ["unsqueeze0_out"], "unsqueeze0") if opset_version == 13 \ - else helper.make_node("Unsqueeze", ["gather0_out"], ["unsqueeze0_out"], "unsqueeze0", axes=[0]), + helper.make_node("Unsqueeze", ["gather0_out", "axes_0"], ["unsqueeze0_out"], "unsqueeze0") + if opset_version == 13 + else helper.make_node("Unsqueeze", ["gather0_out"], ["unsqueeze0_out"], "unsqueeze0", axes=[0]), helper.make_node("Shape", ["input_ids"], ["shape2_out"], "shape2"), helper.make_node("Gather", ["shape2_out", "indices_1"], ["gather1_out"], "gather1"), - helper.make_node("Unsqueeze", ["gather1_out", "axes_0"], ["unsqueeze1_out"], "unsqueeze1") if opset_version == 13 \ - else helper.make_node("Unsqueeze", ["gather1_out"], ["unsqueeze1_out"], "unsqueeze1", axes=[0]), - helper.make_node("Concat", ["unsqueeze0_out", "unsqueeze1_out"], ["concat_out"], "concat", axis=0), + helper.make_node("Unsqueeze", ["gather1_out", "axes_0"], ["unsqueeze1_out"], "unsqueeze1") + if opset_version == 13 + else helper.make_node("Unsqueeze", ["gather1_out"], ["unsqueeze1_out"], "unsqueeze1", axes=[0]), + helper.make_node( + "Concat", + ["unsqueeze0_out", "unsqueeze1_out"], + ["concat_out"], + "concat", + axis=0, + ), helper.make_node("Reshape", ["concat_out", "reshape_init"], ["reshape_out"], "reshape"), helper.make_node("Equal", ["reshape_out", "equal_init"], ["equal_out"], "equal"), helper.make_node("Where", ["equal_out", "where_init", "reshape_out"], ["where_out"], "where"), helper.make_node("Range", ["start_0", "gather1_out", "delta_1"], ["range_out"], "range"), - helper.make_node("Unsqueeze", ["range_out", "axes_0"], ["unsqueeze2_out"], "unsqueeze2") if opset_version == 13 \ - else helper.make_node("Unsqueeze", ["range_out"], ["unsqueeze2_out"], "unsqueeze2", axes=[0]), + helper.make_node("Unsqueeze", ["range_out", "axes_0"], ["unsqueeze2_out"], "unsqueeze2") + if opset_version == 13 + else helper.make_node("Unsqueeze", ["range_out"], ["unsqueeze2_out"], "unsqueeze2", axes=[0]), helper.make_node("Expand", ["unsqueeze2_out", "where_out"], ["expand_out"], "expand"), helper.make_node("Gather", ["pos_embed", "expand_out"], ["pos_gather_out"], "pos_gather"), helper.make_node("Gather", ["word_embed", "input_ids"], ["word_gather_out"], "word_gather"), - helper.make_node("Add", ["word_gather_out", "pos_gather_out"], ["word_add_pos_out"], "word_add_pos"), + helper.make_node( + "Add", + ["word_gather_out", "pos_gather_out"], + ["word_add_pos_out"], + "word_add_pos", + ), helper.make_node("Gather", ["seg_embed", "segment_ids"], ["seg_gather_out"], "seg_gather"), helper.make_node("Add", ["word_add_pos_out", "seg_gather_out"], ["add3_out"], "add3"), - helper.make_node("LayerNormalization", ["add3_out", "layer_norm_weight", "layer_norm_bias"], ["layernorm_out"], - "layernorm", - axis=-1, - epsion=0.000009999999747378752), + helper.make_node( + "LayerNormalization", + ["add3_out", "layer_norm_weight", "layer_norm_bias"], + ["layernorm_out"], + "layernorm", + axis=-1, + epsion=0.000009999999747378752, + ), helper.make_node("Cast", ["input_mask"], ["mask_cast_out"], "mask_cast", to=6), - helper.make_node("ReduceSum", ["mask_cast_out", "axes_1"], ["mask_index_out"], "mask_index", keepdims=0) if opset_version == 13 \ - else helper.make_node("ReduceSum", ["mask_cast_out"], ["mask_index_out"], "mask_index", axes=[1], keepdims=0), - helper.make_node("Attention", ["layernorm_out", "qkv_weights", "qkv_bias", "mask_index_out"], ["att_out"], - "att", - domain="com.microsoft", - num_heads=2), + helper.make_node( + "ReduceSum", + ["mask_cast_out", "axes_1"], + ["mask_index_out"], + "mask_index", + keepdims=0, + ) + if opset_version == 13 + else helper.make_node( + "ReduceSum", + ["mask_cast_out"], + ["mask_index_out"], + "mask_index", + axes=[1], + keepdims=0, + ), + helper.make_node( + "Attention", + ["layernorm_out", "qkv_weights", "qkv_bias", "mask_index_out"], + ["att_out"], + "att", + domain="com.microsoft", + num_heads=2, + ), helper.make_node("MatMul", ["att_out", "matmul_weight"], ["matmul_out"], "matmul"), helper.make_node("Add", ["matmul_out", "add_bias"], ["add_out"], "add"), - helper.make_node("Add", ["add_out", "layernorm_out"], ["add2_out"], "add2") + helper.make_node("Add", ["add_out", "layernorm_out"], ["add2_out"], "add2"), ] # hidden_size=4, num_heads=2, max_seq_length=3 initializers = [ # initializers - helper.make_tensor('indices_0', TensorProto.INT64, [], [0]), - helper.make_tensor('indices_1', TensorProto.INT64, [], [1]), - helper.make_tensor('start_0', TensorProto.INT64, [], [0]), - helper.make_tensor('delta_1', TensorProto.INT64, [], [1]), - helper.make_tensor('word_embed', TensorProto.FLOAT, [2, 4], [1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0]), - helper.make_tensor('pos_embed', TensorProto.FLOAT, [4, 4], - [1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0]), - helper.make_tensor('seg_embed', TensorProto.FLOAT, [2, 4], [1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0]), - helper.make_tensor('layer_norm_weight', TensorProto.FLOAT, [4], [1.0, 2.0, 3.0, 4.0]), - helper.make_tensor('layer_norm_bias', TensorProto.FLOAT, [4], [0.1, 0.2, 0.3, 0.4]), - helper.make_tensor('qkv_weights', TensorProto.FLOAT, [4, 12], [0.1] * 4 * 12), - helper.make_tensor('qkv_bias', TensorProto.FLOAT, [12], [0.1] * 12), - helper.make_tensor('matmul_weight', TensorProto.FLOAT, [4, 4], - [1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0]), - helper.make_tensor('add_bias', TensorProto.FLOAT, [4], [0.1, 0.2, 0.3, 0.4]), - helper.make_tensor('reshape_init', TensorProto.INT64, [1], [-1]), - helper.make_tensor('equal_init', TensorProto.INT64, [2], [-1, -1]), - helper.make_tensor('where_init', TensorProto.INT64, [2], [1, 1]), - helper.make_tensor('axes_0', TensorProto.INT64, [1], [0]), - helper.make_tensor('axes_1', TensorProto.INT64, [1], [1]), + helper.make_tensor("indices_0", TensorProto.INT64, [], [0]), + helper.make_tensor("indices_1", TensorProto.INT64, [], [1]), + helper.make_tensor("start_0", TensorProto.INT64, [], [0]), + helper.make_tensor("delta_1", TensorProto.INT64, [], [1]), + helper.make_tensor( + "word_embed", + TensorProto.FLOAT, + [2, 4], + [1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0], + ), + helper.make_tensor( + "pos_embed", + TensorProto.FLOAT, + [4, 4], + [ + 1.0, + 2.0, + 3.0, + 4.0, + 1.0, + 2.0, + 3.0, + 4.0, + 1.0, + 2.0, + 3.0, + 4.0, + 1.0, + 2.0, + 3.0, + 4.0, + ], + ), + helper.make_tensor( + "seg_embed", + TensorProto.FLOAT, + [2, 4], + [1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0], + ), + helper.make_tensor("layer_norm_weight", TensorProto.FLOAT, [4], [1.0, 2.0, 3.0, 4.0]), + helper.make_tensor("layer_norm_bias", TensorProto.FLOAT, [4], [0.1, 0.2, 0.3, 0.4]), + helper.make_tensor("qkv_weights", TensorProto.FLOAT, [4, 12], [0.1] * 4 * 12), + helper.make_tensor("qkv_bias", TensorProto.FLOAT, [12], [0.1] * 12), + helper.make_tensor( + "matmul_weight", + TensorProto.FLOAT, + [4, 4], + [ + 1.0, + 2.0, + 3.0, + 4.0, + 1.0, + 2.0, + 3.0, + 4.0, + 1.0, + 2.0, + 3.0, + 4.0, + 1.0, + 2.0, + 3.0, + 4.0, + ], + ), + helper.make_tensor("add_bias", TensorProto.FLOAT, [4], [0.1, 0.2, 0.3, 0.4]), + helper.make_tensor("reshape_init", TensorProto.INT64, [1], [-1]), + helper.make_tensor("equal_init", TensorProto.INT64, [2], [-1, -1]), + helper.make_tensor("where_init", TensorProto.INT64, [2], [1, 1]), + helper.make_tensor("axes_0", TensorProto.INT64, [1], [0]), + helper.make_tensor("axes_1", TensorProto.INT64, [1], [1]), ] graph = helper.make_graph( nodes, - "EmbedLayerNorm_format6", #name + "EmbedLayerNorm_format6", # name [ # inputs - helper.make_tensor_value_info('input_ids', TensorProto.INT64, ['batch', 3]), - helper.make_tensor_value_info('segment_ids', TensorProto.INT64, ['batch', 3]), - helper.make_tensor_value_info('input_mask', TensorProto.INT64, ['batch', 3]), + helper.make_tensor_value_info("input_ids", TensorProto.INT64, ["batch", 3]), + helper.make_tensor_value_info("segment_ids", TensorProto.INT64, ["batch", 3]), + helper.make_tensor_value_info("input_mask", TensorProto.INT64, ["batch", 3]), ], [ # outputs - helper.make_tensor_value_info('add2_out', TensorProto.FLOAT, ['batch', 3, 4]), + helper.make_tensor_value_info("add2_out", TensorProto.FLOAT, ["batch", 3, 4]), ], - initializers) + initializers, + ) model = helper.make_model(graph) onnx.save(model, model_name) @@ -290,22 +683,62 @@ def GenerateInitializers2(hidden_size): qkv_weights = [1.0] * hidden_size * (3 * hidden_size) initializers = [ # initializers - helper.make_tensor('word_embed', TensorProto.FLOAT, [2, hidden_size], [1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0]), - helper.make_tensor('pos_embed', TensorProto.FLOAT, [2, hidden_size], [1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0]), - helper.make_tensor('indices_0', TensorProto.INT64, [], [0]), - helper.make_tensor('indices_1', TensorProto.INT64, [], [1]), - helper.make_tensor('start', TensorProto.INT64, [], [0]), - helper.make_tensor('delta', TensorProto.INT64, [], [1]), - helper.make_tensor('layer_norm_weight', TensorProto.FLOAT, [hidden_size], [1.0, 2.0, 3.0, 4.0]), - helper.make_tensor('layer_norm_bias', TensorProto.FLOAT, [hidden_size], [0.1, 0.2, 0.3, 0.4]), - helper.make_tensor('qkv_weights', TensorProto.FLOAT, [hidden_size, 3 * hidden_size], qkv_weights), - helper.make_tensor('qkv_bias', TensorProto.FLOAT, [3 * hidden_size], - [0.1, 0.2, 0.3, 0.4, 0.1, 0.2, 0.3, 0.4, 0.1, 0.2, 0.3, 0.4]), - helper.make_tensor('matmul_weight', TensorProto.FLOAT, [hidden_size, hidden_size], - [1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0]), - helper.make_tensor('add_bias', TensorProto.FLOAT, [hidden_size], [0.1, 0.2, 0.3, 0.4]), - helper.make_tensor('axes_0', TensorProto.INT64, [1], [0]), - helper.make_tensor('axes_1', TensorProto.INT64, [1], [1]), + helper.make_tensor( + "word_embed", + TensorProto.FLOAT, + [2, hidden_size], + [1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0], + ), + helper.make_tensor( + "pos_embed", + TensorProto.FLOAT, + [2, hidden_size], + [1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0, 4.0], + ), + helper.make_tensor("indices_0", TensorProto.INT64, [], [0]), + helper.make_tensor("indices_1", TensorProto.INT64, [], [1]), + helper.make_tensor("start", TensorProto.INT64, [], [0]), + helper.make_tensor("delta", TensorProto.INT64, [], [1]), + helper.make_tensor("layer_norm_weight", TensorProto.FLOAT, [hidden_size], [1.0, 2.0, 3.0, 4.0]), + helper.make_tensor("layer_norm_bias", TensorProto.FLOAT, [hidden_size], [0.1, 0.2, 0.3, 0.4]), + helper.make_tensor( + "qkv_weights", + TensorProto.FLOAT, + [hidden_size, 3 * hidden_size], + qkv_weights, + ), + helper.make_tensor( + "qkv_bias", + TensorProto.FLOAT, + [3 * hidden_size], + [0.1, 0.2, 0.3, 0.4, 0.1, 0.2, 0.3, 0.4, 0.1, 0.2, 0.3, 0.4], + ), + helper.make_tensor( + "matmul_weight", + TensorProto.FLOAT, + [hidden_size, hidden_size], + [ + 1.0, + 2.0, + 3.0, + 4.0, + 1.0, + 2.0, + 3.0, + 4.0, + 1.0, + 2.0, + 3.0, + 4.0, + 1.0, + 2.0, + 3.0, + 4.0, + ], + ), + helper.make_tensor("add_bias", TensorProto.FLOAT, [hidden_size], [0.1, 0.2, 0.3, 0.4]), + helper.make_tensor("axes_0", TensorProto.INT64, [1], [0]), + helper.make_tensor("axes_1", TensorProto.INT64, [1], [1]), ] return initializers @@ -313,30 +746,65 @@ def GenerateInitializers2(hidden_size): def GenerateNodes2(attention_heads): nodes = [ - helper.make_node("Gather", ["word_embed", "input_ids"], ["word_gather_out"], "word_gather", axis=0), + helper.make_node( + "Gather", + ["word_embed", "input_ids"], + ["word_gather_out"], + "word_gather", + axis=0, + ), helper.make_node("Shape", ["input_ids"], ["shape0_out"], "shape0"), helper.make_node("Gather", ["shape0_out", "indices_1"], ["gather0_out"], "gather0"), helper.make_node("Range", ["start", "gather0_out", "delta"], ["range0_out"], "range0"), - helper.make_node("Unsqueeze", ["range0_out", "axes_0"], ["unsqueeze0_out"], "unsqueeze0") if opset_version == 13 \ - else helper.make_node("Unsqueeze", ["range0_out"], ["unsqueeze0_out"], "unsqueeze0", axes=[0]), + helper.make_node("Unsqueeze", ["range0_out", "axes_0"], ["unsqueeze0_out"], "unsqueeze0") + if opset_version == 13 + else helper.make_node("Unsqueeze", ["range0_out"], ["unsqueeze0_out"], "unsqueeze0", axes=[0]), helper.make_node("Shape", ["input_ids"], ["shape1_out"], "shape1"), helper.make_node("Expand", ["unsqueeze0_out", "shape1_out"], ["expand_out"], "expand"), - helper.make_node("Gather", ["pos_embed", "expand_out"], ["pos_gather_out"], "pos_gather", axis=0), + helper.make_node( + "Gather", + ["pos_embed", "expand_out"], + ["pos_gather_out"], + "pos_gather", + axis=0, + ), helper.make_node("Add", ["word_gather_out", "pos_gather_out"], ["add1_out"], "add1"), - helper.make_node("LayerNormalization", ["add1_out", "layer_norm_weight", "layer_norm_bias"], ["layernorm_out"], - "layernorm", - axis=-1, - epsion=0.000009999999747378752), + helper.make_node( + "LayerNormalization", + ["add1_out", "layer_norm_weight", "layer_norm_bias"], + ["layernorm_out"], + "layernorm", + axis=-1, + epsion=0.000009999999747378752, + ), helper.make_node("Cast", ["input_mask"], ["mask_cast_out"], "mask_cast", to=6), - helper.make_node("ReduceSum", ["mask_cast_out", "axes_1"], ["mask_index_out"], "mask_index", keepdims=0) if opset_version == 13 \ - else helper.make_node("ReduceSum", ["mask_cast_out"], ["mask_index_out"], "mask_index", axes=[1], keepdims=0), - helper.make_node("Attention", ["layernorm_out", "qkv_weights", "qkv_bias", "mask_index_out"], ["att_out"], - "att", - domain="com.microsoft", - num_heads=attention_heads), + helper.make_node( + "ReduceSum", + ["mask_cast_out", "axes_1"], + ["mask_index_out"], + "mask_index", + keepdims=0, + ) + if opset_version == 13 + else helper.make_node( + "ReduceSum", + ["mask_cast_out"], + ["mask_index_out"], + "mask_index", + axes=[1], + keepdims=0, + ), + helper.make_node( + "Attention", + ["layernorm_out", "qkv_weights", "qkv_bias", "mask_index_out"], + ["att_out"], + "att", + domain="com.microsoft", + num_heads=attention_heads, + ), helper.make_node("MatMul", ["att_out", "matmul_weight"], ["matmul_out"], "matmul"), helper.make_node("Add", ["matmul_out", "add_bias"], ["add2_out"], "add2"), - helper.make_node("Add", ["add2_out", "layernorm_out"], ["add3_out"], "add3") + helper.make_node("Add", ["add2_out", "layernorm_out"], ["add3_out"], "add3"), ] return nodes @@ -354,15 +822,20 @@ def GenerateModel7(model_name): graph = helper.make_graph( nodes, - "EmbedLayerNorm_format7", #name + "EmbedLayerNorm_format7", # name [ # inputs - helper.make_tensor_value_info('input_ids', TensorProto.INT64, [batch_size, sequence_length]), - helper.make_tensor_value_info('input_mask', TensorProto.INT64, [batch_size, sequence_length]), + helper.make_tensor_value_info("input_ids", TensorProto.INT64, [batch_size, sequence_length]), + helper.make_tensor_value_info("input_mask", TensorProto.INT64, [batch_size, sequence_length]), ], [ # outputs - helper.make_tensor_value_info('add3_out', TensorProto.FLOAT, [batch_size, sequence_length, hidden_size]), + helper.make_tensor_value_info( + "add3_out", + TensorProto.FLOAT, + [batch_size, sequence_length, hidden_size], + ), ], - initializers) + initializers, + ) model = helper.make_model(graph) onnx.save(model, model_name) @@ -381,7 +854,7 @@ def GenerateModel8(model_name): new_nodes = [ helper.make_node("Shape", ["input_ids"], ["shape_out"], "shape"), helper.make_node("Gather", ["shape_out", "indices_1"], ["gather0_out"], "gather0"), - helper.make_node("Expand", ["unsqueeze0_out", "shape_out"], ["expand_out"], "expand") + helper.make_node("Expand", ["unsqueeze0_out", "shape_out"], ["expand_out"], "expand"), ] nodes = nodes + new_nodes @@ -389,15 +862,20 @@ def GenerateModel8(model_name): graph = helper.make_graph( nodes, - "EmbedLayerNorm_format8", #name + "EmbedLayerNorm_format8", # name [ # inputs - helper.make_tensor_value_info('input_ids', TensorProto.INT64, [batch_size, sequence_length]), - helper.make_tensor_value_info('input_mask', TensorProto.INT64, [batch_size, sequence_length]), + helper.make_tensor_value_info("input_ids", TensorProto.INT64, [batch_size, sequence_length]), + helper.make_tensor_value_info("input_mask", TensorProto.INT64, [batch_size, sequence_length]), ], [ # outputs - helper.make_tensor_value_info('add3_out', TensorProto.FLOAT, [batch_size, sequence_length, hidden_size]), + helper.make_tensor_value_info( + "add3_out", + TensorProto.FLOAT, + [batch_size, sequence_length, hidden_size], + ), ], - initializers) + initializers, + ) model = helper.make_model(graph) onnx.save(model, model_name) @@ -420,14 +898,26 @@ def GenerateModel9(model_name): helper.make_node("Expand", ["unsqueeze0_out", "shape_out"], ["expand_out"], "expand"), helper.make_node("Gather", ["shape_out", "indices_0"], ["gather1_out"], "gather1"), helper.make_node("Gather", ["shape_out", "indices_1"], ["gather2_out"], "gather2"), - helper.make_node("Unsqueeze", ["gather1_out", "axes_0"], ["unsqueeze1_out"], "unsqueeze1") if opset_version == 13 \ - else helper.make_node("Unsqueeze", ["gather1_out"], ["unsqueeze1_out"], "unsqueeze1", axes=[0]), - helper.make_node("Unsqueeze", ["gather2_out", "axes_0"], ["unsqueeze2_out"], "unsqueeze2") if opset_version == 13 \ - else helper.make_node("Unsqueeze", ["gather2_out"], ["unsqueeze2_out"], "unsqueeze2", axes=[0]), - helper.make_node("Concat", ["unsqueeze1_out", "unsqueeze2_out"], ["concat_out"], "concat", axis=0), - helper.make_node('ConstantOfShape', ['concat_out'], ['constant_of_shape_out'], - "constant_of_shape", - value=helper.make_tensor('mask_shape', TensorProto.FLOAT, [1], [1.0])), + helper.make_node("Unsqueeze", ["gather1_out", "axes_0"], ["unsqueeze1_out"], "unsqueeze1") + if opset_version == 13 + else helper.make_node("Unsqueeze", ["gather1_out"], ["unsqueeze1_out"], "unsqueeze1", axes=[0]), + helper.make_node("Unsqueeze", ["gather2_out", "axes_0"], ["unsqueeze2_out"], "unsqueeze2") + if opset_version == 13 + else helper.make_node("Unsqueeze", ["gather2_out"], ["unsqueeze2_out"], "unsqueeze2", axes=[0]), + helper.make_node( + "Concat", + ["unsqueeze1_out", "unsqueeze2_out"], + ["concat_out"], + "concat", + axis=0, + ), + helper.make_node( + "ConstantOfShape", + ["concat_out"], + ["constant_of_shape_out"], + "constant_of_shape", + value=helper.make_tensor("mask_shape", TensorProto.FLOAT, [1], [1.0]), + ), helper.make_node("Cast", ["constant_of_shape_out"], ["mask_cast_out"], "mask_cast", to=6), ] nodes = nodes + new_nodes @@ -436,33 +926,39 @@ def GenerateModel9(model_name): graph = helper.make_graph( nodes, - "EmbedLayerNorm_format9", #name + "EmbedLayerNorm_format9", # name [ # inputs - helper.make_tensor_value_info('input_ids', TensorProto.INT64, [batch_size, sequence_length]), + helper.make_tensor_value_info("input_ids", TensorProto.INT64, [batch_size, sequence_length]), ], [ # outputs - helper.make_tensor_value_info('add3_out', TensorProto.FLOAT, [batch_size, sequence_length, hidden_size]), + helper.make_tensor_value_info( + "add3_out", + TensorProto.FLOAT, + [batch_size, sequence_length, hidden_size], + ), ], - initializers) + initializers, + ) model = helper.make_model(graph) onnx.save(model, model_name) + if opset_version == 11: - GenerateModel3('embed_layer_norm_format3.onnx', True) - GenerateModel3('embed_layer_norm_format3_no_cast.onnx', False) - GenerateModel5('embed_layer_norm_format5.onnx') - GenerateModel6('embed_layer_norm_format6.onnx') - GenerateModel7('embed_layer_norm_format7.onnx') #distilbert - GenerateModel8('embed_layer_norm_format8.onnx') #distilbert & shape nodes integration with input mask - GenerateModel9('embed_layer_norm_format9.onnx') #distilbert & shape nodes integration without input mask - GenerateMultipleEmbedModel('embed_layer_norm_multiple.onnx') + GenerateModel3("embed_layer_norm_format3.onnx", True) + GenerateModel3("embed_layer_norm_format3_no_cast.onnx", False) + GenerateModel5("embed_layer_norm_format5.onnx") + GenerateModel6("embed_layer_norm_format6.onnx") + GenerateModel7("embed_layer_norm_format7.onnx") # distilbert + GenerateModel8("embed_layer_norm_format8.onnx") # distilbert & shape nodes integration with input mask + GenerateModel9("embed_layer_norm_format9.onnx") # distilbert & shape nodes integration without input mask + GenerateMultipleEmbedModel("embed_layer_norm_multiple.onnx") else: - GenerateModel3('embed_layer_norm_format3_opset13.onnx', True) - GenerateModel3('embed_layer_norm_format3_no_cast_opset13.onnx', False) - GenerateModel5('embed_layer_norm_format5_opset13.onnx') - GenerateModel6('embed_layer_norm_format6_opset13.onnx') - GenerateModel7('embed_layer_norm_format7_opset13.onnx') - GenerateModel8('embed_layer_norm_format8_opset13.onnx') - GenerateModel9('embed_layer_norm_format9_opset13.onnx') - GenerateMultipleEmbedModel('embed_layer_norm_multiple_opset13.onnx') \ No newline at end of file + GenerateModel3("embed_layer_norm_format3_opset13.onnx", True) + GenerateModel3("embed_layer_norm_format3_no_cast_opset13.onnx", False) + GenerateModel5("embed_layer_norm_format5_opset13.onnx") + GenerateModel6("embed_layer_norm_format6_opset13.onnx") + GenerateModel7("embed_layer_norm_format7_opset13.onnx") + GenerateModel8("embed_layer_norm_format8_opset13.onnx") + GenerateModel9("embed_layer_norm_format9_opset13.onnx") + GenerateMultipleEmbedModel("embed_layer_norm_multiple_opset13.onnx") diff --git a/onnxruntime/test/testdata/transform/fusion/fast_gelu.py b/onnxruntime/test/testdata/transform/fusion/fast_gelu.py index c7f964bbbeb55..aaaffa4ab398a 100644 --- a/onnxruntime/test/testdata/transform/fusion/fast_gelu.py +++ b/onnxruntime/test/testdata/transform/fusion/fast_gelu.py @@ -1,16 +1,14 @@ -import onnx -from onnx import helper -from onnx import AttributeProto, TensorProto, GraphProto, OperatorSetIdProto -from onnx import numpy_helper import numpy as np +import onnx +from onnx import AttributeProto, GraphProto, OperatorSetIdProto, TensorProto, helper, numpy_helper # Gelu formula: x * 0.5 * (1.0 + tanh(0.7978845608028654 * x * (1.0 + 0.044715 * x * x))) has_bias = True # change it to True to generate fast_gelu_with_bias.onnx gelu_use_graph_input = True # change it to False to let Gelu don't have graph inputs as inputs. -X = helper.make_tensor_value_info('input', TensorProto.FLOAT, ["batch", "seqlen", 64]) -Y = helper.make_tensor_value_info('output', TensorProto.FLOAT, ["batch", "seqlen", 64]) +X = helper.make_tensor_value_info("input", TensorProto.FLOAT, ["batch", "seqlen", 64]) +Y = helper.make_tensor_value_info("output", TensorProto.FLOAT, ["batch", "seqlen", 64]) bias_np_vals = (0.01 * np.arange(64)).astype(np.float32).reshape((64)) bias_initializer = numpy_helper.from_array(bias_np_vals, "input_bias") @@ -33,42 +31,42 @@ nodes = [] gelu_input = "input" if not gelu_use_graph_input: - leading_identity = helper.make_node('Identity', [gelu_input], ['identity_leading'], name="identity_leading") + leading_identity = helper.make_node("Identity", [gelu_input], ["identity_leading"], name="identity_leading") gelu_input = "identity_leading" nodes.append(leading_identity) mul_input_name = gelu_input if has_bias: - add0 = helper.make_node('Add', [gelu_input, bias_initializer.name], ['add0'], name="add0") + add0 = helper.make_node("Add", [gelu_input, bias_initializer.name], ["add0"], name="add0") mul_input_name = "add0" nodes.append(add0) -mul1 = helper.make_node('Mul', [mul_input_name, a_weight_initializer.name], ['mul1'], name="mul1") +mul1 = helper.make_node("Mul", [mul_input_name, a_weight_initializer.name], ["mul1"], name="mul1") nodes.append(mul1) -mul2 = helper.make_node('Mul', [mul_input_name, 'mul1'], ['mul2'], name="mul2") +mul2 = helper.make_node("Mul", [mul_input_name, "mul1"], ["mul2"], name="mul2") nodes.append(mul2) -add1 = helper.make_node('Add', ['mul2', a_bias_initializer.name], ['add1'], name="add1") +add1 = helper.make_node("Add", ["mul2", a_bias_initializer.name], ["add1"], name="add1") nodes.append(add1) -mul3 = helper.make_node('Mul', [mul_input_name, b_weight_initializer.name], ['mul3'], name="mul3") +mul3 = helper.make_node("Mul", [mul_input_name, b_weight_initializer.name], ["mul3"], name="mul3") nodes.append(mul3) -mul4 = helper.make_node('Mul', ['mul3', 'add1'], ['mul4'], name="mul4") +mul4 = helper.make_node("Mul", ["mul3", "add1"], ["mul4"], name="mul4") nodes.append(mul4) -tanh = helper.make_node('Tanh', ['mul4'], ['tanh'], name="tanh") +tanh = helper.make_node("Tanh", ["mul4"], ["tanh"], name="tanh") nodes.append(tanh) -add2 = helper.make_node('Add', ['tanh', b_bias_initializer.name], ['add2'], name="add2") +add2 = helper.make_node("Add", ["tanh", b_bias_initializer.name], ["add2"], name="add2") nodes.append(add2) -mul5 = helper.make_node('Mul', [mul_input_name, c_weight_initializer.name], ['mul5'], name="mul5") +mul5 = helper.make_node("Mul", [mul_input_name, c_weight_initializer.name], ["mul5"], name="mul5") nodes.append(mul5) -mul6 = helper.make_node('Mul', ['mul5', 'add2'], ['mul6'], name="mul6") -ending_identity = helper.make_node('Identity', ['mul6'], ['output'], name="identity_ending") +mul6 = helper.make_node("Mul", ["mul5", "add2"], ["mul6"], name="mul6") +ending_identity = helper.make_node("Identity", ["mul6"], ["output"], name="identity_ending") nodes.extend([mul6, ending_identity]) initializers = [] @@ -76,9 +74,16 @@ initializers = [bias_initializer] initializers.extend( - [a_weight_initializer, a_bias_initializer, b_weight_initializer, b_bias_initializer, c_weight_initializer]) + [ + a_weight_initializer, + a_bias_initializer, + b_weight_initializer, + b_bias_initializer, + c_weight_initializer, + ] +) # Create the graph (GraphProto) -graph_def = helper.make_graph(nodes, 'test-model', [X], [Y], initializers) +graph_def = helper.make_graph(nodes, "test-model", [X], [Y], initializers) opsets = [] onnxdomain = OperatorSetIdProto() @@ -94,7 +99,7 @@ kwargs = {} kwargs["opset_imports"] = opsets -model_def = helper.make_model(graph_def, producer_name='onnx-example', **kwargs) +model_def = helper.make_model(graph_def, producer_name="onnx-example", **kwargs) file_name = "fast_gelu" if has_bias: diff --git a/onnxruntime/test/testdata/transform/fusion/fast_gelu2.py b/onnxruntime/test/testdata/transform/fusion/fast_gelu2.py index 0f9490dbc457f..5ff752afa7e6a 100644 --- a/onnxruntime/test/testdata/transform/fusion/fast_gelu2.py +++ b/onnxruntime/test/testdata/transform/fusion/fast_gelu2.py @@ -1,15 +1,13 @@ -import onnx -from onnx import helper -from onnx import AttributeProto, TensorProto, GraphProto, OperatorSetIdProto -from onnx import numpy_helper import numpy as np +import onnx +from onnx import AttributeProto, GraphProto, OperatorSetIdProto, TensorProto, helper, numpy_helper # Gelu formula: x * 0.5 * (1.0 + tanh((sqrt(2 / pi) * (x + 0.044715 * pow(x, 3))))) has_bias = False # change it to True to generate fast_gelu_openai_with_bias.onnx gelu_use_graph_input = True # change it to False to let Gelu don't have graph inputs/outputs as inputs/outputs. -X = helper.make_tensor_value_info('input', TensorProto.FLOAT, ["batch", "seqlen", 64]) -Y = helper.make_tensor_value_info('output', TensorProto.FLOAT, ["batch", "seqlen", 64]) +X = helper.make_tensor_value_info("input", TensorProto.FLOAT, ["batch", "seqlen", 64]) +Y = helper.make_tensor_value_info("output", TensorProto.FLOAT, ["batch", "seqlen", 64]) bias_np_vals = (0.01 * np.arange(64)).astype(np.float32).reshape((64)) bias_initializer = numpy_helper.from_array(bias_np_vals, "input_bias") @@ -32,39 +30,39 @@ nodes = [] gelu_input = "input" if not gelu_use_graph_input: - leading_identity = helper.make_node('Identity', [gelu_input], ['identity_leading'], name="identity_leading") + leading_identity = helper.make_node("Identity", [gelu_input], ["identity_leading"], name="identity_leading") gelu_input = "identity_leading" nodes.append(leading_identity) mul_input_name = gelu_input if has_bias: - add0 = helper.make_node('Add', [gelu_input, bias_initializer.name], ['add0'], name="add0") + add0 = helper.make_node("Add", [gelu_input, bias_initializer.name], ["add0"], name="add0") mul_input_name = "add0" nodes.append(add0) -pow1 = helper.make_node('Pow', [mul_input_name, pow_initializer.name], ['pow1'], name="pow1") +pow1 = helper.make_node("Pow", [mul_input_name, pow_initializer.name], ["pow1"], name="pow1") nodes.append(pow1) -mul1 = helper.make_node('Mul', ['pow1', a_weight_initializer.name], ['mul1'], name="mul1") +mul1 = helper.make_node("Mul", ["pow1", a_weight_initializer.name], ["mul1"], name="mul1") nodes.append(mul1) -add1 = helper.make_node('Add', [mul_input_name, "mul1"], ['add1'], name="add1") +add1 = helper.make_node("Add", [mul_input_name, "mul1"], ["add1"], name="add1") nodes.append(add1) -mul2 = helper.make_node('Mul', ['add1', b_weight_initializer.name], ['mul2'], name="mul2") +mul2 = helper.make_node("Mul", ["add1", b_weight_initializer.name], ["mul2"], name="mul2") nodes.append(mul2) -tanh = helper.make_node('Tanh', ['mul2'], ['tanh'], name="tanh") +tanh = helper.make_node("Tanh", ["mul2"], ["tanh"], name="tanh") nodes.append(tanh) -add2 = helper.make_node('Add', ['tanh', b_bias_initializer.name], ['add2'], name="add2") +add2 = helper.make_node("Add", ["tanh", b_bias_initializer.name], ["add2"], name="add2") nodes.append(add2) -mul5 = helper.make_node('Mul', [mul_input_name, c_weight_initializer.name], ['mul5'], name="mul5") +mul5 = helper.make_node("Mul", [mul_input_name, c_weight_initializer.name], ["mul5"], name="mul5") nodes.append(mul5) -mul6 = helper.make_node('Mul', ['mul5', 'add2'], ['mul6'], name="mul6") -ending_identity = helper.make_node('Identity', ['mul6'], ['output'], name="ending_identity") +mul6 = helper.make_node("Mul", ["mul5", "add2"], ["mul6"], name="mul6") +ending_identity = helper.make_node("Identity", ["mul6"], ["output"], name="ending_identity") nodes.extend([mul6, ending_identity]) initializers = [] @@ -72,9 +70,16 @@ initializers = [bias_initializer] initializers.extend( - [pow_initializer, a_weight_initializer, b_weight_initializer, b_bias_initializer, c_weight_initializer]) + [ + pow_initializer, + a_weight_initializer, + b_weight_initializer, + b_bias_initializer, + c_weight_initializer, + ] +) # Create the graph (GraphProto) -graph_def = helper.make_graph(nodes, 'test-model', [X], [Y], initializers) +graph_def = helper.make_graph(nodes, "test-model", [X], [Y], initializers) opsets = [] onnxdomain = OperatorSetIdProto() @@ -90,7 +95,7 @@ kwargs = {} kwargs["opset_imports"] = opsets -model_def = helper.make_model(graph_def, producer_name='onnx-example', **kwargs) +model_def = helper.make_model(graph_def, producer_name="onnx-example", **kwargs) file_name = "fast_gelu2" if has_bias: diff --git a/onnxruntime/test/testdata/transform/fusion/fast_gelu3_with_casts.py b/onnxruntime/test/testdata/transform/fusion/fast_gelu3_with_casts.py index 30cd332385998..5220751a3e364 100644 --- a/onnxruntime/test/testdata/transform/fusion/fast_gelu3_with_casts.py +++ b/onnxruntime/test/testdata/transform/fusion/fast_gelu3_with_casts.py @@ -1,13 +1,11 @@ -import onnx -from onnx import helper -from onnx import AttributeProto, TensorProto, GraphProto, OperatorSetIdProto -from onnx import numpy_helper import numpy as np +import onnx +from onnx import AttributeProto, GraphProto, OperatorSetIdProto, TensorProto, helper, numpy_helper # Gelu formula: x * 0.5 * (1.0 + tanh((sqrt(2 / pi) * (x + 0.044715 * pow(x, 3))))) -X = helper.make_tensor_value_info('input', TensorProto.FLOAT, ["batch", "seqlen", 64]) -Y = helper.make_tensor_value_info('output', TensorProto.FLOAT, ["batch", "seqlen", 64]) +X = helper.make_tensor_value_info("input", TensorProto.FLOAT, ["batch", "seqlen", 64]) +Y = helper.make_tensor_value_info("output", TensorProto.FLOAT, ["batch", "seqlen", 64]) pow_np_vals = np.asarray([3]).astype(np.float32).reshape(()) pow_initializer = numpy_helper.from_array(pow_np_vals, "pow_init") @@ -26,52 +24,59 @@ nodes = [] gelu_input = "input" -leading_identity = helper.make_node('Identity', [gelu_input], ['identity_leading'], name="identity_leading") +leading_identity = helper.make_node("Identity", [gelu_input], ["identity_leading"], name="identity_leading") gelu_input = "identity_leading" nodes.append(leading_identity) mul_input_name = gelu_input -cast1 = helper.make_node('Cast', [mul_input_name], ['cast1'], name='cast1', to=1) +cast1 = helper.make_node("Cast", [mul_input_name], ["cast1"], name="cast1", to=1) nodes.append(cast1) -pow1 = helper.make_node('Pow', ['cast1', pow_initializer.name], ['pow1'], name="pow1") +pow1 = helper.make_node("Pow", ["cast1", pow_initializer.name], ["pow1"], name="pow1") nodes.append(pow1) -mul1 = helper.make_node('Mul', ['pow1', a_weight_initializer.name], ['mul1'], name="mul1") +mul1 = helper.make_node("Mul", ["pow1", a_weight_initializer.name], ["mul1"], name="mul1") nodes.append(mul1) -cast2 = helper.make_node('Cast', [mul_input_name], ['cast2'], name='cast2', to=1) +cast2 = helper.make_node("Cast", [mul_input_name], ["cast2"], name="cast2", to=1) nodes.append(cast2) -add1 = helper.make_node('Add', ['mul1', 'cast2'], ['add1'], name="add1") +add1 = helper.make_node("Add", ["mul1", "cast2"], ["add1"], name="add1") nodes.append(add1) -mul2 = helper.make_node('Mul', ['add1', b_weight_initializer.name], ['mul2'], name="mul2") +mul2 = helper.make_node("Mul", ["add1", b_weight_initializer.name], ["mul2"], name="mul2") nodes.append(mul2) -tanh = helper.make_node('Tanh', ['mul2'], ['tanh'], name="tanh") +tanh = helper.make_node("Tanh", ["mul2"], ["tanh"], name="tanh") nodes.append(tanh) -add2 = helper.make_node('Add', ['tanh', b_bias_initializer.name], ['add2'], name="add2") +add2 = helper.make_node("Add", ["tanh", b_bias_initializer.name], ["add2"], name="add2") nodes.append(add2) -mul5 = helper.make_node('Mul', [mul_input_name, c_weight_initializer.name], ['mul5'], name="mul5") +mul5 = helper.make_node("Mul", [mul_input_name, c_weight_initializer.name], ["mul5"], name="mul5") nodes.append(mul5) -cast3 = helper.make_node('Cast', ['mul5'], ['cast3'], name='cast3', to=1) +cast3 = helper.make_node("Cast", ["mul5"], ["cast3"], name="cast3", to=1) nodes.append(cast3) -mul6 = helper.make_node('Mul', ['cast3', 'add2'], ['mul6'], name="mul6") -ending_identity = helper.make_node('Identity', ['mul6'], ['output'], name="ending_identity") +mul6 = helper.make_node("Mul", ["cast3", "add2"], ["mul6"], name="mul6") +ending_identity = helper.make_node("Identity", ["mul6"], ["output"], name="ending_identity") nodes.extend([mul6, ending_identity]) initializers = [] initializers.extend( - [pow_initializer, a_weight_initializer, b_weight_initializer, b_bias_initializer, c_weight_initializer]) + [ + pow_initializer, + a_weight_initializer, + b_weight_initializer, + b_bias_initializer, + c_weight_initializer, + ] +) # Create the graph (GraphProto) -graph_def = helper.make_graph(nodes, 'test-model', [X], [Y], initializers) +graph_def = helper.make_graph(nodes, "test-model", [X], [Y], initializers) opsets = [] onnxdomain = OperatorSetIdProto() @@ -87,6 +92,6 @@ kwargs = {} kwargs["opset_imports"] = opsets -model_def = helper.make_model(graph_def, producer_name='onnx-example', **kwargs) +model_def = helper.make_model(graph_def, producer_name="onnx-example", **kwargs) onnx.save(model_def, "fast_gelu3_with_casts.onnx") diff --git a/onnxruntime/test/testdata/transform/fusion/gelu_gen.py b/onnxruntime/test/testdata/transform/fusion/gelu_gen.py index b0e905133b37f..45f546a04635e 100644 --- a/onnxruntime/test/testdata/transform/fusion/gelu_gen.py +++ b/onnxruntime/test/testdata/transform/fusion/gelu_gen.py @@ -1,8 +1,7 @@ -import onnx -from onnx import helper -from onnx import AttributeProto, TensorProto, GraphProto, OperatorSetIdProto -from onnx import numpy_helper import numpy as np +import onnx +from onnx import AttributeProto, GraphProto, OperatorSetIdProto, TensorProto, helper, numpy_helper + """ Generate test model for Gelu subgraph pattern 2: +------------------------------------+ @@ -14,12 +13,12 @@ has_bias = True # change it to True to generate gelu_format2_*_with_bias.onnx gelu_use_graph_input = False # change it to False to let Gelu don't have graph inputs as inputs. -node_has_graph_output = True # change it to False to let Gelu don't have graph output +node_has_graph_output = True # change it to False to let Gelu don't have graph output switch_order = True # switch order of inputs for Mul and Add -X = helper.make_tensor_value_info('input', TensorProto.FLOAT, ["batch", "seqlen", 64]) -Y = helper.make_tensor_value_info('output', TensorProto.FLOAT, ["batch", "seqlen", 64]) -Z = helper.make_tensor_value_info('div', TensorProto.FLOAT, ["batch", "seqlen", 64]) +X = helper.make_tensor_value_info("input", TensorProto.FLOAT, ["batch", "seqlen", 64]) +Y = helper.make_tensor_value_info("output", TensorProto.FLOAT, ["batch", "seqlen", 64]) +Z = helper.make_tensor_value_info("div", TensorProto.FLOAT, ["batch", "seqlen", 64]) value = (0.01 * np.arange(64)).astype(np.float32).reshape((64)) bias_initializer = numpy_helper.from_array(value, "input_bias") @@ -36,37 +35,52 @@ nodes = [] gelu_input = "input" if not gelu_use_graph_input: - leading_identity = helper.make_node('Identity', [gelu_input], ['identity_leading'], name="identity_leading") + leading_identity = helper.make_node("Identity", [gelu_input], ["identity_leading"], name="identity_leading") gelu_input = "identity_leading" nodes.append(leading_identity) gelu_root = gelu_input if has_bias: add0 = helper.make_node( - 'Add', [gelu_input, bias_initializer.name] if switch_order else [bias_initializer.name, gelu_input], ['add0'], - name="add0_node") + "Add", + [gelu_input, bias_initializer.name] if switch_order else [bias_initializer.name, gelu_input], + ["add0"], + name="add0_node", + ) gelu_root = "add0" nodes.append(add0) -div = helper.make_node('Div', [gelu_root, initializer_sqrt_2.name], ['div'], name="div_node") +div = helper.make_node("Div", [gelu_root, initializer_sqrt_2.name], ["div"], name="div_node") nodes.append(div) -erf = helper.make_node('Erf', ['div'], ['erf'], name="erf_node") +erf = helper.make_node("Erf", ["div"], ["erf"], name="erf_node") nodes.append(erf) -add1 = helper.make_node('Add', ['erf', initializer_1.name] if switch_order else [initializer_1.name, 'erf'], ['add1'], - name="add1") +add1 = helper.make_node( + "Add", + ["erf", initializer_1.name] if switch_order else [initializer_1.name, "erf"], + ["add1"], + name="add1", +) nodes.append(add1) -mul = helper.make_node('Mul', [gelu_root, 'add1'] if switch_order else ['add1', gelu_root], ['mul'], name="mul_node") +mul = helper.make_node( + "Mul", + [gelu_root, "add1"] if switch_order else ["add1", gelu_root], + ["mul"], + name="mul_node", +) nodes.append(mul) -mul2 = helper.make_node('Mul', ['mul', initializer_0_5.name] if switch_order else [initializer_0_5.name, 'mul'], - ['mul2'], - name="mul2_node") +mul2 = helper.make_node( + "Mul", + ["mul", initializer_0_5.name] if switch_order else [initializer_0_5.name, "mul"], + ["mul2"], + name="mul2_node", +) nodes.append(mul2) -ending_identity = helper.make_node('Identity', ['mul2'], ['output'], name="identity_ending") +ending_identity = helper.make_node("Identity", ["mul2"], ["output"], name="identity_ending") nodes.append(ending_identity) initializers = [] @@ -76,7 +90,7 @@ initializers.extend([initializer_sqrt_2, initializer_1, initializer_0_5]) # Create the graph (GraphProto) -graph_def = helper.make_graph(nodes, 'gelu_pattern_2', [X], [Y, Z] if node_has_graph_output else [Y], initializers) +graph_def = helper.make_graph(nodes, "gelu_pattern_2", [X], [Y, Z] if node_has_graph_output else [Y], initializers) opsets = [] onnxdomain = OperatorSetIdProto() @@ -92,7 +106,7 @@ kwargs = {} kwargs["opset_imports"] = opsets -model_def = helper.make_model(graph_def, producer_name='onnx-example', **kwargs) +model_def = helper.make_model(graph_def, producer_name="onnx-example", **kwargs) file_name = "gelu_format2_0" if switch_order else "gelu_format2_1" if has_bias: diff --git a/onnxruntime/test/testdata/transform/fusion/gemm_sum_gen.py b/onnxruntime/test/testdata/transform/fusion/gemm_sum_gen.py index 7ae435e99a42b..59be81c557f31 100644 --- a/onnxruntime/test/testdata/transform/fusion/gemm_sum_gen.py +++ b/onnxruntime/test/testdata/transform/fusion/gemm_sum_gen.py @@ -1,7 +1,5 @@ import onnx -from onnx import helper -from onnx import TensorProto -from onnx import OperatorSetIdProto +from onnx import OperatorSetIdProto, TensorProto, helper onnxdomain = OperatorSetIdProto() onnxdomain.version = 12 @@ -14,17 +12,14 @@ def save(model_path, nodes, inputs, outputs, initializers): - graph = helper.make_graph( - nodes, - "GemmSumTest", - inputs, outputs, initializers) + graph = helper.make_graph(nodes, "GemmSumTest", inputs, outputs, initializers) - model = helper.make_model( - graph, opset_imports=opsets, producer_name="onnxruntime-test") + model = helper.make_model(graph, opset_imports=opsets, producer_name="onnxruntime-test") print(model_path) onnx.save(model, model_path) + def gen_gemm_sum_basic(model_path): nodes = [ helper.make_node(op_type="Gemm", inputs=["A", "B"], outputs=["tp0"]), @@ -32,35 +27,41 @@ def gen_gemm_sum_basic(model_path): ] inputs = [ - helper.make_tensor_value_info("A", TensorProto.FLOAT, ['M', 'K']), - helper.make_tensor_value_info("B", TensorProto.FLOAT, ['K', 'N']), - helper.make_tensor_value_info("C", TensorProto.FLOAT, ['M', 'N']) + helper.make_tensor_value_info("A", TensorProto.FLOAT, ["M", "K"]), + helper.make_tensor_value_info("B", TensorProto.FLOAT, ["K", "N"]), + helper.make_tensor_value_info("C", TensorProto.FLOAT, ["M", "N"]), ] - outputs = [ - helper.make_tensor_value_info("output", TensorProto.FLOAT, ['M', 'N']) - ] + outputs = [helper.make_tensor_value_info("output", TensorProto.FLOAT, ["M", "N"])] save(model_path, nodes, inputs, outputs, initializers=[]) + def gen_gemm_sum_attributes(model_path): nodes = [ - helper.make_node(op_type="Gemm", inputs=["A", "B"], outputs=["tp0"], alpha=3.5, beta=6.25, transA=False, transB=True), + helper.make_node( + op_type="Gemm", + inputs=["A", "B"], + outputs=["tp0"], + alpha=3.5, + beta=6.25, + transA=False, + transB=True, + ), helper.make_node(op_type="Sum", inputs=["tp0", "C"], outputs=["output"]), ] inputs = [ - helper.make_tensor_value_info("A", TensorProto.FLOAT, ['M', 'K']), - helper.make_tensor_value_info("B", TensorProto.FLOAT, ['N', 'K']), - helper.make_tensor_value_info("C", TensorProto.FLOAT, ['M', 'N']) + helper.make_tensor_value_info("A", TensorProto.FLOAT, ["M", "K"]), + helper.make_tensor_value_info("B", TensorProto.FLOAT, ["N", "K"]), + helper.make_tensor_value_info("C", TensorProto.FLOAT, ["M", "N"]), ] - outputs = [ - helper.make_tensor_value_info("output", TensorProto.FLOAT, ['M', 'N']) - ] + outputs = [helper.make_tensor_value_info("output", TensorProto.FLOAT, ["M", "N"])] save(model_path, nodes, inputs, outputs, initializers=[]) + def gen_gemm_sum_internal_nodes(model_path): nodes = [ helper.make_node(op_type="Identity", inputs=["A"], outputs=["tp0"]), @@ -72,17 +73,16 @@ def gen_gemm_sum_internal_nodes(model_path): ] inputs = [ - helper.make_tensor_value_info("A", TensorProto.FLOAT, ['M', 'K']), - helper.make_tensor_value_info("B", TensorProto.FLOAT, ['K', 'N']), - helper.make_tensor_value_info("C", TensorProto.FLOAT, ['M', 'N']) + helper.make_tensor_value_info("A", TensorProto.FLOAT, ["M", "K"]), + helper.make_tensor_value_info("B", TensorProto.FLOAT, ["K", "N"]), + helper.make_tensor_value_info("C", TensorProto.FLOAT, ["M", "N"]), ] - outputs = [ - helper.make_tensor_value_info("output", TensorProto.FLOAT, ['M', 'N']) - ] + outputs = [helper.make_tensor_value_info("output", TensorProto.FLOAT, ["M", "N"])] save(model_path, nodes, inputs, outputs, initializers=[]) + def gen_gemm_sum_no_fusion_c_used(model_path): nodes = [ helper.make_node(op_type="Gemm", inputs=["A", "B", "C"], outputs=["tp0"]), @@ -90,18 +90,17 @@ def gen_gemm_sum_no_fusion_c_used(model_path): ] inputs = [ - helper.make_tensor_value_info("A", TensorProto.FLOAT, ['M', 'K']), - helper.make_tensor_value_info("B", TensorProto.FLOAT, ['K', 'N']), - helper.make_tensor_value_info("C", TensorProto.FLOAT, ['M', 'N']), - helper.make_tensor_value_info("D", TensorProto.FLOAT, ['M', 'N']), + helper.make_tensor_value_info("A", TensorProto.FLOAT, ["M", "K"]), + helper.make_tensor_value_info("B", TensorProto.FLOAT, ["K", "N"]), + helper.make_tensor_value_info("C", TensorProto.FLOAT, ["M", "N"]), + helper.make_tensor_value_info("D", TensorProto.FLOAT, ["M", "N"]), ] - outputs = [ - helper.make_tensor_value_info("output", TensorProto.FLOAT, ['M', 'N']) - ] + outputs = [helper.make_tensor_value_info("output", TensorProto.FLOAT, ["M", "N"])] save(model_path, nodes, inputs, outputs, initializers=[]) + def gen_gemm_sum_no_fusion_sum_multiple_inputs(model_path): nodes = [ helper.make_node(op_type="Gemm", inputs=["A", "B"], outputs=["tp0"]), @@ -109,18 +108,17 @@ def gen_gemm_sum_no_fusion_sum_multiple_inputs(model_path): ] inputs = [ - helper.make_tensor_value_info("A", TensorProto.FLOAT, ['M', 'K']), - helper.make_tensor_value_info("B", TensorProto.FLOAT, ['K', 'N']), - helper.make_tensor_value_info("C", TensorProto.FLOAT, ['M', 'N']), - helper.make_tensor_value_info("D", TensorProto.FLOAT, ['M', 'N']), + helper.make_tensor_value_info("A", TensorProto.FLOAT, ["M", "K"]), + helper.make_tensor_value_info("B", TensorProto.FLOAT, ["K", "N"]), + helper.make_tensor_value_info("C", TensorProto.FLOAT, ["M", "N"]), + helper.make_tensor_value_info("D", TensorProto.FLOAT, ["M", "N"]), ] - outputs = [ - helper.make_tensor_value_info("output", TensorProto.FLOAT, ['M', 'N']) - ] + outputs = [helper.make_tensor_value_info("output", TensorProto.FLOAT, ["M", "N"])] save(model_path, nodes, inputs, outputs, initializers=[]) + def gen_gemm_sum_fusion_broadcast(model_path): nodes = [ helper.make_node(op_type="Gemm", inputs=["A", "B"], outputs=["tp0"]), @@ -128,17 +126,16 @@ def gen_gemm_sum_fusion_broadcast(model_path): ] inputs = [ - helper.make_tensor_value_info("A", TensorProto.FLOAT, ['M', 'K']), - helper.make_tensor_value_info("B", TensorProto.FLOAT, ['K', 'N']), - helper.make_tensor_value_info("C", TensorProto.FLOAT, [1, 'N']), + helper.make_tensor_value_info("A", TensorProto.FLOAT, ["M", "K"]), + helper.make_tensor_value_info("B", TensorProto.FLOAT, ["K", "N"]), + helper.make_tensor_value_info("C", TensorProto.FLOAT, [1, "N"]), ] - outputs = [ - helper.make_tensor_value_info("output", TensorProto.FLOAT, ['M', 'N']) - ] + outputs = [helper.make_tensor_value_info("output", TensorProto.FLOAT, ["M", "N"])] save(model_path, nodes, inputs, outputs, initializers=[]) + def gen_gemm_sum_no_fusion_broadcast_failure(model_path): nodes = [ helper.make_node(op_type="Gemm", inputs=["A", "B"], outputs=["tp0"]), @@ -146,18 +143,17 @@ def gen_gemm_sum_no_fusion_broadcast_failure(model_path): ] inputs = [ - helper.make_tensor_value_info("A", TensorProto.FLOAT, ['M', 'K']), - helper.make_tensor_value_info("B", TensorProto.FLOAT, ['K', 'N']), + helper.make_tensor_value_info("A", TensorProto.FLOAT, ["M", "K"]), + helper.make_tensor_value_info("B", TensorProto.FLOAT, ["K", "N"]), # should work with multidirectional broadcast as second argument, but not unidirectional. - helper.make_tensor_value_info("C", TensorProto.FLOAT, [1, 'M', 'N']), + helper.make_tensor_value_info("C", TensorProto.FLOAT, [1, "M", "N"]), ] - outputs = [ - helper.make_tensor_value_info("output", TensorProto.FLOAT, [1, 'M', 'N']) - ] + outputs = [helper.make_tensor_value_info("output", TensorProto.FLOAT, [1, "M", "N"])] save(model_path, nodes, inputs, outputs, initializers=[]) + def gen_gemm_sum_no_fusion_original_gemm_output_used(model_path): nodes = [ helper.make_node(op_type="Gemm", inputs=["A", "B"], outputs=["tp0"]), @@ -165,18 +161,19 @@ def gen_gemm_sum_no_fusion_original_gemm_output_used(model_path): ] inputs = [ - helper.make_tensor_value_info("A", TensorProto.FLOAT, ['M', 'K']), - helper.make_tensor_value_info("B", TensorProto.FLOAT, ['K', 'N']), - helper.make_tensor_value_info("C", TensorProto.FLOAT, ['M', 'N']), + helper.make_tensor_value_info("A", TensorProto.FLOAT, ["M", "K"]), + helper.make_tensor_value_info("B", TensorProto.FLOAT, ["K", "N"]), + helper.make_tensor_value_info("C", TensorProto.FLOAT, ["M", "N"]), ] outputs = [ - helper.make_tensor_value_info("tp0", TensorProto.FLOAT, ['M', 'N']), - helper.make_tensor_value_info("output", TensorProto.FLOAT, ['M', 'N']) + helper.make_tensor_value_info("tp0", TensorProto.FLOAT, ["M", "N"]), + helper.make_tensor_value_info("output", TensorProto.FLOAT, ["M", "N"]), ] save(model_path, nodes, inputs, outputs, initializers=[]) + gen_gemm_sum_basic("gemm_sum_basic.onnx") gen_gemm_sum_attributes("gemm_sum_attributes.onnx") gen_gemm_sum_internal_nodes("gemm_sum_internal_nodes.onnx") @@ -184,4 +181,4 @@ def gen_gemm_sum_no_fusion_original_gemm_output_used(model_path): gen_gemm_sum_no_fusion_sum_multiple_inputs("gemm_sum_no_fusion_sum_multiple_inputs.onnx") gen_gemm_sum_fusion_broadcast("gemm_sum_fusion_broadcast.onnx") gen_gemm_sum_no_fusion_broadcast_failure("gemm_sum_no_fusion_broadcast_failure.onnx") -gen_gemm_sum_no_fusion_original_gemm_output_used("gemm_sum_no_fusion_original_gemm_output_used.onnx") \ No newline at end of file +gen_gemm_sum_no_fusion_original_gemm_output_used("gemm_sum_no_fusion_original_gemm_output_used.onnx") diff --git a/onnxruntime/test/testdata/transform/fusion/gemm_transpose_gen.py b/onnxruntime/test/testdata/transform/fusion/gemm_transpose_gen.py index 9e86a4982363c..276330c2064a2 100644 --- a/onnxruntime/test/testdata/transform/fusion/gemm_transpose_gen.py +++ b/onnxruntime/test/testdata/transform/fusion/gemm_transpose_gen.py @@ -1,7 +1,5 @@ import onnx -from onnx import helper -from onnx import TensorProto -from onnx import OperatorSetIdProto +from onnx import OperatorSetIdProto, TensorProto, helper onnxdomain = OperatorSetIdProto() onnxdomain.version = 12 @@ -14,35 +12,31 @@ def save(model_path, nodes, inputs, outputs, initializers): - graph = helper.make_graph( - nodes, - "TransposeGemmTest", - inputs, outputs, initializers) + graph = helper.make_graph(nodes, "TransposeGemmTest", inputs, outputs, initializers) - model = helper.make_model( - graph, opset_imports=opsets, producer_name="onnxruntime-test") + model = helper.make_model(graph, opset_imports=opsets, producer_name="onnxruntime-test") onnx.save(model, model_path) + # (A')'B' = AB' def gen_gemm_2inputs_transposed(model_path): nodes = [ helper.make_node("Transpose", ["A"], ["tp0"], "TransposeA"), - helper.make_node("Transpose", ["B"], ["tp1"], "TransposeB"), - helper.make_node("Gemm", ["tp0", "tp1"], ["output"], "Gemm", alpha=3.0, transA=1) + helper.make_node("Transpose", ["B"], ["tp1"], "TransposeB"), + helper.make_node("Gemm", ["tp0", "tp1"], ["output"], "Gemm", alpha=3.0, transA=1), ] inputs = [ - helper.make_tensor_value_info("A", TensorProto.FLOAT, ['M', 'K']), - helper.make_tensor_value_info("B", TensorProto.FLOAT, ['N', 'K']) + helper.make_tensor_value_info("A", TensorProto.FLOAT, ["M", "K"]), + helper.make_tensor_value_info("B", TensorProto.FLOAT, ["N", "K"]), ] - outputs = [ - helper.make_tensor_value_info("output", TensorProto.FLOAT, ['M', 'N']) - ] + outputs = [helper.make_tensor_value_info("output", TensorProto.FLOAT, ["M", "N"])] save(model_path, nodes, inputs, outputs, []) + # (A'B)' = B'A def gen_gemm_output_transposed(model_path): nodes = [ @@ -51,36 +45,34 @@ def gen_gemm_output_transposed(model_path): ] inputs = [ - helper.make_tensor_value_info("A", TensorProto.FLOAT, ['K', 'M']), - helper.make_tensor_value_info("B", TensorProto.FLOAT, ['K', 'N']) + helper.make_tensor_value_info("A", TensorProto.FLOAT, ["K", "M"]), + helper.make_tensor_value_info("B", TensorProto.FLOAT, ["K", "N"]), ] - outputs = [ - helper.make_tensor_value_info("output", TensorProto.FLOAT, ['N', 'M']) - ] + outputs = [helper.make_tensor_value_info("output", TensorProto.FLOAT, ["N", "M"])] save(model_path, nodes, inputs, outputs, []) + # ((A')'B')' = BA' def gen_gemm_inputs_output_transposed(model_path): nodes = [ helper.make_node("Transpose", ["A"], ["tp0"], "TransposeA"), - helper.make_node("Transpose", ["B"], ["tp1"], "TransposeB"), + helper.make_node("Transpose", ["B"], ["tp1"], "TransposeB"), helper.make_node("Gemm", ["tp0", "tp1"], ["out"], "Gemm", alpha=3.0, transA=1), helper.make_node("Transpose", ["out"], ["output"], "TransposeOut"), ] inputs = [ - helper.make_tensor_value_info("A", TensorProto.FLOAT, ['M', 'K']), - helper.make_tensor_value_info("B", TensorProto.FLOAT, ['N', 'K']) + helper.make_tensor_value_info("A", TensorProto.FLOAT, ["M", "K"]), + helper.make_tensor_value_info("B", TensorProto.FLOAT, ["N", "K"]), ] - outputs = [ - helper.make_tensor_value_info("output", TensorProto.FLOAT, ['N', 'M']) - ] + outputs = [helper.make_tensor_value_info("output", TensorProto.FLOAT, ["N", "M"])] save(model_path, nodes, inputs, outputs, []) + gen_gemm_2inputs_transposed("gemm_transpose_2inputs_transposed.onnx") gen_gemm_output_transposed("gemm_transpose_output_transposed.onnx") gen_gemm_inputs_output_transposed("gemm_transpose_inputs_output_transposed.onnx") @@ -94,14 +86,13 @@ def gen_gemm_inputs_output_transposed_2(model_path): ] inputs = [ - helper.make_tensor_value_info("A", TensorProto.FLOAT, ['K', 'M']), - helper.make_tensor_value_info("B", TensorProto.FLOAT, ['N', 'K']) + helper.make_tensor_value_info("A", TensorProto.FLOAT, ["K", "M"]), + helper.make_tensor_value_info("B", TensorProto.FLOAT, ["N", "K"]), ] - outputs = [ - helper.make_tensor_value_info("output", TensorProto.FLOAT, ['N', 'M']) - ] + outputs = [helper.make_tensor_value_info("output", TensorProto.FLOAT, ["N", "M"])] save(model_path, nodes, inputs, outputs, []) + gen_gemm_inputs_output_transposed_2("gemm_transpose_inputs_output_transposed_2.onnx") diff --git a/onnxruntime/test/testdata/transform/fusion/gemm_transpose_gen_2.py b/onnxruntime/test/testdata/transform/fusion/gemm_transpose_gen_2.py index 0e1215376c335..7081b0a03d615 100644 --- a/onnxruntime/test/testdata/transform/fusion/gemm_transpose_gen_2.py +++ b/onnxruntime/test/testdata/transform/fusion/gemm_transpose_gen_2.py @@ -1,7 +1,5 @@ import onnx -from onnx import helper -from onnx import TensorProto -from onnx import OperatorSetIdProto +from onnx import OperatorSetIdProto, TensorProto, helper onnxdomain = OperatorSetIdProto() onnxdomain.version = 12 @@ -14,33 +12,30 @@ def save(model_path, nodes, inputs, outputs, initializers): - graph = helper.make_graph( - nodes, - "TransposeGemmTest", - inputs, outputs, initializers) + graph = helper.make_graph(nodes, "TransposeGemmTest", inputs, outputs, initializers) - model = helper.make_model( - graph, opset_imports=opsets, producer_name="onnxruntime-test") + model = helper.make_model(graph, opset_imports=opsets, producer_name="onnxruntime-test") onnx.save(model, model_path) + # (A')'B' = AB' def gemm_transpose_2outputs_from_transpose(model_path): nodes = [ helper.make_node("Transpose", ["A"], ["tp0"], "TransposeA"), - helper.make_node("Transpose", ["B"], ["tp1"], "TransposeB"), + helper.make_node("Transpose", ["B"], ["tp1"], "TransposeB"), helper.make_node("Gemm", ["tp0", "tp1"], ["output"], "Gemm", alpha=3.0, transA=1), helper.make_node("Identity", ["tp0"], ["output2"], "IdentityAt"), ] inputs = [ - helper.make_tensor_value_info("A", TensorProto.FLOAT, ['M', 'K']), - helper.make_tensor_value_info("B", TensorProto.FLOAT, ['N', 'K']) + helper.make_tensor_value_info("A", TensorProto.FLOAT, ["M", "K"]), + helper.make_tensor_value_info("B", TensorProto.FLOAT, ["N", "K"]), ] outputs = [ - helper.make_tensor_value_info("output", TensorProto.FLOAT, ['M', 'N']), - helper.make_tensor_value_info("output2", TensorProto.FLOAT, ['K', 'M']) + helper.make_tensor_value_info("output", TensorProto.FLOAT, ["M", "N"]), + helper.make_tensor_value_info("output2", TensorProto.FLOAT, ["K", "M"]), ] save(model_path, nodes, inputs, outputs, []) @@ -50,26 +45,26 @@ def gemm_transpose_2outputs_from_transpose(model_path): def gemm_transpose_2outputs_from_transpose_to_2gemms(model_path): nodes = [ helper.make_node("Transpose", ["A"], ["tp0"], "TransposeA"), - helper.make_node("Transpose", ["B"], ["tp1"], "TransposeB"), + helper.make_node("Transpose", ["B"], ["tp1"], "TransposeB"), helper.make_node("Gemm", ["tp0", "tp1"], ["output"], "Gemm1", alpha=3.0, transA=1), helper.make_node("Gemm", ["tp1", "C"], ["output3"], "Gemm2", alpha=3.0, transA=1), helper.make_node("Identity", ["tp0"], ["output2"], "IdentityAt"), ] inputs = [ - helper.make_tensor_value_info("A", TensorProto.FLOAT, ['M', 'K']), - helper.make_tensor_value_info("B", TensorProto.FLOAT, ['N', 'K']), - helper.make_tensor_value_info("C", TensorProto.FLOAT, ['K', 'L']) + helper.make_tensor_value_info("A", TensorProto.FLOAT, ["M", "K"]), + helper.make_tensor_value_info("B", TensorProto.FLOAT, ["N", "K"]), + helper.make_tensor_value_info("C", TensorProto.FLOAT, ["K", "L"]), ] outputs = [ - helper.make_tensor_value_info("output", TensorProto.FLOAT, ['M', 'N']), - helper.make_tensor_value_info("output2", TensorProto.FLOAT, ['K', 'M']), - helper.make_tensor_value_info("output3", TensorProto.FLOAT, ['N', 'L']) + helper.make_tensor_value_info("output", TensorProto.FLOAT, ["M", "N"]), + helper.make_tensor_value_info("output2", TensorProto.FLOAT, ["K", "M"]), + helper.make_tensor_value_info("output3", TensorProto.FLOAT, ["N", "L"]), ] save(model_path, nodes, inputs, outputs, []) + gemm_transpose_2outputs_from_transpose("gemm_transpose_2outputs_from_transpose.onnx") gemm_transpose_2outputs_from_transpose_to_2gemms("gemm_transpose_2outputs_from_transpose_to_2gemms.onnx") - diff --git a/onnxruntime/test/testdata/transform/fusion/isinf_reducesum.py b/onnxruntime/test/testdata/transform/fusion/isinf_reducesum.py index ba1b488b54955..447b873f01c6e 100644 --- a/onnxruntime/test/testdata/transform/fusion/isinf_reducesum.py +++ b/onnxruntime/test/testdata/transform/fusion/isinf_reducesum.py @@ -1,53 +1,51 @@ -import onnx -from onnx import helper -from onnx import TensorProto, OperatorSetIdProto from enum import Enum +import onnx +from onnx import OperatorSetIdProto, TensorProto, helper + opsets = [] onnxdomain = OperatorSetIdProto() onnxdomain.version = 12 -onnxdomain.domain = "" # The empty string ("") or absence of this field implies the operator set that is defined as part of the ONNX specification. +onnxdomain.domain = "" # The empty string ("") or absence of this field implies the operator set that is defined as part of the ONNX specification. opsets.append(onnxdomain) msdomain = OperatorSetIdProto() msdomain.version = 1 -msdomain.domain = 'com.microsoft' +msdomain.domain = "com.microsoft" opsets.append(msdomain) -kwargs={} -kwargs['opset_imports'] = opsets +kwargs = {} +kwargs["opset_imports"] = opsets + def GenerateModel(model_name): nodes = [ # subgraph helper.make_node("Cast", ["A"], ["cast1"], "cast_1", to=11), - helper.make_node("IsInf", ["cast1"], ["IsInf_out"], "is_inf"), - helper.make_node("Cast", ["IsInf_out"], ["cast2"], "cast_2", to=7), - helper.make_node("ReduceSum", ["cast2"], ["reduced"], "reduction", keepdims=0), - - helper.make_node("Greater", ["reduced", "one"], ["Y"], "output") + helper.make_node("Greater", ["reduced", "one"], ["Y"], "output"), ] inputs = [ # inputs - helper.make_tensor_value_info('A', TensorProto.FLOAT16, ['M', 'K']), - ] + helper.make_tensor_value_info("A", TensorProto.FLOAT16, ["M", "K"]), + ] - initializers = [ - helper.make_tensor('one', TensorProto.INT64, [1], [1])] + initializers = [helper.make_tensor("one", TensorProto.INT64, [1], [1])] graph = helper.make_graph( nodes, - "IsInfReduceSum", #name + "IsInfReduceSum", # name inputs, [ # outputs - helper.make_tensor_value_info('Y', TensorProto.BOOL, [1]), + helper.make_tensor_value_info("Y", TensorProto.BOOL, [1]), ], - initializers) + initializers, + ) model = helper.make_model(graph, **kwargs) onnx.save(model, model_name) + if __name__ == "__main__": - GenerateModel('isinf_reducesum.onnx') \ No newline at end of file + GenerateModel("isinf_reducesum.onnx") diff --git a/onnxruntime/test/testdata/transform/fusion/layer_norm_t5_gen.py b/onnxruntime/test/testdata/transform/fusion/layer_norm_t5_gen.py index 53e06bacfebbd..eb184fef5e59d 100644 --- a/onnxruntime/test/testdata/transform/fusion/layer_norm_t5_gen.py +++ b/onnxruntime/test/testdata/transform/fusion/layer_norm_t5_gen.py @@ -1,9 +1,8 @@ -import onnx -from onnx import helper -from onnx import TensorProto -from onnx import OperatorSetIdProto from enum import Enum +import onnx +from onnx import OperatorSetIdProto, TensorProto, helper + def GenerateModel(model_name, has_casts=False): nodes = [ # SimplifiedLayerNorm subgraph @@ -24,10 +23,14 @@ def GenerateModel(model_name, has_casts=False): ) initializers = [ # initializers - helper.make_tensor('pow_in_2', TensorProto.FLOAT, [], [2]), - helper.make_tensor('const_e12', TensorProto.FLOAT, [], [1e-12]), - helper.make_tensor('gamma', TensorProto.FLOAT16 if has_casts else - TensorProto.FLOAT, [4], [1, 2, 3, 4]), + helper.make_tensor("pow_in_2", TensorProto.FLOAT, [], [2]), + helper.make_tensor("const_e12", TensorProto.FLOAT, [], [1e-12]), + helper.make_tensor( + "gamma", + TensorProto.FLOAT16 if has_casts else TensorProto.FLOAT, + [4], + [1, 2, 3, 4], + ), ] input_type = TensorProto.FLOAT16 if has_casts else TensorProto.FLOAT @@ -35,14 +38,15 @@ def GenerateModel(model_name, has_casts=False): graph = helper.make_graph( nodes, - "SimplifiedLayerNorm", #name + "SimplifiedLayerNorm", # name [ # inputs - helper.make_tensor_value_info('A', input_type, [16, 32, 4]), + helper.make_tensor_value_info("A", input_type, [16, 32, 4]), ], [ # outputs - helper.make_tensor_value_info('C', output_type, [16, 32, 4]), + helper.make_tensor_value_info("C", output_type, [16, 32, 4]), ], - initializers) + initializers, + ) onnxdomain = OperatorSetIdProto() onnxdomain.version = 12 @@ -57,5 +61,5 @@ def GenerateModel(model_name, has_casts=False): onnx.save(model, model_name) -GenerateModel('layer_norm_t5.onnx') -GenerateModel('simplified_layer_norm_with_casts.onnx', True) +GenerateModel("layer_norm_t5.onnx") +GenerateModel("simplified_layer_norm_with_casts.onnx", True) diff --git a/onnxruntime/test/testdata/transform/fusion/layer_norm_with_cast_2.py b/onnxruntime/test/testdata/transform/fusion/layer_norm_with_cast_2.py index b8142946265f9..091d38d9e6797 100644 --- a/onnxruntime/test/testdata/transform/fusion/layer_norm_with_cast_2.py +++ b/onnxruntime/test/testdata/transform/fusion/layer_norm_with_cast_2.py @@ -1,10 +1,9 @@ -import onnx -import numpy as np -from onnx import helper -from onnx import TensorProto -from onnx import OperatorSetIdProto from enum import Enum +import numpy as np +import onnx +from onnx import OperatorSetIdProto, TensorProto, helper + def GenerateModel(model_name): nodes = [ # LayerNormWithCast2 subgraph @@ -21,23 +20,24 @@ def GenerateModel(model_name): ] initializers = [ # initializers - helper.make_tensor('pow_in_2', TensorProto.FLOAT, [], [2]), - helper.make_tensor('const_0', TensorProto.FLOAT16, [], [0]), - helper.make_tensor('gamma', TensorProto.FLOAT16, [4], [1, 2, 3, 4]), - helper.make_tensor('beta', TensorProto.FLOAT16, [4], [1, 2, 3, 4]), + helper.make_tensor("pow_in_2", TensorProto.FLOAT, [], [2]), + helper.make_tensor("const_0", TensorProto.FLOAT16, [], [0]), + helper.make_tensor("gamma", TensorProto.FLOAT16, [4], [1, 2, 3, 4]), + helper.make_tensor("beta", TensorProto.FLOAT16, [4], [1, 2, 3, 4]), ] graph = helper.make_graph( nodes, - "LayerNormWithCast2", #name + "LayerNormWithCast2", # name [ # inputs - helper.make_tensor_value_info('A', TensorProto.FLOAT16, [16, 32, 4]), + helper.make_tensor_value_info("A", TensorProto.FLOAT16, [16, 32, 4]), ], [ # outputs - helper.make_tensor_value_info('C', TensorProto.FLOAT16, [16, 32, 4]), + helper.make_tensor_value_info("C", TensorProto.FLOAT16, [16, 32, 4]), ], - initializers) - + initializers, + ) + onnxdomain = OperatorSetIdProto() onnxdomain.version = 12 # The empty string ("") or absence of this field implies the operator set that is defined as part of the ONNX specification. @@ -51,4 +51,4 @@ def GenerateModel(model_name): onnx.save(model, model_name) -GenerateModel('layer_norm_with_cast_2.onnx') \ No newline at end of file +GenerateModel("layer_norm_with_cast_2.onnx") diff --git a/onnxruntime/test/testdata/transform/fusion/layer_norm_with_cast_3.py b/onnxruntime/test/testdata/transform/fusion/layer_norm_with_cast_3.py index 36c7b9b56d7c0..a2f3928b5eaf2 100644 --- a/onnxruntime/test/testdata/transform/fusion/layer_norm_with_cast_3.py +++ b/onnxruntime/test/testdata/transform/fusion/layer_norm_with_cast_3.py @@ -1,7 +1,5 @@ import onnx -from onnx import helper -from onnx import TensorProto -from onnx import OperatorSetIdProto +from onnx import OperatorSetIdProto, TensorProto, helper def GenerateModel(model_name): @@ -19,23 +17,24 @@ def GenerateModel(model_name): ] initializers = [ # initializers - helper.make_tensor('pow_in_2', TensorProto.FLOAT, [], [2]), - helper.make_tensor('const_0', TensorProto.FLOAT, [], [0]), - helper.make_tensor('gamma', TensorProto.FLOAT16, [4], [1, 2, 3, 4]), - helper.make_tensor('beta', TensorProto.FLOAT16, [4], [1, 2, 3, 4]), + helper.make_tensor("pow_in_2", TensorProto.FLOAT, [], [2]), + helper.make_tensor("const_0", TensorProto.FLOAT, [], [0]), + helper.make_tensor("gamma", TensorProto.FLOAT16, [4], [1, 2, 3, 4]), + helper.make_tensor("beta", TensorProto.FLOAT16, [4], [1, 2, 3, 4]), ] graph = helper.make_graph( nodes, - "LayerNormWithCast3", #name + "LayerNormWithCast3", # name [ # inputs - helper.make_tensor_value_info('A', TensorProto.FLOAT, [16, 32, 4]), + helper.make_tensor_value_info("A", TensorProto.FLOAT, [16, 32, 4]), ], [ # outputs - helper.make_tensor_value_info('C', TensorProto.FLOAT16, [16, 32, 4]), + helper.make_tensor_value_info("C", TensorProto.FLOAT16, [16, 32, 4]), ], - initializers) - + initializers, + ) + onnxdomain = OperatorSetIdProto() onnxdomain.version = 12 # The empty string ("") or absence of this field implies the operator set that is defined as part of the ONNX specification. @@ -48,7 +47,9 @@ def GenerateModel(model_name): model = helper.make_model(graph, opset_imports=opsets) onnx.save(model, model_name) -GenerateModel('layer_norm_with_cast_3.onnx') + +GenerateModel("layer_norm_with_cast_3.onnx") + def GenerateModel2(model_name): nodes = [ # LayerNormWithCast4 subgraph @@ -67,23 +68,24 @@ def GenerateModel2(model_name): ] initializers = [ # initializers - helper.make_tensor('pow_in_2', TensorProto.FLOAT, [], [2]), - helper.make_tensor('const_0', TensorProto.FLOAT, [], [0]), - helper.make_tensor('gamma', TensorProto.FLOAT16, [4], [1, 2, 3, 4]), - helper.make_tensor('beta', TensorProto.FLOAT16, [4], [1, 2, 3, 4]), + helper.make_tensor("pow_in_2", TensorProto.FLOAT, [], [2]), + helper.make_tensor("const_0", TensorProto.FLOAT, [], [0]), + helper.make_tensor("gamma", TensorProto.FLOAT16, [4], [1, 2, 3, 4]), + helper.make_tensor("beta", TensorProto.FLOAT16, [4], [1, 2, 3, 4]), ] graph = helper.make_graph( nodes, - "LayerNormWithCast4", #name + "LayerNormWithCast4", # name [ # inputs - helper.make_tensor_value_info('A', TensorProto.FLOAT16, [16, 32, 4]), + helper.make_tensor_value_info("A", TensorProto.FLOAT16, [16, 32, 4]), ], [ # outputs - helper.make_tensor_value_info('C', TensorProto.FLOAT16, [16, 32, 4]), + helper.make_tensor_value_info("C", TensorProto.FLOAT16, [16, 32, 4]), ], - initializers) - + initializers, + ) + onnxdomain = OperatorSetIdProto() onnxdomain.version = 12 # The empty string ("") or absence of this field implies the operator set that is defined as part of the ONNX specification. @@ -96,4 +98,5 @@ def GenerateModel2(model_name): model = helper.make_model(graph, opset_imports=opsets) onnx.save(model, model_name) -GenerateModel2('layer_norm_with_cast_4.onnx') + +GenerateModel2("layer_norm_with_cast_4.onnx") diff --git a/onnxruntime/test/testdata/transform/fusion/matmul_integer_to_float.py b/onnxruntime/test/testdata/transform/fusion/matmul_integer_to_float.py index 1a270043baa65..7bba71723b2c8 100644 --- a/onnxruntime/test/testdata/transform/fusion/matmul_integer_to_float.py +++ b/onnxruntime/test/testdata/transform/fusion/matmul_integer_to_float.py @@ -1,31 +1,71 @@ -import onnx -from onnx import helper -from onnx import TensorProto from enum import Enum +import onnx +from onnx import TensorProto, helper + + def MakeSubGraph(suffix, has_bias): mul_bottom_output = "mul_output" + suffix if has_bias else "output" + suffix nodes = [ - helper.make_node("MatMulInteger", ["a_quantized", "b_quantized" + suffix, "a_zp", "b_zp" + suffix], ["matmul_output_int32" + suffix], "MatMulInteger" + suffix), - helper.make_node("Mul", ["a_scale", "b_scale" + suffix], ["multiplier" + suffix], "mul_right" + suffix), - helper.make_node("Cast", ["matmul_output_int32" + suffix], ["matmul_output_float" + suffix], "cast" + suffix, to=1), - helper.make_node("Mul", ["matmul_output_float" + suffix, "multiplier" + suffix], [mul_bottom_output], "mul_bottom" + suffix), - ] - + helper.make_node( + "MatMulInteger", + ["a_quantized", "b_quantized" + suffix, "a_zp", "b_zp" + suffix], + ["matmul_output_int32" + suffix], + "MatMulInteger" + suffix, + ), + helper.make_node( + "Mul", + ["a_scale", "b_scale" + suffix], + ["multiplier" + suffix], + "mul_right" + suffix, + ), + helper.make_node( + "Cast", + ["matmul_output_int32" + suffix], + ["matmul_output_float" + suffix], + "cast" + suffix, + to=1, + ), + helper.make_node( + "Mul", + ["matmul_output_float" + suffix, "multiplier" + suffix], + [mul_bottom_output], + "mul_bottom" + suffix, + ), + ] + if has_bias: - nodes.extend([helper.make_node("Add", [mul_bottom_output, "bias" + suffix], ["output" + suffix], "bias_add" + suffix),]) - + nodes.extend( + [ + helper.make_node( + "Add", + [mul_bottom_output, "bias" + suffix], + ["output" + suffix], + "bias_add" + suffix, + ), + ] + ) + return nodes + def MakeInitializer(suffix): return [ - helper.make_tensor('b_quantized' + suffix, TensorProto.UINT8, [2,3], [2, 4, 5, 6, 7, 8]), - helper.make_tensor('b_zp' + suffix, TensorProto.UINT8, [], [128]), - helper.make_tensor('b_scale' + suffix, TensorProto.FLOAT, [], [1.8]), - ] + helper.make_tensor("b_quantized" + suffix, TensorProto.UINT8, [2, 3], [2, 4, 5, 6, 7, 8]), + helper.make_tensor("b_zp" + suffix, TensorProto.UINT8, [], [128]), + helper.make_tensor("b_scale" + suffix, TensorProto.FLOAT, [], [1.8]), + ] + def GenerateModel(model_name): - nodes = [helper.make_node("DynamicQuantizeLinear", ["input"], ["a_quantized", "a_scale", "a_zp"], "DynamicQuantizeLinear"),] + nodes = [ + helper.make_node( + "DynamicQuantizeLinear", + ["input"], + ["a_quantized", "a_scale", "a_zp"], + "DynamicQuantizeLinear", + ), + ] nodes.extend(MakeSubGraph("_1", True)) nodes.extend(MakeSubGraph("_2", True)) nodes.extend(MakeSubGraph("_3", False)) @@ -34,30 +74,34 @@ def GenerateModel(model_name): initializers.extend(MakeInitializer("_1")) initializers.extend(MakeInitializer("_3")) - initializers.extend([ - helper.make_tensor('bias_1', TensorProto.FLOAT, [3], [2, 4, 5]), - helper.make_tensor('bias_2', TensorProto.FLOAT, [3,3], [1, 2, 3, 4, 5, 6, 7, 8, 9]), - ]) + initializers.extend( + [ + helper.make_tensor("bias_1", TensorProto.FLOAT, [3], [2, 4, 5]), + helper.make_tensor("bias_2", TensorProto.FLOAT, [3, 3], [1, 2, 3, 4, 5, 6, 7, 8, 9]), + ] + ) graph = helper.make_graph( nodes, - "MatMulIntegerToFloat_fusion", #name - [ # inputs - helper.make_tensor_value_info('input', TensorProto.FLOAT, [3, 2]), - # matrix b corresponding inputs for subgraph 2 - helper.make_tensor_value_info('b_quantized_2', TensorProto.UINT8, [2, 3]), - helper.make_tensor_value_info('b_zp_2', TensorProto.UINT8, [1]), - helper.make_tensor_value_info('b_scale_2', TensorProto.FLOAT, [1]), + "MatMulIntegerToFloat_fusion", # name + [ # inputs + helper.make_tensor_value_info("input", TensorProto.FLOAT, [3, 2]), + # matrix b corresponding inputs for subgraph 2 + helper.make_tensor_value_info("b_quantized_2", TensorProto.UINT8, [2, 3]), + helper.make_tensor_value_info("b_zp_2", TensorProto.UINT8, [1]), + helper.make_tensor_value_info("b_scale_2", TensorProto.FLOAT, [1]), ], [ # outputs - helper.make_tensor_value_info('output_1', TensorProto.FLOAT, [3, 3]), - helper.make_tensor_value_info('output_2', TensorProto.FLOAT, [3, 3]), - helper.make_tensor_value_info('output_3', TensorProto.FLOAT, [3, 3]), + helper.make_tensor_value_info("output_1", TensorProto.FLOAT, [3, 3]), + helper.make_tensor_value_info("output_2", TensorProto.FLOAT, [3, 3]), + helper.make_tensor_value_info("output_3", TensorProto.FLOAT, [3, 3]), ], - initializers) + initializers, + ) model = helper.make_model(graph) onnx.save(model, model_name) + if __name__ == "__main__": - GenerateModel('matmul_integer_to_float.onnx') \ No newline at end of file + GenerateModel("matmul_integer_to_float.onnx") diff --git a/onnxruntime/test/testdata/transform/fusion/matmul_scale_gen.py b/onnxruntime/test/testdata/transform/fusion/matmul_scale_gen.py index 850ff6bc22fcb..e4692894ea5f1 100644 --- a/onnxruntime/test/testdata/transform/fusion/matmul_scale_gen.py +++ b/onnxruntime/test/testdata/transform/fusion/matmul_scale_gen.py @@ -1,8 +1,5 @@ import onnx -from onnx import helper -from onnx import TensorProto -from onnx import OperatorSetIdProto - +from onnx import OperatorSetIdProto, TensorProto, helper onnxdomain = OperatorSetIdProto() onnxdomain.version = 12 @@ -17,20 +14,14 @@ def save(model_path, nodes, inputs, outputs, initializers, opsets=opsets): - graph = helper.make_graph( - nodes, - "MatMulScaleTest", - inputs, outputs, initializers) + graph = helper.make_graph(nodes, "MatMulScaleTest", inputs, outputs, initializers) - model = helper.make_model( - graph, opset_imports=opsets, producer_name="onnxruntime-test") + model = helper.make_model(graph, opset_imports=opsets, producer_name="onnxruntime-test") onnx.save(model, model_path) -def gen(model_path, - use_transpose_matmul, - scale_input_0, scale_input_1, scale_output): +def gen(model_path, use_transpose_matmul, scale_input_0, scale_input_1, scale_output): matmul_op = "FusedMatMul" if use_transpose_matmul else "MatMul" matmul_domain = "com.microsoft" if use_transpose_matmul else "" matmul_attrs = {"alpha": scale_value} if use_transpose_matmul else {} @@ -38,48 +29,47 @@ def gen(model_path, nodes = [] if scale_input_0: - nodes.append(helper.make_node( - "Mul", ["input_0", "scale"], ["scaled_input_0"], "scale input_0")) + nodes.append(helper.make_node("Mul", ["input_0", "scale"], ["scaled_input_0"], "scale input_0")) if scale_input_1: - nodes.append(helper.make_node( - "Div", ["input_1", "scale_reciprocal"], ["scaled_input_1"], "scale input_1")) - - nodes.append(helper.make_node( - matmul_op, - [ - "scaled_input_0" if scale_input_0 else "input_0", - "scaled_input_1" if scale_input_1 else "input_1" - ], - [ - "unscaled_output" if scale_output else "output" - ], - matmul_op, - "", - matmul_domain, - **matmul_attrs)) + nodes.append( + helper.make_node( + "Div", + ["input_1", "scale_reciprocal"], + ["scaled_input_1"], + "scale input_1", + ) + ) + + nodes.append( + helper.make_node( + matmul_op, + [ + "scaled_input_0" if scale_input_0 else "input_0", + "scaled_input_1" if scale_input_1 else "input_1", + ], + ["unscaled_output" if scale_output else "output"], + matmul_op, + "", + matmul_domain, + **matmul_attrs + ) + ) if scale_output: - nodes.append(helper.make_node( - "Mul", ["scale", "unscaled_output"], ["output"], "scale output")) + nodes.append(helper.make_node("Mul", ["scale", "unscaled_output"], ["output"], "scale output")) initializers = [ helper.make_tensor("scale", TensorProto.FLOAT, [], [scale_value]), - helper.make_tensor("scale_reciprocal", - TensorProto.FLOAT, [], [1/scale_value]) + helper.make_tensor("scale_reciprocal", TensorProto.FLOAT, [], [1 / scale_value]), ] inputs = [ - helper.make_tensor_value_info( - "input_0", TensorProto.FLOAT, [2, 'M', 'K']), - helper.make_tensor_value_info( - "input_1", TensorProto.FLOAT, [2, 'K', 'N']) + helper.make_tensor_value_info("input_0", TensorProto.FLOAT, [2, "M", "K"]), + helper.make_tensor_value_info("input_1", TensorProto.FLOAT, [2, "K", "N"]), ] - outputs = [ - helper.make_tensor_value_info( - "output", TensorProto.FLOAT, [2, 'M', 'N']) - ] + outputs = [helper.make_tensor_value_info("output", TensorProto.FLOAT, [2, "M", "N"])] save(model_path, nodes, inputs, outputs, initializers) @@ -99,85 +89,69 @@ def gen_unfusable(model_path, unfusable_type): matmul_op = "MatMul" if unfusable_type == UNFUSABLE_DIV_NOT_SCALE: - scale_node = helper.make_node( - "Div", ["scale", "input_0"], ["scaled_input_0"], "scale input_0") + scale_node = helper.make_node("Div", ["scale", "input_0"], ["scaled_input_0"], "scale input_0") elif unfusable_type == UNFUSABLE_SCALE_NOT_SCALAR: - scale_node = helper.make_node( - "Mul", ["scale_non_scalar", "input_0"], ["scaled_input_0"], "scale input_0") + scale_node = helper.make_node("Mul", ["scale_non_scalar", "input_0"], ["scaled_input_0"], "scale input_0") elif unfusable_type == UNFUSABLE_SCALE_NOT_CONSTANT: - scale_node = helper.make_node( - "Mul", ["input_0", "input_0"], ["scaled_input_0"], "scale input_0") + scale_node = helper.make_node("Mul", ["input_0", "input_0"], ["scaled_input_0"], "scale input_0") else: raise ValueError("Invalid unfusable_type: {}".format(unfusable_type)) nodes = [ scale_node, - helper.make_node( - matmul_op, ["scaled_input_0", "input_1"], ["output"], matmul_op) + helper.make_node(matmul_op, ["scaled_input_0", "input_1"], ["output"], matmul_op), ] initializers = [ helper.make_tensor("scale", TensorProto.FLOAT, [], [scale_value]), - helper.make_tensor("scale_non_scalar", TensorProto.FLOAT, - [2, 1, 1], [scale_value, scale_value]) + helper.make_tensor("scale_non_scalar", TensorProto.FLOAT, [2, 1, 1], [scale_value, scale_value]), ] inputs = [ - helper.make_tensor_value_info( - "input_0", TensorProto.FLOAT, [2, 'M', 'K']), - helper.make_tensor_value_info( - "input_1", TensorProto.FLOAT, [2, 'K', 'N']) + helper.make_tensor_value_info("input_0", TensorProto.FLOAT, [2, "M", "K"]), + helper.make_tensor_value_info("input_1", TensorProto.FLOAT, [2, "K", "N"]), ] - outputs = [ - helper.make_tensor_value_info( - "output", TensorProto.FLOAT, [2, 'M', 'N']) - ] + outputs = [helper.make_tensor_value_info("output", TensorProto.FLOAT, [2, "M", "N"])] save(model_path, nodes, inputs, outputs, initializers) -gen_unfusable("matmul_scale_unfusable_div_not_scale.onnx", - UNFUSABLE_DIV_NOT_SCALE) -gen_unfusable("matmul_scale_unfusable_scale_not_scalar.onnx", - UNFUSABLE_SCALE_NOT_SCALAR) -gen_unfusable("matmul_scale_unfusable_scale_not_constant.onnx", - UNFUSABLE_SCALE_NOT_CONSTANT) +gen_unfusable("matmul_scale_unfusable_div_not_scale.onnx", UNFUSABLE_DIV_NOT_SCALE) +gen_unfusable("matmul_scale_unfusable_scale_not_scalar.onnx", UNFUSABLE_SCALE_NOT_SCALAR) +gen_unfusable("matmul_scale_unfusable_scale_not_constant.onnx", UNFUSABLE_SCALE_NOT_CONSTANT) def gen_reused_input_scale(model_path): matmul_op = "MatMul" nodes = [ + helper.make_node("Mul", ["input_0", "scale"], ["scaled_input_0"], "scale input_0"), helper.make_node( - "Mul", ["input_0", "scale"], ["scaled_input_0"], - "scale input_0"), - helper.make_node( - matmul_op, ["scaled_input_0", "input_1"], ["output_0"], - "MatMul input_0 and input_1"), + matmul_op, + ["scaled_input_0", "input_1"], + ["output_0"], + "MatMul input_0 and input_1", + ), helper.make_node( - matmul_op, ["scaled_input_0", "input_2"], ["output_1"], - "MatMul input_0 and input_2") + matmul_op, + ["scaled_input_0", "input_2"], + ["output_1"], + "MatMul input_0 and input_2", + ), ] - initializers = [ - helper.make_tensor("scale", TensorProto.FLOAT, [], [scale_value]) - ] + initializers = [helper.make_tensor("scale", TensorProto.FLOAT, [], [scale_value])] inputs = [ - helper.make_tensor_value_info( - "input_0", TensorProto.FLOAT, [2, 'M', 'K']), - helper.make_tensor_value_info( - "input_1", TensorProto.FLOAT, [2, 'K', 'N']), - helper.make_tensor_value_info( - "input_2", TensorProto.FLOAT, [2, 'K', 'N']) + helper.make_tensor_value_info("input_0", TensorProto.FLOAT, [2, "M", "K"]), + helper.make_tensor_value_info("input_1", TensorProto.FLOAT, [2, "K", "N"]), + helper.make_tensor_value_info("input_2", TensorProto.FLOAT, [2, "K", "N"]), ] outputs = [ - helper.make_tensor_value_info( - "output_0", TensorProto.FLOAT, [2, 'M', 'N']), - helper.make_tensor_value_info( - "output_1", TensorProto.FLOAT, [2, 'M', 'N']) + helper.make_tensor_value_info("output_0", TensorProto.FLOAT, [2, "M", "N"]), + helper.make_tensor_value_info("output_1", TensorProto.FLOAT, [2, "M", "N"]), ] save(model_path, nodes, inputs, outputs, initializers) @@ -190,28 +164,24 @@ def gen_int32(model_path): matmul_op = "MatMul" nodes = [ + helper.make_node("Mul", ["input_0", "scale"], ["scaled_input_0"], "scale input_0"), helper.make_node( - "Mul", ["input_0", "scale"], ["scaled_input_0"], - "scale input_0"), - helper.make_node( - matmul_op, ["scaled_input_0", "input_1"], ["output_0"], - "MatMul input_0 and input_1"), + matmul_op, + ["scaled_input_0", "input_1"], + ["output_0"], + "MatMul input_0 and input_1", + ), ] - initializers = [ - helper.make_tensor("scale", TensorProto.INT32, [], [int(scale_value)]) - ] + initializers = [helper.make_tensor("scale", TensorProto.INT32, [], [int(scale_value)])] inputs = [ - helper.make_tensor_value_info( - "input_0", TensorProto.INT32, [2, 'M', 'K']), - helper.make_tensor_value_info( - "input_1", TensorProto.INT32, [2, 'K', 'N']), + helper.make_tensor_value_info("input_0", TensorProto.INT32, [2, "M", "K"]), + helper.make_tensor_value_info("input_1", TensorProto.INT32, [2, "K", "N"]), ] outputs = [ - helper.make_tensor_value_info( - "output_0", TensorProto.INT32, [2, 'M', 'N']), + helper.make_tensor_value_info("output_0", TensorProto.INT32, [2, "M", "N"]), ] save(model_path, nodes, inputs, outputs, initializers) @@ -223,28 +193,24 @@ def gen_int32(model_path): def gen_scale_input(model_path): nodes = [ + helper.make_node("Mul", ["input_0", "scale"], ["scaled_input_0"], "scale input_0"), helper.make_node( - "Mul", ["input_0", "scale"], ["scaled_input_0"], - "scale input_0"), - helper.make_node( - "MatMul", ["scaled_input_0", "input_1"], ["output_0"], - "MatMul input_0 and input_1"), + "MatMul", + ["scaled_input_0", "input_1"], + ["output_0"], + "MatMul input_0 and input_1", + ), ] - initializers = [ - helper.make_tensor("scale", TensorProto.FLOAT, [1], [1.0]) - ] + initializers = [helper.make_tensor("scale", TensorProto.FLOAT, [1], [1.0])] inputs = [ - helper.make_tensor_value_info( - "input_0", TensorProto.FLOAT, []), - helper.make_tensor_value_info( - "input_1", TensorProto.FLOAT, [1, 'K']), + helper.make_tensor_value_info("input_0", TensorProto.FLOAT, []), + helper.make_tensor_value_info("input_1", TensorProto.FLOAT, [1, "K"]), ] outputs = [ - helper.make_tensor_value_info( - "output_0", TensorProto.FLOAT, ['K']), + helper.make_tensor_value_info("output_0", TensorProto.FLOAT, ["K"]), ] onnxdomain = OperatorSetIdProto() @@ -254,4 +220,4 @@ def gen_scale_input(model_path): save(model_path, nodes, inputs, outputs, initializers, opsets) -gen_scale_input("matmul_scale_with_scale_input.onnx") \ No newline at end of file +gen_scale_input("matmul_scale_with_scale_input.onnx") diff --git a/onnxruntime/test/testdata/transform/fusion/not_where.py b/onnxruntime/test/testdata/transform/fusion/not_where.py index ba3b3a5feff19..7e48164d5161a 100644 --- a/onnxruntime/test/testdata/transform/fusion/not_where.py +++ b/onnxruntime/test/testdata/transform/fusion/not_where.py @@ -1,21 +1,22 @@ -import onnx -from onnx import helper -from onnx import TensorProto, OperatorSetIdProto from enum import Enum +import onnx +from onnx import OperatorSetIdProto, TensorProto, helper + opsets = [] onnxdomain = OperatorSetIdProto() onnxdomain.version = 12 -onnxdomain.domain = "" # The empty string ("") or absence of this field implies the operator set that is defined as part of the ONNX specification. +onnxdomain.domain = "" # The empty string ("") or absence of this field implies the operator set that is defined as part of the ONNX specification. opsets.append(onnxdomain) msdomain = OperatorSetIdProto() msdomain.version = 1 -msdomain.domain = 'com.microsoft' +msdomain.domain = "com.microsoft" opsets.append(msdomain) -kwargs={} -kwargs['opset_imports'] = opsets +kwargs = {} +kwargs["opset_imports"] = opsets + def GenerateModel(model_name): nodes = [ # subgraph @@ -34,30 +35,32 @@ def GenerateModel(model_name): ] inputs = [ # inputs - helper.make_tensor_value_info('X', TensorProto.BOOL, ['M', 'K']), - ] + helper.make_tensor_value_info("X", TensorProto.BOOL, ["M", "K"]), + ] initializers = [ - helper.make_tensor('v0', TensorProto.FLOAT, [1], [1.0]), - helper.make_tensor('v1', TensorProto.FLOAT, [1], [-1.0]), - ] + helper.make_tensor("v0", TensorProto.FLOAT, [1], [1.0]), + helper.make_tensor("v1", TensorProto.FLOAT, [1], [-1.0]), + ] graph = helper.make_graph( nodes, - "NotWhere", #name + "NotWhere", # name inputs, [ # outputs - helper.make_tensor_value_info('not_X_2', TensorProto.BOOL, ['M', 'K']), - helper.make_tensor_value_info('Y1', TensorProto.FLOAT, ['M', 'K']), - helper.make_tensor_value_info('Y2', TensorProto.FLOAT, ['M', 'K']), - helper.make_tensor_value_info('Y3', TensorProto.FLOAT, ['M', 'K']), - helper.make_tensor_value_info('Y4', TensorProto.FLOAT, ['M', 'K']), - helper.make_tensor_value_info('Y5', TensorProto.FLOAT, ['M', 'K']), + helper.make_tensor_value_info("not_X_2", TensorProto.BOOL, ["M", "K"]), + helper.make_tensor_value_info("Y1", TensorProto.FLOAT, ["M", "K"]), + helper.make_tensor_value_info("Y2", TensorProto.FLOAT, ["M", "K"]), + helper.make_tensor_value_info("Y3", TensorProto.FLOAT, ["M", "K"]), + helper.make_tensor_value_info("Y4", TensorProto.FLOAT, ["M", "K"]), + helper.make_tensor_value_info("Y5", TensorProto.FLOAT, ["M", "K"]), ], - initializers) + initializers, + ) model = helper.make_model(graph, **kwargs) onnx.save(model, model_name) + if __name__ == "__main__": - GenerateModel('not_where.onnx') \ No newline at end of file + GenerateModel("not_where.onnx") diff --git a/onnxruntime/test/testdata/transform/fusion/reshape_fusion_gen.py b/onnxruntime/test/testdata/transform/fusion/reshape_fusion_gen.py index 42d8cb68ebaf6..1eaf194164732 100644 --- a/onnxruntime/test/testdata/transform/fusion/reshape_fusion_gen.py +++ b/onnxruntime/test/testdata/transform/fusion/reshape_fusion_gen.py @@ -1,14 +1,15 @@ import onnx -from onnx import helper -from onnx import TensorProto +from onnx import TensorProto, helper + def save_model(graph, file_name): - model = helper.make_model(graph) - onnx.checker.check_model(model) - onnx.save(model, file_name) + model = helper.make_model(graph) + onnx.checker.check_model(model) + onnx.save(model, file_name) + graph = helper.make_graph( - [ # nodes + [ # nodes helper.make_node("Shape", ["SubgraphRoot"], ["shape1_out"], "shape1"), helper.make_node("Shape", ["SubgraphRoot"], ["shape2_out"], "shape2"), helper.make_node("Gather", ["shape1_out", "indices0"], ["gather0_out"], "gather0", axis=0), @@ -17,433 +18,552 @@ def save_model(graph, file_name): helper.make_node("Unsqueeze", ["gather0_out"], ["unsqueeze0_out"], "unsqueeze0", axes=[0]), helper.make_node("Unsqueeze", ["gather2_out"], ["unsqueeze2_out"], "unsqueeze2", axes=[0]), helper.make_node("Unsqueeze", ["gather3_out"], ["unsqueeze3_out"], "unsqueeze3", axes=[0]), - - helper.make_node("Concat", ["unsqueeze0_out", "a1", "unsqueeze2_out", "unsqueeze3_out", "a4"], ["concat_out"], "concat", axis=0), + helper.make_node( + "Concat", + ["unsqueeze0_out", "a1", "unsqueeze2_out", "unsqueeze3_out", "a4"], + ["concat_out"], + "concat", + axis=0, + ), helper.make_node("Reshape", ["SubgraphRoot", "concat_out"], ["Result"], "reshape"), ], - "Reshape_Fusion", #name + "Reshape_Fusion", # name [ # inputs - helper.make_tensor_value_info('SubgraphRoot', TensorProto.FLOAT, ['unk_0', 256, 'unk_2', 'unk_3']), + helper.make_tensor_value_info("SubgraphRoot", TensorProto.FLOAT, ["unk_0", 256, "unk_2", "unk_3"]), ], [ # outputs - helper.make_tensor_value_info('Result', TensorProto.FLOAT, ['unk_1', 128, 'unk_2', 'unk_3', 'unk_4']), - helper.make_tensor_value_info('gather3_out', TensorProto.INT64, []), + helper.make_tensor_value_info("Result", TensorProto.FLOAT, ["unk_1", 128, "unk_2", "unk_3", "unk_4"]), + helper.make_tensor_value_info("gather3_out", TensorProto.INT64, []), ], [ # initializers - helper.make_tensor('a1', TensorProto.INT64, [1], [128]), - helper.make_tensor('a4', TensorProto.INT64, [1], [-1]), - helper.make_tensor('indices0', TensorProto.INT64, [], [0]), - helper.make_tensor('indices2', TensorProto.INT64, [], [2]), - helper.make_tensor('indices3', TensorProto.INT64, [], [3]), - ] + helper.make_tensor("a1", TensorProto.INT64, [1], [128]), + helper.make_tensor("a4", TensorProto.INT64, [1], [-1]), + helper.make_tensor("indices0", TensorProto.INT64, [], [0]), + helper.make_tensor("indices2", TensorProto.INT64, [], [2]), + helper.make_tensor("indices3", TensorProto.INT64, [], [3]), + ], ) -save_model(graph, 'reshape_fusion_internal_nodes_reused.onnx') +save_model(graph, "reshape_fusion_internal_nodes_reused.onnx") graph = helper.make_graph( - [ # nodes + [ # nodes helper.make_node("Shape", ["SubgraphRoot"], ["shape0_out"], "shape0"), helper.make_node("Shape", ["SubgraphRoot"], ["shape1_out"], "shape1"), helper.make_node("Gather", ["shape0_out", "indices0"], ["gather0_out"], "gather0", axis=0), helper.make_node("Gather", ["shape1_out", "indices1"], ["gather1_out"], "gather1", axis=0), helper.make_node("Unsqueeze", ["gather0_out"], ["unsqueeze0_out"], "unsqueeze0", axes=[0]), helper.make_node("Unsqueeze", ["gather1_out"], ["unsqueeze1_out"], "unsqueeze1", axes=[0]), - - helper.make_node("Concat", ["unsqueeze0_out", "unsqueeze1_out", "a"], ["concat_out"], "concat", axis=0), + helper.make_node( + "Concat", + ["unsqueeze0_out", "unsqueeze1_out", "a"], + ["concat_out"], + "concat", + axis=0, + ), helper.make_node("Reshape", ["SubgraphRoot", "concat_out"], ["Result"], "reshape"), ], - "Reshape_Fusion", #name + "Reshape_Fusion", # name [ # inputs - helper.make_tensor_value_info('SubgraphRoot', TensorProto.FLOAT, [10, 20, 30]), + helper.make_tensor_value_info("SubgraphRoot", TensorProto.FLOAT, [10, 20, 30]), ], [ # outputs - helper.make_tensor_value_info('Result', TensorProto.FLOAT, [10, 20, 'unk']), - helper.make_tensor_value_info('gather0_out', TensorProto.INT64, []), + helper.make_tensor_value_info("Result", TensorProto.FLOAT, [10, 20, "unk"]), + helper.make_tensor_value_info("gather0_out", TensorProto.INT64, []), ], [ # initializers - helper.make_tensor('a', TensorProto.INT64, [1], [-1]), - helper.make_tensor('indices0', TensorProto.INT64, [], [0]), - helper.make_tensor('indices1', TensorProto.INT64, [], [1]), - ] + helper.make_tensor("a", TensorProto.INT64, [1], [-1]), + helper.make_tensor("indices0", TensorProto.INT64, [], [0]), + helper.make_tensor("indices1", TensorProto.INT64, [], [1]), + ], ) -save_model(graph, 'reshape_fusion_internal_node_is_graph_output.onnx') +save_model(graph, "reshape_fusion_internal_node_is_graph_output.onnx") graph = helper.make_graph( - [ # nodes + [ # nodes helper.make_node("Shape", ["SubgraphRoot"], ["shape2_out"], "shape2"), helper.make_node("Gather", ["shape2_out", "indices2"], ["gather2_out"], "gather2", axis=0), helper.make_node("Unsqueeze", ["gather2_out"], ["unsqueeze2_out"], "unsqueeze2", axes=[0]), - helper.make_node("Concat", ["a", "unsqueeze2_out"], ["concat_out"], "concat", axis=0), helper.make_node("Reshape", ["SubgraphRoot", "concat_out"], ["Result"], "reshape"), ], - "Reshape_Fusion", #name + "Reshape_Fusion", # name [ # inputs - helper.make_tensor_value_info('SubgraphRoot', TensorProto.FLOAT, [10, 20, 30]), + helper.make_tensor_value_info("SubgraphRoot", TensorProto.FLOAT, [10, 20, 30]), ], [ # outputs - helper.make_tensor_value_info('Result', TensorProto.FLOAT, ['unk_0', 'unk_1', 'unk_2']), + helper.make_tensor_value_info("Result", TensorProto.FLOAT, ["unk_0", "unk_1", "unk_2"]), ], [ # initializers - helper.make_tensor('a', TensorProto.INT64, [2], [1, 200]), - helper.make_tensor('indices2', TensorProto.INT64, [], [1]), - ] + helper.make_tensor("a", TensorProto.INT64, [2], [1, 200]), + helper.make_tensor("indices2", TensorProto.INT64, [], [1]), + ], ) -save_model(graph, 'reshape_fusion_multiple_values_in_initializer_tensor_1.onnx') +save_model(graph, "reshape_fusion_multiple_values_in_initializer_tensor_1.onnx") graph = helper.make_graph( - [ # nodes + [ # nodes helper.make_node("Shape", ["SubgraphRoot"], ["shape2_out"], "shape2"), helper.make_node("Gather", ["shape2_out", "indices2"], ["gather2_out"], "gather2", axis=0), helper.make_node("Unsqueeze", ["gather2_out"], ["unsqueeze2_out"], "unsqueeze2", axes=[0]), - helper.make_node("Concat", ["a", "unsqueeze2_out"], ["concat_out"], "concat", axis=0), helper.make_node("Reshape", ["SubgraphRoot", "concat_out"], ["Result"], "reshape"), ], - "Reshape_Fusion", #name + "Reshape_Fusion", # name [ # inputs - helper.make_tensor_value_info('SubgraphRoot', TensorProto.FLOAT, [10, 20, 30]), + helper.make_tensor_value_info("SubgraphRoot", TensorProto.FLOAT, [10, 20, 30]), ], [ # outputs - helper.make_tensor_value_info('Result', TensorProto.FLOAT, ['unk_0', 'unk_1', 'unk_2']), + helper.make_tensor_value_info("Result", TensorProto.FLOAT, ["unk_0", "unk_1", "unk_2"]), ], [ # initializers - helper.make_tensor('a', TensorProto.INT64, [2], [1, 200]), - helper.make_tensor('indices2', TensorProto.INT64, [], [2]), - ] + helper.make_tensor("a", TensorProto.INT64, [2], [1, 200]), + helper.make_tensor("indices2", TensorProto.INT64, [], [2]), + ], ) -save_model(graph, 'reshape_fusion_multiple_values_in_initializer_tensor_2.onnx') +save_model(graph, "reshape_fusion_multiple_values_in_initializer_tensor_2.onnx") graph = helper.make_graph( - [ # nodes + [ # nodes helper.make_node("Shape", ["AnotherInput"], ["shape2_out"], "shape2"), helper.make_node("Gather", ["shape2_out", "indices2"], ["gather2_out"], "gather2", axis=0), helper.make_node("Unsqueeze", ["gather2_out"], ["unsqueeze2_out"], "unsqueeze2", axes=[0]), - helper.make_node("Concat", ["a", "unsqueeze2_out"], ["concat_out"], "concat", axis=0), helper.make_node("Reshape", ["SubgraphRoot", "concat_out"], ["Result"], "reshape"), ], - "Reshape_Fusion", #name + "Reshape_Fusion", # name [ # inputs - helper.make_tensor_value_info('SubgraphRoot', TensorProto.FLOAT, [10, 20, 30]), - helper.make_tensor_value_info('AnotherInput', TensorProto.FLOAT, ['input_dim_0', 'input_dim_1', 'input_dim_2']), + helper.make_tensor_value_info("SubgraphRoot", TensorProto.FLOAT, [10, 20, 30]), + helper.make_tensor_value_info( + "AnotherInput", + TensorProto.FLOAT, + ["input_dim_0", "input_dim_1", "input_dim_2"], + ), ], [ # outputs - helper.make_tensor_value_info('Result', TensorProto.FLOAT, ['unk_0', 'unk_1', 'unk_2']), + helper.make_tensor_value_info("Result", TensorProto.FLOAT, ["unk_0", "unk_1", "unk_2"]), ], [ # initializers - helper.make_tensor('a', TensorProto.INT64, [2], [1, 200]), - helper.make_tensor('indices2', TensorProto.INT64, [], [2]), - ] + helper.make_tensor("a", TensorProto.INT64, [2], [1, 200]), + helper.make_tensor("indices2", TensorProto.INT64, [], [2]), + ], ) -save_model(graph, 'reshape_fusion_input_is_graph_input.onnx') +save_model(graph, "reshape_fusion_input_is_graph_input.onnx") graph = helper.make_graph( - [ # nodes + [ # nodes helper.make_node("Concat", ["a"], ["concat_out"], "concat", axis=0), helper.make_node("Reshape", ["SubgraphRoot", "concat_out"], ["Result"], "reshape"), ], - "Reshape_Fusion", #name + "Reshape_Fusion", # name [ # inputs - helper.make_tensor_value_info('SubgraphRoot', TensorProto.FLOAT, [2, 3, 4]), - helper.make_tensor_value_info('a', TensorProto.INT64, [3]), + helper.make_tensor_value_info("SubgraphRoot", TensorProto.FLOAT, [2, 3, 4]), + helper.make_tensor_value_info("a", TensorProto.INT64, [3]), ], [ # outputs - helper.make_tensor_value_info('Result', TensorProto.FLOAT, ['unk_0', 'unk_1', 'unk_2']), + helper.make_tensor_value_info("Result", TensorProto.FLOAT, ["unk_0", "unk_1", "unk_2"]), ], [ # initializers - helper.make_tensor('a', TensorProto.INT64, [3], [1, 1, 2*3*4]), - ] + helper.make_tensor("a", TensorProto.INT64, [3], [1, 1, 2 * 3 * 4]), + ], ) -save_model(graph, 'reshape_fusion_overridable_initializer.onnx') +save_model(graph, "reshape_fusion_overridable_initializer.onnx") graph = helper.make_graph( - [ # nodes + [ # nodes helper.make_node("Shape", ["query"], ["shape0_out"], "shape0"), helper.make_node("Gather", ["shape0_out", "indices0"], ["gather0_out"], "gather0", axis=0), helper.make_node("Unsqueeze", ["gather0_out"], ["unsqueeze0_out"], "unsqueeze0", axes=[0]), helper.make_node("Concat", ["a", "unsqueeze0_out"], ["concat_out"], "concat", axis=0), helper.make_node("Reshape", ["doc_word_mask", "concat_out"], ["Result"], "reshape"), ], - "Reshape_Fusion", #name + "Reshape_Fusion", # name [ # inputs - helper.make_tensor_value_info('query', TensorProto.FLOAT, [1, 50]), - helper.make_tensor_value_info('doc_word_mask', TensorProto.FLOAT, [1, 200, 50]), + helper.make_tensor_value_info("query", TensorProto.FLOAT, [1, 50]), + helper.make_tensor_value_info("doc_word_mask", TensorProto.FLOAT, [1, 200, 50]), ], [ # outputs - helper.make_tensor_value_info('Result', TensorProto.FLOAT, [10, 20, 'unk']), + helper.make_tensor_value_info("Result", TensorProto.FLOAT, [10, 20, "unk"]), ], [ # initializers - helper.make_tensor('a', TensorProto.INT64, [1], [-1]), - helper.make_tensor('indices0', TensorProto.INT64, [], [1]), - ] + helper.make_tensor("a", TensorProto.INT64, [1], [-1]), + helper.make_tensor("indices0", TensorProto.INT64, [], [1]), + ], ) -save_model(graph, 'reshape_fusion_with_graph_inputs.onnx') +save_model(graph, "reshape_fusion_with_graph_inputs.onnx") graph = helper.make_graph( - [ # nodes + [ # nodes helper.make_node("Shape", ["SubgraphRoot"], ["shape0_out"], "shape0"), helper.make_node("Shape", ["SubgraphRoot"], ["shape1_out"], "shape1"), helper.make_node("Gather", ["shape0_out", "indices0"], ["gather0_out"], "gather0", axis=0), helper.make_node("Gather", ["shape1_out", "indices1"], ["gather1_out"], "gather1", axis=0), helper.make_node("Unsqueeze", ["gather0_out"], ["unsqueeze0_out"], "unsqueeze0", axes=[0]), helper.make_node("Unsqueeze", ["gather1_out"], ["unsqueeze1_out"], "unsqueeze1", axes=[0]), - helper.make_node("Shape", ["SubgraphRoot"], ["shape2_out"], "shape2"), - helper.make_node("Slice", ["shape2_out", "slice_starts", "slice_ends"], ["slice_out"], "slice1"), - - helper.make_node("Concat", ["unsqueeze0_out", "unsqueeze1_out", "slice_out"], ["concat_out"], "concat", axis=0), + helper.make_node( + "Slice", + ["shape2_out", "slice_starts", "slice_ends"], + ["slice_out"], + "slice1", + ), + helper.make_node( + "Concat", + ["unsqueeze0_out", "unsqueeze1_out", "slice_out"], + ["concat_out"], + "concat", + axis=0, + ), helper.make_node("Reshape", ["SubgraphRoot", "concat_out"], ["Result"], "reshape"), ], - "Reshape_Fusion", #name + "Reshape_Fusion", # name [ # inputs - helper.make_tensor_value_info('SubgraphRoot', TensorProto.FLOAT, [10, 20, 30]), + helper.make_tensor_value_info("SubgraphRoot", TensorProto.FLOAT, [10, 20, 30]), ], [ # outputs - helper.make_tensor_value_info('Result', TensorProto.FLOAT, [10, 20, 'unk']), + helper.make_tensor_value_info("Result", TensorProto.FLOAT, [10, 20, "unk"]), ], [ # initializers - helper.make_tensor('indices0', TensorProto.INT64, [], [0]), - helper.make_tensor('indices1', TensorProto.INT64, [], [1]), - helper.make_tensor('slice_starts', TensorProto.INT64, [1], [2]), - helper.make_tensor('slice_ends', TensorProto.INT64, [1], [3]) - ] + helper.make_tensor("indices0", TensorProto.INT64, [], [0]), + helper.make_tensor("indices1", TensorProto.INT64, [], [1]), + helper.make_tensor("slice_starts", TensorProto.INT64, [1], [2]), + helper.make_tensor("slice_ends", TensorProto.INT64, [1], [3]), + ], ) -save_model(graph, 'reshape_fusion_concat_subgraph.onnx') +save_model(graph, "reshape_fusion_concat_subgraph.onnx") graph = helper.make_graph( - [ # nodes + [ # nodes helper.make_node("Shape", ["SubgraphRoot"], ["shape0_out"], "shape0"), helper.make_node("Shape", ["SubgraphRoot"], ["shape1_out"], "shape1"), helper.make_node("Gather", ["shape0_out", "indices0"], ["gather0_out"], "gather0", axis=0), helper.make_node("Gather", ["shape1_out", "indices1"], ["gather1_out"], "gather1", axis=0), helper.make_node("Unsqueeze", ["gather0_out"], ["unsqueeze0_out"], "unsqueeze0", axes=[0]), helper.make_node("Unsqueeze", ["gather1_out"], ["unsqueeze1_out"], "unsqueeze1", axes=[0]), - helper.make_node("Shape", ["SubgraphRoot"], ["shape2_out"], "shape2"), - helper.make_node("Slice", ["shape2_out"], ["slice_out"], "slice1", starts = [2], ends = [3]), - - helper.make_node("Concat", ["unsqueeze0_out", "unsqueeze1_out", "slice_out"], ["concat_out"], "concat", axis=0), + helper.make_node("Slice", ["shape2_out"], ["slice_out"], "slice1", starts=[2], ends=[3]), + helper.make_node( + "Concat", + ["unsqueeze0_out", "unsqueeze1_out", "slice_out"], + ["concat_out"], + "concat", + axis=0, + ), helper.make_node("Reshape", ["SubgraphRoot", "concat_out"], ["Result"], "reshape"), ], - "Reshape_Fusion", #name + "Reshape_Fusion", # name [ # inputs - helper.make_tensor_value_info('SubgraphRoot', TensorProto.FLOAT, [10, 20, 30]), + helper.make_tensor_value_info("SubgraphRoot", TensorProto.FLOAT, [10, 20, 30]), ], [ # outputs - helper.make_tensor_value_info('Result', TensorProto.FLOAT, [10, 20, 'unk']), + helper.make_tensor_value_info("Result", TensorProto.FLOAT, [10, 20, "unk"]), ], [ # initializers - helper.make_tensor('indices0', TensorProto.INT64, [], [0]), - helper.make_tensor('indices1', TensorProto.INT64, [], [1]), - ] + helper.make_tensor("indices0", TensorProto.INT64, [], [0]), + helper.make_tensor("indices1", TensorProto.INT64, [], [1]), + ], ) # Save this model without checking -onnx.save(helper.make_model(graph), 'reshape_fusion_with_slice1.onnx') +onnx.save(helper.make_model(graph), "reshape_fusion_with_slice1.onnx") graph = helper.make_graph( - [ # nodes + [ # nodes helper.make_node("Shape", ["SubgraphRoot"], ["shape0_out"], "shape0"), helper.make_node("Shape", ["SubgraphRoot"], ["shape1_out"], "shape1"), helper.make_node("Gather", ["shape0_out", "indices0"], ["gather0_out"], "gather0", axis=0), helper.make_node("Gather", ["shape1_out", "indices1"], ["gather1_out"], "gather1", axis=0), helper.make_node("Unsqueeze", ["gather0_out"], ["unsqueeze0_out"], "unsqueeze0", axes=[0]), helper.make_node("Unsqueeze", ["gather1_out"], ["unsqueeze1_out"], "unsqueeze1", axes=[0]), - helper.make_node("Shape", ["unsqueeze0_out"], ["dummy_out"], "dummy"), - helper.make_node("Shape", ["SubgraphRoot"], ["shape2_out"], "shape2"), - helper.make_node("Slice", ["shape2_out", "slice_starts", "slice_ends"], ["slice_out"], "slice1"), - - helper.make_node("Concat", ["unsqueeze0_out", "unsqueeze1_out", "slice_out"], ["concat_out"], "concat", axis=0), + helper.make_node( + "Slice", + ["shape2_out", "slice_starts", "slice_ends"], + ["slice_out"], + "slice1", + ), + helper.make_node( + "Concat", + ["unsqueeze0_out", "unsqueeze1_out", "slice_out"], + ["concat_out"], + "concat", + axis=0, + ), helper.make_node("Reshape", ["SubgraphRoot", "concat_out"], ["Result"], "reshape"), ], - "Reshape_Fusion", #name + "Reshape_Fusion", # name [ # inputs - helper.make_tensor_value_info('SubgraphRoot', TensorProto.FLOAT, [10, 20, 30]), + helper.make_tensor_value_info("SubgraphRoot", TensorProto.FLOAT, [10, 20, 30]), ], [ # outputs - helper.make_tensor_value_info('Result', TensorProto.FLOAT, [10, 20, 'unk']), - helper.make_tensor_value_info('slice_out', TensorProto.INT64, [1]), + helper.make_tensor_value_info("Result", TensorProto.FLOAT, [10, 20, "unk"]), + helper.make_tensor_value_info("slice_out", TensorProto.INT64, [1]), ], [ # initializers - helper.make_tensor('indices0', TensorProto.INT64, [], [0]), - helper.make_tensor('indices1', TensorProto.INT64, [], [1]), - helper.make_tensor('slice_starts', TensorProto.INT64, [1], [2]), - helper.make_tensor('slice_ends', TensorProto.INT64, [1], [3]) - ] + helper.make_tensor("indices0", TensorProto.INT64, [], [0]), + helper.make_tensor("indices1", TensorProto.INT64, [], [1]), + helper.make_tensor("slice_starts", TensorProto.INT64, [1], [2]), + helper.make_tensor("slice_ends", TensorProto.INT64, [1], [3]), + ], ) -save_model(graph, 'reshape_fusion_concat_subgraph_multiple_outputs.onnx') +save_model(graph, "reshape_fusion_concat_subgraph_multiple_outputs.onnx") graph = helper.make_graph( - [ # nodes + [ # nodes helper.make_node("Shape", ["SubgraphRoot"], ["shape0_out"], "shape0"), helper.make_node("Shape", ["SubgraphRoot"], ["shape1_out"], "shape1"), helper.make_node("Gather", ["shape0_out", "indices0"], ["gather0_out"], "gather0", axis=0), helper.make_node("Gather", ["shape1_out", "indices1"], ["gather1_out"], "gather1", axis=0), helper.make_node("Unsqueeze", ["gather0_out"], ["unsqueeze0_out"], "unsqueeze0", axes=[0]), helper.make_node("Unsqueeze", ["gather1_out"], ["unsqueeze1_out"], "unsqueeze1", axes=[0]), - helper.make_node("Shape", ["SubgraphRoot"], ["shape2_out"], "shape2"), - helper.make_node("Slice", ["shape2_out", "slice_starts", "slice_ends"], ["slice_out"], "slice1"), - helper.make_node("Pad", ["slice_out", "pads"], ["pad0_out"], "pad0", mode = "constant"), - - helper.make_node("Concat", ["unsqueeze0_out", "unsqueeze1_out", "pad0_out"], ["concat_out"], "concat", axis=0), + helper.make_node( + "Slice", + ["shape2_out", "slice_starts", "slice_ends"], + ["slice_out"], + "slice1", + ), + helper.make_node("Pad", ["slice_out", "pads"], ["pad0_out"], "pad0", mode="constant"), + helper.make_node( + "Concat", + ["unsqueeze0_out", "unsqueeze1_out", "pad0_out"], + ["concat_out"], + "concat", + axis=0, + ), helper.make_node("Reshape", ["SubgraphRoot", "concat_out"], ["Result"], "reshape"), ], - "Reshape_Fusion", #name + "Reshape_Fusion", # name [ # inputs - helper.make_tensor_value_info('SubgraphRoot', TensorProto.FLOAT, [10, 20, 30]), - ], - [ # outputs - helper.make_tensor_value_info('Result', TensorProto.FLOAT, [10, 20, 'unk']) + helper.make_tensor_value_info("SubgraphRoot", TensorProto.FLOAT, [10, 20, 30]), ], + [helper.make_tensor_value_info("Result", TensorProto.FLOAT, [10, 20, "unk"])], # outputs [ # initializers - helper.make_tensor('indices0', TensorProto.INT64, [], [0]), - helper.make_tensor('indices1', TensorProto.INT64, [], [1]), - helper.make_tensor('pads', TensorProto.INT64, [2], [1, 0]), - helper.make_tensor('slice_starts', TensorProto.INT64, [1], [2]), - helper.make_tensor('slice_ends', TensorProto.INT64, [1], [3]) - ] + helper.make_tensor("indices0", TensorProto.INT64, [], [0]), + helper.make_tensor("indices1", TensorProto.INT64, [], [1]), + helper.make_tensor("pads", TensorProto.INT64, [2], [1, 0]), + helper.make_tensor("slice_starts", TensorProto.INT64, [1], [2]), + helper.make_tensor("slice_ends", TensorProto.INT64, [1], [3]), + ], ) -save_model(graph, 'reshape_fusion_concat_subgraph_not_triggered.onnx') +save_model(graph, "reshape_fusion_concat_subgraph_not_triggered.onnx") graph = helper.make_graph( - [ # nodes + [ # nodes helper.make_node("Shape", ["SubgraphRoot"], ["shape0_out"], "shape0"), helper.make_node("Shape", ["SubgraphRoot"], ["shape1_out"], "shape1"), helper.make_node("Gather", ["shape0_out", "indices0"], ["gather0_out"], "gather0", axis=0), helper.make_node("Gather", ["shape1_out", "indices1"], ["gather1_out"], "gather1", axis=0), helper.make_node("Unsqueeze", ["gather0_out"], ["unsqueeze0_out"], "unsqueeze0", axes=[0]), helper.make_node("Unsqueeze", ["gather1_out"], ["unsqueeze1_out"], "unsqueeze1", axes=[0]), - helper.make_node("Shape", ["SubgraphRoot"], ["shape2_out"], "shape2"), - helper.make_node("Slice", ["shape2_out", "slice_starts", "slice_ends"], ["slice_out"], "slice1"), + helper.make_node( + "Slice", + ["shape2_out", "slice_starts", "slice_ends"], + ["slice_out"], + "slice1", + ), helper.make_node("Squeeze", ["slice_out"], ["squeeze0_out"], "squeeze0", axes=[0]), helper.make_node("Div", ["squeeze0_out", "div_init"], ["div_out"], "div"), helper.make_node("Unsqueeze", ["div_out"], ["unsqueeze2_out"], "unsqueeze2", axes=[0]), - - helper.make_node("Concat", ["unsqueeze0_out", "unsqueeze1_out", "unsqueeze2_out"], ["concat_out"], "concat", axis=0), + helper.make_node( + "Concat", + ["unsqueeze0_out", "unsqueeze1_out", "unsqueeze2_out"], + ["concat_out"], + "concat", + axis=0, + ), helper.make_node("Reshape", ["SubgraphRoot", "concat_out"], ["Result"], "reshape"), ], - "Reshape_Fusion", #name + "Reshape_Fusion", # name [ # inputs - helper.make_tensor_value_info('SubgraphRoot', TensorProto.FLOAT, [10, 20, 30]), - ], - [ # outputs - helper.make_tensor_value_info('Result', TensorProto.FLOAT, [10, 20, 'unk']) + helper.make_tensor_value_info("SubgraphRoot", TensorProto.FLOAT, [10, 20, 30]), ], + [helper.make_tensor_value_info("Result", TensorProto.FLOAT, [10, 20, "unk"])], # outputs [ # initializers - helper.make_tensor('indices0', TensorProto.INT64, [], [0]), - helper.make_tensor('indices1', TensorProto.INT64, [], [1]), - helper.make_tensor('div_init', TensorProto.INT64, [], [1]), - helper.make_tensor('slice_starts', TensorProto.INT64, [1], [2]), - helper.make_tensor('slice_ends', TensorProto.INT64, [1], [3]) - ] + helper.make_tensor("indices0", TensorProto.INT64, [], [0]), + helper.make_tensor("indices1", TensorProto.INT64, [], [1]), + helper.make_tensor("div_init", TensorProto.INT64, [], [1]), + helper.make_tensor("slice_starts", TensorProto.INT64, [1], [2]), + helper.make_tensor("slice_ends", TensorProto.INT64, [1], [3]), + ], ) -save_model(graph, 'reshape_fusion_concat_subgraph_div.onnx') +save_model(graph, "reshape_fusion_concat_subgraph_div.onnx") graph = helper.make_graph( - [ # nodes + [ # nodes helper.make_node("Shape", ["SubgraphRoot"], ["shape0_out"], "shape0"), helper.make_node("Shape", ["SubgraphRoot"], ["shape1_out"], "shape1"), helper.make_node("Gather", ["shape0_out", "indices0"], ["gather0_out"], "gather0", axis=0), helper.make_node("Gather", ["shape1_out", "indices1"], ["gather1_out"], "gather1", axis=0), helper.make_node("Unsqueeze", ["gather0_out"], ["unsqueeze0_out"], "unsqueeze0", axes=[0]), helper.make_node("Unsqueeze", ["gather1_out"], ["unsqueeze1_out"], "unsqueeze1", axes=[0]), - helper.make_node("Shape", ["SubgraphRoot"], ["shape2_out"], "shape2"), - helper.make_node("Slice", ["shape2_out", "slice_starts_0", "slice_ends_0"], ["slice0_out"], "slice0"), + helper.make_node( + "Slice", + ["shape2_out", "slice_starts_0", "slice_ends_0"], + ["slice0_out"], + "slice0", + ), helper.make_node("Squeeze", ["slice0_out"], ["squeeze0_out"], "squeeze0", axes=[0]), helper.make_node("Shape", ["SubgraphRoot"], ["shape3_out"], "shape3"), - helper.make_node("Slice", ["shape3_out", "slice_starts_1", "slice_ends_1"], ["slice1_out"], "slice1"), + helper.make_node( + "Slice", + ["shape3_out", "slice_starts_1", "slice_ends_1"], + ["slice1_out"], + "slice1", + ), helper.make_node("Squeeze", ["slice1_out"], ["squeeze1_out"], "squeeze1", axes=[0]), helper.make_node("Mul", ["squeeze0_out", "squeeze1_out"], ["mul_out"], "mul"), helper.make_node("Unsqueeze", ["mul_out"], ["unsqueeze2_out"], "unsqueeze2", axes=[0]), - - helper.make_node("Concat", ["unsqueeze0_out", "unsqueeze1_out", "unsqueeze2_out"], ["concat_out"], "concat", axis=0), + helper.make_node( + "Concat", + ["unsqueeze0_out", "unsqueeze1_out", "unsqueeze2_out"], + ["concat_out"], + "concat", + axis=0, + ), helper.make_node("Reshape", ["SubgraphRoot", "concat_out"], ["Result"], "reshape"), ], - "Reshape_Fusion", #name + "Reshape_Fusion", # name [ # inputs - helper.make_tensor_value_info('SubgraphRoot', TensorProto.FLOAT, [10, 20, 30]), - ], - [ # outputs - helper.make_tensor_value_info('Result', TensorProto.FLOAT, [10, 20, 'unk']) + helper.make_tensor_value_info("SubgraphRoot", TensorProto.FLOAT, [10, 20, 30]), ], + [helper.make_tensor_value_info("Result", TensorProto.FLOAT, [10, 20, "unk"])], # outputs [ # initializers - helper.make_tensor('indices0', TensorProto.INT64, [], [0]), - helper.make_tensor('indices1', TensorProto.INT64, [], [1]), - helper.make_tensor('slice_starts_0', TensorProto.INT64, [1], [2]), - helper.make_tensor('slice_ends_0', TensorProto.INT64, [1], [3]), - helper.make_tensor('slice_starts_1', TensorProto.INT64, [1], [1]), - helper.make_tensor('slice_ends_1', TensorProto.INT64, [1], [2]) - ] + helper.make_tensor("indices0", TensorProto.INT64, [], [0]), + helper.make_tensor("indices1", TensorProto.INT64, [], [1]), + helper.make_tensor("slice_starts_0", TensorProto.INT64, [1], [2]), + helper.make_tensor("slice_ends_0", TensorProto.INT64, [1], [3]), + helper.make_tensor("slice_starts_1", TensorProto.INT64, [1], [1]), + helper.make_tensor("slice_ends_1", TensorProto.INT64, [1], [2]), + ], ) -save_model(graph, 'reshape_fusion_concat_subgraph_mul.onnx') +save_model(graph, "reshape_fusion_concat_subgraph_mul.onnx") matmul_weights = [ - -0.04888916015625, 0.0143280029296875, 0.066650390625,-0.0343017578125, - -0.0010356903076171875, -0.00048232078552246094, 0.07470703125, -0.04736328125, - 0.01454925537109375, -0.0086669921875, -0.051971435546875, -0.0201568603515625, - 0.040435791015625, -0.019256591796875, 0.0205078125, 0.0111541748046875, - 0.0071868896484375, -0.0298309326171875, -0.0306549072265625, -0.0225372314453125, - -0.04193115234375, 0.07073974609375, -0.048065185546875, 0.0198822021484375, - -0.035552978515625, -0.022796630859375, 0.03839111328125, 0.007099151611328125, - -0.0080108642578125, -0.0017957687377929688, 0.0266265869140625,-0.028289794921875, - 0.0032901763916015625, 0.0208740234375, -0.01529693603515625, -0.046600341796875, - -0.034637451171875, 0.011322021484375, -0.026458740234375, 0.04656982421875, - -0.0091705322265625, 0.017913818359375, -0.019256591796875, -0.001216888427734375, - -0.08245849609375, -0.023162841796875, -0.04132080078125, -0.03363037109375, - 0.0029315948486328125, 0.03173828125, -0.004024505615234375, 0.04534912109375, - -0.0036163330078125, -0.03912353515625, -0.00800323486328125, 0.058197021484375, - 0.05572509765625, 0.01165771484375, 0.06756591796875, 0.05816650390625, - -0.0654296875, -0.0241851806640625, 0.0205535888671875, -0.031707763671875 + -0.04888916015625, + 0.0143280029296875, + 0.066650390625, + -0.0343017578125, + -0.0010356903076171875, + -0.00048232078552246094, + 0.07470703125, + -0.04736328125, + 0.01454925537109375, + -0.0086669921875, + -0.051971435546875, + -0.0201568603515625, + 0.040435791015625, + -0.019256591796875, + 0.0205078125, + 0.0111541748046875, + 0.0071868896484375, + -0.0298309326171875, + -0.0306549072265625, + -0.0225372314453125, + -0.04193115234375, + 0.07073974609375, + -0.048065185546875, + 0.0198822021484375, + -0.035552978515625, + -0.022796630859375, + 0.03839111328125, + 0.007099151611328125, + -0.0080108642578125, + -0.0017957687377929688, + 0.0266265869140625, + -0.028289794921875, + 0.0032901763916015625, + 0.0208740234375, + -0.01529693603515625, + -0.046600341796875, + -0.034637451171875, + 0.011322021484375, + -0.026458740234375, + 0.04656982421875, + -0.0091705322265625, + 0.017913818359375, + -0.019256591796875, + -0.001216888427734375, + -0.08245849609375, + -0.023162841796875, + -0.04132080078125, + -0.03363037109375, + 0.0029315948486328125, + 0.03173828125, + -0.004024505615234375, + 0.04534912109375, + -0.0036163330078125, + -0.03912353515625, + -0.00800323486328125, + 0.058197021484375, + 0.05572509765625, + 0.01165771484375, + 0.06756591796875, + 0.05816650390625, + -0.0654296875, + -0.0241851806640625, + 0.0205535888671875, + -0.031707763671875, ] -add_weight = [-0.23681640625, -0.16552734375, 0.2191162109375, -0.1756591796875, - -0.03460693359375, -0.05316162109375, -0.336181640625, -0.253662109375] +add_weight = [ + -0.23681640625, + -0.16552734375, + 0.2191162109375, + -0.1756591796875, + -0.03460693359375, + -0.05316162109375, + -0.336181640625, + -0.253662109375, +] graph = helper.make_graph( - [ # nodes + [ # nodes helper.make_node("Add", ["Input", "Bias"], ["add0_out"], "add0"), helper.make_node("Shape", ["add0_out"], ["shape0_out"], "shape0"), helper.make_node("Gather", ["shape0_out", "indices0"], ["gather0_out"], "gather0", axis=0), helper.make_node("Unsqueeze", ["gather0_out"], ["unsqueeze0_out"], "unsqueeze0", axes=[0]), - helper.make_node("Concat", ["unsqueeze0_out", "dim_-1", "dim_2", "dim_4"], ["concat_out"], "concat", axis=0), + helper.make_node( + "Concat", + ["unsqueeze0_out", "dim_-1", "dim_2", "dim_4"], + ["concat_out"], + "concat", + axis=0, + ), helper.make_node("MatMul", ["add0_out", "matmul_weight"], ["matmul_out"], "matmul"), helper.make_node("Add", ["matmul_out", "add_weight"], ["add1_out"], "add1"), helper.make_node("Reshape", ["add1_out", "concat_out"], ["Result"], "reshape"), ], - "Reshape_Fusion", #name + "Reshape_Fusion", # name [ # inputs - helper.make_tensor_value_info('Input', TensorProto.FLOAT, [1, 8]), + helper.make_tensor_value_info("Input", TensorProto.FLOAT, [1, 8]), ], [ # outputs - helper.make_tensor_value_info('Result', TensorProto.FLOAT, [1, -1, 2, 4]), + helper.make_tensor_value_info("Result", TensorProto.FLOAT, [1, -1, 2, 4]), ], [ # initializers - helper.make_tensor('Bias', TensorProto.FLOAT, [8], add_weight), - helper.make_tensor('dim_-1', TensorProto.INT64, [1], [-1]), - helper.make_tensor('dim_2', TensorProto.INT64, [1], [2]), - helper.make_tensor('dim_4', TensorProto.INT64, [1], [4]), - helper.make_tensor('indices0', TensorProto.INT64, [], [0]), - helper.make_tensor('matmul_weight', TensorProto.FLOAT, [8, 8], matmul_weights), - helper.make_tensor('add_weight', TensorProto.FLOAT, [8], add_weight), - ] + helper.make_tensor("Bias", TensorProto.FLOAT, [8], add_weight), + helper.make_tensor("dim_-1", TensorProto.INT64, [1], [-1]), + helper.make_tensor("dim_2", TensorProto.INT64, [1], [2]), + helper.make_tensor("dim_4", TensorProto.INT64, [1], [4]), + helper.make_tensor("indices0", TensorProto.INT64, [], [0]), + helper.make_tensor("matmul_weight", TensorProto.FLOAT, [8, 8], matmul_weights), + helper.make_tensor("add_weight", TensorProto.FLOAT, [8], add_weight), + ], ) -save_model(graph, 'reshape_fusion_distillbert.onnx') - +save_model(graph, "reshape_fusion_distillbert.onnx") diff --git a/onnxruntime/test/testdata/transform/fusion/skip_layer_norm_gen.py b/onnxruntime/test/testdata/transform/fusion/skip_layer_norm_gen.py index c88e3903e4aa0..07f3411d4a129 100644 --- a/onnxruntime/test/testdata/transform/fusion/skip_layer_norm_gen.py +++ b/onnxruntime/test/testdata/transform/fusion/skip_layer_norm_gen.py @@ -1,12 +1,12 @@ -import onnx -from onnx import helper -from onnx import TensorProto from enum import Enum +import onnx +from onnx import TensorProto, helper + class Format(Enum): - Format1 = 1, - Format2 = 2, + Format1 = (1,) + Format2 = (2,) Format3 = 3 @@ -25,32 +25,42 @@ def GenerateModel(format, model_name, multi_output_add=False, add_output_in_grap ] initializers = [ # initializers - helper.make_tensor('pow_in_2', TensorProto.FLOAT, [], [2]), - helper.make_tensor('const_e12', TensorProto.FLOAT, [], [1e-12]), - helper.make_tensor('gamma', TensorProto.FLOAT, [4], [1.0, 2.0, 3.0, 4.0]), - helper.make_tensor('beta', TensorProto.FLOAT, [4], [0.1, 0.2, 0.3, 0.4]), + helper.make_tensor("pow_in_2", TensorProto.FLOAT, [], [2]), + helper.make_tensor("const_e12", TensorProto.FLOAT, [], [1e-12]), + helper.make_tensor("gamma", TensorProto.FLOAT, [4], [1.0, 2.0, 3.0, 4.0]), + helper.make_tensor("beta", TensorProto.FLOAT, [4], [0.1, 0.2, 0.3, 0.4]), ] if format is Format.Format1: - nodes.extend([ - helper.make_node("Add", ["A", "bias"], ["add3_out"], "add3"), - helper.make_node("Add", ["add3_out", "B"], ["ln_in"], "add2"), - ]) - initializers.extend([ - helper.make_tensor('bias', TensorProto.FLOAT, [4], [0.1, 0.2, 0.3, 0.4]), - ]) + nodes.extend( + [ + helper.make_node("Add", ["A", "bias"], ["add3_out"], "add3"), + helper.make_node("Add", ["add3_out", "B"], ["ln_in"], "add2"), + ] + ) + initializers.extend( + [ + helper.make_tensor("bias", TensorProto.FLOAT, [4], [0.1, 0.2, 0.3, 0.4]), + ] + ) elif format is Format.Format2: - nodes.extend([ - helper.make_node("Add", ["B", "bias"], ["add3_out"], "add3"), - helper.make_node("Add", ["A", "add3_out"], ["ln_in"], "add2"), - ]) - initializers.extend([ - helper.make_tensor('bias', TensorProto.FLOAT, [4], [0.1, 0.2, 0.3, 0.4]), - ]) + nodes.extend( + [ + helper.make_node("Add", ["B", "bias"], ["add3_out"], "add3"), + helper.make_node("Add", ["A", "add3_out"], ["ln_in"], "add2"), + ] + ) + initializers.extend( + [ + helper.make_tensor("bias", TensorProto.FLOAT, [4], [0.1, 0.2, 0.3, 0.4]), + ] + ) elif format is Format.Format3: - nodes.extend([ - helper.make_node("Add", ["A", "B"], ["ln_in"], "add2"), - ]) + nodes.extend( + [ + helper.make_node("Add", ["A", "B"], ["ln_in"], "add2"), + ] + ) if multi_output_add: neg_input = "ln_in" if format is Format.Format3 else "add3_out" @@ -58,16 +68,17 @@ def GenerateModel(format, model_name, multi_output_add=False, add_output_in_grap graph = helper.make_graph( nodes, - "SkipLayerNorm_format3", #name + "SkipLayerNorm_format3", # name [ # inputs - helper.make_tensor_value_info('A', TensorProto.FLOAT, [16, 32, 4]), - helper.make_tensor_value_info('B', TensorProto.FLOAT, [16, 32, 4]), + helper.make_tensor_value_info("A", TensorProto.FLOAT, [16, 32, 4]), + helper.make_tensor_value_info("B", TensorProto.FLOAT, [16, 32, 4]), ], [ # outputs - helper.make_tensor_value_info('C', TensorProto.FLOAT, [16, 32, 4]), + helper.make_tensor_value_info("C", TensorProto.FLOAT, [16, 32, 4]), ], - initializers) - + initializers, + ) + if add_output_in_graph_output: extra_output = "ln_in" if format is Format.Format3 else "add3_out" graph.output.extend([helper.make_tensor_value_info(extra_output, TensorProto.FLOAT, [16, 32, 4])]) @@ -76,13 +87,25 @@ def GenerateModel(format, model_name, multi_output_add=False, add_output_in_grap onnx.save(model, model_name) -GenerateModel(Format.Format1, 'skip_layer_norm_format1.onnx') -GenerateModel(Format.Format2, 'skip_layer_norm_format2.onnx') -GenerateModel(Format.Format3, 'skip_layer_norm_format3.onnx') -GenerateModel(Format.Format1, 'skip_layer_norm_format1_partial.onnx', multi_output_add = True) -GenerateModel(Format.Format2, 'skip_layer_norm_format2_partial.onnx', multi_output_add = True) -GenerateModel(Format.Format3, 'skip_layer_norm_format3_no_fusion.onnx', multi_output_add = True) +GenerateModel(Format.Format1, "skip_layer_norm_format1.onnx") +GenerateModel(Format.Format2, "skip_layer_norm_format2.onnx") +GenerateModel(Format.Format3, "skip_layer_norm_format3.onnx") +GenerateModel(Format.Format1, "skip_layer_norm_format1_partial.onnx", multi_output_add=True) +GenerateModel(Format.Format2, "skip_layer_norm_format2_partial.onnx", multi_output_add=True) +GenerateModel(Format.Format3, "skip_layer_norm_format3_no_fusion.onnx", multi_output_add=True) -GenerateModel(Format.Format1, 'skip_layer_norm_format1_graph_output.onnx', add_output_in_graph_output = True) -GenerateModel(Format.Format2, 'skip_layer_norm_format2_graph_output.onnx', add_output_in_graph_output = True) -GenerateModel(Format.Format3, 'skip_layer_norm_format3_graph_output.onnx', add_output_in_graph_output = True) \ No newline at end of file +GenerateModel( + Format.Format1, + "skip_layer_norm_format1_graph_output.onnx", + add_output_in_graph_output=True, +) +GenerateModel( + Format.Format2, + "skip_layer_norm_format2_graph_output.onnx", + add_output_in_graph_output=True, +) +GenerateModel( + Format.Format3, + "skip_layer_norm_format3_graph_output.onnx", + add_output_in_graph_output=True, +) diff --git a/onnxruntime/test/testdata/transform/fusion/transpose_matmul_gen.py b/onnxruntime/test/testdata/transform/fusion/transpose_matmul_gen.py index 2904a618bbe00..b99a97bee2491 100644 --- a/onnxruntime/test/testdata/transform/fusion/transpose_matmul_gen.py +++ b/onnxruntime/test/testdata/transform/fusion/transpose_matmul_gen.py @@ -1,7 +1,5 @@ import onnx -from onnx import helper -from onnx import TensorProto -from onnx import OperatorSetIdProto +from onnx import OperatorSetIdProto, TensorProto, helper onnxdomain = OperatorSetIdProto() onnxdomain.version = 12 @@ -14,23 +12,16 @@ def save(model_path, nodes, inputs, outputs, initializers): - graph = helper.make_graph( - nodes, - "TransposeMatMulTest", - inputs, outputs, initializers) + graph = helper.make_graph(nodes, "TransposeMatMulTest", inputs, outputs, initializers) - model = helper.make_model( - graph, opset_imports=opsets, producer_name="onnxruntime-test") + model = helper.make_model(graph, opset_imports=opsets, producer_name="onnxruntime-test") onnx.save(model, model_path) def gen_from_transpose_scale_matmul(model_path): nodes = [ - helper.make_node( - "Transpose", - ["input_0"], - ["transposed_input_0"]), + helper.make_node("Transpose", ["input_0"], ["transposed_input_0"]), helper.make_node( "FusedMatMul", ["transposed_input_0", "input_1"], @@ -38,138 +29,107 @@ def gen_from_transpose_scale_matmul(model_path): "FusedMatMul", "", msdomain.domain, - alpha=3.0, transA=1) + alpha=3.0, + transA=1, + ), ] inputs = [ - helper.make_tensor_value_info( - "input_0", TensorProto.FLOAT, ['M', 'K']), - helper.make_tensor_value_info( - "input_1", TensorProto.FLOAT, ['K', 'N']) + helper.make_tensor_value_info("input_0", TensorProto.FLOAT, ["M", "K"]), + helper.make_tensor_value_info("input_1", TensorProto.FLOAT, ["K", "N"]), ] - outputs = [ - helper.make_tensor_value_info( - "output", TensorProto.FLOAT, ['M', 'N']) - ] + outputs = [helper.make_tensor_value_info("output", TensorProto.FLOAT, ["M", "N"])] save(model_path, nodes, inputs, outputs, []) -gen_from_transpose_scale_matmul( - "transpose_matmul_2d_fusion_from_transpose_scale_matmul.onnx") +gen_from_transpose_scale_matmul("transpose_matmul_2d_fusion_from_transpose_scale_matmul.onnx") def gen_invalid_default_perm(model_path): nodes = [ - helper.make_node( - "Transpose", - ["input_0"], - ["transposed_input_0"]), - helper.make_node( - "MatMul", - ["transposed_input_0", "input_1"], - ["output"]) + helper.make_node("Transpose", ["input_0"], ["transposed_input_0"]), + helper.make_node("MatMul", ["transposed_input_0", "input_1"], ["output"]), ] inputs = [ - helper.make_tensor_value_info( - "input_0", TensorProto.FLOAT, ['K', 'M', 3, 2]), - helper.make_tensor_value_info( - "input_1", TensorProto.FLOAT, [2, 3, 'K', 'N']) + helper.make_tensor_value_info("input_0", TensorProto.FLOAT, ["K", "M", 3, 2]), + helper.make_tensor_value_info("input_1", TensorProto.FLOAT, [2, 3, "K", "N"]), ] - outputs = [ - helper.make_tensor_value_info( - "output", TensorProto.FLOAT, [2, 3, 'M', 'N']) - ] + outputs = [helper.make_tensor_value_info("output", TensorProto.FLOAT, [2, 3, "M", "N"])] save(model_path, nodes, inputs, outputs, []) -gen_invalid_default_perm( - "transpose_matmul_4d_fusion_invalid_default_perm.onnx") +gen_invalid_default_perm("transpose_matmul_4d_fusion_invalid_default_perm.onnx") def gen_with_preserved_transpose(model_path): nodes = [ - helper.make_node( - "Transpose", - ["input_0"], - ["transposed_input_0"]), - helper.make_node( - "MatMul", - ["transposed_input_0", "input_1"], - ["output_0"]), - helper.make_node( - "Identity", - ["transposed_input_0"], - ["output_1"]) + helper.make_node("Transpose", ["input_0"], ["transposed_input_0"]), + helper.make_node("MatMul", ["transposed_input_0", "input_1"], ["output_0"]), + helper.make_node("Identity", ["transposed_input_0"], ["output_1"]), ] inputs = [ - helper.make_tensor_value_info( - "input_0", TensorProto.FLOAT, ['K', 'M']), - helper.make_tensor_value_info( - "input_1", TensorProto.FLOAT, ['K', 'N']) + helper.make_tensor_value_info("input_0", TensorProto.FLOAT, ["K", "M"]), + helper.make_tensor_value_info("input_1", TensorProto.FLOAT, ["K", "N"]), ] outputs = [ - helper.make_tensor_value_info( - "output_0", TensorProto.FLOAT, ['M', 'N']), - helper.make_tensor_value_info( - "output_1", TensorProto.FLOAT, ['M', 'K']) + helper.make_tensor_value_info("output_0", TensorProto.FLOAT, ["M", "N"]), + helper.make_tensor_value_info("output_1", TensorProto.FLOAT, ["M", "K"]), ] save(model_path, nodes, inputs, outputs, []) -gen_with_preserved_transpose( - "transpose_matmul_2d_fusion_with_preserved_transpose.onnx") +gen_with_preserved_transpose("transpose_matmul_2d_fusion_with_preserved_transpose.onnx") def gen_transpose_fusion_with_cast(model_path): - cast_1 = helper.make_node( - "Cast", - ["input_1"], - ["casted_input_1"], - "Cast_1", - to = TensorProto.FLOAT16) + cast_1 = helper.make_node("Cast", ["input_1"], ["casted_input_1"], "Cast_1", to=TensorProto.FLOAT16) transpose_0 = helper.make_node( "Transpose", ["input_0"], ["transposed_input_0"], "Transpose_0", - perm = [0, 1, 3, 2]) + perm=[0, 1, 3, 2], + ) cast_0 = helper.make_node( "Cast", ["transposed_input_0"], ["transposed_casted_input_0"], "Cast_0", - to = TensorProto.FLOAT16) + to=TensorProto.FLOAT16, + ) matmul_0 = helper.make_node( "MatMul", ["transposed_casted_input_0", "casted_input_1"], ["output_0"], - "MatMul_0") + "MatMul_0", + ) nodes = [transpose_0, cast_0, cast_1, matmul_0] - input_0 = helper.make_tensor_value_info("input_0", TensorProto.FLOAT, [3, 2, 'N', 'N']) - input_1 = helper.make_tensor_value_info("input_1", TensorProto.FLOAT, [3, 2, 'N', 'N']) + input_0 = helper.make_tensor_value_info("input_0", TensorProto.FLOAT, [3, 2, "N", "N"]) + input_1 = helper.make_tensor_value_info("input_1", TensorProto.FLOAT, [3, 2, "N", "N"]) inputs = [input_0, input_1] - output_0 = helper.make_tensor_value_info("output_0", TensorProto.FLOAT16, [3, 2, 'N', 'N']) + output_0 = helper.make_tensor_value_info("output_0", TensorProto.FLOAT16, [3, 2, "N", "N"]) outputs = [output_0] # Testcase0: First input of MatMul is transposed save(model_path + "0.onnx", nodes, inputs, outputs, []) # Testcase1: Re-arragne nodes so that the transpose is on second input of matmul transpose_1 = helper.make_node( - "Transpose", - ["input_1"], - ["transposed_input_1"], - "Transpose_1", - perm = [0, 1, 3, 2]) + "Transpose", + ["input_1"], + ["transposed_input_1"], + "Transpose_1", + perm=[0, 1, 3, 2], + ) cast_1.input[0] = "transposed_input_1" cast_1.output[0] = "transposed_casted_input_1" cast_0.input[0] = "input_0" @@ -188,27 +148,25 @@ def gen_transpose_fusion_with_cast(model_path): # Testcase3: Create a second MatMul node using the outputs from the same Cast nodes as before # with each Cast node feeding more than one node. - nodes.append(helper.make_node( + nodes.append( + helper.make_node( "MatMul", ["transposed_casted_input_0", "transposed_casted_input_1"], ["output_1"], - "MatMul_1")) - output_1 = helper.make_tensor_value_info("output_1", TensorProto.FLOAT16, [3, 2, 'N', 'N']) + "MatMul_1", + ) + ) + output_1 = helper.make_tensor_value_info("output_1", TensorProto.FLOAT16, [3, 2, "N", "N"]) outputs.append(output_1) save(model_path + "3.onnx", nodes, inputs, outputs, []) # Testcase4: The second MatMul uses transposed inputs without cast. nodes.pop() outputs.pop() - matmul_1 = helper.make_node( - "MatMul", - ["transposed_input_0", "transposed_input_1"], - ["output_1"], - "MatMul_1") + matmul_1 = helper.make_node("MatMul", ["transposed_input_0", "transposed_input_1"], ["output_1"], "MatMul_1") nodes.append(matmul_1) - outputs.append(helper.make_tensor_value_info( - "output_1", TensorProto.FLOAT, [3, 2, 'N', 'N'])) + outputs.append(helper.make_tensor_value_info("output_1", TensorProto.FLOAT, [3, 2, "N", "N"])) save(model_path + "4.onnx", nodes, inputs, outputs, []) # Testcase5: Each MatMul uses outputs from a Cast and a Transpose @@ -219,33 +177,22 @@ def gen_transpose_fusion_with_cast(model_path): output_1.type.tensor_type.elem_type = TensorProto.FLOAT save(model_path + "5.onnx", nodes, inputs, outputs, []) -gen_transpose_fusion_with_cast( - "transpose_cast_matmul_4d_fusion") + +gen_transpose_fusion_with_cast("transpose_cast_matmul_4d_fusion") + def gen_transpose_fusion_invalid_datatype(model_path, datatype): nodes = [ - helper.make_node( - "Transpose", - ["input_0"], - ["transposed_input_0"], - perm = [0, 1, 3, 2]), - helper.make_node( - "MatMul", - ["transposed_input_0", "input_1"], - ["output"]) + helper.make_node("Transpose", ["input_0"], ["transposed_input_0"], perm=[0, 1, 3, 2]), + helper.make_node("MatMul", ["transposed_input_0", "input_1"], ["output"]), ] inputs = [ - helper.make_tensor_value_info( - "input_0", datatype, [2, 3, 'K', 'M']), - helper.make_tensor_value_info( - "input_1", datatype, [2, 3, 'K', 'N']) + helper.make_tensor_value_info("input_0", datatype, [2, 3, "K", "M"]), + helper.make_tensor_value_info("input_1", datatype, [2, 3, "K", "N"]), ] - outputs = [ - helper.make_tensor_value_info( - "output", datatype, [2, 3, 'M', 'N']) - ] + outputs = [helper.make_tensor_value_info("output", datatype, [2, 3, "M", "N"])] save(model_path, nodes, inputs, outputs, []) @@ -256,71 +203,36 @@ def gen_transpose_fusion_invalid_datatype(model_path, datatype): def gen_transpose_matmul_trans_batch_fusion(model_path): nodes = [ - helper.make_node( - "Transpose", - ["input_0"], - ["transposed_input_0"], - perm = [1, 2, 0]), - helper.make_node( - "Transpose", - ["input_1"], - ["transposed_input_1"], - perm = [0, 2, 1]), - helper.make_node( - "MatMul", - ["transposed_input_0", "transposed_input_1"], - ["output"]) + helper.make_node("Transpose", ["input_0"], ["transposed_input_0"], perm=[1, 2, 0]), + helper.make_node("Transpose", ["input_1"], ["transposed_input_1"], perm=[0, 2, 1]), + helper.make_node("MatMul", ["transposed_input_0", "transposed_input_1"], ["output"]), ] inputs = [ - helper.make_tensor_value_info( - "input_0", TensorProto.FLOAT, ['K', 3, 'M']), - helper.make_tensor_value_info( - "input_1", TensorProto.FLOAT, [3, 'N', 'K']), + helper.make_tensor_value_info("input_0", TensorProto.FLOAT, ["K", 3, "M"]), + helper.make_tensor_value_info("input_1", TensorProto.FLOAT, [3, "N", "K"]), ] - outputs = [ - helper.make_tensor_value_info( - "output", TensorProto.FLOAT, [3, 'M', 'N']) - ] + outputs = [helper.make_tensor_value_info("output", TensorProto.FLOAT, [3, "M", "N"])] save(model_path + "1.onnx", nodes, inputs, outputs, []) nodes = [ - helper.make_node( - "Transpose", - ["input_0"], - ["transposed_input_0"], - perm = [1, 2, 0, 3]), - helper.make_node( - "Transpose", - ["input_0"], - ["transposed_input_1"], - perm = [1, 2, 3, 0]), - helper.make_node( - "MatMul", - ["transposed_input_0", "transposed_input_1"], - ["output"]) + helper.make_node("Transpose", ["input_0"], ["transposed_input_0"], perm=[1, 2, 0, 3]), + helper.make_node("Transpose", ["input_0"], ["transposed_input_1"], perm=[1, 2, 3, 0]), + helper.make_node("MatMul", ["transposed_input_0", "transposed_input_1"], ["output"]), ] inputs = [ - helper.make_tensor_value_info( - "input_0", TensorProto.FLOAT, ['M', 2, 3, 'K']), + helper.make_tensor_value_info("input_0", TensorProto.FLOAT, ["M", 2, 3, "K"]), ] - outputs = [ - helper.make_tensor_value_info( - "output", TensorProto.FLOAT, [2, 3, 'M', 'M']) - ] + outputs = [helper.make_tensor_value_info("output", TensorProto.FLOAT, [2, 3, "M", "M"])] save(model_path + "2.onnx", nodes, inputs, outputs, []) nodes = [ - helper.make_node( - "Transpose", - ["input_0"], - ["transposed_input_0"], - perm = [1, 2, 3, 0]), + helper.make_node("Transpose", ["input_0"], ["transposed_input_0"], perm=[1, 2, 3, 0]), helper.make_node( "FusedMatMul", ["transposed_input_0", "input_1"], @@ -328,90 +240,56 @@ def gen_transpose_matmul_trans_batch_fusion(model_path): "FusedMatMul", "", msdomain.domain, - alpha=3.0, transA=1, transBatchB=1) + alpha=3.0, + transA=1, + transBatchB=1, + ), ] inputs = [ - helper.make_tensor_value_info( - "input_0", TensorProto.FLOAT, ['M', 2, 3, 'K']), - helper.make_tensor_value_info( - "input_1", TensorProto.FLOAT, ['K', 2, 3, 'N']), + helper.make_tensor_value_info("input_0", TensorProto.FLOAT, ["M", 2, 3, "K"]), + helper.make_tensor_value_info("input_1", TensorProto.FLOAT, ["K", 2, 3, "N"]), ] - outputs = [ - helper.make_tensor_value_info( - "output", TensorProto.FLOAT, [2, 3, 'M', 'M']) - ] + outputs = [helper.make_tensor_value_info("output", TensorProto.FLOAT, [2, 3, "M", "M"])] save(model_path + "3.onnx", nodes, inputs, outputs, []) -gen_transpose_matmul_trans_batch_fusion( - "transpose_matmul_trans_batch_fusion") +gen_transpose_matmul_trans_batch_fusion("transpose_matmul_trans_batch_fusion") def gen_transpose_matmul_trans_batch_fusion_invalid_cases(model_path): nodes = [ - helper.make_node( - "Transpose", - ["input_0"], - ["transposed_input_0"], - perm = [1, 2, 0]), - helper.make_node( - "MatMul", - ["transposed_input_0", "input_1"], - ["output"]) + helper.make_node("Transpose", ["input_0"], ["transposed_input_0"], perm=[1, 2, 0]), + helper.make_node("MatMul", ["transposed_input_0", "input_1"], ["output"]), ] inputs = [ - helper.make_tensor_value_info( - "input_0", TensorProto.FLOAT, ['K', 3, 'M']), - helper.make_tensor_value_info( - "input_1", TensorProto.FLOAT, [2, 3, 'K', 'N']), + helper.make_tensor_value_info("input_0", TensorProto.FLOAT, ["K", 3, "M"]), + helper.make_tensor_value_info("input_1", TensorProto.FLOAT, [2, 3, "K", "N"]), ] - outputs = [ - helper.make_tensor_value_info( - "output", TensorProto.FLOAT, [2, 3, 'M', 'N']) - ] + outputs = [helper.make_tensor_value_info("output", TensorProto.FLOAT, [2, 3, "M", "N"])] save(model_path + "1.onnx", nodes, inputs, outputs, []) nodes = [ - helper.make_node( - "Transpose", - ["input_0"], - ["transposed_input_0"], - perm = [0, 2, 1, 3]), - helper.make_node( - "Transpose", - ["input_0"], - ["transposed_input_1"], - perm = [0, 2, 3, 1]), - helper.make_node( - "MatMul", - ["transposed_input_0", "transposed_input_1"], - ["output"]) + helper.make_node("Transpose", ["input_0"], ["transposed_input_0"], perm=[0, 2, 1, 3]), + helper.make_node("Transpose", ["input_0"], ["transposed_input_1"], perm=[0, 2, 3, 1]), + helper.make_node("MatMul", ["transposed_input_0", "transposed_input_1"], ["output"]), ] inputs = [ - helper.make_tensor_value_info( - "input_0", TensorProto.FLOAT, [2, 'M', 3, 'K']), + helper.make_tensor_value_info("input_0", TensorProto.FLOAT, [2, "M", 3, "K"]), ] - outputs = [ - helper.make_tensor_value_info( - "output", TensorProto.FLOAT, [2, 3, 'M', 'M']) - ] + outputs = [helper.make_tensor_value_info("output", TensorProto.FLOAT, [2, 3, "M", "M"])] save(model_path + "2.onnx", nodes, inputs, outputs, []) nodes = [ - helper.make_node( - "Transpose", - ["input_0"], - ["transposed_input_0"], - perm = [1, 2, 3, 0]), + helper.make_node("Transpose", ["input_0"], ["transposed_input_0"], perm=[1, 2, 3, 0]), helper.make_node( "FusedMatMul", ["transposed_input_0", "input_1"], @@ -419,23 +297,19 @@ def gen_transpose_matmul_trans_batch_fusion_invalid_cases(model_path): "FusedMatMul", "", msdomain.domain, - alpha=3.0, transBatchA=1) + alpha=3.0, + transBatchA=1, + ), ] inputs = [ - helper.make_tensor_value_info( - "input_0", TensorProto.FLOAT, ['K', 'M', 2, 3]), - helper.make_tensor_value_info( - "input_1", TensorProto.FLOAT, [2, 3, 'K', 'N']), + helper.make_tensor_value_info("input_0", TensorProto.FLOAT, ["K", "M", 2, 3]), + helper.make_tensor_value_info("input_1", TensorProto.FLOAT, [2, 3, "K", "N"]), ] - outputs = [ - helper.make_tensor_value_info( - "output", TensorProto.FLOAT, [2, 3, 'M', 'M']) - ] + outputs = [helper.make_tensor_value_info("output", TensorProto.FLOAT, [2, 3, "M", "M"])] save(model_path + "3.onnx", nodes, inputs, outputs, []) -gen_transpose_matmul_trans_batch_fusion_invalid_cases( - "transpose_matmul_trans_batch_fusion_invalid_case") +gen_transpose_matmul_trans_batch_fusion_invalid_cases("transpose_matmul_trans_batch_fusion_invalid_case") diff --git a/onnxruntime/test/testdata/transform/id-elim.py b/onnxruntime/test/testdata/transform/id-elim.py index 105e3a1071e2e..838fbb1f4a798 100644 --- a/onnxruntime/test/testdata/transform/id-elim.py +++ b/onnxruntime/test/testdata/transform/id-elim.py @@ -1,41 +1,34 @@ -import onnx -from onnx import helper -from onnx import TensorProto, GraphProto, OperatorSetIdProto -from onnx import numpy_helper import numpy as np +import onnx +from onnx import GraphProto, OperatorSetIdProto, TensorProto, helper, numpy_helper -X1 = helper.make_tensor_value_info('x1', TensorProto.INT64, [4, 4]) -X2 = helper.make_tensor_value_info('x2', TensorProto.INT64, [4, 4]) -Y1 = helper.make_tensor_value_info('output1', TensorProto.INT64, [4, 4]) -Y2 = helper.make_tensor_value_info('output2', TensorProto.INT64, [4, 4]) +X1 = helper.make_tensor_value_info("x1", TensorProto.INT64, [4, 4]) +X2 = helper.make_tensor_value_info("x2", TensorProto.INT64, [4, 4]) +Y1 = helper.make_tensor_value_info("output1", TensorProto.INT64, [4, 4]) +Y2 = helper.make_tensor_value_info("output2", TensorProto.INT64, [4, 4]) -add1 = helper.make_node('Add', ['x1', 'x2'], ['add1'], name='add1') -add2 = helper.make_node('Add', ['x1', 'x2'], ['add2'], name='add2') -id1 = helper.make_node('Identity', ['add1'], ['output1'], name='id1') -id2 = helper.make_node('Identity', ['add2'], ['output2'], name='id2') +add1 = helper.make_node("Add", ["x1", "x2"], ["add1"], name="add1") +add2 = helper.make_node("Add", ["x1", "x2"], ["add2"], name="add2") +id1 = helper.make_node("Identity", ["add1"], ["output1"], name="id1") +id2 = helper.make_node("Identity", ["add2"], ["output2"], name="id2") # Create the graph (GraphProto) -graph_def = helper.make_graph( - [add1, add2, id1, id2], - 'identity_elimination_model', - [X1, X2], - [Y1, Y2] -) +graph_def = helper.make_graph([add1, add2, id1, id2], "identity_elimination_model", [X1, X2], [Y1, Y2]) opsets = [] onnxdomain = OperatorSetIdProto() onnxdomain.version = 12 -onnxdomain.domain = "" # The empty string ("") or absence of this field implies the operator set that is defined as part of the ONNX specification. +onnxdomain.domain = "" # The empty string ("") or absence of this field implies the operator set that is defined as part of the ONNX specification. opsets.append(onnxdomain) msdomain = OperatorSetIdProto() msdomain.version = 1 -msdomain.domain = 'com.microsoft' +msdomain.domain = "com.microsoft" opsets.append(msdomain) -kwargs={} -kwargs['opset_imports'] = opsets +kwargs = {} +kwargs["opset_imports"] = opsets # Create the model (ModelProto) -model_def = helper.make_model(graph_def, producer_name='onnx-example', **kwargs) -onnx.save(model_def, 'id-elim.onnx') +model_def = helper.make_model(graph_def, producer_name="onnx-example", **kwargs) +onnx.save(model_def, "id-elim.onnx") diff --git a/onnxruntime/test/testdata/transform/id-scan9_sum.py b/onnxruntime/test/testdata/transform/id-scan9_sum.py index 798c4afd39862..f2a7de656c8ee 100644 --- a/onnxruntime/test/testdata/transform/id-scan9_sum.py +++ b/onnxruntime/test/testdata/transform/id-scan9_sum.py @@ -1,61 +1,41 @@ -import onnx -from onnx import helper -from onnx import TensorProto, GraphProto, OperatorSetIdProto -from onnx import numpy_helper import numpy as np +import onnx +from onnx import GraphProto, OperatorSetIdProto, TensorProto, helper, numpy_helper -initial = helper.make_tensor_value_info('initial', TensorProto.FLOAT, [2]) -x = helper.make_tensor_value_info('x', TensorProto.FLOAT, [3, 2]) -y = helper.make_tensor_value_info('y', TensorProto.FLOAT, [3, 2]) -z = helper.make_tensor_value_info('z', TensorProto.FLOAT, [3, 2]) +initial = helper.make_tensor_value_info("initial", TensorProto.FLOAT, [2]) +x = helper.make_tensor_value_info("x", TensorProto.FLOAT, [3, 2]) +y = helper.make_tensor_value_info("y", TensorProto.FLOAT, [3, 2]) +z = helper.make_tensor_value_info("z", TensorProto.FLOAT, [3, 2]) -sum_in = helper.make_tensor_value_info('sum_in', TensorProto.FLOAT, [2]) -next = helper.make_tensor_value_info('next', TensorProto.FLOAT, [2]) -sum_out = helper.make_tensor_value_info('sum_out', TensorProto.FLOAT, [2]) -scan_out = helper.make_tensor_value_info('scan_out', TensorProto.FLOAT, [2]) +sum_in = helper.make_tensor_value_info("sum_in", TensorProto.FLOAT, [2]) +next = helper.make_tensor_value_info("next", TensorProto.FLOAT, [2]) +sum_out = helper.make_tensor_value_info("sum_out", TensorProto.FLOAT, [2]) +scan_out = helper.make_tensor_value_info("scan_out", TensorProto.FLOAT, [2]) -add_node = helper.make_node( - 'Add', - inputs=['sum_in', 'next'], - outputs=['sum_out'] -) -id_node = helper.make_node( - 'Identity', - inputs=['sum_out'], - outputs=['scan_out'] -) -scan_body = helper.make_graph( - [add_node, id_node], - 'scan_body', - [sum_in, next], - [sum_out, scan_out] -) +add_node = helper.make_node("Add", inputs=["sum_in", "next"], outputs=["sum_out"]) +id_node = helper.make_node("Identity", inputs=["sum_out"], outputs=["scan_out"]) +scan_body = helper.make_graph([add_node, id_node], "scan_body", [sum_in, next], [sum_out, scan_out]) # create scan op node scan_node = helper.make_node( - 'Scan', - inputs=['initial', 'x'], - outputs=['y', 'z'], + "Scan", + inputs=["initial", "x"], + outputs=["y", "z"], num_scan_inputs=1, - body=scan_body + body=scan_body, ) # Create the graph (GraphProto) -graph_def = helper.make_graph( - [scan_node], - 'test_scan9_sum', - [initial, x], - [y, z] -) +graph_def = helper.make_graph([scan_node], "test_scan9_sum", [initial, x], [y, z]) opsets = [] onnxdomain = OperatorSetIdProto() onnxdomain.version = 9 -onnxdomain.domain = "" # The empty string ("") or absence of this field implies the operator set that is defined as part of the ONNX specification. +onnxdomain.domain = "" # The empty string ("") or absence of this field implies the operator set that is defined as part of the ONNX specification. opsets.append(onnxdomain) -kwargs={} -kwargs['opset_imports'] = opsets +kwargs = {} +kwargs["opset_imports"] = opsets # Create the model (ModelProto) -model_def = helper.make_model(graph_def, producer_name='onnx-example', **kwargs) -onnx.save(model_def, 'scan9_sum.onnx') +model_def = helper.make_model(graph_def, producer_name="onnx-example", **kwargs) +onnx.save(model_def, "scan9_sum.onnx") diff --git a/onnxruntime/test/testdata/transform/model_parallel/bart_mlp_megatron_basic_test.py b/onnxruntime/test/testdata/transform/model_parallel/bart_mlp_megatron_basic_test.py index b531f34bc8edb..8f8798750f56f 100644 --- a/onnxruntime/test/testdata/transform/model_parallel/bart_mlp_megatron_basic_test.py +++ b/onnxruntime/test/testdata/transform/model_parallel/bart_mlp_megatron_basic_test.py @@ -1,94 +1,133 @@ -import onnx -from onnx import helper -from onnx import AttributeProto, TensorProto, GraphProto, OperatorSetIdProto -from onnx import numpy_helper import numpy as np +import onnx +from onnx import AttributeProto, GraphProto, OperatorSetIdProto, TensorProto, helper, numpy_helper hidden_size = 4 weight_dim_to_split = 16 -X = helper.make_tensor_value_info('input', TensorProto.FLOAT, ["batch", "seqlen", hidden_size]) -Y = helper.make_tensor_value_info('output', TensorProto.FLOAT, ["batch", "seqlen", hidden_size]) +X = helper.make_tensor_value_info("input", TensorProto.FLOAT, ["batch", "seqlen", hidden_size]) +Y = helper.make_tensor_value_info("output", TensorProto.FLOAT, ["batch", "seqlen", hidden_size]) -a_weight_np_vals = (0.01 * np.arange(hidden_size * weight_dim_to_split, dtype=np.float32)).reshape((weight_dim_to_split, hidden_size)) -a_weight_initializer = numpy_helper.from_array(a_weight_np_vals, "encoder.t5_stack.block.1.layer.1.DenseReluDense.wi.weight") +a_weight_np_vals = (0.01 * np.arange(hidden_size * weight_dim_to_split, dtype=np.float32)).reshape( + (weight_dim_to_split, hidden_size) +) +a_weight_initializer = numpy_helper.from_array( + a_weight_np_vals, "encoder.t5_stack.block.1.layer.1.DenseReluDense.wi.weight" +) -a_bias_np_vals = (0.01 * np.arange(weight_dim_to_split, dtype=np.float32)) # weight_dim_to_split numbers in total +a_bias_np_vals = 0.01 * np.arange(weight_dim_to_split, dtype=np.float32) # weight_dim_to_split numbers in total a_bias_initializer = numpy_helper.from_array(a_bias_np_vals, "encoder.t5_stack.block.1.layer.1.DenseReluDense.wi.bias") dropout_np_vals = np.asarray([0.1], dtype=np.float32).reshape(()) dropout_initializer = numpy_helper.from_array(dropout_np_vals, "ratio") - + dropout_mode_np_vals = np.array([False], dtype=np.bool).reshape(()) dropout_mode_initializer = numpy_helper.from_array(dropout_mode_np_vals, "mode") -b_weight_np_vals = (0.01 * np.arange(hidden_size * weight_dim_to_split, dtype=np.float32)).reshape((hidden_size, weight_dim_to_split)) -b_weight_initializer = numpy_helper.from_array(b_weight_np_vals, "encoder.t5_stack.block.1.layer.1.DenseReluDense.wo.weight") +b_weight_np_vals = (0.01 * np.arange(hidden_size * weight_dim_to_split, dtype=np.float32)).reshape( + (hidden_size, weight_dim_to_split) +) +b_weight_initializer = numpy_helper.from_array( + b_weight_np_vals, "encoder.t5_stack.block.1.layer.1.DenseReluDense.wo.weight" +) -b_bias_np_vals = (0.01 * np.arange(hidden_size, dtype=np.float32)) # hidden_size numbers in total +b_bias_np_vals = 0.01 * np.arange(hidden_size, dtype=np.float32) # hidden_size numbers in total b_bias_initializer = numpy_helper.from_array(b_bias_np_vals, "encoder.t5_stack.block.1.layer.1.DenseReluDense.wo.bias") -transpose1 = helper.make_node('Transpose', [a_weight_initializer.name], ['transpose1'], name='transpose1', perm=[1,0]) -transpose2 = helper.make_node('Transpose', [b_weight_initializer.name], ['transpose2'], name='transpose2', perm=[1,0]) +transpose1 = helper.make_node( + "Transpose", + [a_weight_initializer.name], + ["transpose1"], + name="transpose1", + perm=[1, 0], +) +transpose2 = helper.make_node( + "Transpose", + [b_weight_initializer.name], + ["transpose2"], + name="transpose2", + perm=[1, 0], +) matmul = helper.make_node( - 'MatMul', # node name - ['input', 'transpose1'], # inputs - ['matmul'], # outputs - name="matmul" + "MatMul", # node name + ["input", "transpose1"], # inputs + ["matmul"], # outputs + name="matmul", ) biasgelu = helper.make_node( - 'BiasGelu', # node name - ['matmul', a_bias_initializer.name], # inputs - ['biasgelu'], # outputs + "BiasGelu", # node name + ["matmul", a_bias_initializer.name], # inputs + ["biasgelu"], # outputs name="biasgelu", - domain="com.microsoft" + domain="com.microsoft", ) -dropout1 = helper.make_node('Dropout', - ["biasgelu", dropout_initializer.name, dropout_mode_initializer.name], - ['dropout1', "dropout1_mask"], - name='dropout1') +dropout1 = helper.make_node( + "Dropout", + ["biasgelu", dropout_initializer.name, dropout_mode_initializer.name], + ["dropout1", "dropout1_mask"], + name="dropout1", +) matmul2 = helper.make_node( - 'MatMul', # node name - ['dropout1', "transpose2"], # inputs - ['matmul2'], # outputs - name="matmul2" + "MatMul", # node name + ["dropout1", "transpose2"], # inputs + ["matmul2"], # outputs + name="matmul2", ) add = helper.make_node( - 'Add', # node name - ['matmul2', b_bias_initializer.name], # inputs - ['add'], # outputs - name="add" + "Add", # node name + ["matmul2", b_bias_initializer.name], # inputs + ["add"], # outputs + name="add", ) -dropout2 = helper.make_node('Dropout', - ["add", dropout_initializer.name, dropout_mode_initializer.name], - ['dropout2', "dropout2_mask"], - name='dropout2') +dropout2 = helper.make_node( + "Dropout", + ["add", dropout_initializer.name, dropout_mode_initializer.name], + ["dropout2", "dropout2_mask"], + name="dropout2", +) identity = helper.make_node( - 'Identity', # node name - ['dropout2'], # inputs - ['output'], # outputs - name="identity" + "Identity", # node name + ["dropout2"], # inputs + ["output"], # outputs + name="identity", ) # Create the graph (GraphProto) graph_def = helper.make_graph( - [transpose1, transpose2, matmul, biasgelu, dropout1, matmul2, add, dropout2, identity], - 'test-model', + [ + transpose1, + transpose2, + matmul, + biasgelu, + dropout1, + matmul2, + add, + dropout2, + identity, + ], + "test-model", [X], [Y], - [a_weight_initializer, a_bias_initializer, b_weight_initializer, b_bias_initializer, dropout_initializer, dropout_mode_initializer] + [ + a_weight_initializer, + a_bias_initializer, + b_weight_initializer, + b_bias_initializer, + dropout_initializer, + dropout_mode_initializer, + ], ) opsets = [] onnxdomain = OperatorSetIdProto() onnxdomain.version = 12 -onnxdomain.domain = "" # The empty string ("") or absence of this field implies the operator set that is defined as part of the ONNX specification. +onnxdomain.domain = "" # The empty string ("") or absence of this field implies the operator set that is defined as part of the ONNX specification. opsets.append(onnxdomain) msdomain = OperatorSetIdProto() @@ -96,10 +135,10 @@ msdomain.domain = "com.microsoft" opsets.append(msdomain) -kwargs={} +kwargs = {} kwargs["opset_imports"] = opsets # Create the model (ModelProto) -model_def = helper.make_model(graph_def, producer_name='onnx-example', **kwargs) +model_def = helper.make_model(graph_def, producer_name="onnx-example", **kwargs) -onnx.save(model_def, 'bart_mlp_megatron_basic_test.onnx') \ No newline at end of file +onnx.save(model_def, "bart_mlp_megatron_basic_test.onnx") diff --git a/onnxruntime/test/testdata/transform/model_parallel/bart_self_attention_megatron_basic_test.py b/onnxruntime/test/testdata/transform/model_parallel/bart_self_attention_megatron_basic_test.py index 81cd321c04368..8704079d1c1fd 100644 --- a/onnxruntime/test/testdata/transform/model_parallel/bart_self_attention_megatron_basic_test.py +++ b/onnxruntime/test/testdata/transform/model_parallel/bart_self_attention_megatron_basic_test.py @@ -1,150 +1,254 @@ -import onnx -from onnx import helper -from onnx import TensorProto, GraphProto, OperatorSetIdProto -from onnx import numpy_helper -import numpy as np import random +import numpy as np +import onnx +from onnx import GraphProto, OperatorSetIdProto, TensorProto, helper, numpy_helper + batch = 6 hidden_size = 4 attention_head = 2 hidden_per_attention = 2 -relative_attention_num_buckets=32 -input_len=8 -output_len=8 +relative_attention_num_buckets = 32 +input_len = 8 +output_len = 8 -X = helper.make_tensor_value_info('input', TensorProto.FLOAT, [batch, input_len, hidden_size]) -Y = helper.make_tensor_value_info('output', TensorProto.FLOAT, [output_len, batch, hidden_size]) +X = helper.make_tensor_value_info("input", TensorProto.FLOAT, [batch, input_len, hidden_size]) +Y = helper.make_tensor_value_info("output", TensorProto.FLOAT, [output_len, batch, hidden_size]) q_weight_np_vals = (0.01 * np.arange(hidden_size * hidden_size, dtype=np.float32)).reshape((hidden_size, hidden_size)) -q_weight_initializer = numpy_helper.from_array(q_weight_np_vals, 'encoder.layers.0.self_attn.q_proj.weight') +q_weight_initializer = numpy_helper.from_array(q_weight_np_vals, "encoder.layers.0.self_attn.q_proj.weight") k_weight_np_vals = (0.01 * np.arange(hidden_size * hidden_size, dtype=np.float32)).reshape((hidden_size, hidden_size)) -k_weight_initializer = numpy_helper.from_array(k_weight_np_vals, 'encoder.layers.0.self_attn.k_proj.weight') +k_weight_initializer = numpy_helper.from_array(k_weight_np_vals, "encoder.layers.0.self_attn.k_proj.weight") v_weight_np_vals = (0.01 * np.arange(hidden_size * hidden_size, dtype=np.float32)).reshape((hidden_size, hidden_size)) -v_weight_initializer = numpy_helper.from_array(v_weight_np_vals, 'encoder.layers.0.self_attn.v_proj.weight') +v_weight_initializer = numpy_helper.from_array(v_weight_np_vals, "encoder.layers.0.self_attn.v_proj.weight") -q_bias_np_vals = (0.01 * np.arange(hidden_size, dtype=np.float32)) -q_bias_initializer = numpy_helper.from_array(q_bias_np_vals, 'encoder.layers.0.self_attn.q_proj.bias') +q_bias_np_vals = 0.01 * np.arange(hidden_size, dtype=np.float32) +q_bias_initializer = numpy_helper.from_array(q_bias_np_vals, "encoder.layers.0.self_attn.q_proj.bias") -k_bias_np_vals = (0.01 * np.arange(hidden_size, dtype=np.float32)) -k_bias_initializer = numpy_helper.from_array(k_bias_np_vals, 'encoder.layers.0.self_attn.k_proj.bias') +k_bias_np_vals = 0.01 * np.arange(hidden_size, dtype=np.float32) +k_bias_initializer = numpy_helper.from_array(k_bias_np_vals, "encoder.layers.0.self_attn.k_proj.bias") -v_bias_np_vals = (0.01 * np.arange(hidden_size, dtype=np.float32)) -v_bias_initializer = numpy_helper.from_array(v_bias_np_vals, 'encoder.layers.0.self_attn.v_proj.bias') +v_bias_np_vals = 0.01 * np.arange(hidden_size, dtype=np.float32) +v_bias_initializer = numpy_helper.from_array(v_bias_np_vals, "encoder.layers.0.self_attn.v_proj.bias") -q_shape_initializer = numpy_helper.from_array(np.asarray([input_len, batch*attention_head , hidden_per_attention], dtype=np.int64), 'q_shape') -k_shape_initializer = numpy_helper.from_array(np.asarray([-1, batch*attention_head , hidden_per_attention], dtype=np.int64), 'k_shape') -v_shape_initializer = numpy_helper.from_array(np.asarray([-1, batch*attention_head , hidden_per_attention], dtype=np.int64), 'v_shape') +q_shape_initializer = numpy_helper.from_array( + np.asarray([input_len, batch * attention_head, hidden_per_attention], dtype=np.int64), + "q_shape", +) +k_shape_initializer = numpy_helper.from_array( + np.asarray([-1, batch * attention_head, hidden_per_attention], dtype=np.int64), + "k_shape", +) +v_shape_initializer = numpy_helper.from_array( + np.asarray([-1, batch * attention_head, hidden_per_attention], dtype=np.int64), + "v_shape", +) mul_np_vals = np.asarray([0.1767766922712326], dtype=np.float32).reshape(()) mul_initializer = numpy_helper.from_array(mul_np_vals, "mul_const") -qk_shape_initializer = numpy_helper.from_array(np.asarray([batch, attention_head , input_len, input_len], dtype=np.int64), 'qk_shape') +qk_shape_initializer = numpy_helper.from_array( + np.asarray([batch, attention_head, input_len, input_len], dtype=np.int64), + "qk_shape", +) -dummy_condition_initializer = numpy_helper.from_array(np.zeros((batch, input_len), dtype=bool), 'dummy_cond') -inf_const_initializer = numpy_helper.from_array(np.asarray([-np.inf], dtype=np.float32), 'inf_const') +dummy_condition_initializer = numpy_helper.from_array(np.zeros((batch, input_len), dtype=bool), "dummy_cond") +inf_const_initializer = numpy_helper.from_array(np.asarray([-np.inf], dtype=np.float32), "inf_const") -where_shape_initializer = numpy_helper.from_array(np.asarray([batch*attention_head , input_len, input_len], dtype=np.int64), 'where_shape') +where_shape_initializer = numpy_helper.from_array( + np.asarray([batch * attention_head, input_len, input_len], dtype=np.int64), + "where_shape", +) dropout_np_vals = np.asarray([0.1], dtype=np.float32).reshape(()) dropout_initializer = numpy_helper.from_array(dropout_np_vals, "ratio") - + dropout_mode_np_vals = np.array([False], dtype=np.bool).reshape(()) dropout_mode_initializer = numpy_helper.from_array(dropout_mode_np_vals, "mode") -shape_initializer3 = numpy_helper.from_array(np.array([input_len, batch, attention_head * hidden_per_attention], dtype=np.int64), 'concat_shape_3') +shape_initializer3 = numpy_helper.from_array( + np.array([input_len, batch, attention_head * hidden_per_attention], dtype=np.int64), + "concat_shape_3", +) -dense_weight_np_vals = (0.01 * np.arange(hidden_size * hidden_size, dtype=np.float32)).reshape((hidden_size, hidden_size)) -dense_weight_initializer = numpy_helper.from_array(dense_weight_np_vals, 'encoder.layers.0.self_attn.out_proj.weight') +dense_weight_np_vals = (0.01 * np.arange(hidden_size * hidden_size, dtype=np.float32)).reshape( + (hidden_size, hidden_size) +) +dense_weight_initializer = numpy_helper.from_array(dense_weight_np_vals, "encoder.layers.0.self_attn.out_proj.weight") -dense_bias_np_vals = (0.01 * np.arange(hidden_size, dtype=np.float32)) -dense_bias_initializer = numpy_helper.from_array(dense_bias_np_vals, 'encoder.layers.0.self_attn.out_proj.bias') +dense_bias_np_vals = 0.01 * np.arange(hidden_size, dtype=np.float32) +dense_bias_initializer = numpy_helper.from_array(dense_bias_np_vals, "encoder.layers.0.self_attn.out_proj.bias") -transpose_ip = helper.make_node('Transpose', ['input'], ['transpose_ip'], name='transpose_ip', perm=[1,0,2]) +transpose_ip = helper.make_node("Transpose", ["input"], ["transpose_ip"], name="transpose_ip", perm=[1, 0, 2]) -transpose_q = helper.make_node('Transpose', [q_weight_initializer.name], ['transpose_q'], name='transpose_q', perm=[1,0]) -transpose_k = helper.make_node('Transpose', [k_weight_initializer.name], ['transpose_k'], name='transpose_k', perm=[1,0]) -transpose_v = helper.make_node('Transpose', [v_weight_initializer.name], ['transpose_v'], name='transpose_v', perm=[1,0]) +transpose_q = helper.make_node( + "Transpose", + [q_weight_initializer.name], + ["transpose_q"], + name="transpose_q", + perm=[1, 0], +) +transpose_k = helper.make_node( + "Transpose", + [k_weight_initializer.name], + ["transpose_k"], + name="transpose_k", + perm=[1, 0], +) +transpose_v = helper.make_node( + "Transpose", + [v_weight_initializer.name], + ["transpose_v"], + name="transpose_v", + perm=[1, 0], +) -matmul_q = helper.make_node('MatMul', ['transpose_ip', 'transpose_q'], ['matmul_q'], name='matmul_q') -matmul_k = helper.make_node('MatMul', ['transpose_ip', 'transpose_k'], ['matmul_k'], name='matmul_k') -matmul_v = helper.make_node('MatMul', ['transpose_ip', 'transpose_v'], ['matmul_v'], name='matmul_v') +matmul_q = helper.make_node("MatMul", ["transpose_ip", "transpose_q"], ["matmul_q"], name="matmul_q") +matmul_k = helper.make_node("MatMul", ["transpose_ip", "transpose_k"], ["matmul_k"], name="matmul_k") +matmul_v = helper.make_node("MatMul", ["transpose_ip", "transpose_v"], ["matmul_v"], name="matmul_v") -add_q = helper.make_node('Add', ['matmul_q', q_bias_initializer.name], ['add_q'], name='add_q') -add_k = helper.make_node('Add', ['matmul_k', k_bias_initializer.name], ['add_k'], name='add_k') -add_v = helper.make_node('Add', ['matmul_v', v_bias_initializer.name], ['add_v'], name='add_v') +add_q = helper.make_node("Add", ["matmul_q", q_bias_initializer.name], ["add_q"], name="add_q") +add_k = helper.make_node("Add", ["matmul_k", k_bias_initializer.name], ["add_k"], name="add_k") +add_v = helper.make_node("Add", ["matmul_v", v_bias_initializer.name], ["add_v"], name="add_v") -mul_q = helper.make_node('Mul', ['add_q' , 'mul_const'], ['mul_q'], name='mul_q') +mul_q = helper.make_node("Mul", ["add_q", "mul_const"], ["mul_q"], name="mul_q") -reshape_q = helper.make_node('Reshape', ['mul_q', q_shape_initializer.name], ['reshape_q'], name='reshape_q') -reshape_k = helper.make_node('Reshape', ['add_k', k_shape_initializer.name], ['reshape_k'], name='reshape_k') -reshape_v = helper.make_node('Reshape', ['add_v', v_shape_initializer.name], ['reshape_v'], name='reshape_v') +reshape_q = helper.make_node("Reshape", ["mul_q", q_shape_initializer.name], ["reshape_q"], name="reshape_q") +reshape_k = helper.make_node("Reshape", ["add_k", k_shape_initializer.name], ["reshape_k"], name="reshape_k") +reshape_v = helper.make_node("Reshape", ["add_v", v_shape_initializer.name], ["reshape_v"], name="reshape_v") -transpose_q2 = helper.make_node('Transpose', ['reshape_q'], ['transpose_q2'], name='transpose_q2', perm=[1,0,2]) -transpose_k2 = helper.make_node('Transpose', ['reshape_k'], ['transpose_k2'], name='transpose_k2', perm=[1,2,0]) -transpose_v2 = helper.make_node('Transpose', ['reshape_v'], ['transpose_v2'], name='transpose_v2', perm=[1,0,2]) +transpose_q2 = helper.make_node("Transpose", ["reshape_q"], ["transpose_q2"], name="transpose_q2", perm=[1, 0, 2]) +transpose_k2 = helper.make_node("Transpose", ["reshape_k"], ["transpose_k2"], name="transpose_k2", perm=[1, 2, 0]) +transpose_v2 = helper.make_node("Transpose", ["reshape_v"], ["transpose_v2"], name="transpose_v2", perm=[1, 0, 2]) -matmul = helper.make_node('MatMul', ['transpose_q2', 'transpose_k2'], ['matmul'], name='matmul') -reshape_qk = helper.make_node("Reshape", ['matmul', qk_shape_initializer.name], ['reshape_qk'], name='reshape_qk') +matmul = helper.make_node("MatMul", ["transpose_q2", "transpose_k2"], ["matmul"], name="matmul") +reshape_qk = helper.make_node("Reshape", ["matmul", qk_shape_initializer.name], ["reshape_qk"], name="reshape_qk") -unsqueeze = helper.make_node('Unsqueeze', [dummy_condition_initializer.name],['unsqueeze_cond'], axes=[1,2], name='unsqueeze_cond') -where = helper.make_node('Where', ['unsqueeze_cond', inf_const_initializer.name, 'reshape_qk'], ['where'], name='where') +unsqueeze = helper.make_node( + "Unsqueeze", + [dummy_condition_initializer.name], + ["unsqueeze_cond"], + axes=[1, 2], + name="unsqueeze_cond", +) +where = helper.make_node( + "Where", + ["unsqueeze_cond", inf_const_initializer.name, "reshape_qk"], + ["where"], + name="where", +) -reshape_where = helper.make_node("Reshape", ['where', where_shape_initializer.name], ['reshape_where'], name='reshape_where') +reshape_where = helper.make_node( + "Reshape", + ["where", where_shape_initializer.name], + ["reshape_where"], + name="reshape_where", +) -softmax = helper.make_node('Softmax', ['reshape_where'], ['softmax'], name='softmax', axis=2) -dropout1 = helper.make_node('Dropout', - ["softmax", dropout_initializer.name, dropout_mode_initializer.name], - ['dropout1', "dropout1_mask"], - name='dropout1') +softmax = helper.make_node("Softmax", ["reshape_where"], ["softmax"], name="softmax", axis=2) +dropout1 = helper.make_node( + "Dropout", + ["softmax", dropout_initializer.name, dropout_mode_initializer.name], + ["dropout1", "dropout1_mask"], + name="dropout1", +) -matmul2 = helper.make_node('MatMul', ['dropout1', 'transpose_v2'], ['matmul2'], name='matmul2') -transpose = helper.make_node('Transpose', ['matmul2'], ['transpose'], name='transpose', perm=[1,0,2]) -reshape = helper.make_node('Reshape', ['transpose', shape_initializer3.name], ['reshape'], name='reshape') +matmul2 = helper.make_node("MatMul", ["dropout1", "transpose_v2"], ["matmul2"], name="matmul2") +transpose = helper.make_node("Transpose", ["matmul2"], ["transpose"], name="transpose", perm=[1, 0, 2]) +reshape = helper.make_node("Reshape", ["transpose", shape_initializer3.name], ["reshape"], name="reshape") -transpose_o_weight = helper.make_node('Transpose', [dense_weight_initializer.name], ['transpose_o_weight'], name='transpose_o_weight', perm=[1,0]) -matmul3 = helper.make_node('MatMul', ['reshape', 'transpose_o_weight'], ['matmul3'], name='matmul3') -add3 = helper.make_node('Add', ['matmul3', dense_bias_initializer.name], ['add3'], name='add3') -identity = helper.make_node('Identity', ['add3'], ['output'], name='identity') +transpose_o_weight = helper.make_node( + "Transpose", + [dense_weight_initializer.name], + ["transpose_o_weight"], + name="transpose_o_weight", + perm=[1, 0], +) +matmul3 = helper.make_node("MatMul", ["reshape", "transpose_o_weight"], ["matmul3"], name="matmul3") +add3 = helper.make_node("Add", ["matmul3", dense_bias_initializer.name], ["add3"], name="add3") +identity = helper.make_node("Identity", ["add3"], ["output"], name="identity") # Create the graph (GraphProto) graph_def = helper.make_graph( - [transpose_ip,transpose_q,transpose_k,transpose_v,matmul_q,matmul_k,matmul_v,add_q,add_k,add_v, - mul_q,reshape_q,reshape_k,reshape_v,transpose_q2,transpose_k2,transpose_v2,matmul,reshape_qk, - unsqueeze,where,reshape_where,softmax,dropout1,matmul2,transpose,reshape,transpose_o_weight, - matmul3,add3,identity], - 'self-attention-megatron-test-model', + [ + transpose_ip, + transpose_q, + transpose_k, + transpose_v, + matmul_q, + matmul_k, + matmul_v, + add_q, + add_k, + add_v, + mul_q, + reshape_q, + reshape_k, + reshape_v, + transpose_q2, + transpose_k2, + transpose_v2, + matmul, + reshape_qk, + unsqueeze, + where, + reshape_where, + softmax, + dropout1, + matmul2, + transpose, + reshape, + transpose_o_weight, + matmul3, + add3, + identity, + ], + "self-attention-megatron-test-model", [X], [Y], - [q_weight_initializer,k_weight_initializer,v_weight_initializer,q_bias_initializer,k_bias_initializer, - v_bias_initializer,q_shape_initializer,k_shape_initializer,v_shape_initializer,mul_initializer, - qk_shape_initializer,dummy_condition_initializer,inf_const_initializer,where_shape_initializer, - dropout_initializer,dropout_mode_initializer,shape_initializer3, - dense_weight_initializer, dense_bias_initializer] + [ + q_weight_initializer, + k_weight_initializer, + v_weight_initializer, + q_bias_initializer, + k_bias_initializer, + v_bias_initializer, + q_shape_initializer, + k_shape_initializer, + v_shape_initializer, + mul_initializer, + qk_shape_initializer, + dummy_condition_initializer, + inf_const_initializer, + where_shape_initializer, + dropout_initializer, + dropout_mode_initializer, + shape_initializer3, + dense_weight_initializer, + dense_bias_initializer, + ], ) opsets = [] onnxdomain = OperatorSetIdProto() onnxdomain.version = 12 -onnxdomain.domain = "" # The empty string ("") or absence of this field implies the operator set that is defined as part of the ONNX specification. +onnxdomain.domain = "" # The empty string ("") or absence of this field implies the operator set that is defined as part of the ONNX specification. opsets.append(onnxdomain) msdomain = OperatorSetIdProto() msdomain.version = 1 -msdomain.domain = 'com.microsoft' +msdomain.domain = "com.microsoft" opsets.append(msdomain) -kwargs={} -kwargs['opset_imports'] = opsets +kwargs = {} +kwargs["opset_imports"] = opsets # Create the model (ModelProto) -model_def = helper.make_model(graph_def, producer_name='onnx-example', **kwargs) -onnx.save(model_def, 'bart_self_attention_megatron_basic_test.onnx') - - +model_def = helper.make_model(graph_def, producer_name="onnx-example", **kwargs) +onnx.save(model_def, "bart_self_attention_megatron_basic_test.onnx") diff --git a/onnxruntime/test/testdata/transform/model_parallel/mlp_megatron_basic_test.py b/onnxruntime/test/testdata/transform/model_parallel/mlp_megatron_basic_test.py index de718c702e880..b26d384cbb4c9 100644 --- a/onnxruntime/test/testdata/transform/model_parallel/mlp_megatron_basic_test.py +++ b/onnxruntime/test/testdata/transform/model_parallel/mlp_megatron_basic_test.py @@ -1,84 +1,90 @@ -import onnx -from onnx import helper -from onnx import AttributeProto, TensorProto, GraphProto, OperatorSetIdProto -from onnx import numpy_helper import numpy as np +import onnx +from onnx import AttributeProto, GraphProto, OperatorSetIdProto, TensorProto, helper, numpy_helper hidden_size = 4 weight_dim_to_split = 16 -X = helper.make_tensor_value_info('input', TensorProto.FLOAT, ["batch", "seqlen", hidden_size]) -Y = helper.make_tensor_value_info('output', TensorProto.FLOAT, ["batch", "seqlen", hidden_size]) +X = helper.make_tensor_value_info("input", TensorProto.FLOAT, ["batch", "seqlen", hidden_size]) +Y = helper.make_tensor_value_info("output", TensorProto.FLOAT, ["batch", "seqlen", hidden_size]) -a_weight_np_vals = (0.01 * np.arange(hidden_size * weight_dim_to_split, dtype=np.float32)).reshape((hidden_size, weight_dim_to_split)) -a_weight_initializer = numpy_helper.from_array(a_weight_np_vals, "transformer.layers.0.mlp.dense_h_to_4h.weight_transposed") +a_weight_np_vals = (0.01 * np.arange(hidden_size * weight_dim_to_split, dtype=np.float32)).reshape( + (hidden_size, weight_dim_to_split) +) +a_weight_initializer = numpy_helper.from_array( + a_weight_np_vals, "transformer.layers.0.mlp.dense_h_to_4h.weight_transposed" +) -a_bias_np_vals = (0.01 * np.arange(weight_dim_to_split, dtype=np.float32)) # weight_dim_to_split numbers in total +a_bias_np_vals = 0.01 * np.arange(weight_dim_to_split, dtype=np.float32) # weight_dim_to_split numbers in total a_bias_initializer = numpy_helper.from_array(a_bias_np_vals, "transformer.layers.0.mlp.dense_h_to_4h.bias") -b_weight_np_vals = (0.01 * np.arange(weight_dim_to_split * hidden_size, dtype=np.float32)).reshape((weight_dim_to_split, hidden_size)) -b_weight_initializer = numpy_helper.from_array(b_weight_np_vals, "transformer.layers.0.mlp.dense_4h_to_h.weight_transposed") +b_weight_np_vals = (0.01 * np.arange(weight_dim_to_split * hidden_size, dtype=np.float32)).reshape( + (weight_dim_to_split, hidden_size) +) +b_weight_initializer = numpy_helper.from_array( + b_weight_np_vals, "transformer.layers.0.mlp.dense_4h_to_h.weight_transposed" +) -b_bias_np_vals = (0.01 * np.arange(hidden_size, dtype=np.float32)) # hidden_size numbers in total +b_bias_np_vals = 0.01 * np.arange(hidden_size, dtype=np.float32) # hidden_size numbers in total b_bias_initializer = numpy_helper.from_array(b_bias_np_vals, "transformer.layers.0.mlp.dense_4h_to_h.bias") matmul = helper.make_node( - 'MatMul', # node name - ['input', a_weight_initializer.name], # inputs - ['matmul'], # outputs - name="matmul" + "MatMul", # node name + ["input", a_weight_initializer.name], # inputs + ["matmul"], # outputs + name="matmul", ) add = helper.make_node( - 'Add', # node name - ['matmul', a_bias_initializer.name], # inputs - ['add'], # outputs - name="add" + "Add", # node name + ["matmul", a_bias_initializer.name], # inputs + ["add"], # outputs + name="add", ) gelu = helper.make_node( - 'Gelu', # node name - ['add'], # inputs - ['gelu'], # outputs + "Gelu", # node name + ["add"], # inputs + ["gelu"], # outputs name="gelu", doc_string="", - domain="com.microsoft" + domain="com.microsoft", ) matmul2 = helper.make_node( - 'MatMul', # node name - ['gelu', b_weight_initializer.name], # inputs - ['matmul2'], # outputs - name="matmul2" + "MatMul", # node name + ["gelu", b_weight_initializer.name], # inputs + ["matmul2"], # outputs + name="matmul2", ) add2 = helper.make_node( - 'Add', # node name - ['matmul2', b_bias_initializer.name], # inputs - ['add2'], # outputs - name="add2" + "Add", # node name + ["matmul2", b_bias_initializer.name], # inputs + ["add2"], # outputs + name="add2", ) -identity = helper.make_node( - 'Identity', # node name - ['add2'], # inputs - ['output'], # outputs - name="identity" -) +identity = helper.make_node("Identity", ["add2"], ["output"], name="identity") # node name # inputs # outputs # Create the graph (GraphProto) graph_def = helper.make_graph( [matmul, add, gelu, matmul2, add2, identity], - 'test-model', + "test-model", [X], [Y], - [a_weight_initializer, a_bias_initializer, b_weight_initializer, b_bias_initializer] + [ + a_weight_initializer, + a_bias_initializer, + b_weight_initializer, + b_bias_initializer, + ], ) opsets = [] onnxdomain = OperatorSetIdProto() onnxdomain.version = 10 -onnxdomain.domain = "" # The empty string ("") or absence of this field implies the operator set that is defined as part of the ONNX specification. +onnxdomain.domain = "" # The empty string ("") or absence of this field implies the operator set that is defined as part of the ONNX specification. opsets.append(onnxdomain) msdomain = OperatorSetIdProto() @@ -86,10 +92,10 @@ msdomain.domain = "com.microsoft" opsets.append(msdomain) -kwargs={} +kwargs = {} kwargs["opset_imports"] = opsets # Create the model (ModelProto) -model_def = helper.make_model(graph_def, producer_name='onnx-example', **kwargs) +model_def = helper.make_model(graph_def, producer_name="onnx-example", **kwargs) -onnx.save(model_def, 'mlp_megatron_basic_test.onnx') \ No newline at end of file +onnx.save(model_def, "mlp_megatron_basic_test.onnx") diff --git a/onnxruntime/test/testdata/transform/model_parallel/self_attention_megatron_basic_test.py b/onnxruntime/test/testdata/transform/model_parallel/self_attention_megatron_basic_test.py index 916de919442bb..5083ceeb434db 100644 --- a/onnxruntime/test/testdata/transform/model_parallel/self_attention_megatron_basic_test.py +++ b/onnxruntime/test/testdata/transform/model_parallel/self_attention_megatron_basic_test.py @@ -1,8 +1,6 @@ -import onnx -from onnx import helper -from onnx import TensorProto, GraphProto, OperatorSetIdProto -from onnx import numpy_helper import numpy as np +import onnx +from onnx import GraphProto, OperatorSetIdProto, TensorProto, helper, numpy_helper hidden_size = 4 attention_head = 2 @@ -14,79 +12,140 @@ # |->Reshape->Transpose->| | # |->Reshape->Transpose------------------------------------------>| -X = helper.make_tensor_value_info('input', TensorProto.FLOAT, ['batch', 'seqlen', hidden_size]) -X_mask = helper.make_tensor_value_info('mask', TensorProto.FLOAT, ['batch', 1, 'seqlen', 'seqlen']) -Y = helper.make_tensor_value_info('output', TensorProto.FLOAT, ['batch', 'seqlen', hidden_size]) +X = helper.make_tensor_value_info("input", TensorProto.FLOAT, ["batch", "seqlen", hidden_size]) +X_mask = helper.make_tensor_value_info("mask", TensorProto.FLOAT, ["batch", 1, "seqlen", "seqlen"]) +Y = helper.make_tensor_value_info("output", TensorProto.FLOAT, ["batch", "seqlen", hidden_size]) -qkv_weight_np_vals = (0.01 * np.arange(hidden_size * (hidden_size * 3), dtype=np.float32)).reshape((hidden_size, hidden_size * 3)) -qkv_weight_initializer = numpy_helper.from_array(qkv_weight_np_vals, 'transformer.attention.query_key_value.weight_transposed') +qkv_weight_np_vals = (0.01 * np.arange(hidden_size * (hidden_size * 3), dtype=np.float32)).reshape( + (hidden_size, hidden_size * 3) +) +qkv_weight_initializer = numpy_helper.from_array( + qkv_weight_np_vals, "transformer.attention.query_key_value.weight_transposed" +) -qkv_bias_np_vals = (0.01 * np.arange(hidden_size * 3, dtype=np.float32)) -qkv_bias_initializer = numpy_helper.from_array(qkv_bias_np_vals, 'transformer.attention.query_key_value.bias') +qkv_bias_np_vals = 0.01 * np.arange(hidden_size * 3, dtype=np.float32) +qkv_bias_initializer = numpy_helper.from_array(qkv_bias_np_vals, "transformer.attention.query_key_value.bias") -dense_weight_np_vals = (0.01 * np.arange(hidden_size * hidden_size, dtype=np.float32)).reshape((hidden_size, hidden_size)) -dense_weight_initializer = numpy_helper.from_array(dense_weight_np_vals, 'transformer.attention.dense.weight_transposed') +dense_weight_np_vals = (0.01 * np.arange(hidden_size * hidden_size, dtype=np.float32)).reshape( + (hidden_size, hidden_size) +) +dense_weight_initializer = numpy_helper.from_array( + dense_weight_np_vals, "transformer.attention.dense.weight_transposed" +) -dense_bias_np_vals = (0.01 * np.arange(hidden_size, dtype=np.float32)) -dense_bias_initializer = numpy_helper.from_array(dense_bias_np_vals, 'transformer.attention.dense.bias') +dense_bias_np_vals = 0.01 * np.arange(hidden_size, dtype=np.float32) +dense_bias_initializer = numpy_helper.from_array(dense_bias_np_vals, "transformer.attention.dense.bias") shape_val = np.array([0, 0, attention_head, hidden_per_attention], dtype=np.int64) -shape_initializer = numpy_helper.from_array(shape_val, 'concat_shape_0') +shape_initializer = numpy_helper.from_array(shape_val, "concat_shape_0") shape_val1 = np.array([0, 0, attention_head, hidden_per_attention], dtype=np.int64) -shape_initializer1 = numpy_helper.from_array(shape_val1, 'concat_shape_1') +shape_initializer1 = numpy_helper.from_array(shape_val1, "concat_shape_1") shape_val2 = np.array([0, 0, attention_head, hidden_per_attention], dtype=np.int64) -shape_initializer2 = numpy_helper.from_array(shape_val2, 'concat_shape_2') +shape_initializer2 = numpy_helper.from_array(shape_val2, "concat_shape_2") shape_val3 = np.array([0, 0, hidden_size], dtype=np.int64) -shape_initializer3 = numpy_helper.from_array(shape_val3, 'concat_shape_3') +shape_initializer3 = numpy_helper.from_array(shape_val3, "concat_shape_3") -matmul1 = helper.make_node('MatMul', ['input', qkv_weight_initializer.name], ['matmul1'], name='matmul1') -add1 = helper.make_node('Add', ['matmul1', qkv_bias_initializer.name], ['add1'], name='add1') -split = helper.make_node('Split', ['add1'], ['mixed_query_layer', 'mixed_key_layer', 'mixed_value_layer'], name='split', axis=2) -reshape = helper.make_node('Reshape', ['mixed_query_layer', shape_initializer.name], ['reshape'], name='reshape') -reshape1 = helper.make_node('Reshape', ['mixed_key_layer', shape_initializer1.name], ['reshape1'], name='reshape1') -reshape2 = helper.make_node('Reshape', ['mixed_value_layer', shape_initializer2.name], ['reshape2'], name='reshape2') -transpose = helper.make_node('Transpose', ['reshape'], ['transpose'], name='transpose', perm=[0,2,1,3]) -transpose1 = helper.make_node('Transpose', ['reshape1'], ['transpose1'], name='transpose1', perm=[0,2,3,1]) -transpose2 = helper.make_node('Transpose', ['reshape2'], ['transpose2'], name='transpose2', perm=[0,2,1,3]) -matmul2 = helper.make_node('MatMul', ['transpose', 'transpose1'], ['matmul2'], name='matmul2') +matmul1 = helper.make_node("MatMul", ["input", qkv_weight_initializer.name], ["matmul1"], name="matmul1") +add1 = helper.make_node("Add", ["matmul1", qkv_bias_initializer.name], ["add1"], name="add1") +split = helper.make_node( + "Split", + ["add1"], + ["mixed_query_layer", "mixed_key_layer", "mixed_value_layer"], + name="split", + axis=2, +) +reshape = helper.make_node( + "Reshape", + ["mixed_query_layer", shape_initializer.name], + ["reshape"], + name="reshape", +) +reshape1 = helper.make_node( + "Reshape", + ["mixed_key_layer", shape_initializer1.name], + ["reshape1"], + name="reshape1", +) +reshape2 = helper.make_node( + "Reshape", + ["mixed_value_layer", shape_initializer2.name], + ["reshape2"], + name="reshape2", +) +transpose = helper.make_node("Transpose", ["reshape"], ["transpose"], name="transpose", perm=[0, 2, 1, 3]) +transpose1 = helper.make_node("Transpose", ["reshape1"], ["transpose1"], name="transpose1", perm=[0, 2, 3, 1]) +transpose2 = helper.make_node("Transpose", ["reshape2"], ["transpose2"], name="transpose2", perm=[0, 2, 1, 3]) +matmul2 = helper.make_node("MatMul", ["transpose", "transpose1"], ["matmul2"], name="matmul2") # Use the mask input for below 3 nodes. This is different from the original GPT-2 model, but it's OK for te sub-graph test. -div = helper.make_node('Div', ['matmul2', 'mask'], ['div'], name='div') -mul = helper.make_node('Mul', ['div', 'mask'], ['mul'], name='mul') -sub = helper.make_node('Sub', ['mul', 'mask'], ['sub'], name='sub') -softmax = helper.make_node('Softmax', ['sub'], ['softmax'], name='softmax', axis=3) -dropout1 = helper.make_node('Dropout', ['softmax'], ['dropout1'], name='dropout1') -matmul3 = helper.make_node('MatMul', ['dropout1', 'transpose2'], ['matmul3'], name='matmul3') -transpose3 = helper.make_node('Transpose', ['matmul3'], ['transpose3'], name='transpose3', perm=[0,2,1,3]) -reshape3 = helper.make_node('Reshape', ['transpose3', shape_initializer3.name], ['reshape3'], name='reshape3') -matmul4 = helper.make_node('MatMul', ['reshape3', dense_weight_initializer.name], ['matmul4'], name='matmul4') -add2 = helper.make_node('Add', ['matmul4', dense_bias_initializer.name], ['add2'], name='add2') -dropout2 = helper.make_node('Dropout', ['add2'], ['dropout2'], name='dropout2') +div = helper.make_node("Div", ["matmul2", "mask"], ["div"], name="div") +mul = helper.make_node("Mul", ["div", "mask"], ["mul"], name="mul") +sub = helper.make_node("Sub", ["mul", "mask"], ["sub"], name="sub") +softmax = helper.make_node("Softmax", ["sub"], ["softmax"], name="softmax", axis=3) +dropout1 = helper.make_node("Dropout", ["softmax"], ["dropout1"], name="dropout1") +matmul3 = helper.make_node("MatMul", ["dropout1", "transpose2"], ["matmul3"], name="matmul3") +transpose3 = helper.make_node("Transpose", ["matmul3"], ["transpose3"], name="transpose3", perm=[0, 2, 1, 3]) +reshape3 = helper.make_node("Reshape", ["transpose3", shape_initializer3.name], ["reshape3"], name="reshape3") +matmul4 = helper.make_node("MatMul", ["reshape3", dense_weight_initializer.name], ["matmul4"], name="matmul4") +add2 = helper.make_node("Add", ["matmul4", dense_bias_initializer.name], ["add2"], name="add2") +dropout2 = helper.make_node("Dropout", ["add2"], ["dropout2"], name="dropout2") # Add dummy Identity so during inference dropout2 can be removed for testing. -identity = helper.make_node('Identity', ['dropout2'], ['output'], name='identity') +identity = helper.make_node("Identity", ["dropout2"], ["output"], name="identity") # Create the graph (GraphProto) graph_def = helper.make_graph( - [matmul1, add1, split, reshape, reshape1, reshape2, transpose, transpose1, transpose2, matmul2, div, mul, sub, softmax, dropout1, matmul3, transpose3, reshape3, matmul4, add2, dropout2, identity], - 'self-attention-megatron-test-model', + [ + matmul1, + add1, + split, + reshape, + reshape1, + reshape2, + transpose, + transpose1, + transpose2, + matmul2, + div, + mul, + sub, + softmax, + dropout1, + matmul3, + transpose3, + reshape3, + matmul4, + add2, + dropout2, + identity, + ], + "self-attention-megatron-test-model", [X, X_mask], [Y], - [qkv_weight_initializer, qkv_bias_initializer, dense_weight_initializer, dense_bias_initializer, shape_initializer, shape_initializer1, shape_initializer2, shape_initializer3] + [ + qkv_weight_initializer, + qkv_bias_initializer, + dense_weight_initializer, + dense_bias_initializer, + shape_initializer, + shape_initializer1, + shape_initializer2, + shape_initializer3, + ], ) opsets = [] onnxdomain = OperatorSetIdProto() onnxdomain.version = 12 -onnxdomain.domain = "" # The empty string ("") or absence of this field implies the operator set that is defined as part of the ONNX specification. +onnxdomain.domain = "" # The empty string ("") or absence of this field implies the operator set that is defined as part of the ONNX specification. opsets.append(onnxdomain) msdomain = OperatorSetIdProto() msdomain.version = 1 -msdomain.domain = 'com.microsoft' +msdomain.domain = "com.microsoft" opsets.append(msdomain) -kwargs={} -kwargs['opset_imports'] = opsets +kwargs = {} +kwargs["opset_imports"] = opsets # Create the model (ModelProto) -model_def = helper.make_model(graph_def, producer_name='onnx-example', **kwargs) -onnx.save(model_def, 'self_attention_megatron_basic_test.onnx') +model_def = helper.make_model(graph_def, producer_name="onnx-example", **kwargs) +onnx.save(model_def, "self_attention_megatron_basic_test.onnx") diff --git a/onnxruntime/test/testdata/transform/noop-add.py b/onnxruntime/test/testdata/transform/noop-add.py index 3e996d42e667d..11dbc7269e816 100644 --- a/onnxruntime/test/testdata/transform/noop-add.py +++ b/onnxruntime/test/testdata/transform/noop-add.py @@ -1,31 +1,31 @@ import onnx -from onnx import helper -from onnx import TensorProto, OperatorSetIdProto +from onnx import OperatorSetIdProto, TensorProto, helper opsets = [] onnxdomain = OperatorSetIdProto() onnxdomain.version = 12 -onnxdomain.domain = "" # The empty string ("") or absence of this field implies the operator set that is defined as part of the ONNX specification. +onnxdomain.domain = "" # The empty string ("") or absence of this field implies the operator set that is defined as part of the ONNX specification. opsets.append(onnxdomain) msdomain = OperatorSetIdProto() msdomain.version = 1 -msdomain.domain = 'com.microsoft' +msdomain.domain = "com.microsoft" opsets.append(msdomain) -kwargs={} -kwargs['opset_imports'] = opsets +kwargs = {} +kwargs["opset_imports"] = opsets + def GenerateModel(model_name): nodes = [ # subgraph # float helper.make_node("Identity", ["X1"], ["id_1"], "id_1"), - helper.make_node("Add", ["float_1", "id_1"], ["add_1"], "add_1"), + helper.make_node("Add", ["float_1", "id_1"], ["add_1"], "add_1"), helper.make_node("Identity", ["add_1"], ["Y1"], "id_2"), # float_16 helper.make_node("Identity", ["X2"], ["id_3"], "id_3"), helper.make_node("Add", ["float16_1", "id_3"], ["add_2"], "add_2"), - helper.make_node("Identity", ["add_2"], ["Y2"], "id_4"), + helper.make_node("Identity", ["add_2"], ["Y2"], "id_4"), # int64 - flip the input 0 and 1 helper.make_node("Identity", ["X3"], ["id_5"], "id_5"), helper.make_node("Add", ["id_5", "int64_1"], ["add_3"], "add_3"), @@ -34,46 +34,48 @@ def GenerateModel(model_name): helper.make_node("Identity", ["X4"], ["id_7"], "id_7"), helper.make_node("Add", ["id_7", "int64_2"], ["add_4"], "add_4"), helper.make_node("Identity", ["add_4"], ["Y4"], "id_8"), - #float + # float helper.make_node("Identity", ["X5"], ["id_9"], "id_9"), - helper.make_node("Add", ["float_2", "id_9"], ["add_5"], "add_5"), + helper.make_node("Add", ["float_2", "id_9"], ["add_5"], "add_5"), helper.make_node("Identity", ["add_5"], ["Y5"], "id_10"), ] inputs = [ # inputs - helper.make_tensor_value_info('X1', TensorProto.FLOAT, ['M', 'K']), - helper.make_tensor_value_info('X2', TensorProto.FLOAT16, ['M', 'K']), - helper.make_tensor_value_info('X3', TensorProto.INT64, ['M', 'K']), - helper.make_tensor_value_info('X4', TensorProto.INT64, ['M', 'K']), - helper.make_tensor_value_info('X5', TensorProto.FLOAT, ['M', 'K']), - ] + helper.make_tensor_value_info("X1", TensorProto.FLOAT, ["M", "K"]), + helper.make_tensor_value_info("X2", TensorProto.FLOAT16, ["M", "K"]), + helper.make_tensor_value_info("X3", TensorProto.INT64, ["M", "K"]), + helper.make_tensor_value_info("X4", TensorProto.INT64, ["M", "K"]), + helper.make_tensor_value_info("X5", TensorProto.FLOAT, ["M", "K"]), + ] initializers = [ - helper.make_tensor('float_1', TensorProto.FLOAT, [1], [0.0]), - helper.make_tensor('float16_1', TensorProto.FLOAT16, [1], [0]), - # int64 - set tensor size to 0 - helper.make_tensor('int64_1', TensorProto.INT64, (), [0]), - # higher rank - helper.make_tensor('int64_2', TensorProto.INT64, [1,1,1], [0]), - #float - set initializer size = 0 - helper.make_tensor('float_2', TensorProto.FLOAT, [0], []), - ] + helper.make_tensor("float_1", TensorProto.FLOAT, [1], [0.0]), + helper.make_tensor("float16_1", TensorProto.FLOAT16, [1], [0]), + # int64 - set tensor size to 0 + helper.make_tensor("int64_1", TensorProto.INT64, (), [0]), + # higher rank + helper.make_tensor("int64_2", TensorProto.INT64, [1, 1, 1], [0]), + # float - set initializer size = 0 + helper.make_tensor("float_2", TensorProto.FLOAT, [0], []), + ] graph = helper.make_graph( nodes, - "NoopAdd", #name + "NoopAdd", # name inputs, [ # outputs - helper.make_tensor_value_info('Y1', TensorProto.FLOAT, ['M', 'K']), - helper.make_tensor_value_info('Y2', TensorProto.FLOAT16, ['M', 'K']), - helper.make_tensor_value_info('Y3', TensorProto.INT64, ['M', 'K']), - helper.make_tensor_value_info('Y4', TensorProto.INT64, ['M', 'K', 1]), - helper.make_tensor_value_info('Y5', TensorProto.FLOAT, ['M', 'K']), + helper.make_tensor_value_info("Y1", TensorProto.FLOAT, ["M", "K"]), + helper.make_tensor_value_info("Y2", TensorProto.FLOAT16, ["M", "K"]), + helper.make_tensor_value_info("Y3", TensorProto.INT64, ["M", "K"]), + helper.make_tensor_value_info("Y4", TensorProto.INT64, ["M", "K", 1]), + helper.make_tensor_value_info("Y5", TensorProto.FLOAT, ["M", "K"]), ], - initializers) + initializers, + ) model = helper.make_model(graph, **kwargs) onnx.save(model, model_name) + if __name__ == "__main__": - GenerateModel('noop-add.onnx') \ No newline at end of file + GenerateModel("noop-add.onnx") diff --git a/onnxruntime/test/testdata/transform/propagate_cast/gen_propagate_cast.py b/onnxruntime/test/testdata/transform/propagate_cast/gen_propagate_cast.py index 8601fbda4f08f..910ff93a32ead 100644 --- a/onnxruntime/test/testdata/transform/propagate_cast/gen_propagate_cast.py +++ b/onnxruntime/test/testdata/transform/propagate_cast/gen_propagate_cast.py @@ -1,9 +1,8 @@ -import onnx -from onnx import helper -from onnx import TensorProto -from onnx import OperatorSetIdProto import itertools + import numpy as np +import onnx +from onnx import OperatorSetIdProto, TensorProto, helper onnxdomain = OperatorSetIdProto() onnxdomain.version = 12 @@ -22,13 +21,9 @@ def type_to_string(type): def save(model_path, nodes, inputs, outputs, initializers): - graph = helper.make_graph( - nodes, - "CastPropagateTest", - inputs, outputs, initializers) + graph = helper.make_graph(nodes, "CastPropagateTest", inputs, outputs, initializers) - model = helper.make_model( - graph, opset_imports=opsets, producer_name="onnxruntime-test") + model = helper.make_model(graph, opset_imports=opsets, producer_name="onnxruntime-test") onnx.save(model, model_path + ".onnx") @@ -38,95 +33,67 @@ def gen_fuse_back2back_casts(model_path): for (type1, type2) in list(itertools.product([TensorProto.FLOAT, TensorProto.FLOAT16], repeat=2)): nodes = [ - helper.make_node( - "MatMul", - ["input_0", "input_1"], - ["product"], - "MatMul_0"), - helper.make_node( - "Cast", - ["product"], - ["product_cast"], - "Cast_0", - to=type1), - helper.make_node( - "Cast", - ["product_cast"], - ["output"], - "Cast_1", - to=type2) + helper.make_node("MatMul", ["input_0", "input_1"], ["product"], "MatMul_0"), + helper.make_node("Cast", ["product"], ["product_cast"], "Cast_0", to=type1), + helper.make_node("Cast", ["product_cast"], ["output"], "Cast_1", to=type2), ] - input_type = type2 if type1 != type2 else ( - TensorProto.FLOAT16 if type1 == TensorProto.FLOAT else TensorProto.FLOAT) - output_type = input_type if type1 != type2 else ( - TensorProto.FLOAT16 if input_type == TensorProto.FLOAT else TensorProto.FLOAT) + input_type = ( + type2 if type1 != type2 else (TensorProto.FLOAT16 if type1 == TensorProto.FLOAT else TensorProto.FLOAT) + ) + output_type = ( + input_type + if type1 != type2 + else (TensorProto.FLOAT16 if input_type == TensorProto.FLOAT else TensorProto.FLOAT) + ) inputs = [ - helper.make_tensor_value_info( - "input_0", input_type, ['M', 'K']), - helper.make_tensor_value_info( - "input_1", input_type, ['K', 'N']) + helper.make_tensor_value_info("input_0", input_type, ["M", "K"]), + helper.make_tensor_value_info("input_1", input_type, ["K", "N"]), ] outputs = [ - helper.make_tensor_value_info( - "output", output_type, ['M', 'N']), + helper.make_tensor_value_info("output", output_type, ["M", "N"]), ] - save(model_path + "_" + type_to_string(type1) + "_" + - type_to_string(type2), nodes, inputs, outputs, []) + save( + model_path + "_" + type_to_string(type1) + "_" + type_to_string(type2), + nodes, + inputs, + outputs, + [], + ) def gen_fuse_sibling_casts(model_path): for (type1, type2) in list(itertools.product([TensorProto.FLOAT, TensorProto.FLOAT16], repeat=2)): - input_type = type2 if type1 != type2 else ( - TensorProto.FLOAT16 if type1 == TensorProto.FLOAT else TensorProto.FLOAT) + input_type = ( + type2 if type1 != type2 else (TensorProto.FLOAT16 if type1 == TensorProto.FLOAT else TensorProto.FLOAT) + ) nodes = [ - helper.make_node( - "MatMul", - ["input_0", "input_1"], - ["product"], - "MatMul_0"), - helper.make_node( - "Cast", - ["product"], - ["cast_0_output"], - "Cast_0", - to=type1), - helper.make_node( - "Identity", - ["cast_0_output"], - ["output_0"], - "Identity_0"), - helper.make_node( - "Cast", - ["product"], - ["cast_1_output"], - "Cast_1", - to=type2), - helper.make_node( - "Identity", - ["cast_1_output"], - ["output_1"], - "Identity_1") + helper.make_node("MatMul", ["input_0", "input_1"], ["product"], "MatMul_0"), + helper.make_node("Cast", ["product"], ["cast_0_output"], "Cast_0", to=type1), + helper.make_node("Identity", ["cast_0_output"], ["output_0"], "Identity_0"), + helper.make_node("Cast", ["product"], ["cast_1_output"], "Cast_1", to=type2), + helper.make_node("Identity", ["cast_1_output"], ["output_1"], "Identity_1"), ] inputs = [ - helper.make_tensor_value_info( - "input_0", input_type, ['M', 'K']), - helper.make_tensor_value_info( - "input_1", input_type, ['K', 'N']) + helper.make_tensor_value_info("input_0", input_type, ["M", "K"]), + helper.make_tensor_value_info("input_1", input_type, ["K", "N"]), ] outputs = [ - helper.make_tensor_value_info( - "output_0", type1, ['M', 'N']), - helper.make_tensor_value_info( - "output_1", type2, ['M', 'N']) + helper.make_tensor_value_info("output_0", type1, ["M", "N"]), + helper.make_tensor_value_info("output_1", type2, ["M", "N"]), ] - save(model_path + "_" + type_to_string(type1) + "_" + - type_to_string(type2), nodes, inputs, outputs, []) + save( + model_path + "_" + type_to_string(type1) + "_" + type_to_string(type2), + nodes, + inputs, + outputs, + [], + ) def flip_type(type, flip=True): @@ -134,65 +101,60 @@ def flip_type(type, flip=True): def do_cast_inputs(input_0, input_1, nodes, input_cast_type): - nodes.extend([helper.make_node( - "Cast", - [input_0], - ["cast_"+input_0], - "Cast_0", - to=input_cast_type), - helper.make_node( - "Cast", - [input_1], - ["cast_"+input_1], - "Cast_1", - to=input_cast_type)]) - return "cast_"+input_0, "cast_"+input_1 + nodes.extend( + [ + helper.make_node("Cast", [input_0], ["cast_" + input_0], "Cast_0", to=input_cast_type), + helper.make_node("Cast", [input_1], ["cast_" + input_1], "Cast_1", to=input_cast_type), + ] + ) + return "cast_" + input_0, "cast_" + input_1 def do_transpose_inputs(input_0, input_1, nodes): - nodes.extend([helper.make_node("Transpose", [input_0], ["input_transpose_0"], "Transpose_0"), - helper.make_node("Transpose", [input_1], ["input_transpose_1"], "Transpose_1")]) + nodes.extend( + [ + helper.make_node("Transpose", [input_0], ["input_transpose_0"], "Transpose_0"), + helper.make_node("Transpose", [input_1], ["input_transpose_1"], "Transpose_1"), + ] + ) return "input_transpose_0", "input_transpose_1" def do_cast_product(product, nodes, product_type): - nodes.insert(1, helper.make_node( - "Cast", - [product], - [product+"_cast"], - "Cast_2", - to=product_type)) - return product+"_cast" + nodes.insert( + 1, + helper.make_node("Cast", [product], [product + "_cast"], "Cast_2", to=product_type), + ) + return product + "_cast" def do_transpose_product(product, nodes): if transpose_product: - nodes.append(helper.make_node("Transpose", [product], [ - product+"_transpose"], "Transpose_2")) - return product+"_transpose" + nodes.append(helper.make_node("Transpose", [product], [product + "_transpose"], "Transpose_2")) + return product + "_transpose" def do_cast_sum(sum, nodes, type): - nodes.append(helper.make_node( - "Cast", - [sum], - ["cast_"+sum], - "Cast_3", - to=type)) - return "cast_"+sum + nodes.append(helper.make_node("Cast", [sum], ["cast_" + sum], "Cast_3", to=type)) + return "cast_" + sum def do_cast_input2(input_2, nodes, type): - nodes.append(helper.make_node( - "Cast", - [input_2], - ["cast_"+input_2], - "Cast_4", - to=type)) - return "cast_"+input_2 - - -def gen_propagate_cast_test_model(model_path, transpose_inputs, transpose_product, cast_inputs, cast_product, insert_add, cast_sum, cast_input2, transpose_inputs_before_cast=False): + nodes.append(helper.make_node("Cast", [input_2], ["cast_" + input_2], "Cast_4", to=type)) + return "cast_" + input_2 + + +def gen_propagate_cast_test_model( + model_path, + transpose_inputs, + transpose_product, + cast_inputs, + cast_product, + insert_add, + cast_sum, + cast_input2, + transpose_inputs_before_cast=False, +): input_0 = "input_0" input_1 = "input_1" product = "product" @@ -211,13 +173,7 @@ def gen_propagate_cast_test_model(model_path, transpose_inputs, transpose_produc input_type = flip_type(input_type) if transpose_inputs: input_0, input_1 = do_transpose_inputs(input_0, input_1, nodes) - nodes.append(helper.make_node( - "MatMul", - [input_0, - input_1], - [product], - "MatMul_0") - ) + nodes.append(helper.make_node("MatMul", [input_0, input_1], [product], "MatMul_0")) if transpose_product: product = do_transpose_product(product, nodes) @@ -226,53 +182,50 @@ def gen_propagate_cast_test_model(model_path, transpose_inputs, transpose_produc product_type = flip_type(product_type) inputs = [ - helper.make_tensor_value_info( - "input_0", input_type, ['N', 'N']), - helper.make_tensor_value_info( - "input_1", input_type, ['N', 'N']) + helper.make_tensor_value_info("input_0", input_type, ["N", "N"]), + helper.make_tensor_value_info("input_1", input_type, ["N", "N"]), ] if insert_add: input_2 = "input_2" add_input_type = flip_type(product_type, cast_input2) - inputs.append(helper.make_tensor_value_info( - input_2, add_input_type, ['N', 'N'])) + inputs.append(helper.make_tensor_value_info(input_2, add_input_type, ["N", "N"])) output = "sum" output_type = product_type if cast_input2: - input_2 = do_cast_input2( - input_2, nodes, flip_type(add_input_type)) - nodes.append(helper.make_node( - "Add", [product, input_2], [output], "Add_0")) + input_2 = do_cast_input2(input_2, nodes, flip_type(add_input_type)) + nodes.append(helper.make_node("Add", [product, input_2], [output], "Add_0")) if cast_sum: output = do_cast_sum(output, nodes, flip_type(output_type)) output_type = flip_type(output_type) else: output = product output_type = product_type - outputs = [ - helper.make_tensor_value_info( - output, output_type, ['N', 'N']) - ] - - save(model_path + ("_transpose_inputs" if transpose_inputs else "") + - ("_transpose_product" if transpose_product else "") + - ("_cast_inputs" if cast_inputs else "") + - ("_cast_product" if cast_product else "") + - ("_cast_input2" if cast_input2 else "") + - ("_cast_sum" if cast_sum else ""), - nodes, inputs, outputs, []) + outputs = [helper.make_tensor_value_info(output, output_type, ["N", "N"])] + + save( + model_path + + ("_transpose_inputs" if transpose_inputs else "") + + ("_transpose_product" if transpose_product else "") + + ("_cast_inputs" if cast_inputs else "") + + ("_cast_product" if cast_product else "") + + ("_cast_input2" if cast_input2 else "") + + ("_cast_sum" if cast_sum else ""), + nodes, + inputs, + outputs, + [], + ) def gen_matmul_two_products(model_path, transpose, transpose_before_cast, second_matmul, cast_inputs): def do_transpose(output_0, output_1, transpose, nodes): - nodes.append(helper.make_node("Transpose", [output_0], [ - "transpose_0_"+output_0], "Transpose_0")) - output_0 = "transpose_0_"+output_0 + nodes.append(helper.make_node("Transpose", [output_0], ["transpose_0_" + output_0], "Transpose_0")) + output_0 = "transpose_0_" + output_0 if transpose > 1: - nodes.append(helper.make_node("Transpose", [output_1], [ - "transpose_1_"+output_1], "Transpose_1")) - output_1 = "transpose_1_"+output_1 + nodes.append(helper.make_node("Transpose", [output_1], ["transpose_1_" + output_1], "Transpose_1")) + output_1 = "transpose_1_" + output_1 return output_0, output_1 + input_type = flip_type(TensorProto.FLOAT, cast_inputs) input_0 = "input_0" input_1 = "input_1" @@ -281,71 +234,63 @@ def do_transpose(output_0, output_1, transpose, nodes): output_1 = "product" outputs = [] nodes = [] - cast_count=0 + cast_count = 0 inputs = [ - helper.make_tensor_value_info( - "input_0", input_type, ['M', 'K']), - helper.make_tensor_value_info( - "input_1", input_type, ['K', 'N']) + helper.make_tensor_value_info("input_0", input_type, ["M", "K"]), + helper.make_tensor_value_info("input_1", input_type, ["K", "N"]), ] if cast_inputs: input_type = flip_type(input_type) input_0, input_1 = do_cast_inputs(input_0, input_1, nodes, input_type) - cast_count +=2 + cast_count += 2 output0_type = input_type output1_type = input_type - nodes.append(helper.make_node( - "MatMul", - [input_0, input_1], - [output], - "MatMul_0")) + nodes.append(helper.make_node("MatMul", [input_0, input_1], [output], "MatMul_0")) if second_matmul: - nodes.append(helper.make_node( - "MatMul", - [input_0, input_1], - ["second_"+output], - "MatMul_1")) - outputs.append(helper.make_tensor_value_info( - "second_"+output, input_type, ['M', 'N'])) + nodes.append(helper.make_node("MatMul", [input_0, input_1], ["second_" + output], "MatMul_1")) + outputs.append(helper.make_tensor_value_info("second_" + output, input_type, ["M", "N"])) if add_products: - nodes.append(helper.make_node( - "Add", - [output, "second_"+output], - ["sum"], - "Add_0")) - outputs.append(helper.make_tensor_value_info( - "sum", input_type, ['M', 'N'])) + nodes.append(helper.make_node("Add", [output, "second_" + output], ["sum"], "Add_0")) + outputs.append(helper.make_tensor_value_info("sum", input_type, ["M", "N"])) if transpose > 0 and transpose_before_cast: output_0, output_1 = do_transpose(output_0, output_1, transpose, nodes) output0_type = flip_type(output0_type) - nodes.append(helper.make_node( - "Cast", - [output_0], - ["cast_"+str(cast_count)+"_"+output_0], - "Cast_"+str(cast_count), - to=output0_type)) - output_0 = "cast_"+str(cast_count)+"_"+output_0 + nodes.append( + helper.make_node( + "Cast", + [output_0], + ["cast_" + str(cast_count) + "_" + output_0], + "Cast_" + str(cast_count), + to=output0_type, + ) + ) + output_0 = "cast_" + str(cast_count) + "_" + output_0 cast_count += 1 if second_matmul: - nodes.append(helper.make_node( - "Cast", - [output_1], - ["cast_"+str(cast_count)+"_"+output_1], - "Cast_"+str(cast_count), - to=TensorProto.FLOAT16)) - output_1 = "cast_"+str(cast_count)+"_"+output_1 + nodes.append( + helper.make_node( + "Cast", + [output_1], + ["cast_" + str(cast_count) + "_" + output_1], + "Cast_" + str(cast_count), + to=TensorProto.FLOAT16, + ) + ) + output_1 = "cast_" + str(cast_count) + "_" + output_1 output1_type = flip_type(output1_type) if transpose > 0 and not transpose_before_cast: output_0, output_1 = do_transpose(output_0, output_1, transpose, nodes) - outputs.extend([ - helper.make_tensor_value_info( - output_0, output0_type, ['M', 'N']), - helper.make_tensor_value_info( - output_1, output1_type, ['M', 'N']) - ]) + outputs.extend( + [ + helper.make_tensor_value_info(output_0, output0_type, ["M", "N"]), + helper.make_tensor_value_info(output_1, output1_type, ["M", "N"]), + ] + ) model_path += "_cast_inputs" if cast_inputs else "" - model_path += ("_transpose_before_cast" if transpose_before_cast else "_transpose_after_cast") if transpose > 0 else "" + model_path += ( + ("_transpose_before_cast" if transpose_before_cast else "_transpose_after_cast") if transpose > 0 else "" + ) model_path += "_transpose" if transpose > 1 else "" model_path += "_second_matmul" if second_matmul else "" model_path += "_add_products" if add_products else "" @@ -353,71 +298,103 @@ def do_transpose(output_0, output_1, transpose, nodes): def gen_bool_to_float16_cast(model_path): - X1 = helper.make_tensor_value_info('x1', TensorProto.INT64, [1, 1]) - X2 = helper.make_tensor_value_info('x2', TensorProto.INT64, [1, 1]) - X3 = helper.make_tensor_value_info('x3', TensorProto.FLOAT, [1, 1]) - Y = helper.make_tensor_value_info('output', TensorProto.FLOAT16, [1, 1]) + X1 = helper.make_tensor_value_info("x1", TensorProto.INT64, [1, 1]) + X2 = helper.make_tensor_value_info("x2", TensorProto.INT64, [1, 1]) + X3 = helper.make_tensor_value_info("x3", TensorProto.FLOAT, [1, 1]) + Y = helper.make_tensor_value_info("output", TensorProto.FLOAT16, [1, 1]) - less1 = helper.make_node('Less', ['x1', 'x2'], ['less1'], name='less1') - cast1 = helper.make_node('Cast', ['less1'], ['cast1'], name='cast1', to=TensorProto.FLOAT16) - cast2 = helper.make_node('Cast', ['x3'], ['cast2'], name='cast2', to=TensorProto.FLOAT16) - add1 = helper.make_node('Add', ['cast1', 'cast2'], ['output']) + less1 = helper.make_node("Less", ["x1", "x2"], ["less1"], name="less1") + cast1 = helper.make_node("Cast", ["less1"], ["cast1"], name="cast1", to=TensorProto.FLOAT16) + cast2 = helper.make_node("Cast", ["x3"], ["cast2"], name="cast2", to=TensorProto.FLOAT16) + add1 = helper.make_node("Add", ["cast1", "cast2"], ["output"]) save(model_path, [less1, cast1, cast2, add1], [X1, X2, X3], [Y], []) def gen_bool_to_float_cast(model_path): - X1 = helper.make_tensor_value_info('x1', TensorProto.INT64, [1, 1]) - X2 = helper.make_tensor_value_info('x2', TensorProto.INT64, [1, 1]) - X3 = helper.make_tensor_value_info('x3', TensorProto.FLOAT16, [1, 1]) - Y = helper.make_tensor_value_info('output', TensorProto.FLOAT16, [1, 1]) + X1 = helper.make_tensor_value_info("x1", TensorProto.INT64, [1, 1]) + X2 = helper.make_tensor_value_info("x2", TensorProto.INT64, [1, 1]) + X3 = helper.make_tensor_value_info("x3", TensorProto.FLOAT16, [1, 1]) + Y = helper.make_tensor_value_info("output", TensorProto.FLOAT16, [1, 1]) - less1 = helper.make_node('Less', ['x1', 'x2'], ['less1'], name='less1') - cast1 = helper.make_node('Cast', ['less1'], ['cast1'], name='cast1', to=TensorProto.FLOAT) - cast2 = helper.make_node('Cast', ['x3'], ['cast2'], name='cast2', to=TensorProto.FLOAT) - add1 = helper.make_node('Add', ['cast1', 'cast2'], ['add1']) - cast3 = helper.make_node('Cast', ['add1'], ['output'], name='cast3', to=TensorProto.FLOAT16) + less1 = helper.make_node("Less", ["x1", "x2"], ["less1"], name="less1") + cast1 = helper.make_node("Cast", ["less1"], ["cast1"], name="cast1", to=TensorProto.FLOAT) + cast2 = helper.make_node("Cast", ["x3"], ["cast2"], name="cast2", to=TensorProto.FLOAT) + add1 = helper.make_node("Add", ["cast1", "cast2"], ["add1"]) + cast3 = helper.make_node("Cast", ["add1"], ["output"], name="cast3", to=TensorProto.FLOAT16) save(model_path, [less1, cast1, cast2, cast3, add1], [X1, X2, X3], [Y], []) def gen_one_input_one_output_test(op, model_path, axes_attribute=False): - X = helper.make_tensor_value_info('x', TensorProto.FLOAT16, [2, 2]) + X = helper.make_tensor_value_info("x", TensorProto.FLOAT16, [2, 2]) output_shape = [2, 2] - if (op=="Unsqueeze"): + if op == "Unsqueeze": output_shape.append(1) - Y = helper.make_tensor_value_info('y', TensorProto.FLOAT16, output_shape) - node_inputs=[] - graph_inputs=[X] - cast1 = helper.make_node('Cast', ['x'], ['cast1'], name='cast1', to=TensorProto.FLOAT) - node_inputs.insert(0, 'cast1') + Y = helper.make_tensor_value_info("y", TensorProto.FLOAT16, output_shape) + node_inputs = [] + graph_inputs = [X] + cast1 = helper.make_node("Cast", ["x"], ["cast1"], name="cast1", to=TensorProto.FLOAT) + node_inputs.insert(0, "cast1") if axes_attribute: - node = helper.make_node(op, node_inputs, ['op_output'], name=op+str(1), axes=np.array([2]).astype(np.int64)) + node = helper.make_node( + op, + node_inputs, + ["op_output"], + name=op + str(1), + axes=np.array([2]).astype(np.int64), + ) else: - node = helper.make_node(op, node_inputs, ['op_output'], name=op+str(1)) - cast2 = helper.make_node('Cast', ['op_output'], [ - 'y'], name='cast2', to=TensorProto.FLOAT16) + node = helper.make_node(op, node_inputs, ["op_output"], name=op + str(1)) + cast2 = helper.make_node("Cast", ["op_output"], ["y"], name="cast2", to=TensorProto.FLOAT16) save(model_path, [cast1, node, cast2], graph_inputs, [Y], []) -for (transpose_inputs, transpose_product, cast_inputs, cast_product, insert_add, cast_sum, cast_input2) in list(itertools.product([False, True], repeat=7)): +for ( + transpose_inputs, + transpose_product, + cast_inputs, + cast_product, + insert_add, + cast_sum, + cast_input2, +) in list(itertools.product([False, True], repeat=7)): if not insert_add and (cast_sum or cast_input2): continue if cast_inputs or cast_product or cast_sum: - gen_propagate_cast_test_model("matmul_add" if insert_add else "matmul", transpose_inputs, - transpose_product, cast_inputs, cast_product, insert_add, cast_sum, cast_input2) + gen_propagate_cast_test_model( + "matmul_add" if insert_add else "matmul", + transpose_inputs, + transpose_product, + cast_inputs, + cast_product, + insert_add, + cast_sum, + cast_input2, + ) gen_fuse_sibling_casts("fuse_sibling_casts") gen_fuse_back2back_casts("fuse_back2back_casts") -for (transpose, transpose_before_cast, second_matmul, add_products, cast_inputs) in list(itertools.product([0, 1, 2], [False, True], [False, True], [False, True], [False, True])): +for ( + transpose, + transpose_before_cast, + second_matmul, + add_products, + cast_inputs, +) in list(itertools.product([0, 1, 2], [False, True], [False, True], [False, True], [False, True])): if not transpose and transpose_before_cast: continue if not second_matmul and add_products: continue - gen_matmul_two_products("matmul_two_outputs", transpose, - transpose_before_cast, second_matmul, cast_inputs) - + gen_matmul_two_products( + "matmul_two_outputs", + transpose, + transpose_before_cast, + second_matmul, + cast_inputs, + ) + gen_bool_to_float16_cast("negative_test_case_bool_fp16_cast") gen_bool_to_float_cast("negative_test_case_bool_fp_cast") diff --git a/onnxruntime/test/testdata/transform/qdq_conv_gen.py b/onnxruntime/test/testdata/transform/qdq_conv_gen.py index 9d26b42e820b6..a8c4d64bb2999 100644 --- a/onnxruntime/test/testdata/transform/qdq_conv_gen.py +++ b/onnxruntime/test/testdata/transform/qdq_conv_gen.py @@ -1,6 +1,5 @@ import onnx -from onnx import helper -from onnx import TensorProto +from onnx import TensorProto, helper # Generate a basic QDQ Conv model with `num_convs` Conv nodes and their surrounding DQ/Q nodes @@ -11,45 +10,73 @@ def GenerateModel(model_path, num_convs): outputs = [] for i in range(num_convs): + def name(base): return f"{base}_{i}" - nodes.extend([ - helper.make_node("DequantizeLinear", [name("X"), name("Scale"), name("Zero_point_uint8")], [name("input_DQ")], name("input_DQ")), - helper.make_node("DequantizeLinear", [name("W"), name("Scale"), name("Zero_point_uint8")], [name("conv_weight_DQ")], name("conv_weight_DQ")), - helper.make_node("DequantizeLinear", [name("Bias"), name("Scale"), name("Zero_point_int32")], [name("conv_bias_DQ")], name("conv_bias_DQ")), - helper.make_node("Conv", [name("input_DQ"), name("conv_weight_DQ"), name("conv_bias_DQ")], [name("conv_output")], name("conv")), - helper.make_node("QuantizeLinear", [name("conv_output"), name("Scale"), name("Zero_point_uint8")], [name("Y")], name("output_Q")), - ]) - - initializers.extend([ - helper.make_tensor(name('Scale'), TensorProto.FLOAT, [1], [256.0]), - helper.make_tensor(name('Zero_point_uint8'), TensorProto.UINT8, [1], [0]), - helper.make_tensor(name('Zero_point_int32'), TensorProto.INT32, [1], [0]), - helper.make_tensor(name('W'), TensorProto.UINT8, [1, 1, 3, 3], [128] * 9), - helper.make_tensor(name('Bias'), TensorProto.INT32, [1], [64]), - ]) - - inputs.extend([ - helper.make_tensor_value_info(name('X'), TensorProto.UINT8, [1, 1, 5, 5]), - ]) - - outputs.extend([ - helper.make_tensor_value_info(name('Y'), TensorProto.UINT8, [1, 1, 3, 3]), - ]) - - graph = helper.make_graph( - nodes, - f"QDQ_Conv_x_{num_convs}", - inputs, - outputs, - initializers - ) + nodes.extend( + [ + helper.make_node( + "DequantizeLinear", + [name("X"), name("Scale"), name("Zero_point_uint8")], + [name("input_DQ")], + name("input_DQ"), + ), + helper.make_node( + "DequantizeLinear", + [name("W"), name("Scale"), name("Zero_point_uint8")], + [name("conv_weight_DQ")], + name("conv_weight_DQ"), + ), + helper.make_node( + "DequantizeLinear", + [name("Bias"), name("Scale"), name("Zero_point_int32")], + [name("conv_bias_DQ")], + name("conv_bias_DQ"), + ), + helper.make_node( + "Conv", + [name("input_DQ"), name("conv_weight_DQ"), name("conv_bias_DQ")], + [name("conv_output")], + name("conv"), + ), + helper.make_node( + "QuantizeLinear", + [name("conv_output"), name("Scale"), name("Zero_point_uint8")], + [name("Y")], + name("output_Q"), + ), + ] + ) + + initializers.extend( + [ + helper.make_tensor(name("Scale"), TensorProto.FLOAT, [1], [256.0]), + helper.make_tensor(name("Zero_point_uint8"), TensorProto.UINT8, [1], [0]), + helper.make_tensor(name("Zero_point_int32"), TensorProto.INT32, [1], [0]), + helper.make_tensor(name("W"), TensorProto.UINT8, [1, 1, 3, 3], [128] * 9), + helper.make_tensor(name("Bias"), TensorProto.INT32, [1], [64]), + ] + ) + + inputs.extend( + [ + helper.make_tensor_value_info(name("X"), TensorProto.UINT8, [1, 1, 5, 5]), + ] + ) + + outputs.extend( + [ + helper.make_tensor_value_info(name("Y"), TensorProto.UINT8, [1, 1, 3, 3]), + ] + ) + + graph = helper.make_graph(nodes, f"QDQ_Conv_x_{num_convs}", inputs, outputs, initializers) model = helper.make_model(graph) onnx.save(model, model_path) if __name__ == "__main__": - GenerateModel('qdq_conv.onnx', 1) - GenerateModel('runtime_optimization/qdq_convs.onnx', 3) + GenerateModel("qdq_conv.onnx", 1) + GenerateModel("runtime_optimization/qdq_convs.onnx", 3) diff --git a/onnxruntime/test/testdata/transform/runtime_optimization/add_with_surrounding_identities_gen.py b/onnxruntime/test/testdata/transform/runtime_optimization/add_with_surrounding_identities_gen.py index 7b64268b86e9d..1b81315cd80a3 100644 --- a/onnxruntime/test/testdata/transform/runtime_optimization/add_with_surrounding_identities_gen.py +++ b/onnxruntime/test/testdata/transform/runtime_optimization/add_with_surrounding_identities_gen.py @@ -1,6 +1,5 @@ import onnx -from onnx import helper -from onnx import TensorProto +from onnx import TensorProto, helper graph = helper.make_graph( [ # nodes @@ -11,14 +10,14 @@ ], "AddWithSurroundingIdentities", # name [ # inputs - helper.make_tensor_value_info('A', TensorProto.FLOAT, [1]), - helper.make_tensor_value_info('B', TensorProto.FLOAT, [1]), + helper.make_tensor_value_info("A", TensorProto.FLOAT, [1]), + helper.make_tensor_value_info("B", TensorProto.FLOAT, [1]), ], [ # outputs - helper.make_tensor_value_info('C', TensorProto.FLOAT, [1]), + helper.make_tensor_value_info("C", TensorProto.FLOAT, [1]), ], - [ # initializers - ]) + [], # initializers +) model = helper.make_model(graph) -onnx.save(model, r'add_with_surrounding_identities.onnx') +onnx.save(model, r"add_with_surrounding_identities.onnx") diff --git a/orttraining/orttraining/eager/opgen/onnxgen.py b/orttraining/orttraining/eager/opgen/onnxgen.py index 5579750a75548..87c4036f48b0a 100755 --- a/orttraining/orttraining/eager/opgen/onnxgen.py +++ b/orttraining/orttraining/eager/opgen/onnxgen.py @@ -7,96 +7,98 @@ from sys import argv from onnx import defs -out_file = path.join( - path.dirname(path.realpath(__file__)), - 'opgen', - 'onnxops.py') +out_file = path.join(path.dirname(path.realpath(__file__)), "opgen", "onnxops.py") onnx_ops = {} for schema in defs.get_all_schemas_with_history(): - key = schema.name.lower() - if schema.deprecated: - continue - if key not in onnx_ops or \ - onnx_ops[key].since_version < schema.since_version: - onnx_ops[key] = schema + key = schema.name.lower() + if schema.deprecated: + continue + if key not in onnx_ops or onnx_ops[key].since_version < schema.since_version: + onnx_ops[key] = schema + def convert_to_aten_type(onnx_type_strs): - type_map = {'tensor(float16)' : 'at::kHalf', - 'tensor(float)' : 'at::kFloat', - 'tensor(double)' : 'at::kDouble', - 'tensor(bfloat16)' : 'at::kBFloat16', - 'tensor(int32)' : 'at::kInt', - 'tensor(int16)' : 'at::kShort', - 'tensor(int8)' : 'at::kByte', - 'tensor(int64)' : 'at::kLong', - 'tensor(bool)' : 'at::kBool', - } - result = set({}) - for onnx_type in onnx_type_strs: - # ONNX has more types, like tensor(string), ignore those types at this momemnt - if onnx_type in type_map: - result.add(type_map[onnx_type]) - return result - -with open(out_file, 'wt') as fp: - def write(s): fp.write(s) - def writeline(s = ''): fp.write(s + '\n') - - writeline(f'# AUTO-GENERATED CODE! - DO NOT EDIT!') - writeline(f'# $ python {" ".join(argv)}') - writeline() - - writeline('from opgen.generator import ONNXAttr, ONNXOp, AttrType') - writeline() - - for op_name, schema in sorted(onnx_ops.items()): - writeline(f'class {schema.name}(ONNXOp):') - writeline(f' """') - doc_str = schema.doc.strip('\r\n') - for doc_line in str.splitlines(doc_str, keepends=False): - writeline(f' {doc_line}') - writeline(f' """') + type_map = { + "tensor(float16)": "at::kHalf", + "tensor(float)": "at::kFloat", + "tensor(double)": "at::kDouble", + "tensor(bfloat16)": "at::kBFloat16", + "tensor(int32)": "at::kInt", + "tensor(int16)": "at::kShort", + "tensor(int8)": "at::kByte", + "tensor(int64)": "at::kLong", + "tensor(bool)": "at::kBool", + } + result = set({}) + for onnx_type in onnx_type_strs: + # ONNX has more types, like tensor(string), ignore those types at this momemnt + if onnx_type in type_map: + result.add(type_map[onnx_type]) + return result + + +with open(out_file, "wt") as fp: + + def write(s): + fp.write(s) + + def writeline(s=""): + fp.write(s + "\n") + + writeline(f"# AUTO-GENERATED CODE! - DO NOT EDIT!") + writeline(f'# $ python {" ".join(argv)}') writeline() - write(' def __init__(self') - - for input in schema.inputs: - write(f', {input.name}') - - if len(schema.attributes) > 0: - writeline(',') - for i, (k, attr) in enumerate(schema.attributes.items()): - write(f' {attr.name}=None') - if i < len(schema.attributes) - 1: - writeline(', ') - - writeline('):') - write(f' super().__init__(\'{schema.name}\', {len(schema.outputs)}') - writeline(',') - write(' ') - input_types = [] - for input in schema.inputs: - input_types.append(convert_to_aten_type(input.types)) - write(str(input_types)) - if len(schema.inputs) > 0: - writeline(',') - input_names = ','.join([input.name for input in schema.inputs]) - write(f' {input_names}') - - - if len(schema.attributes) > 0: - writeline(',') - for i, (k, attr) in enumerate(schema.attributes.items()): - write(f' {attr.name}=ONNXAttr({attr.name}, {attr.type})') - if i < len(schema.attributes) - 1: - writeline(', ') - - writeline(')') + + writeline("from opgen.generator import ONNXAttr, ONNXOp, AttrType") writeline() - writeline('onnx_ops = {') - for i, (op_name, schema) in enumerate(onnx_ops.items()): - writeline(f' \'{op_name}\': {schema.name},') - write('}') + for op_name, schema in sorted(onnx_ops.items()): + writeline(f"class {schema.name}(ONNXOp):") + writeline(f' """') + doc_str = schema.doc.strip("\r\n") + for doc_line in str.splitlines(doc_str, keepends=False): + writeline(f" {doc_line}") + writeline(f' """') + writeline() + write(" def __init__(self") + + for input in schema.inputs: + write(f", {input.name}") + + if len(schema.attributes) > 0: + writeline(",") + for i, (k, attr) in enumerate(schema.attributes.items()): + write(f" {attr.name}=None") + if i < len(schema.attributes) - 1: + writeline(", ") + + writeline("):") + write(f" super().__init__('{schema.name}', {len(schema.outputs)}") + writeline(",") + write(" ") + input_types = [] + for input in schema.inputs: + input_types.append(convert_to_aten_type(input.types)) + write(str(input_types)) + if len(schema.inputs) > 0: + writeline(",") + input_names = ",".join([input.name for input in schema.inputs]) + write(f" {input_names}") + + if len(schema.attributes) > 0: + writeline(",") + for i, (k, attr) in enumerate(schema.attributes.items()): + write(f" {attr.name}=ONNXAttr({attr.name}, {attr.type})") + if i < len(schema.attributes) - 1: + writeline(", ") + + writeline(")") + writeline() + + writeline("onnx_ops = {") + for i, (op_name, schema) in enumerate(onnx_ops.items()): + writeline(f" '{op_name}': {schema.name},") + write("}") -print(f'File updated: {out_file}') \ No newline at end of file +print(f"File updated: {out_file}") diff --git a/orttraining/orttraining/eager/opgen/opgen.py b/orttraining/orttraining/eager/opgen/opgen.py index e0f267a9fb6de..77192035dfe87 100755 --- a/orttraining/orttraining/eager/opgen/opgen.py +++ b/orttraining/orttraining/eager/opgen/opgen.py @@ -11,13 +11,15 @@ from importlib.machinery import SourceFileLoader import argparse -parser = argparse.ArgumentParser(description='Generate ORT ATen operations') -parser.add_argument('--ops_module', type=str, - help='Python module containing the Onnx Operation signature and list of ops to map') -parser.add_argument('--output_file', default=None, type=str, help='Output file [default to std out]') -parser.add_argument('--header_file', type=str, - help='Header file which contains ATen / Pytorch operation signature') -parser.add_argument('--custom_ops', action='store_true', help='Whether we are generating code for custom ops or native operation') +parser = argparse.ArgumentParser(description="Generate ORT ATen operations") +parser.add_argument( + "--ops_module", type=str, help="Python module containing the Onnx Operation signature and list of ops to map" +) +parser.add_argument("--output_file", default=None, type=str, help="Output file [default to std out]") +parser.add_argument("--header_file", type=str, help="Header file which contains ATen / Pytorch operation signature") +parser.add_argument( + "--custom_ops", action="store_true", help="Whether we are generating code for custom ops or native operation" +) args = parser.parse_args() ops_module = SourceFileLoader("opgen.customop", args.ops_module).load_module() @@ -28,7 +30,7 @@ print(f"INFO: Using RegistrationDeclarations from: {regdecs_path}") output = sys.stdout if args.output_file: - output = open(args.output_file, 'wt') + output = open(args.output_file, "wt") with CPPParser(regdecs_path) as parser, SourceWriter(output) as writer: - ortgen.run(parser, writer) + ortgen.run(parser, writer) diff --git a/orttraining/orttraining/eager/opgen/opgen/__init__.py b/orttraining/orttraining/eager/opgen/opgen/__init__.py index 6fcf0de4918d2..5b7f7a925cc05 100644 --- a/orttraining/orttraining/eager/opgen/opgen/__init__.py +++ b/orttraining/orttraining/eager/opgen/opgen/__init__.py @@ -1,2 +1,2 @@ # Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. \ No newline at end of file +# Licensed under the MIT License. diff --git a/orttraining/orttraining/eager/opgen/opgen/ast.py b/orttraining/orttraining/eager/opgen/opgen/ast.py index 672d30a57b8a8..140822b72b67e 100644 --- a/orttraining/orttraining/eager/opgen/opgen/ast.py +++ b/orttraining/orttraining/eager/opgen/opgen/ast.py @@ -5,341 +5,412 @@ from typing import TextIO, List, Union from opgen.lexer import Token + class Node(object): - def __init__(self): - self.tokens = [] + def __init__(self): + self.tokens = [] + + def write(self, writer: TextIO): + raise NotImplementedError(self.write) + + def __str__(self): + writer = io.StringIO() + self.write(writer) + return writer.getvalue() - def write(self, writer: TextIO): - raise NotImplementedError(self.write) - def __str__(self): - writer = io.StringIO() - self.write(writer) - return writer.getvalue() +# region Syntax List -#region Syntax List class SyntaxListMember(Node): - def __init__(self, member: Node, trailing_separator: Token = None): - super().__init__() - self.member = member - self.trailing_separator = trailing_separator + def __init__(self, member: Node, trailing_separator: Token = None): + super().__init__() + self.member = member + self.trailing_separator = trailing_separator + + def write(self, writer: TextIO): + self.member.write(writer) + if self.trailing_separator: + writer.write(self.trailing_separator.value) + writer.write(" ") - def write(self, writer: TextIO): - self.member.write(writer) - if self.trailing_separator: - writer.write(self.trailing_separator.value) - writer.write(" ") class SyntaxList(Node): - open_token: Token - members: List[SyntaxListMember] - close_token: Token - - def __init__(self): - super().__init__() - self.open_token = None - self.members = [] - self.close_token = None - - def __iter__(self): - return self.members.__iter__() - - def __getitem__(self, key): - return self.members.__getitem__(key) - - def __len__(self): - return len(self.members) - - def append(self, member: Node, trailing_separator: Token): - self.members.append(SyntaxListMember(member, trailing_separator)) - - def write(self, writer: TextIO): - if self.open_token: - writer.write(self.open_token.value) - for member in self.members: - member.write(writer) - if self.close_token: - writer.write(self.close_token.value) - -#endregion - -#region Expressions - -class Expression(Node): pass + open_token: Token + members: List[SyntaxListMember] + close_token: Token + + def __init__(self): + super().__init__() + self.open_token = None + self.members = [] + self.close_token = None + + def __iter__(self): + return self.members.__iter__() + + def __getitem__(self, key): + return self.members.__getitem__(key) + + def __len__(self): + return len(self.members) + + def append(self, member: Node, trailing_separator: Token): + self.members.append(SyntaxListMember(member, trailing_separator)) + + def write(self, writer: TextIO): + if self.open_token: + writer.write(self.open_token.value) + for member in self.members: + member.write(writer) + if self.close_token: + writer.write(self.close_token.value) + + +# endregion + +# region Expressions + + +class Expression(Node): + pass + class LiteralExpression(Expression): - def __init__(self, token: Token): - super().__init__() - self.token = token + def __init__(self, token: Token): + super().__init__() + self.token = token + + def write(self, writer: TextIO): + writer.write(self.token.value) - def write(self, writer: TextIO): - writer.write(self.token.value) class ArrayExpression(Expression): - def __init__(self, elements: SyntaxList): - self.elements = elements + def __init__(self, elements: SyntaxList): + self.elements = elements + -#endregion +# endregion + +# region Types -#region Types class Type(Node): - def _desugar_self(self) -> "Type": - return self + def _desugar_self(self) -> "Type": + return self + + def desugar(self) -> "Type": + desugared = self + while True: + _desugared = desugared._desugar_self() + if _desugared == desugared: + return desugared + desugared = _desugared - def desugar(self) -> "Type": - desugared = self - while True: - _desugared = desugared._desugar_self() - if _desugared == desugared: - return desugared - desugared = _desugared class ExpressionType(Type): - def __init__(self, expression: Expression): - super().__init__() - self.expression = expression + def __init__(self, expression: Expression): + super().__init__() + self.expression = expression + + def write(self, writer: TextIO): + self.expression.write(writer) - def write(self, writer: TextIO): - self.expression.write(writer) class ConcreteType(Type): - def __init__(self, identifier_tokens: Union[Token, List[Token]]): - super().__init__() - if isinstance(identifier_tokens, Token): - self.identifier_tokens = [identifier_tokens] - else: - self.identifier_tokens = identifier_tokens + def __init__(self, identifier_tokens: Union[Token, List[Token]]): + super().__init__() + if isinstance(identifier_tokens, Token): + self.identifier_tokens = [identifier_tokens] + else: + self.identifier_tokens = identifier_tokens + + def write(self, writer: TextIO): + for identifier_token in self.identifier_tokens: + writer.write(identifier_token.value) - def write(self, writer: TextIO): - for identifier_token in self.identifier_tokens: - writer.write(identifier_token.value) class ConstType(Type): - def __init__(self, const_token: Token, inner_type: Type): - super().__init__() - self.const_token = const_token - self.inner_type = inner_type + def __init__(self, const_token: Token, inner_type: Type): + super().__init__() + self.const_token = const_token + self.inner_type = inner_type + + def write(self, writer: TextIO): + writer.write(self.const_token.value) + writer.write(" ") + self.inner_type.write(writer) - def write(self, writer: TextIO): - writer.write(self.const_token.value) - writer.write(" ") - self.inner_type.write(writer) + def _desugar_self(self) -> Type: + return self.inner_type - def _desugar_self(self) -> Type: - return self.inner_type class ReferenceType(Type): - def __init__(self, inner_type: Type, reference_token: Token): - super().__init__() - self.inner_type = inner_type - self.reference_token = reference_token + def __init__(self, inner_type: Type, reference_token: Token): + super().__init__() + self.inner_type = inner_type + self.reference_token = reference_token - def write(self, writer: TextIO): - self.inner_type.write(writer) - writer.write(self.reference_token.value) + def write(self, writer: TextIO): + self.inner_type.write(writer) + writer.write(self.reference_token.value) + + def _desugar_self(self) -> Type: + return self.inner_type - def _desugar_self(self) -> Type: - return self.inner_type class ModifiedType(Type): - def __init__(self, base_type: Type): - super().__init__() - self.base_type = base_type + def __init__(self, base_type: Type): + super().__init__() + self.base_type = base_type + + def _desugar_self(self) -> Type: + return self.base_type - def _desugar_self(self) -> Type: - return self.base_type class OptionalType(ModifiedType): - def __init__(self, base_type: Type, token: Token): - super().__init__(base_type) - self.token = token + def __init__(self, base_type: Type, token: Token): + super().__init__(base_type) + self.token = token + + def write(self, writer: TextIO): + self.base_type.write(writer) + writer.write(self.token.value) - def write(self, writer: TextIO): - self.base_type.write(writer) - writer.write(self.token.value) class ArrayType(ModifiedType): - def __init__( - self, - base_type: Type, - open_token: Token, - length_token: Token, - close_token: Token): - super().__init__(base_type) - self.open_token = open_token - self.length_token = length_token - self.close_token = close_token - - def write(self, writer: TextIO): - self.base_type.write(writer) - writer.write(self.open_token.value) - if self.length_token: - writer.write(self.length_token.value) - writer.write(self.close_token.value) + def __init__(self, base_type: Type, open_token: Token, length_token: Token, close_token: Token): + super().__init__(base_type) + self.open_token = open_token + self.length_token = length_token + self.close_token = close_token + + def write(self, writer: TextIO): + self.base_type.write(writer) + writer.write(self.open_token.value) + if self.length_token: + writer.write(self.length_token.value) + writer.write(self.close_token.value) + class TemplateType(Type): - def __init__( - self, - identifier_tokens: Union[Token, List[Token]], - type_arguments: SyntaxList): - super().__init__() - if isinstance(identifier_tokens, Token): - self.identifier_tokens = [identifier_tokens] - else: - self.identifier_tokens = identifier_tokens - self.type_arguments = type_arguments - - def write(self, writer: TextIO): - for identifier_token in self.identifier_tokens: - writer.write(identifier_token.value) - self.type_arguments.write(writer) + def __init__(self, identifier_tokens: Union[Token, List[Token]], type_arguments: SyntaxList): + super().__init__() + if isinstance(identifier_tokens, Token): + self.identifier_tokens = [identifier_tokens] + else: + self.identifier_tokens = identifier_tokens + self.type_arguments = type_arguments + + def write(self, writer: TextIO): + for identifier_token in self.identifier_tokens: + writer.write(identifier_token.value) + self.type_arguments.write(writer) + class TupleMemberType(Type): - def __init__(self, element_type: Type, element_name: Token): - super().__init__() - self.element_type = element_type - self.element_name = element_name + def __init__(self, element_type: Type, element_name: Token): + super().__init__() + self.element_type = element_type + self.element_name = element_name - def write(self, writer: TextIO): - self.element_type.write(writer) + def write(self, writer: TextIO): + self.element_type.write(writer) + + def _desugar_self(self) -> Type: + return self.element_name - def _desugar_self(self) -> Type: - return self.element_name class TupleType(Type): - def __init__(self, elements: SyntaxList): - super().__init__() - self.elements = elements + def __init__(self, elements: SyntaxList): + super().__init__() + self.elements = elements + + def write(self, writer: TextIO): + self.elements.write(writer) - def write(self, writer: TextIO): - self.elements.write(writer) class AliasInfo(Node): - before_set: List[str] - after_set: List[str] - tokens: List[Token] - - def __init__(self): - super().__init__() - self.before_set = [] - self.after_set = [] - self.tokens = [] - self.is_writable = False - - def __str__(self): - buffer = io.StringIO() - self.write(buffer) - return buffer.getvalue() - - def __eq__(self, obj): - return isinstance(obj, AliasInfo) and str(self) == str(obj) - - def __ne__(self, obj): - return not self.__eq__(obj) - - def write(self, writer: TextIO): - writer.write("(") - writer.write("|".join(self.before_set)) - if self.is_writable: - writer.write("!") - writer.write(" -> ") - writer.write("|".join(self.after_set)) - writer.write(")") + before_set: List[str] + after_set: List[str] + tokens: List[Token] + + def __init__(self): + super().__init__() + self.before_set = [] + self.after_set = [] + self.tokens = [] + self.is_writable = False + + def __str__(self): + buffer = io.StringIO() + self.write(buffer) + return buffer.getvalue() + + def __eq__(self, obj): + return isinstance(obj, AliasInfo) and str(self) == str(obj) + + def __ne__(self, obj): + return not self.__eq__(obj) + + def write(self, writer: TextIO): + writer.write("(") + writer.write("|".join(self.before_set)) + if self.is_writable: + writer.write("!") + writer.write(" -> ") + writer.write("|".join(self.after_set)) + writer.write(")") + class AliasInfoType(Type): - def __init__(self, inner_type: Type, alias_info: AliasInfo): - super().__init__() - self.inner_type = inner_type - self.alias_info = alias_info - self.inner_type.alias_info = alias_info + def __init__(self, inner_type: Type, alias_info: AliasInfo): + super().__init__() + self.inner_type = inner_type + self.alias_info = alias_info + self.inner_type.alias_info = alias_info - def write(self, writer: TextIO): - self.inner_type.write(writer) - self.alias_info.write(writer) + def write(self, writer: TextIO): + self.inner_type.write(writer) + self.alias_info.write(writer) + + def _desugar_self(self) -> Type: + return self.inner_type - def _desugar_self(self) -> Type: - return self.inner_type class KWArgsSentinelType(Type): - def __init__(self, token: Token): - super().__init__() - self.token = token - - def write(self, writer: TextIO): - writer.write(self.token.value) - -class TensorType(ConcreteType): pass -class IntType(ConcreteType): pass -class FloatType(ConcreteType): pass -class BoolType(ConcreteType): pass -class StrType(ConcreteType): pass -class ScalarType(ConcreteType): pass -class ScalarTypeType(ConcreteType): pass -class DimnameType(ConcreteType): pass -class GeneratorType(ConcreteType): pass -class TensorOptionsType(ConcreteType): pass -class LayoutType(ConcreteType): pass -class DeviceType(ConcreteType): pass -class MemoryFormatType(ConcreteType): pass -class QSchemeType(ConcreteType): pass -class StorageType(ConcreteType): pass -class ConstQuantizerPtrType(ConcreteType): pass -class StreamType(ConcreteType): pass - -#region Decls - -class Decl(Node): pass + def __init__(self, token: Token): + super().__init__() + self.token = token + + def write(self, writer: TextIO): + writer.write(self.token.value) + + +class TensorType(ConcreteType): + pass + + +class IntType(ConcreteType): + pass + + +class FloatType(ConcreteType): + pass + + +class BoolType(ConcreteType): + pass + + +class StrType(ConcreteType): + pass + + +class ScalarType(ConcreteType): + pass + + +class ScalarTypeType(ConcreteType): + pass + + +class DimnameType(ConcreteType): + pass + + +class GeneratorType(ConcreteType): + pass + + +class TensorOptionsType(ConcreteType): + pass + + +class LayoutType(ConcreteType): + pass + + +class DeviceType(ConcreteType): + pass + + +class MemoryFormatType(ConcreteType): + pass + + +class QSchemeType(ConcreteType): + pass + + +class StorageType(ConcreteType): + pass + + +class ConstQuantizerPtrType(ConcreteType): + pass + + +class StreamType(ConcreteType): + pass + + +# region Decls + + +class Decl(Node): + pass + class ParameterDecl(Decl): - def __init__( - self, - parameter_type: Type, - identifier: Token = None, - equals: Token = None, - default_value: Expression = None): - super().__init__() - self.parameter_type = parameter_type - self.identifier = identifier - self.equals = equals - self.default_value = default_value - - def write(self, writer: TextIO): - self.parameter_type.write(writer) - if self.identifier: - writer.write(" ") - writer.write(self.identifier.value) + def __init__( + self, parameter_type: Type, identifier: Token = None, equals: Token = None, default_value: Expression = None + ): + super().__init__() + self.parameter_type = parameter_type + self.identifier = identifier + self.equals = equals + self.default_value = default_value + + def write(self, writer: TextIO): + self.parameter_type.write(writer) + if self.identifier: + writer.write(" ") + writer.write(self.identifier.value) + class FunctionDecl(Decl): - def __init__( - self, - identifier: Token, - parameters: SyntaxList, - return_type: Type = None, - semicolon: Token = None, - arrow: Token = None): - super().__init__() - self.is_leaf = False - self.identifier = identifier - self.return_type = return_type - self.parameters = parameters - self.semicolon = semicolon - self.arrow = arrow - - def get_parameter(self, identifier: str) -> ParameterDecl: - for param in self.parameters: - id = param.member.identifier - if id and id.value == identifier: - return param.member - return None + def __init__( + self, + identifier: Token, + parameters: SyntaxList, + return_type: Type = None, + semicolon: Token = None, + arrow: Token = None, + ): + super().__init__() + self.is_leaf = False + self.identifier = identifier + self.return_type = return_type + self.parameters = parameters + self.semicolon = semicolon + self.arrow = arrow + + def get_parameter(self, identifier: str) -> ParameterDecl: + for param in self.parameters: + id = param.member.identifier + if id and id.value == identifier: + return param.member + return None + class TranslationUnitDecl(Decl): - def __init__(self, decls: List[FunctionDecl]): - super().__init__() - self.decls = decls + def __init__(self, decls: List[FunctionDecl]): + super().__init__() + self.decls = decls + + def __iter__(self): + return self.decls.__iter__() - def __iter__(self): - return self.decls.__iter__() -#endregion \ No newline at end of file +# endregion diff --git a/orttraining/orttraining/eager/opgen/opgen/atenops.py b/orttraining/orttraining/eager/opgen/opgen/atenops.py index a7360e4df3d37..5ff1990a56774 100644 --- a/orttraining/orttraining/eager/opgen/opgen/atenops.py +++ b/orttraining/orttraining/eager/opgen/opgen/atenops.py @@ -1,10 +1,11 @@ from copy import deepcopy -from opgen.generator import \ - ORTGen as ORTGen, \ - ONNXOp as ONNXOp, \ - SignatureOnly as SignatureOnly, \ - MakeTorchFallback as MakeTorchFallback +from opgen.generator import ( + ORTGen as ORTGen, + ONNXOp as ONNXOp, + SignatureOnly as SignatureOnly, + MakeTorchFallback as MakeTorchFallback, +) from opgen.onnxops import * @@ -13,103 +14,145 @@ TORCH_API_CHANGE_VERSION = "1.11.1" -kMSDomain = 'onnxruntime::kMSDomain' +kMSDomain = "onnxruntime::kMSDomain" + class ReluGrad(ONNXOp): - def __init__(self, dY, X): - super().__init__('ReluGrad', 1, [{'at::kHalf', 'at::kFloat', 'at::kBFloat16'}, {'at::kHalf', 'at::kFloat', 'at::kBFloat16'}], dY, X) - self.domain = kMSDomain + def __init__(self, dY, X): + super().__init__( + "ReluGrad", + 1, + [{"at::kHalf", "at::kFloat", "at::kBFloat16"}, {"at::kHalf", "at::kFloat", "at::kBFloat16"}], + dY, + X, + ) + self.domain = kMSDomain + class Gelu(ONNXOp): - def __init__(self, X): - super().__init__('Gelu', 1, [{'at::kHalf', 'at::kFloat', 'at::kBFloat16'}], X) - self.domain = kMSDomain + def __init__(self, X): + super().__init__("Gelu", 1, [{"at::kHalf", "at::kFloat", "at::kBFloat16"}], X) + self.domain = kMSDomain + class GeluGrad(ONNXOp): - def __init__(self, dY, X): - super().__init__('GeluGrad', 1, [{'at::kHalf', 'at::kFloat', 'at::kBFloat16'}, {'at::kHalf', 'at::kFloat', 'at::kBFloat16'}], dY, X) - self.domain = kMSDomain + def __init__(self, dY, X): + super().__init__( + "GeluGrad", + 1, + [{"at::kHalf", "at::kFloat", "at::kBFloat16"}, {"at::kHalf", "at::kFloat", "at::kBFloat16"}], + dY, + X, + ) + self.domain = kMSDomain + ops = {} type_promotion_ops = [] for binary_op, onnx_op in { - 'add': Add('self', Mul('alpha', 'other')), - 'sub': Sub('self', Mul('alpha', 'other')), - 'mul': Mul('self', 'other'), - 'div': Div('self', 'other')}.items(): - for dtype in ['Tensor', 'Scalar']: - for variant in ['', '_']: - name = f'aten::{binary_op}{variant}.{dtype}' - if name not in ops: - ops[f'aten::{binary_op}{variant}.{dtype}'] = deepcopy(onnx_op) - type_promotion_ops.append(f'aten::{binary_op}{variant}.{dtype}') + "add": Add("self", Mul("alpha", "other")), + "sub": Sub("self", Mul("alpha", "other")), + "mul": Mul("self", "other"), + "div": Div("self", "other"), +}.items(): + for dtype in ["Tensor", "Scalar"]: + for variant in ["", "_"]: + name = f"aten::{binary_op}{variant}.{dtype}" + if name not in ops: + ops[f"aten::{binary_op}{variant}.{dtype}"] = deepcopy(onnx_op) + type_promotion_ops.append(f"aten::{binary_op}{variant}.{dtype}") for unary_op in [ - 'abs','acos','acosh', 'asinh', 'atanh', 'asin', 'atan', 'ceil', 'cos', - 'cosh', 'erf', 'exp', 'floor', 'isnan', 'log', 'reciprocal', 'neg', 'round', - 'relu', 'selu', 'sigmoid', 'sin', 'sinh', 'sqrt', 'tan', 'tanh', 'nonzero', - 'sign', 'hardsigmoid', 'isinf', 'det']: - aten_name = f'aten::{unary_op}' - onnx_op = onnx_ops[unary_op]('self') - ops[aten_name] = onnx_op - # produce the in-place variant as well for ops that support it - if unary_op not in ['isnan', 'nonzero', 'min', 'max', 'isinf', 'det']: - ops[f'{aten_name}_'] = onnx_op + "abs", + "acos", + "acosh", + "asinh", + "atanh", + "asin", + "atan", + "ceil", + "cos", + "cosh", + "erf", + "exp", + "floor", + "isnan", + "log", + "reciprocal", + "neg", + "round", + "relu", + "selu", + "sigmoid", + "sin", + "sinh", + "sqrt", + "tan", + "tanh", + "nonzero", + "sign", + "hardsigmoid", + "isinf", + "det", +]: + aten_name = f"aten::{unary_op}" + onnx_op = onnx_ops[unary_op]("self") + ops[aten_name] = onnx_op + # produce the in-place variant as well for ops that support it + if unary_op not in ["isnan", "nonzero", "min", "max", "isinf", "det"]: + ops[f"{aten_name}_"] = onnx_op hand_implemented = { - 'aten::empty.memory_format': SignatureOnly(), - 'aten::empty_strided': SignatureOnly(), - 'aten::zero_': SignatureOnly(), - 'aten::copy_': SignatureOnly(), - 'aten::_reshape_alias': SignatureOnly(), - 'aten::view': SignatureOnly(), - 'aten::_copy_from_and_resize' : SignatureOnly(), - 'aten::as_strided' : SignatureOnly(), - # manually implement Slice using stride and offset. - 'aten::slice.Tensor' : SignatureOnly(), - - 'aten::addmm': Gemm('mat1', 'mat2', 'self', alpha='alpha', beta='beta'), - 'aten::add_.Tensor': SignatureOnly(), - 'aten::t': Transpose('self'), - 'aten::mm': MatMul('self', 'mat2'), - 'aten::zeros_like': ConstantOfShape(Shape('self')), #the default constant is 0, so don't need to speicify attribute - - 'aten::sum.dim_IntList': ReduceSum('self', 'dim', keepdims='keepdim'), - 'aten::threshold_backward': ReluGrad('grad_output', 'self'), - - 'aten::fmod.Scalar': Mod('self', 'other', fmod=1), - 'aten::fmod.Tensor': Mod('self', 'other', fmod=1), - - 'aten::softshrink': Shrink('self', bias='lambd', lambd='lambd'), #yes, bias is set to 'lambd' - 'aten::hardshrink': Shrink('self', bias=0, lambd='lambd'), - 'aten::gelu' : Gelu('self'), - 'aten::max' : ReduceMax('self', keepdims=1), - 'aten::min' : ReduceMin('self', keepdims=1), - 'aten::_cat': Concat('tensors', 'dim'), - 'aten::fill_.Scalar': ConstantOfShape('self', value='value'), - - 'aten::ne.Scalar':MakeTorchFallback(), - 'aten::ne.Scalar_out': MakeTorchFallback(), - 'aten::ne.Tensor_out': MakeTorchFallback(), - 'aten::eq.Tensor': MakeTorchFallback(), - 'aten::eq.Tensor_out':MakeTorchFallback(), - 'aten::bitwise_and.Tensor_out' : MakeTorchFallback(), - 'aten::masked_select' : MakeTorchFallback(), - 'aten::_local_scalar_dense' : MakeTorchFallback(), - 'aten::gt.Scalar_out' : MakeTorchFallback(), + "aten::empty.memory_format": SignatureOnly(), + "aten::empty_strided": SignatureOnly(), + "aten::zero_": SignatureOnly(), + "aten::copy_": SignatureOnly(), + "aten::_reshape_alias": SignatureOnly(), + "aten::view": SignatureOnly(), + "aten::_copy_from_and_resize": SignatureOnly(), + "aten::as_strided": SignatureOnly(), + # manually implement Slice using stride and offset. + "aten::slice.Tensor": SignatureOnly(), + "aten::addmm": Gemm("mat1", "mat2", "self", alpha="alpha", beta="beta"), + "aten::add_.Tensor": SignatureOnly(), + "aten::t": Transpose("self"), + "aten::mm": MatMul("self", "mat2"), + "aten::zeros_like": ConstantOfShape( + Shape("self") + ), # the default constant is 0, so don't need to speicify attribute + "aten::sum.dim_IntList": ReduceSum("self", "dim", keepdims="keepdim"), + "aten::threshold_backward": ReluGrad("grad_output", "self"), + "aten::fmod.Scalar": Mod("self", "other", fmod=1), + "aten::fmod.Tensor": Mod("self", "other", fmod=1), + "aten::softshrink": Shrink("self", bias="lambd", lambd="lambd"), # yes, bias is set to 'lambd' + "aten::hardshrink": Shrink("self", bias=0, lambd="lambd"), + "aten::gelu": Gelu("self"), + "aten::max": ReduceMax("self", keepdims=1), + "aten::min": ReduceMin("self", keepdims=1), + "aten::_cat": Concat("tensors", "dim"), + "aten::fill_.Scalar": ConstantOfShape("self", value="value"), + "aten::ne.Scalar": MakeTorchFallback(), + "aten::ne.Scalar_out": MakeTorchFallback(), + "aten::ne.Tensor_out": MakeTorchFallback(), + "aten::eq.Tensor": MakeTorchFallback(), + "aten::eq.Tensor_out": MakeTorchFallback(), + "aten::bitwise_and.Tensor_out": MakeTorchFallback(), + "aten::masked_select": MakeTorchFallback(), + "aten::_local_scalar_dense": MakeTorchFallback(), + "aten::gt.Scalar_out": MakeTorchFallback(), } # Signature of gelu_backward was changed in this commit id 983ba5e585485ed61a0c0012ef6944f5685e3d97 and PR 61439 # This is done to make sure it is backward and future compatible if version.parse(torch.__version__) < version.parse(TORCH_API_CHANGE_VERSION): - hand_implemented['aten::gelu_backward'] = GeluGrad('grad', 'self') + hand_implemented["aten::gelu_backward"] = GeluGrad("grad", "self") else: - hand_implemented['aten::gelu_backward'] = GeluGrad('grad_output', 'self') + hand_implemented["aten::gelu_backward"] = GeluGrad("grad_output", "self") -ops = {**ops, **hand_implemented} +ops = {**ops, **hand_implemented} # TODO: this is a temporary allowlist for ops need type promotion # Need to enhance the support for onnx type constrains to automatically # resolve whether the op need type promotion. # Will remove this list in the future. -type_promotion_ops = (*type_promotion_ops, 'aten::gelu_backward') +type_promotion_ops = (*type_promotion_ops, "aten::gelu_backward") diff --git a/orttraining/orttraining/eager/opgen/opgen/custom_ops.py b/orttraining/orttraining/eager/opgen/opgen/custom_ops.py index d2631277cc541..4ba9d0af6baa7 100644 --- a/orttraining/orttraining/eager/opgen/opgen/custom_ops.py +++ b/orttraining/orttraining/eager/opgen/opgen/custom_ops.py @@ -2,17 +2,18 @@ from opgen.generator import AttrType, ONNXAttr -from opgen.generator import \ - ORTGen as ORTGen, \ - ONNXOp as ONNXOp, \ - SignatureOnly as SignatureOnly, \ - MakeTorchFallback as MakeTorchFallback +from opgen.generator import ( + ORTGen as ORTGen, + ONNXOp as ONNXOp, + SignatureOnly as SignatureOnly, + MakeTorchFallback as MakeTorchFallback, +) from opgen.onnxops import * ops = { - 'gemm': Gemm('A', 'B', 'C', 'alpha', 'beta', 'transA', 'transB'), - 'batchnorm_inplace': BatchNormalization('X', 'scale', 'B', 'input_mean', 'input_var', 'epsilon', 'momentum', 1) + "gemm": Gemm("A", "B", "C", "alpha", "beta", "transA", "transB"), + "batchnorm_inplace": BatchNormalization("X", "scale", "B", "input_mean", "input_var", "epsilon", "momentum", 1), } type_promotion_ops = {} diff --git a/orttraining/orttraining/eager/opgen/opgen/generator.py b/orttraining/orttraining/eager/opgen/opgen/generator.py index 6bfc75434485d..fa4532013de58 100644 --- a/orttraining/orttraining/eager/opgen/opgen/generator.py +++ b/orttraining/orttraining/eager/opgen/opgen/generator.py @@ -11,622 +11,643 @@ import opgen.ast as ast import opgen.writer as writer + class Outputs: - def __init__(self, count: int): - self.count = count - self.name = None + def __init__(self, count: int): + self.count = count + self.name = None + + def __str__(self): + return self.name if self.name else f"" - def __str__(self): - return self.name if self.name else f'' class AttrType: - FLOAT = 'at::ScalarType::Float' - FLOATS = '' - INT = 'at::ScalarType::Int' - INTS = '' - STRING = 'const char*' - STRINGS = '' - TENSOR = 'at::Tensor' - LONG = 'at::ScalarType::Long' + FLOAT = "at::ScalarType::Float" + FLOATS = "" + INT = "at::ScalarType::Int" + INTS = "" + STRING = "const char*" + STRINGS = "" + TENSOR = "at::Tensor" + LONG = "at::ScalarType::Long" + class ONNXAttr: - def __init__(self, value, type: AttrType=None): - self.value = value - self.type = type + def __init__(self, value, type: AttrType = None): + self.value = value + self.type = type + class ONNXOpEvalContext: - ops: List['ONNXOp'] + ops: List["ONNXOp"] - def __init__(self): - self.ops = [] + def __init__(self): + self.ops = [] + + def prepare_outputs(self): + for i, op in enumerate(self.ops): + op.outputs.name = f"ort_outputs_{i}_{op.name}" - def prepare_outputs(self): - for i, op in enumerate(self.ops): - op.outputs.name = f'ort_outputs_{i}_{op.name}' class ONNXOp: - def __init__(self, - name: str, - outputs: int, - input_types: List, - *inputs: Union[str, Outputs], - **attributes: Optional[Union[str, Outputs]]): - self.name = name - self.outputs = Outputs(outputs) - self.inputs = inputs - self.attributes = attributes - self.domain = None - self.input_types = input_types + def __init__( + self, + name: str, + outputs: int, + input_types: List, + *inputs: Union[str, Outputs], + **attributes: Optional[Union[str, Outputs]], + ): + self.name = name + self.outputs = Outputs(outputs) + self.inputs = inputs + self.attributes = attributes + self.domain = None + self.input_types = input_types + + def eval(self, ctx: ONNXOpEvalContext): + evaluated_inputs = [] - def eval(self, ctx: ONNXOpEvalContext): - evaluated_inputs = [] + for i in self.inputs: + if isinstance(i, ONNXOp): + i = i.eval(ctx) + evaluated_inputs.append(i) - for i in self.inputs: - if isinstance(i, ONNXOp): - i = i.eval(ctx) - evaluated_inputs.append(i) + self.inputs = evaluated_inputs - self.inputs = evaluated_inputs + ctx.ops.append(self) - ctx.ops.append(self) + return self.outputs - return self.outputs class SignatureOnly(ONNXOp): - def __init__(self): super().__init__(None, 0, []) + def __init__(self): + super().__init__(None, 0, []) + class MakeTorchFallback(ONNXOp): - def __init__(self): super().__init__(None, 0, []) + def __init__(self): + super().__init__(None, 0, []) + class FunctionGenerationError(NotImplementedError): - def __init__(self, cpp_func: ast.FunctionDecl, message: str): - super().__init__(f'{message} ({cpp_func.identifier})') + def __init__(self, cpp_func: ast.FunctionDecl, message: str): + super().__init__(f"{message} ({cpp_func.identifier})") + class MappedOpFunction: - def __init__( - self, - op_namespace: str, - mapped_op_name: str, - onnx_op: ONNXOp, - cpp_func: ast.FunctionDecl, - signature_only: bool, - make_torch_fallback: bool): - self.op_namespace = op_namespace - self.mapped_op_name = mapped_op_name - self.onnx_op = onnx_op - self.cpp_func = cpp_func - self.signature_only = signature_only - self.make_torch_fallback = make_torch_fallback + def __init__( + self, + op_namespace: str, + mapped_op_name: str, + onnx_op: ONNXOp, + cpp_func: ast.FunctionDecl, + signature_only: bool, + make_torch_fallback: bool, + ): + self.op_namespace = op_namespace + self.mapped_op_name = mapped_op_name + self.onnx_op = onnx_op + self.cpp_func = cpp_func + self.signature_only = signature_only + self.make_torch_fallback = make_torch_fallback + class ORTGen: - _mapped_ops: Dict[str, ONNXOp] - _custom_ops: bool - - def __init__( - self, - ops: Optional[Dict[str, ONNXOp]] = None, - custom_ops : bool = False, - type_promotion_ops : List = ()): - self._mapped_ops = {} - if ops: - self.register_many(ops) - self._custom_ops = custom_ops - self.type_promotion_ops = type_promotion_ops - - def register(self, aten_name: str, onnx_op: ONNXOp): - self._mapped_ops[aten_name] = onnx_op - - def register_many(self, ops: Dict[str, ONNXOp]): - for k, v in ops.items(): - self.register(k, v) - - def run(self, cpp_parser: parser.CPPParser, writer: writer.SourceWriter): - self._write_file_prelude(writer) - - generated_funcs = [] - current_ns = None - - for mapped_func in self._parse_mapped_function_decls(cpp_parser): - del self._mapped_ops[mapped_func.mapped_op_name] - generated_funcs.append(mapped_func) - - ns = mapped_func.op_namespace - if current_ns and current_ns != ns: + _mapped_ops: Dict[str, ONNXOp] + _custom_ops: bool + + def __init__( + self, ops: Optional[Dict[str, ONNXOp]] = None, custom_ops: bool = False, type_promotion_ops: List = () + ): + self._mapped_ops = {} + if ops: + self.register_many(ops) + self._custom_ops = custom_ops + self.type_promotion_ops = type_promotion_ops + + def register(self, aten_name: str, onnx_op: ONNXOp): + self._mapped_ops[aten_name] = onnx_op + + def register_many(self, ops: Dict[str, ONNXOp]): + for k, v in ops.items(): + self.register(k, v) + + def run(self, cpp_parser: parser.CPPParser, writer: writer.SourceWriter): + self._write_file_prelude(writer) + + generated_funcs = [] current_ns = None - writer.pop_namespace() - if ns != current_ns: - current_ns = ns + + for mapped_func in self._parse_mapped_function_decls(cpp_parser): + del self._mapped_ops[mapped_func.mapped_op_name] + generated_funcs.append(mapped_func) + + ns = mapped_func.op_namespace + if current_ns and current_ns != ns: + current_ns = None + writer.pop_namespace() + if ns != current_ns: + current_ns = ns + writer.writeline() + writer.push_namespace(ns) + + writer.writeline() + if mapped_func.cpp_func.torch_func: + writer.writeline(f"// {mapped_func.cpp_func.torch_func.torch_schema}") + + self._write_function_signature(writer, mapped_func.cpp_func) + if mapped_func.signature_only: + writer.writeline(";") + else: + writer.writeline(" {") + writer.push_indent() + self._write_function_body(writer, mapped_func) + writer.pop_indent() + writer.writeline("}") + + if current_ns: + current_ns = None + writer.pop_namespace() + + if not self._custom_ops: + self._write_function_registrations(writer, generated_funcs) + else: + self._write_custom_ops_registrations(writer, generated_funcs) + self._write_file_postlude(writer) + + if len(self._mapped_ops) > 0: + raise Exception( + "Torch operation(s) could not be parsed for mapping: " + + ", ".join([f"'{o}'" for o in self._mapped_ops.keys()]) + ) + + def _write_file_prelude(self, writer: writer.SourceWriter): + writer.writeline("// AUTO-GENERATED CODE! - DO NOT EDIT!") + writer.writeline(f'// $ python {" ".join(sys.argv)}') + writer.writeline() + writer.writeline('#include "python/onnxruntime_pybind_state_common.h"') + writer.writeline() + writer.writeline("#include ") + writer.writeline("#include ") writer.writeline() - writer.push_namespace(ns) + writer.writeline("#include ") + writer.writeline() + writer.writeline('#include "ort_tensor.h"') + writer.writeline('#include "ort_aten.h"') + writer.writeline('#include "ort_log.h"') + writer.writeline() + writer.push_namespace("torch_ort") + writer.push_namespace("eager") + writer.writeline() + writer.writeline("using namespace at;") + writer.writeline("using NodeAttributes = onnxruntime::NodeAttributes;") - writer.writeline() - if mapped_func.cpp_func.torch_func: - writer.writeline(f'// {mapped_func.cpp_func.torch_func.torch_schema}') + def _write_file_postlude(self, writer: writer.SourceWriter): + writer.pop_namespaces() - self._write_function_signature(writer, mapped_func.cpp_func) - if mapped_func.signature_only: - writer.writeline(';') - else: - writer.writeline(' {') + def _write_function_signature(self, writer: writer.SourceWriter, cpp_func: ast.FunctionDecl): + cpp_func.return_type.write(writer) + writer.write(f" {cpp_func.identifier.value}(") writer.push_indent() - self._write_function_body(writer, mapped_func) + for param_list_member in cpp_func.parameters: + writer.writeline() + if isinstance(param_list_member.member.parameter_type, ast.KWArgsSentinelType): + writer.write("// ") + param_list_member.write(writer) writer.pop_indent() - writer.writeline('}') - - if current_ns: - current_ns = None - writer.pop_namespace() - - if not self._custom_ops: - self._write_function_registrations(writer, generated_funcs) - else: - self._write_custom_ops_registrations(writer, generated_funcs) - self._write_file_postlude(writer) - - if len(self._mapped_ops) > 0: - raise Exception('Torch operation(s) could not be parsed for mapping: ' + \ - ', '.join([f'\'{o}\'' for o in self._mapped_ops.keys()])) - - def _write_file_prelude(self, writer: writer.SourceWriter): - writer.writeline('// AUTO-GENERATED CODE! - DO NOT EDIT!') - writer.writeline(f'// $ python {" ".join(sys.argv)}') - writer.writeline() - writer.writeline('#include "python/onnxruntime_pybind_state_common.h"') - writer.writeline() - writer.writeline('#include ') - writer.writeline('#include ') - writer.writeline() - writer.writeline('#include ') - writer.writeline() - writer.writeline('#include "ort_tensor.h"') - writer.writeline('#include "ort_aten.h"') - writer.writeline('#include "ort_log.h"') - writer.writeline() - writer.push_namespace('torch_ort') - writer.push_namespace('eager') - writer.writeline() - writer.writeline('using namespace at;') - writer.writeline('using NodeAttributes = onnxruntime::NodeAttributes;') - - def _write_file_postlude(self, writer: writer.SourceWriter): - writer.pop_namespaces() - - def _write_function_signature( - self, - writer: writer.SourceWriter, - cpp_func: ast.FunctionDecl): - cpp_func.return_type.write(writer) - writer.write(f' {cpp_func.identifier.value}(') - writer.push_indent() - for param_list_member in cpp_func.parameters: - writer.writeline() - if isinstance( - param_list_member.member.parameter_type, - ast.KWArgsSentinelType): - writer.write('// ') - param_list_member.write(writer) - writer.pop_indent() - writer.write(')') - - def _write_cpu_fall_back(self, - writer: writer.SourceWriter, - mapped_func: MappedOpFunction): - onnx_op, cpp_func = mapped_func.onnx_op, mapped_func.cpp_func - #return at::native::call_fallback_fn< - # &at::native::cpu_fallback, - # ATEN_OP(eq_Tensor)>::call(self, other); - writer.writeline('return native::call_fallback_fn<') - writer.push_indent() - writer.writeline('&native::cpu_fallback,') - writer.write('ATEN_OP(') - writer.write(cpp_func.identifier.value) - writer.write(')>::call(') - - params = ', '.join([p.member.identifier.value for p \ - in cpp_func.parameters if p.member.identifier]) - writer.write(params) - writer.writeline(');') - writer.pop_indent() - - - def _write_function_body( - self, - writer: writer.SourceWriter, - mapped_func: MappedOpFunction): - onnx_op, cpp_func = mapped_func.onnx_op, mapped_func.cpp_func - - assert(len(cpp_func.parameters) > 0) - - # Debug Logging - log_params = ', '.join([p.member.identifier.value for p \ - in cpp_func.parameters if p.member.identifier]) - writer.writeline(f'ORT_LOG_FN({log_params});') - writer.writeline() - - if mapped_func.make_torch_fallback: - return self._write_cpu_fall_back(writer, mapped_func) - - # Eval the outer ONNX op to produce a topologically ordered list of ops - ctx = ONNXOpEvalContext() - onnx_op.eval(ctx) - ctx.prepare_outputs() - - # Fetch the ORT invoker from an at::Tensor.device() - # FIXME: find the first at::Tensor param anywhere in the signature - # instead of simply the first parameter? - first_param = cpp_func.parameters[0].member - # Check if the first parameter is tensorlist and if yes it's size should be > 0 - if first_param.parameter_type.desugar().identifier_tokens[0].value == 'TensorList': - writer.write('assert(') - writer.write(first_param.identifier.value) - writer.writeline('.size()>0);') - - # generate the type check - need_type_check = False - if not self._custom_ops: - for onnx_op_index, onnx_op in enumerate(ctx.ops): - for op_input in onnx_op.inputs: - if not isinstance(op_input, Outputs): - need_type_check = True - break - if need_type_check: - writer.write('if (') - i = 0 - for onnx_op_index, onnx_op in enumerate(ctx.ops): - for idx, op_input in enumerate(onnx_op.inputs): - if isinstance(op_input, Outputs): - continue - writer.writeline(' || ' if i > 0 else '') - if i == 0: - writer.push_indent() - cpp_param = cpp_func.get_parameter(op_input) - supported_types = ','.join([type for type in onnx_op.input_types[idx]]) - writer.write('!IsSupportedType(%s, {%s})' % (cpp_param.identifier.value, supported_types)) - i += 1 - writer.writeline(') {') - self._write_cpu_fall_back(writer, mapped_func) - writer.pop_indent() - writer.writeline('}') - - if not isinstance( - first_param.parameter_type.desugar(), - ast.ConcreteType) or 'Tensor' not in first_param.parameter_type.desugar().identifier_tokens[0].value: - raise FunctionGenerationError( - cpp_func, - 'First parameter must be an at::Tensor') - - writer.write('auto& invoker = GetORTInvoker(') - writer.write(first_param.identifier.value) - if first_param.parameter_type.desugar().identifier_tokens[0].value == 'TensorList': - writer.write('[0]') - writer.writeline('.device());') - writer.writeline() - - # FIXME: warn if we have not consumed all torch parameters (either as - # an ORT input or ORT attribute). - - # Perform kernel fission on the ATen op to yield a chain of ORT Invokes - # e.g. aten::add(x, y, α) -> onnx::Add(x, onnx::Mul(α, y)) - - # whether need type promotion - need_type_promotion = False - if mapped_func.mapped_op_name in self.type_promotion_ops: - types_from_tensor = [] - types_from_scalar = [] - for onnx_op_index, onnx_op in enumerate(ctx.ops): - for op_input in onnx_op.inputs: - if isinstance(op_input, Outputs): - continue - cpp_param = cpp_func.get_parameter(op_input) - if cpp_param: - if cpp_param.parameter_type.desugar().identifier_tokens[0].value == 'Tensor': - types_from_tensor.append(f'{op_input}.scalar_type()') - elif cpp_param.parameter_type.desugar().identifier_tokens[0].value == 'Scalar': - types_from_scalar.append(f'{op_input}.type()') - if len(types_from_tensor) > 0 or len(types_from_scalar) > 0 : - need_type_promotion = True - writer.writeline('auto promoted_type = PromoteScalarTypesWithCategory({%s}, {%s});' - % (','.join(types_from_tensor), ','.join(types_from_scalar))) + writer.write(")") + + def _write_cpu_fall_back(self, writer: writer.SourceWriter, mapped_func: MappedOpFunction): + onnx_op, cpp_func = mapped_func.onnx_op, mapped_func.cpp_func + # return at::native::call_fallback_fn< + # &at::native::cpu_fallback, + # ATEN_OP(eq_Tensor)>::call(self, other); + writer.writeline("return native::call_fallback_fn<") + writer.push_indent() + writer.writeline("&native::cpu_fallback,") + writer.write("ATEN_OP(") + writer.write(cpp_func.identifier.value) + writer.write(")>::call(") + + params = ", ".join([p.member.identifier.value for p in cpp_func.parameters if p.member.identifier]) + writer.write(params) + writer.writeline(");") + writer.pop_indent() + + def _write_function_body(self, writer: writer.SourceWriter, mapped_func: MappedOpFunction): + onnx_op, cpp_func = mapped_func.onnx_op, mapped_func.cpp_func + + assert len(cpp_func.parameters) > 0 + + # Debug Logging + log_params = ", ".join([p.member.identifier.value for p in cpp_func.parameters if p.member.identifier]) + writer.writeline(f"ORT_LOG_FN({log_params});") writer.writeline() - for onnx_op_index, onnx_op in enumerate(ctx.ops): - # Torch -> ORT inputs - for op_input in onnx_op.inputs: - if isinstance(op_input, Outputs): - continue - cpp_param = cpp_func.get_parameter(op_input) - writer.write(f'auto ort_input_{op_input} = ') - writer.writeline(f'create_ort_value(invoker, {op_input});') - if need_type_promotion: - type_func_str = 'type()' if cpp_param.parameter_type.desugar().identifier_tokens[0].value == 'Scalar' else 'scalar_type()' - writer.write(f'if ({op_input}.{type_func_str} != *promoted_type)') - writer.writeline('{') - writer.push_indent() - writer.writeline(f'ort_input_{op_input} = CastToType(invoker, ort_input_{op_input}, *promoted_type);') - writer.pop_indent() - writer.writeline('}') - - # Torch kwargs -> ORT attributes - attrs = { k:v for k, v in onnx_op.attributes.items() if v and v.value } - if len(attrs) > 0: - attrs_arg = 'attrs' + if mapped_func.make_torch_fallback: + return self._write_cpu_fall_back(writer, mapped_func) + + # Eval the outer ONNX op to produce a topologically ordered list of ops + ctx = ONNXOpEvalContext() + onnx_op.eval(ctx) + ctx.prepare_outputs() + + # Fetch the ORT invoker from an at::Tensor.device() + # FIXME: find the first at::Tensor param anywhere in the signature + # instead of simply the first parameter? + first_param = cpp_func.parameters[0].member + # Check if the first parameter is tensorlist and if yes it's size should be > 0 + if first_param.parameter_type.desugar().identifier_tokens[0].value == "TensorList": + writer.write("assert(") + writer.write(first_param.identifier.value) + writer.writeline(".size()>0);") + + # generate the type check + need_type_check = False + if not self._custom_ops: + for onnx_op_index, onnx_op in enumerate(ctx.ops): + for op_input in onnx_op.inputs: + if not isinstance(op_input, Outputs): + need_type_check = True + break + if need_type_check: + writer.write("if (") + i = 0 + for onnx_op_index, onnx_op in enumerate(ctx.ops): + for idx, op_input in enumerate(onnx_op.inputs): + if isinstance(op_input, Outputs): + continue + writer.writeline(" || " if i > 0 else "") + if i == 0: + writer.push_indent() + cpp_param = cpp_func.get_parameter(op_input) + supported_types = ",".join([type for type in onnx_op.input_types[idx]]) + writer.write("!IsSupportedType(%s, {%s})" % (cpp_param.identifier.value, supported_types)) + i += 1 + writer.writeline(") {") + self._write_cpu_fall_back(writer, mapped_func) + writer.pop_indent() + writer.writeline("}") + + if ( + not isinstance(first_param.parameter_type.desugar(), ast.ConcreteType) + or "Tensor" not in first_param.parameter_type.desugar().identifier_tokens[0].value + ): + raise FunctionGenerationError(cpp_func, "First parameter must be an at::Tensor") + + writer.write("auto& invoker = GetORTInvoker(") + writer.write(first_param.identifier.value) + if first_param.parameter_type.desugar().identifier_tokens[0].value == "TensorList": + writer.write("[0]") + writer.writeline(".device());") writer.writeline() - writer.writeline(f'NodeAttributes {attrs_arg}({len(attrs)});') - - for attr_name, attr in attrs.items(): - writer.write(f'{attrs_arg}["{attr_name}"] = ') - writer.writeline('create_ort_attribute(') - writer.push_indent() - writer.write(f'"{attr_name}", {attr.value}') - if attr.type.startswith('at::ScalarType::'): - writer.write(f', {attr.type}') - elif attr.type == AttrType.TENSOR: - writer.write(f', true') - elif attr.type != AttrType.STRING: - raise FunctionGenerationError( - cpp_func, - f'Unsure how how to map ONNX op "{onnx_op.name}" attribute ' + - f'"{attr_name}" of type "{attr.type}" to a call to ' + - 'create_ort_attribute. Please teach generator.py.') - writer.writeline(');') - writer.pop_indent() - attrs_arg = f'&{attrs_arg}' - else: - attrs_arg = 'nullptr' - - # Outputs vector - writer.writeline() - writer.write(f'std::vector {onnx_op.outputs}') - writer.writeline(f'({onnx_op.outputs.count});') - - return_info = cpp_func.torch_func.return_type if cpp_func.torch_func else None - in_place_params = {} - - if return_info: - for input_index, op_input in enumerate(onnx_op.inputs): - if isinstance(op_input, Outputs): - continue - - # See if this input is aliased as an in-place tensor - cpp_param = cpp_func.get_parameter(op_input) - if cpp_param: - for torch_p in cpp_param.torch_param: - if isinstance(return_info, ast.TupleType): - for output_index, output_param in enumerate(return_info.elements): - assert isinstance(output_param.member, ast.TupleMemberType), "output_param.member must be of TupleMemberType" - output_alias = self._get_alias_info(output_param.member.element_type) - if output_alias and self._get_alias_info(torch_p) == output_alias and output_alias.is_writable: - writer.writeline(f'{onnx_op.outputs}[{output_index}] = ort_input_{onnx_op.inputs[input_index]};') - in_place_params[output_index] = cpp_param.identifier.value - break - else: - output_alias = self._get_alias_info(return_info) - if output_alias and self._get_alias_info(torch_p) == output_alias and output_alias.is_writable: - writer.writeline(f'{onnx_op.outputs}[0] = ort_input_{onnx_op.inputs[input_index]};') - in_place_params[0] = cpp_param.identifier.value - break - - if len(in_place_params) != 0 and len(in_place_params) != (len(return_info.elements) if isinstance(return_info, ast.TupleType) else 1): - raise Exception(f'Cannot mix and match inplace with non-inplace parameters - function: {cpp_func.identifier.value} ' + - f'in_place_params={in_place_params}, return_elements={return_info.elements}') - - # Perform the invocation - writer.writeline() - if onnx_op_index == 0: - writer.write('auto ') - writer.writeline(f'status = invoker.Invoke("{onnx_op.name}", {{') - writer.push_indent() - for op_input in onnx_op.inputs: - if isinstance(op_input, Outputs): - if op_input.count != 1: - raise FunctionGenerationError( - cpp_func, - 'multiple outputs not supported') - op_input = f'{op_input}[0]' + + # FIXME: warn if we have not consumed all torch parameters (either as + # an ORT input or ORT attribute). + + # Perform kernel fission on the ATen op to yield a chain of ORT Invokes + # e.g. aten::add(x, y, α) -> onnx::Add(x, onnx::Mul(α, y)) + + # whether need type promotion + need_type_promotion = False + if mapped_func.mapped_op_name in self.type_promotion_ops: + types_from_tensor = [] + types_from_scalar = [] + for onnx_op_index, onnx_op in enumerate(ctx.ops): + for op_input in onnx_op.inputs: + if isinstance(op_input, Outputs): + continue + cpp_param = cpp_func.get_parameter(op_input) + if cpp_param: + if cpp_param.parameter_type.desugar().identifier_tokens[0].value == "Tensor": + types_from_tensor.append(f"{op_input}.scalar_type()") + elif cpp_param.parameter_type.desugar().identifier_tokens[0].value == "Scalar": + types_from_scalar.append(f"{op_input}.type()") + if len(types_from_tensor) > 0 or len(types_from_scalar) > 0: + need_type_promotion = True + writer.writeline( + "auto promoted_type = PromoteScalarTypesWithCategory({%s}, {%s});" + % (",".join(types_from_tensor), ",".join(types_from_scalar)) + ) + writer.writeline() + + for onnx_op_index, onnx_op in enumerate(ctx.ops): + # Torch -> ORT inputs + for op_input in onnx_op.inputs: + if isinstance(op_input, Outputs): + continue + cpp_param = cpp_func.get_parameter(op_input) + writer.write(f"auto ort_input_{op_input} = ") + writer.writeline(f"create_ort_value(invoker, {op_input});") + if need_type_promotion: + type_func_str = ( + "type()" + if cpp_param.parameter_type.desugar().identifier_tokens[0].value == "Scalar" + else "scalar_type()" + ) + writer.write(f"if ({op_input}.{type_func_str} != *promoted_type)") + writer.writeline("{") + writer.push_indent() + writer.writeline( + f"ort_input_{op_input} = CastToType(invoker, ort_input_{op_input}, *promoted_type);" + ) + writer.pop_indent() + writer.writeline("}") + + # Torch kwargs -> ORT attributes + attrs = {k: v for k, v in onnx_op.attributes.items() if v and v.value} + if len(attrs) > 0: + attrs_arg = "attrs" + writer.writeline() + writer.writeline(f"NodeAttributes {attrs_arg}({len(attrs)});") + + for attr_name, attr in attrs.items(): + writer.write(f'{attrs_arg}["{attr_name}"] = ') + writer.writeline("create_ort_attribute(") + writer.push_indent() + writer.write(f'"{attr_name}", {attr.value}') + if attr.type.startswith("at::ScalarType::"): + writer.write(f", {attr.type}") + elif attr.type == AttrType.TENSOR: + writer.write(f", true") + elif attr.type != AttrType.STRING: + raise FunctionGenerationError( + cpp_func, + f'Unsure how how to map ONNX op "{onnx_op.name}" attribute ' + + f'"{attr_name}" of type "{attr.type}" to a call to ' + + "create_ort_attribute. Please teach generator.py.", + ) + writer.writeline(");") + writer.pop_indent() + attrs_arg = f"&{attrs_arg}" + else: + attrs_arg = "nullptr" + + # Outputs vector + writer.writeline() + writer.write(f"std::vector {onnx_op.outputs}") + writer.writeline(f"({onnx_op.outputs.count});") + + return_info = cpp_func.torch_func.return_type if cpp_func.torch_func else None + in_place_params = {} + + if return_info: + for input_index, op_input in enumerate(onnx_op.inputs): + if isinstance(op_input, Outputs): + continue + + # See if this input is aliased as an in-place tensor + cpp_param = cpp_func.get_parameter(op_input) + if cpp_param: + for torch_p in cpp_param.torch_param: + if isinstance(return_info, ast.TupleType): + for output_index, output_param in enumerate(return_info.elements): + assert isinstance( + output_param.member, ast.TupleMemberType + ), "output_param.member must be of TupleMemberType" + output_alias = self._get_alias_info(output_param.member.element_type) + if ( + output_alias + and self._get_alias_info(torch_p) == output_alias + and output_alias.is_writable + ): + writer.writeline( + f"{onnx_op.outputs}[{output_index}] = ort_input_{onnx_op.inputs[input_index]};" + ) + in_place_params[output_index] = cpp_param.identifier.value + break + else: + output_alias = self._get_alias_info(return_info) + if ( + output_alias + and self._get_alias_info(torch_p) == output_alias + and output_alias.is_writable + ): + writer.writeline(f"{onnx_op.outputs}[0] = ort_input_{onnx_op.inputs[input_index]};") + in_place_params[0] = cpp_param.identifier.value + break + + if len(in_place_params) != 0 and len(in_place_params) != ( + len(return_info.elements) if isinstance(return_info, ast.TupleType) else 1 + ): + raise Exception( + f"Cannot mix and match inplace with non-inplace parameters - function: {cpp_func.identifier.value} " + + f"in_place_params={in_place_params}, return_elements={return_info.elements}" + ) + + # Perform the invocation + writer.writeline() + if onnx_op_index == 0: + writer.write("auto ") + writer.writeline(f'status = invoker.Invoke("{onnx_op.name}", {{') + writer.push_indent() + for op_input in onnx_op.inputs: + if isinstance(op_input, Outputs): + if op_input.count != 1: + raise FunctionGenerationError(cpp_func, "multiple outputs not supported") + op_input = f"{op_input}[0]" + else: + op_input = f"ort_input_{op_input}" + writer.writeline(f"std::move({op_input}),") + writer.pop_indent() + writer.write(f"}}, {onnx_op.outputs}, {attrs_arg}") + if onnx_op.domain: + writer.write(f", {onnx_op.domain}") + writer.writeline(");") + writer.writeline() + + # Assert invocation + writer.writeline("if (!status.IsOK())") + writer.push_indent() + writer.writeline("throw std::runtime_error(") + writer.push_indent() + writer.writeline('"ORT return failure status:" + status.ErrorMessage());') + writer.pop_indent() + writer.pop_indent() + writer.writeline() + + # We'll potentially return back to Torch from this op + return_outputs = onnx_op.outputs + + # TODO: Pick the right "out" Torch parameter; do not assume the first one + # TODO: Handle mutliple results + # TODO: Assert return type + + if len(in_place_params) == 0: + # tensor options + writer.write(f"at::TensorOptions tensor_options = {first_param.identifier.value}") + if first_param.parameter_type.desugar().identifier_tokens[0].value == "TensorList": + writer.write("[0]") + writer.write(".options()") + if need_type_promotion: + writer.write(".dtype(*promoted_type)") + writer.writeline(";") + + writer.writeline("return aten_tensor_from_ort(") + writer.push_indent() + if ( + isinstance(cpp_func.return_type, ast.TemplateType) + and cpp_func.return_type.identifier_tokens[-1].value == "std::vector" + ): + writer.writeline(f"{return_outputs},") + writer.writeline("tensor_options);") + else: + writer.writeline(f"std::move({return_outputs}[0]),") + writer.writeline("tensor_options);") + writer.pop_indent() + return else: - op_input = f'ort_input_{op_input}' - writer.writeline(f'std::move({op_input}),') - writer.pop_indent() - writer.write(f'}}, {onnx_op.outputs}, {attrs_arg}') - if onnx_op.domain: - writer.write(f', {onnx_op.domain}') - writer.writeline(');') - writer.writeline() - - # Assert invocation - writer.writeline('if (!status.IsOK())') - writer.push_indent() - writer.writeline('throw std::runtime_error(') - writer.push_indent() - writer.writeline('"ORT return failure status:" + status.ErrorMessage());') - writer.pop_indent() - writer.pop_indent() - writer.writeline() - - # We'll potentially return back to Torch from this op - return_outputs = onnx_op.outputs - - # TODO: Pick the right "out" Torch parameter; do not assume the first one - # TODO: Handle mutliple results - # TODO: Assert return type - - if len(in_place_params) == 0: - # tensor options - writer.write(f'at::TensorOptions tensor_options = {first_param.identifier.value}') - if first_param.parameter_type.desugar().identifier_tokens[0].value == 'TensorList': - writer.write('[0]') - writer.write('.options()') - if need_type_promotion: - writer.write('.dtype(*promoted_type)') - writer.writeline(';') - - writer.writeline('return aten_tensor_from_ort(') - writer.push_indent() - if isinstance(cpp_func.return_type, ast.TemplateType) and cpp_func.return_type.identifier_tokens[-1].value == 'std::vector': - writer.writeline(f'{return_outputs},') - writer.writeline('tensor_options);') - else: - writer.writeline(f'std::move({return_outputs}[0]),') - writer.writeline('tensor_options);') - writer.pop_indent() - return - else: - if len(in_place_params) == 1: - writer.writeline(f'return {in_place_params[0]};') + if len(in_place_params) == 1: + writer.writeline(f"return {in_place_params[0]};") + else: + if not ( + isinstance(cpp_func.return_type, ast.TemplateType) + and cpp_func.return_type.identifier_tokens[-1].value == "std::tuple" + ): + raise Exception(f"") + tensorRef = "Tensor&," * len(in_place_params) + tensorRef = tensorRef[: len(tensorRef) - 1] + writer.write(f"return std::tuple<{tensorRef}>(") + for index, key in enumerate(sorted(in_place_params)): + if index > 0: + writer.write(", ") + writer.write(in_place_params[key]) + writer.writeline(");") + + def _write_function_registrations(self, writer: writer.SourceWriter, generated_funcs: List[MappedOpFunction]): + writer.writeline() + writer.writeline("TORCH_LIBRARY_IMPL(aten, ORT, m) {") + writer.push_indent() + + for mapped_func in generated_funcs: + cpp_func, torch_func = mapped_func.cpp_func, mapped_func.cpp_func.torch_func + + if mapped_func.op_namespace: + reg_function_arg = f"{mapped_func.op_namespace}::" + else: + reg_function_arg = "" + reg_function_arg += cpp_func.identifier.value + + writer.write("m.impl(") + reg_function_arg = f"TORCH_FN({reg_function_arg})" + + writer.writeline(f'"{torch_func.identifier.value}", {reg_function_arg});') + + writer.pop_indent() + writer.writeline("}") + writer.writeline() + + def _write_custom_ops_registrations(self, writer: writer.SourceWriter, generated_funcs: List[MappedOpFunction]): + writer.writeline() + writer.writeline("void GenerateCustomOpsBindings(pybind11::module_ m) {") + writer.push_indent() + writer.writeline('ORT_LOG_INFO << "GenerateCustomOpsBindings init";') + + for mapped_func in generated_funcs: + cpp_func = mapped_func.cpp_func + writer.write("m.def(") + writer.writeline(f'"{cpp_func.identifier.value}", &{cpp_func.identifier.value});') + + writer.pop_indent() + writer.writeline("}") + writer.writeline() + + def _get_alias_info(self, torch_type_or_param: Union[ast.Type, ast.ParameterDecl]): + if isinstance(torch_type_or_param, ast.ParameterDecl): + torch_type = torch_type_or_param.parameter_type else: - if not (isinstance(cpp_func.return_type, ast.TemplateType) and cpp_func.return_type.identifier_tokens[-1].value == 'std::tuple'): - raise Exception(f'') - tensorRef = "Tensor&," * len(in_place_params) - tensorRef = tensorRef[:len(tensorRef)-1] - writer.write(f'return std::tuple<{tensorRef}>(') - for index, key in enumerate(sorted(in_place_params)): - if index > 0: - writer.write(', ') - writer.write(in_place_params[key]) - writer.writeline(');') - - def _write_function_registrations( - self, - writer: writer.SourceWriter, - generated_funcs: List[MappedOpFunction]): - writer.writeline() - writer.writeline('TORCH_LIBRARY_IMPL(aten, ORT, m) {') - writer.push_indent() - - for mapped_func in generated_funcs: - cpp_func, torch_func = mapped_func.cpp_func, mapped_func.cpp_func.torch_func - - - if mapped_func.op_namespace: - reg_function_arg = f'{mapped_func.op_namespace}::' - else: - reg_function_arg = '' - reg_function_arg += cpp_func.identifier.value - - writer.write('m.impl(') - reg_function_arg = f'TORCH_FN({reg_function_arg})' - - writer.writeline(f'"{torch_func.identifier.value}", {reg_function_arg});') - - writer.pop_indent() - writer.writeline('}') - writer.writeline() - - def _write_custom_ops_registrations( - self, - writer: writer.SourceWriter, - generated_funcs: List[MappedOpFunction]): - writer.writeline() - writer.writeline('void GenerateCustomOpsBindings(pybind11::module_ m) {') - writer.push_indent() - writer.writeline('ORT_LOG_INFO << "GenerateCustomOpsBindings init";') - - for mapped_func in generated_funcs: - cpp_func = mapped_func.cpp_func - writer.write('m.def(') - writer.writeline(f'"{cpp_func.identifier.value}", &{cpp_func.identifier.value});') - - writer.pop_indent() - writer.writeline('}') - writer.writeline() - - def _get_alias_info(self, torch_type_or_param: Union[ast.Type, ast.ParameterDecl]): - if isinstance(torch_type_or_param, ast.ParameterDecl): - torch_type = torch_type_or_param.parameter_type - else: - torch_type = torch_type_or_param - return getattr(torch_type.desugar(), 'alias_info', None) - - def _parse_mapped_function_decls(self, cpp_parser: parser.CPPParser): - for cpp_func in self._parse_function_decls(cpp_parser): - torch_func = cpp_func.torch_func - if not torch_func: - op_namespace = None - op_name = cpp_func.identifier.value - else: - op_name = torch_func.identifier.value - - try: - op_namespace = op_name[0:op_name.index('::')] - op_namewithoutnamespace = op_name[len(op_namespace) + 2:] - except: - op_namespace = None - op_namewithoutnamespace = op_name - - cpp_func.identifier.value = op_namewithoutnamespace.replace('.', '_') - - onnx_op = self._mapped_ops.get(op_name) - if not onnx_op: - continue - - yield MappedOpFunction( - op_namespace, - op_name, - onnx_op, - cpp_func, - isinstance(onnx_op, SignatureOnly), - isinstance(onnx_op, MakeTorchFallback)) - - def _parse_function_decls(self, cpp_parser: parser.CPPParser): - # Parse the C++ declarations - tu = cpp_parser.parse_translation_unit() - - # Parse the Torch schema from the JSON comment that follows each C++ decl - # and link associated Torch and C++ decls (functions, parameters, returns) - for cpp_func in tu: - hasSchema = False - if cpp_func.semicolon and cpp_func.semicolon.trailing_trivia: - for trivia in cpp_func.semicolon.trailing_trivia: - if trivia.kind == lexer.TokenKind.SINGLE_LINE_COMMENT: - yield self._parse_and_link_torch_function_decl(cpp_func, trivia) - hasSchema = True - break - - if not hasSchema: - # customops might not have torch schema - cpp_func.torch_func = None - yield cpp_func - - def _parse_and_link_torch_function_decl( - self, - cpp_func: ast.FunctionDecl, - torch_schema_comment_trivia: lexer.Token): - metadata = json.loads(torch_schema_comment_trivia.value.lstrip('//')) - schema = metadata['schema'] - - schema_parser = parser.torch_create_from_string(schema) - schema_parser.set_source_location(cpp_func.semicolon.location) - torch_func = schema_parser.parse_function() - - torch_func.torch_schema = schema - torch_func.torch_dispatch = metadata['dispatch'] == 'True' - torch_func.torch_default = metadata['default'] == 'True' - - cpp_func.torch_func = torch_func - - if cpp_func.return_type: - cpp_func.return_type.torch_type = torch_func.return_type - - # Synthesize KWArgsSentinelType in the C++ declaration if we have one - for i, torch_param in enumerate([p.member for p in torch_func.parameters]): - if isinstance(torch_param.parameter_type, ast.KWArgsSentinelType): - cpp_func.parameters.members.insert(i, ast.SyntaxListMember( - torch_param, - lexer.Token(None, lexer.TokenKind.COMMA, ','))) - break - - # Link Torch parameters to their C++ counterparts, special casing - # TensorOptions parameters - for i, cpp_param in enumerate([p.member for p in cpp_func.parameters]): - if not getattr(cpp_param, 'torch_param', None): - cpp_param.torch_param = [] - - torch_param_range = 1 - if isinstance(cpp_param.parameter_type.desugar(), ast.TensorOptionsType): - torch_param_range = 4 - - for j in range(torch_param_range): - torch_param = torch_func.parameters[i + j].member - cpp_param.torch_param.append(torch_param) - - return cpp_func \ No newline at end of file + torch_type = torch_type_or_param + return getattr(torch_type.desugar(), "alias_info", None) + + def _parse_mapped_function_decls(self, cpp_parser: parser.CPPParser): + for cpp_func in self._parse_function_decls(cpp_parser): + torch_func = cpp_func.torch_func + if not torch_func: + op_namespace = None + op_name = cpp_func.identifier.value + else: + op_name = torch_func.identifier.value + + try: + op_namespace = op_name[0 : op_name.index("::")] + op_namewithoutnamespace = op_name[len(op_namespace) + 2 :] + except: + op_namespace = None + op_namewithoutnamespace = op_name + + cpp_func.identifier.value = op_namewithoutnamespace.replace(".", "_") + + onnx_op = self._mapped_ops.get(op_name) + if not onnx_op: + continue + + yield MappedOpFunction( + op_namespace, + op_name, + onnx_op, + cpp_func, + isinstance(onnx_op, SignatureOnly), + isinstance(onnx_op, MakeTorchFallback), + ) + + def _parse_function_decls(self, cpp_parser: parser.CPPParser): + # Parse the C++ declarations + tu = cpp_parser.parse_translation_unit() + + # Parse the Torch schema from the JSON comment that follows each C++ decl + # and link associated Torch and C++ decls (functions, parameters, returns) + for cpp_func in tu: + hasSchema = False + if cpp_func.semicolon and cpp_func.semicolon.trailing_trivia: + for trivia in cpp_func.semicolon.trailing_trivia: + if trivia.kind == lexer.TokenKind.SINGLE_LINE_COMMENT: + yield self._parse_and_link_torch_function_decl(cpp_func, trivia) + hasSchema = True + break + + if not hasSchema: + # customops might not have torch schema + cpp_func.torch_func = None + yield cpp_func + + def _parse_and_link_torch_function_decl(self, cpp_func: ast.FunctionDecl, torch_schema_comment_trivia: lexer.Token): + metadata = json.loads(torch_schema_comment_trivia.value.lstrip("//")) + schema = metadata["schema"] + + schema_parser = parser.torch_create_from_string(schema) + schema_parser.set_source_location(cpp_func.semicolon.location) + torch_func = schema_parser.parse_function() + + torch_func.torch_schema = schema + torch_func.torch_dispatch = metadata["dispatch"] == "True" + torch_func.torch_default = metadata["default"] == "True" + + cpp_func.torch_func = torch_func + + if cpp_func.return_type: + cpp_func.return_type.torch_type = torch_func.return_type + + # Synthesize KWArgsSentinelType in the C++ declaration if we have one + for i, torch_param in enumerate([p.member for p in torch_func.parameters]): + if isinstance(torch_param.parameter_type, ast.KWArgsSentinelType): + cpp_func.parameters.members.insert( + i, ast.SyntaxListMember(torch_param, lexer.Token(None, lexer.TokenKind.COMMA, ",")) + ) + break + + # Link Torch parameters to their C++ counterparts, special casing + # TensorOptions parameters + for i, cpp_param in enumerate([p.member for p in cpp_func.parameters]): + if not getattr(cpp_param, "torch_param", None): + cpp_param.torch_param = [] + + torch_param_range = 1 + if isinstance(cpp_param.parameter_type.desugar(), ast.TensorOptionsType): + torch_param_range = 4 + + for j in range(torch_param_range): + torch_param = torch_func.parameters[i + j].member + cpp_param.torch_param.append(torch_param) + + return cpp_func diff --git a/orttraining/orttraining/eager/opgen/opgen/lexer.py b/orttraining/orttraining/eager/opgen/opgen/lexer.py index c38769435ea32..661d646350f53 100644 --- a/orttraining/orttraining/eager/opgen/opgen/lexer.py +++ b/orttraining/orttraining/eager/opgen/opgen/lexer.py @@ -5,361 +5,365 @@ from abc import ABC from typing import List, Optional, Union, Tuple + class SourceLocation(object): - def __init__(self, - offset: int = 0, - line: int = 1, - column: int = 1): - self.offset = offset - self.line = line - self.column = column + def __init__(self, offset: int = 0, line: int = 1, column: int = 1): + self.offset = offset + self.line = line + self.column = column + + def increment_line(self): + return SourceLocation(self.offset + 1, self.line + 1, 1) - def increment_line(self): - return SourceLocation(self.offset + 1, self.line + 1, 1) + def increment_column(self, count: int = 1): + return SourceLocation(self.offset + count, self.line, self.column + count) - def increment_column(self, count: int = 1): - return SourceLocation(self.offset + count, self.line, self.column + count) + def __str__(self) -> str: + return f"({self.line},{self.column})" - def __str__(self) -> str: - return f"({self.line},{self.column})" + def __repr__(self) -> str: + return f"({self.offset},{self.line},{self.column})" - def __repr__(self) -> str: - return f"({self.offset},{self.line},{self.column})" + def __eq__(self, other) -> bool: + return ( + self.__class__ == other.__class__ + and self.offset == other.offset + and self.line == other.line + and self.column == other.column + ) - def __eq__(self, other) -> bool: - return ( - self.__class__ == other.__class__ and - self.offset == other.offset and - self.line == other.line and - self.column == other.column) class TokenKind(Enum): - UNKNOWN = 1 - EOF = 2 - WHITESPACE = 3 - SINGLE_LINE_COMMENT = 4 - MULTI_LINE_COMMENT = 5 - IDENTIFIER = 6 - NUMBER = 7 - STRING = 8 - OPEN_PAREN = 9 - CLOSE_PAREN = 10 - OPEN_BRACKET = 11 - CLOSE_BRACKET = 12 - LESS_THAN = 13 - GREATER_THAN = 14 - COMMA = 15 - SEMICOLON = 16 - COLON = 17 - DOUBLECOLON = 18 - AND = 19 - OR = 20 - DIV = 21 - MUL = 22 - MINUS = 23 - EQUALS = 24 - QUESTION_MARK = 25 - EXCLAIMATION_MARK = 26 - ARROW = 27 + UNKNOWN = 1 + EOF = 2 + WHITESPACE = 3 + SINGLE_LINE_COMMENT = 4 + MULTI_LINE_COMMENT = 5 + IDENTIFIER = 6 + NUMBER = 7 + STRING = 8 + OPEN_PAREN = 9 + CLOSE_PAREN = 10 + OPEN_BRACKET = 11 + CLOSE_BRACKET = 12 + LESS_THAN = 13 + GREATER_THAN = 14 + COMMA = 15 + SEMICOLON = 16 + COLON = 17 + DOUBLECOLON = 18 + AND = 19 + OR = 20 + DIV = 21 + MUL = 22 + MINUS = 23 + EQUALS = 24 + QUESTION_MARK = 25 + EXCLAIMATION_MARK = 26 + ARROW = 27 + class Token(object): - def __init__(self, - location: Union[SourceLocation, Tuple[int, int, int]], - kind: TokenKind, - value: str, - leading_trivia: Optional[List["Token"]] = None, - trailing_trivia: Optional[List["Token"]] = None): - if isinstance(location, tuple) or isinstance(location, list): - location = SourceLocation(location[0], location[1], location[2]) - - self.location = location - self.kind = kind - self.value = value - self.leading_trivia = leading_trivia - self.trailing_trivia = trailing_trivia - - def is_trivia(self) -> bool: - return ( - self.kind == TokenKind.WHITESPACE or - self.kind == TokenKind.SINGLE_LINE_COMMENT or - self.kind == TokenKind.MULTI_LINE_COMMENT) - - def has_trailing_trivia(self, trivia_kind: TokenKind) -> bool: - if not self.trailing_trivia: - return False - for trivia in self.trailing_trivia: - if trivia.kind == trivia_kind: - return True - return False - - def __str__(self) -> str: - return f"{self.location}: [{self.kind}] '{self.value}'" - - def __repr__(self) -> str: - rep = f"Token({repr(self.location)},{self.kind}" - if self.value: - rep += ",\"" + self.value + "\"" - if self.leading_trivia: - rep += f",leading_trivia={self.leading_trivia}" - if self.trailing_trivia: - rep += f",trailing_trivia={self.trailing_trivia}" - return rep + ")" - - def __eq__(self, other) -> bool: - return ( - self.__class__ == other.__class__ and - self.location == other.location and - self.kind == other.kind and - self.value == other.value and - self.leading_trivia == other.leading_trivia and - self.trailing_trivia == other.trailing_trivia) + def __init__( + self, + location: Union[SourceLocation, Tuple[int, int, int]], + kind: TokenKind, + value: str, + leading_trivia: Optional[List["Token"]] = None, + trailing_trivia: Optional[List["Token"]] = None, + ): + if isinstance(location, tuple) or isinstance(location, list): + location = SourceLocation(location[0], location[1], location[2]) + + self.location = location + self.kind = kind + self.value = value + self.leading_trivia = leading_trivia + self.trailing_trivia = trailing_trivia + + def is_trivia(self) -> bool: + return ( + self.kind == TokenKind.WHITESPACE + or self.kind == TokenKind.SINGLE_LINE_COMMENT + or self.kind == TokenKind.MULTI_LINE_COMMENT + ) + + def has_trailing_trivia(self, trivia_kind: TokenKind) -> bool: + if not self.trailing_trivia: + return False + for trivia in self.trailing_trivia: + if trivia.kind == trivia_kind: + return True + return False + + def __str__(self) -> str: + return f"{self.location}: [{self.kind}] '{self.value}'" + + def __repr__(self) -> str: + rep = f"Token({repr(self.location)},{self.kind}" + if self.value: + rep += ',"' + self.value + '"' + if self.leading_trivia: + rep += f",leading_trivia={self.leading_trivia}" + if self.trailing_trivia: + rep += f",trailing_trivia={self.trailing_trivia}" + return rep + ")" + + def __eq__(self, other) -> bool: + return ( + self.__class__ == other.__class__ + and self.location == other.location + and self.kind == other.kind + and self.value == other.value + and self.leading_trivia == other.leading_trivia + and self.trailing_trivia == other.trailing_trivia + ) + class Reader(ABC): - def open(self): - pass + def open(self): + pass + + def close(self): + pass - def close(self): - pass + def read_char(self) -> str: + return None - def read_char(self) -> str: - return None class FileReader(Reader): - def __init__(self, path: str): - self.path = path + def __init__(self, path: str): + self.path = path + + def open(self): + self.fp = open(self.path) - def open(self): - self.fp = open(self.path) + def close(self): + self.fp.close() - def close(self): - self.fp.close() + def read_char(self) -> str: + return self.fp.read(1) - def read_char(self) -> str: - return self.fp.read(1) class StringReader(Reader): - def __init__(self, buffer: str): - self.buffer = buffer - self.position = 0 + def __init__(self, buffer: str): + self.buffer = buffer + self.position = 0 + + def read_char(self) -> str: + if self.position < len(self.buffer): + c = self.buffer[self.position] + self.position += 1 + return c + return None - def read_char(self) -> str: - if self.position < len(self.buffer): - c = self.buffer[self.position] - self.position += 1 - return c - return None class Lexer(object): - _peek: str - _next_token: Token - _first_token_leading_trivia: List[Token] - - char_to_token_kind = { - '(': TokenKind.OPEN_PAREN, - ')': TokenKind.CLOSE_PAREN, - '<': TokenKind.LESS_THAN, - '>': TokenKind.GREATER_THAN, - '[': TokenKind.OPEN_BRACKET, - ']': TokenKind.CLOSE_BRACKET, - ',': TokenKind.COMMA, - ';': TokenKind.SEMICOLON, - '&': TokenKind.AND, - '|': TokenKind.OR, - '=': TokenKind.EQUALS, - '?': TokenKind.QUESTION_MARK, - '!': TokenKind.EXCLAIMATION_MARK, - '*': TokenKind.MUL - } - - def __init__(self, reader: Reader): - self._reader = reader - self._peek = None - self._current_token_location = SourceLocation() - self._next_token_location = SourceLocation() - self._next_token = None - self._first_token_leading_trivia = [] - - def __enter__(self): - self._reader.open() - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - self._reader.close() - - def _make_token(self, kind: TokenKind, value: str) -> Token: - return Token(self._next_token_location, kind, value) - - def _peek_char(self) -> str: - if self._peek: - return self._peek - self._peek = self._reader.read_char() - return self._peek - - def _read_char(self) -> str: - if self._peek: - c = self._peek - self._peek = None - else: - c = self._reader.read_char() - if c: - self._current_token_location = \ - self._current_token_location.increment_line() \ - if c == '\n' \ - else self._current_token_location.increment_column() - return c - - def set_source_location(self, origin: SourceLocation): - self._current_token_location = origin - - def lex(self) -> Token: - """ - Lex a single semantic token from the source, gathering into it - any trailing whitespace or comment trivia that may follow. The - first non-trivia token in the buffer may also have leading trivia - attached to it. - """ - token: Token - leading_trivia: Optional[List[Token]] = None - trailing_trivia: Optional[List[Token]] = None - - while True: - token = self._lex_core() - if token.is_trivia(): - if not leading_trivia: - leading_trivia = [token] - else: - leading_trivia.append(token) - else: - break - - while True: - trailing = self._lex_core() - if trailing.is_trivia(): - if not trailing_trivia: - trailing_trivia = [trailing] + _peek: str + _next_token: Token + _first_token_leading_trivia: List[Token] + + char_to_token_kind = { + "(": TokenKind.OPEN_PAREN, + ")": TokenKind.CLOSE_PAREN, + "<": TokenKind.LESS_THAN, + ">": TokenKind.GREATER_THAN, + "[": TokenKind.OPEN_BRACKET, + "]": TokenKind.CLOSE_BRACKET, + ",": TokenKind.COMMA, + ";": TokenKind.SEMICOLON, + "&": TokenKind.AND, + "|": TokenKind.OR, + "=": TokenKind.EQUALS, + "?": TokenKind.QUESTION_MARK, + "!": TokenKind.EXCLAIMATION_MARK, + "*": TokenKind.MUL, + } + + def __init__(self, reader: Reader): + self._reader = reader + self._peek = None + self._current_token_location = SourceLocation() + self._next_token_location = SourceLocation() + self._next_token = None + self._first_token_leading_trivia = [] + + def __enter__(self): + self._reader.open() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self._reader.close() + + def _make_token(self, kind: TokenKind, value: str) -> Token: + return Token(self._next_token_location, kind, value) + + def _peek_char(self) -> str: + if self._peek: + return self._peek + self._peek = self._reader.read_char() + return self._peek + + def _read_char(self) -> str: + if self._peek: + c = self._peek + self._peek = None else: - trailing_trivia.append(trailing) - else: - self._next_token = trailing - break - - token.leading_trivia = leading_trivia - token.trailing_trivia = trailing_trivia - - return token - - def _lex_core(self) -> Token: - """Lex a single token from the source including comments and whitespace.""" - if self._next_token: - token = self._next_token - self._next_token = None - return token - - self._next_token_location = self._current_token_location - - c = self._peek_char() - if not c: - return self._make_token(TokenKind.EOF, None) - - kind = Lexer.char_to_token_kind.get(c) - if kind: - return self._make_token(kind, self._read_char()) - - if c.isspace(): - return self._lex_sequence( - TokenKind.WHITESPACE, - lambda c: c.isspace()) - - if self._is_identifier_char(c, first_char = True): - return self._lex_sequence( - TokenKind.IDENTIFIER, - lambda c: self._is_identifier_char(c)) - - if c == ':': - self._read_char() - if self._peek_char() == ':': - return self._make_token(TokenKind.DOUBLECOLON, c + self._read_char()) - else: - return self._make_token(TokenKind.COLON, c) - - if c == '/': - self._read_char() - if self._peek_char() == '/': - return self._lex_sequence( - TokenKind.SINGLE_LINE_COMMENT, - lambda c: c != "\n", "/") - elif self._peek_char() == '*': - raise NotImplementedError("Multi-line comments not supported") - else: - return self._make_token(TokenKind.DIV, c) - - if c == '-': - self._read_char() - p = self._peek_char() - if p == '>': - return self._make_token(TokenKind.ARROW, c + self._read_char()) - elif p == '.' or p.isnumeric(): - return self._lex_number(c) - else: - return self._make_token(TokenKind.MINUS, c) - - if c == '.' or c.isnumeric(): - return self._lex_number() - - if c == '"' or c == '\'': - return self._lex_string() - - return self._make_token(TokenKind.UNKNOWN, c) - - def _lex_number(self, s: str = "") -> Token: - s += self._read_char() - - in_exponent = False - have_decimal_separator = '.' in s - - while True: - p = self._peek_char() - if not p: - break - if p.isnumeric(): - s += self._read_char() - elif not have_decimal_separator and not in_exponent and p == '.': - have_decimal_separator = True - s += self._read_char() - elif not in_exponent and (p == 'e' or p == 'E'): - in_exponent = True - s += self._read_char() - if self._peek_char() == '-': - s += self._read_char() - else: - break - - return self._make_token(TokenKind.NUMBER, s) - - def _lex_string(self) -> Token: - term = self._read_char() - s = "" - while True: - p = self._peek_char() - if p == '\\': - self._read_char() - s += self._read_char() - elif p == term: - self._read_char() - return self._make_token(TokenKind.STRING, s) - else: + c = self._reader.read_char() + if c: + self._current_token_location = ( + self._current_token_location.increment_line() + if c == "\n" + else self._current_token_location.increment_column() + ) + return c + + def set_source_location(self, origin: SourceLocation): + self._current_token_location = origin + + def lex(self) -> Token: + """ + Lex a single semantic token from the source, gathering into it + any trailing whitespace or comment trivia that may follow. The + first non-trivia token in the buffer may also have leading trivia + attached to it. + """ + token: Token + leading_trivia: Optional[List[Token]] = None + trailing_trivia: Optional[List[Token]] = None + + while True: + token = self._lex_core() + if token.is_trivia(): + if not leading_trivia: + leading_trivia = [token] + else: + leading_trivia.append(token) + else: + break + + while True: + trailing = self._lex_core() + if trailing.is_trivia(): + if not trailing_trivia: + trailing_trivia = [trailing] + else: + trailing_trivia.append(trailing) + else: + self._next_token = trailing + break + + token.leading_trivia = leading_trivia + token.trailing_trivia = trailing_trivia + + return token + + def _lex_core(self) -> Token: + """Lex a single token from the source including comments and whitespace.""" + if self._next_token: + token = self._next_token + self._next_token = None + return token + + self._next_token_location = self._current_token_location + + c = self._peek_char() + if not c: + return self._make_token(TokenKind.EOF, None) + + kind = Lexer.char_to_token_kind.get(c) + if kind: + return self._make_token(kind, self._read_char()) + + if c.isspace(): + return self._lex_sequence(TokenKind.WHITESPACE, lambda c: c.isspace()) + + if self._is_identifier_char(c, first_char=True): + return self._lex_sequence(TokenKind.IDENTIFIER, lambda c: self._is_identifier_char(c)) + + if c == ":": + self._read_char() + if self._peek_char() == ":": + return self._make_token(TokenKind.DOUBLECOLON, c + self._read_char()) + else: + return self._make_token(TokenKind.COLON, c) + + if c == "/": + self._read_char() + if self._peek_char() == "/": + return self._lex_sequence(TokenKind.SINGLE_LINE_COMMENT, lambda c: c != "\n", "/") + elif self._peek_char() == "*": + raise NotImplementedError("Multi-line comments not supported") + else: + return self._make_token(TokenKind.DIV, c) + + if c == "-": + self._read_char() + p = self._peek_char() + if p == ">": + return self._make_token(TokenKind.ARROW, c + self._read_char()) + elif p == "." or p.isnumeric(): + return self._lex_number(c) + else: + return self._make_token(TokenKind.MINUS, c) + + if c == "." or c.isnumeric(): + return self._lex_number() + + if c == '"' or c == "'": + return self._lex_string() + + return self._make_token(TokenKind.UNKNOWN, c) + + def _lex_number(self, s: str = "") -> Token: s += self._read_char() - def _is_identifier_char(self, c: str, first_char = False) -> bool: - if c == '_' or c.isalpha(): - return True - return c in [':', '.'] or c.isnumeric() if not first_char else False - - def _lex_sequence(self, kind: TokenKind, predicate, s: str = ""): - while True: - c = self._read_char() - if c: - s += c - p = self._peek_char() - if not p or not predicate(p): - return self._make_token(kind, s) \ No newline at end of file + in_exponent = False + have_decimal_separator = "." in s + + while True: + p = self._peek_char() + if not p: + break + if p.isnumeric(): + s += self._read_char() + elif not have_decimal_separator and not in_exponent and p == ".": + have_decimal_separator = True + s += self._read_char() + elif not in_exponent and (p == "e" or p == "E"): + in_exponent = True + s += self._read_char() + if self._peek_char() == "-": + s += self._read_char() + else: + break + + return self._make_token(TokenKind.NUMBER, s) + + def _lex_string(self) -> Token: + term = self._read_char() + s = "" + while True: + p = self._peek_char() + if p == "\\": + self._read_char() + s += self._read_char() + elif p == term: + self._read_char() + return self._make_token(TokenKind.STRING, s) + else: + s += self._read_char() + + def _is_identifier_char(self, c: str, first_char=False) -> bool: + if c == "_" or c.isalpha(): + return True + return c in [":", "."] or c.isnumeric() if not first_char else False + + def _lex_sequence(self, kind: TokenKind, predicate, s: str = ""): + while True: + c = self._read_char() + if c: + s += c + p = self._peek_char() + if not p or not predicate(p): + return self._make_token(kind, s) diff --git a/orttraining/orttraining/eager/opgen/opgen/onnxops.py b/orttraining/orttraining/eager/opgen/opgen/onnxops.py index 97a83cf3be794..98a2dd4d5997e 100644 --- a/orttraining/orttraining/eager/opgen/opgen/onnxops.py +++ b/orttraining/orttraining/eager/opgen/opgen/onnxops.py @@ -3,4981 +3,6498 @@ from opgen.generator import ONNXAttr, ONNXOp, AttrType + class Abs(ONNXOp): - """ - Absolute takes one input data (Tensor) and produces one output data - (Tensor) where the absolute is, y = abs(x), is applied to - the tensor elementwise. - """ + """ + Absolute takes one input data (Tensor) and produces one output data + (Tensor) where the absolute is, y = abs(x), is applied to + the tensor elementwise. + """ + + def __init__(self, X): + super().__init__( + "Abs", + 1, + [ + { + "at::kDouble", + "at::kLong", + "at::kByte", + "at::kInt", + "at::kHalf", + "at::kShort", + "at::kFloat", + "at::kBFloat16", + } + ], + X, + ) - def __init__(self, X): - super().__init__('Abs', 1, - [{'at::kDouble', 'at::kLong', 'at::kByte', 'at::kInt', 'at::kHalf', 'at::kShort', 'at::kFloat', 'at::kBFloat16'}], - X) class Acos(ONNXOp): - """ - Calculates the arccosine (inverse of cosine) of the given input tensor, element-wise. - """ + """ + Calculates the arccosine (inverse of cosine) of the given input tensor, element-wise. + """ + + def __init__(self, input): + super().__init__("Acos", 1, [{"at::kDouble", "at::kHalf", "at::kFloat"}], input) - def __init__(self, input): - super().__init__('Acos', 1, - [{'at::kDouble', 'at::kHalf', 'at::kFloat'}], - input) class Acosh(ONNXOp): - """ - Calculates the hyperbolic arccosine of the given input tensor element-wise. - """ + """ + Calculates the hyperbolic arccosine of the given input tensor element-wise. + """ + + def __init__(self, input): + super().__init__("Acosh", 1, [{"at::kDouble", "at::kHalf", "at::kFloat"}], input) - def __init__(self, input): - super().__init__('Acosh', 1, - [{'at::kDouble', 'at::kHalf', 'at::kFloat'}], - input) class Adagrad(ONNXOp): - """ - Compute one iteration of ADAGRAD, a stochastic gradient based optimization - algorithm. This operator can conduct the optimization of multiple tensor variables. - - Let's define the behavior of this operator. As you can imagine, ADAGRAD requires - some parameters: - - - The initial learning-rate "R". - - The update count "T". That is, the number of training iterations conducted. - - A L2-norm regularization coefficient "norm_coefficient". - - A learning-rate decay factor "decay_factor". - - A small constant "epsilon" to avoid dividing-by-zero. - - At each ADAGRAD iteration, the optimized tensors are moved along a direction - computed based on their estimated gradient and accumulated squared gradient. Assume - that only a single tensor "X" is updated by this operator. We need the value of "X", - its gradient "G", and its accumulated squared gradient "H". Therefore, variables in - this operator's input list are sequentially "R", "T", "X", "G", and "H". Other - parameters are given as attributes because they are usually constants. Also, the - corresponding output tensors are the new value of "X" (called "X_new"), and then - the new accumulated squared gradient (called "H_new"). Those outputs are computed - from the given inputs following the pseudo code below. - - Let "+", "-", "*", and "/" are all element-wise arithmetic operations with - numpy-style broadcasting support. The pseudo code to compute those outputs is: - - // Compute a scalar learning-rate factor. At the first update of X, T is generally - // 0 (0-based update index) or 1 (1-based update index). - r = R / (1 + T * decay_factor); - - // Add gradient of 0.5 * norm_coefficient * ||X||_2^2, where ||X||_2 is the 2-norm. - G_regularized = norm_coefficient * X + G; - - // Compute new accumulated squared gradient. - H_new = H + G_regularized * G_regularized; - - // Compute the adaptive part of per-coordinate learning rate. Note that Sqrt(...) - // computes element-wise square-root. - H_adaptive = Sqrt(H_new) + epsilon - - // Compute the new value of "X". - X_new = X - r * G_regularized / H_adaptive; - - If one assign this operators to optimize multiple inputs, for example, "X_1" and "X_2", the same - pseudo code may be extended to handle all tensors jointly. More specifically, we can view "X" as a - concatenation of "X_1" and "X_2" (of course, their gradient and accumulate gradient should - be concatenated too) and then just reuse the entire pseudo code. - - Note that ADAGRAD was first proposed in http://jmlr.org/papers/volume12/duchi11a/duchi11a.pdf. - In that reference paper, this operator is a special case of the Figure 1's composite mirror - descent update. - """ - - def __init__(self, R, T, inputs, - decay_factor=None, - epsilon=None, - norm_coefficient=None): - super().__init__('Adagrad', 1, - [{'at::kDouble', 'at::kFloat'}, {'at::kLong'}, {'at::kDouble', 'at::kFloat'}], - R,T,inputs, - decay_factor=ONNXAttr(decay_factor, AttrType.FLOAT), - epsilon=ONNXAttr(epsilon, AttrType.FLOAT), - norm_coefficient=ONNXAttr(norm_coefficient, AttrType.FLOAT)) + """ + Compute one iteration of ADAGRAD, a stochastic gradient based optimization + algorithm. This operator can conduct the optimization of multiple tensor variables. + + Let's define the behavior of this operator. As you can imagine, ADAGRAD requires + some parameters: + + - The initial learning-rate "R". + - The update count "T". That is, the number of training iterations conducted. + - A L2-norm regularization coefficient "norm_coefficient". + - A learning-rate decay factor "decay_factor". + - A small constant "epsilon" to avoid dividing-by-zero. + + At each ADAGRAD iteration, the optimized tensors are moved along a direction + computed based on their estimated gradient and accumulated squared gradient. Assume + that only a single tensor "X" is updated by this operator. We need the value of "X", + its gradient "G", and its accumulated squared gradient "H". Therefore, variables in + this operator's input list are sequentially "R", "T", "X", "G", and "H". Other + parameters are given as attributes because they are usually constants. Also, the + corresponding output tensors are the new value of "X" (called "X_new"), and then + the new accumulated squared gradient (called "H_new"). Those outputs are computed + from the given inputs following the pseudo code below. + + Let "+", "-", "*", and "/" are all element-wise arithmetic operations with + numpy-style broadcasting support. The pseudo code to compute those outputs is: + + // Compute a scalar learning-rate factor. At the first update of X, T is generally + // 0 (0-based update index) or 1 (1-based update index). + r = R / (1 + T * decay_factor); + + // Add gradient of 0.5 * norm_coefficient * ||X||_2^2, where ||X||_2 is the 2-norm. + G_regularized = norm_coefficient * X + G; + + // Compute new accumulated squared gradient. + H_new = H + G_regularized * G_regularized; + + // Compute the adaptive part of per-coordinate learning rate. Note that Sqrt(...) + // computes element-wise square-root. + H_adaptive = Sqrt(H_new) + epsilon + + // Compute the new value of "X". + X_new = X - r * G_regularized / H_adaptive; + + If one assign this operators to optimize multiple inputs, for example, "X_1" and "X_2", the same + pseudo code may be extended to handle all tensors jointly. More specifically, we can view "X" as a + concatenation of "X_1" and "X_2" (of course, their gradient and accumulate gradient should + be concatenated too) and then just reuse the entire pseudo code. + + Note that ADAGRAD was first proposed in http://jmlr.org/papers/volume12/duchi11a/duchi11a.pdf. + In that reference paper, this operator is a special case of the Figure 1's composite mirror + descent update. + """ + + def __init__(self, R, T, inputs, decay_factor=None, epsilon=None, norm_coefficient=None): + super().__init__( + "Adagrad", + 1, + [{"at::kDouble", "at::kFloat"}, {"at::kLong"}, {"at::kDouble", "at::kFloat"}], + R, + T, + inputs, + decay_factor=ONNXAttr(decay_factor, AttrType.FLOAT), + epsilon=ONNXAttr(epsilon, AttrType.FLOAT), + norm_coefficient=ONNXAttr(norm_coefficient, AttrType.FLOAT), + ) + class Adam(ONNXOp): - """ - Compute one iteration of Adam, a stochastic gradient based optimization - algorithm. This operator can conduct the optimization of multiple tensor variables. - - Let's define the behavior of this operator. First of all, Adam requires - some parameters: - - - The learning-rate "R". - - The update count "T". That is, the number of training iterations conducted. - - A L2-norm regularization coefficient "norm_coefficient". - - A small constant "epsilon" to avoid dividing-by-zero. - - Two coefficients, "alpha" and "beta". - - At each Adam iteration, the optimized tensors are moved along a direction - computed based on their exponentially-averaged historical gradient and - exponentially-averaged historical squared gradient. Assume that only a tensor - "X" is being optimized. The rest of required information is - - - the value of "X", - - "X"'s gradient (denoted by "G"), - - "X"'s exponentially-averaged historical gradient (denoted by "V"), and - - "X"'s exponentially-averaged historical squared gradient (denoted by "H"). - - Some of those parameters are passed into this operator as input tensors and others - are stored as this operator's attributes. Specifically, this operator's input tensor - list is ["R", "T", "X", "G", "V", "H"]. That is, "R" is the first input, "T" is - the second input, and so on. Other parameters are given as attributes because they - are constants. Moreover, the corresponding output tensors are - - - the new value of "X" (called "X_new"), - - the new exponentially-averaged historical gradient (denoted by "V_new"), and - - the new exponentially-averaged historical squared gradient (denoted by "H_new"). - - Those outputs are computed following the pseudo code below. - - Let "+", "-", "*", and "/" are all element-wise arithmetic operations with - numpy-style broadcasting support. The pseudo code to compute those outputs is: - - // Add gradient of 0.5 * norm_coefficient * ||X||_2^2, where ||X||_2 is the 2-norm. - G_regularized = norm_coefficient * X + G - - // Update exponentially-averaged historical gradient. - V_new = alpha * V + (1 - alpha) * G_regularized - - // Update exponentially-averaged historical squared gradient. - H_new = beta * H + (1 - beta) * G_regularized * G_regularized - - // Compute the element-wise square-root of H_new. V_new will be element-wisely - // divided by H_sqrt for a better update direction. - H_sqrt = Sqrt(H_new) + epsilon - - // Compute learning-rate. Note that "alpha**T"/"beta**T" is alpha's/beta's T-th power. - R_adjusted = T > 0 ? R * Sqrt(1 - beta**T) / (1 - alpha**T) : R - - // Compute new value of "X". - X_new = X - R_adjusted * V_new / H_sqrt - - // Post-update regularization. - X_final = (1 - norm_coefficient_post) * X_new - - If there are multiple inputs to be optimized, the pseudo code will be applied - independently to each of them. - """ - - def __init__(self, R, T, inputs, - alpha=None, - beta=None, - epsilon=None, - norm_coefficient=None, - norm_coefficient_post=None): - super().__init__('Adam', 1, - [{'at::kDouble', 'at::kFloat'}, {'at::kLong'}, {'at::kDouble', 'at::kFloat'}], - R,T,inputs, - alpha=ONNXAttr(alpha, AttrType.FLOAT), - beta=ONNXAttr(beta, AttrType.FLOAT), - epsilon=ONNXAttr(epsilon, AttrType.FLOAT), - norm_coefficient=ONNXAttr(norm_coefficient, AttrType.FLOAT), - norm_coefficient_post=ONNXAttr(norm_coefficient_post, AttrType.FLOAT)) + """ + Compute one iteration of Adam, a stochastic gradient based optimization + algorithm. This operator can conduct the optimization of multiple tensor variables. + + Let's define the behavior of this operator. First of all, Adam requires + some parameters: + + - The learning-rate "R". + - The update count "T". That is, the number of training iterations conducted. + - A L2-norm regularization coefficient "norm_coefficient". + - A small constant "epsilon" to avoid dividing-by-zero. + - Two coefficients, "alpha" and "beta". + + At each Adam iteration, the optimized tensors are moved along a direction + computed based on their exponentially-averaged historical gradient and + exponentially-averaged historical squared gradient. Assume that only a tensor + "X" is being optimized. The rest of required information is + + - the value of "X", + - "X"'s gradient (denoted by "G"), + - "X"'s exponentially-averaged historical gradient (denoted by "V"), and + - "X"'s exponentially-averaged historical squared gradient (denoted by "H"). + + Some of those parameters are passed into this operator as input tensors and others + are stored as this operator's attributes. Specifically, this operator's input tensor + list is ["R", "T", "X", "G", "V", "H"]. That is, "R" is the first input, "T" is + the second input, and so on. Other parameters are given as attributes because they + are constants. Moreover, the corresponding output tensors are + + - the new value of "X" (called "X_new"), + - the new exponentially-averaged historical gradient (denoted by "V_new"), and + - the new exponentially-averaged historical squared gradient (denoted by "H_new"). + + Those outputs are computed following the pseudo code below. + + Let "+", "-", "*", and "/" are all element-wise arithmetic operations with + numpy-style broadcasting support. The pseudo code to compute those outputs is: + + // Add gradient of 0.5 * norm_coefficient * ||X||_2^2, where ||X||_2 is the 2-norm. + G_regularized = norm_coefficient * X + G + + // Update exponentially-averaged historical gradient. + V_new = alpha * V + (1 - alpha) * G_regularized + + // Update exponentially-averaged historical squared gradient. + H_new = beta * H + (1 - beta) * G_regularized * G_regularized + + // Compute the element-wise square-root of H_new. V_new will be element-wisely + // divided by H_sqrt for a better update direction. + H_sqrt = Sqrt(H_new) + epsilon + + // Compute learning-rate. Note that "alpha**T"/"beta**T" is alpha's/beta's T-th power. + R_adjusted = T > 0 ? R * Sqrt(1 - beta**T) / (1 - alpha**T) : R + + // Compute new value of "X". + X_new = X - R_adjusted * V_new / H_sqrt + + // Post-update regularization. + X_final = (1 - norm_coefficient_post) * X_new + + If there are multiple inputs to be optimized, the pseudo code will be applied + independently to each of them. + """ + + def __init__( + self, R, T, inputs, alpha=None, beta=None, epsilon=None, norm_coefficient=None, norm_coefficient_post=None + ): + super().__init__( + "Adam", + 1, + [{"at::kDouble", "at::kFloat"}, {"at::kLong"}, {"at::kDouble", "at::kFloat"}], + R, + T, + inputs, + alpha=ONNXAttr(alpha, AttrType.FLOAT), + beta=ONNXAttr(beta, AttrType.FLOAT), + epsilon=ONNXAttr(epsilon, AttrType.FLOAT), + norm_coefficient=ONNXAttr(norm_coefficient, AttrType.FLOAT), + norm_coefficient_post=ONNXAttr(norm_coefficient_post, AttrType.FLOAT), + ) + class Add(ONNXOp): - """ - Performs element-wise binary addition (with Numpy-style broadcasting support). - - This operator supports **multidirectional (i.e., Numpy-style) broadcasting**; for more details please check [the doc](Broadcasting.md). - - (Opset 14 change): Extend supported types to include uint8, int8, uint16, and int16. - """ - - def __init__(self, A, B): - super().__init__('Add', 1, - [{'at::kDouble', 'at::kLong', 'at::kByte', 'at::kInt', 'at::kHalf', 'at::kShort', 'at::kFloat', 'at::kBFloat16'}, {'at::kDouble', 'at::kLong', 'at::kByte', 'at::kInt', 'at::kHalf', 'at::kShort', 'at::kFloat', 'at::kBFloat16'}], - A,B) + """ + Performs element-wise binary addition (with Numpy-style broadcasting support). + + This operator supports **multidirectional (i.e., Numpy-style) broadcasting**; for more details please check [the doc](Broadcasting.md). + + (Opset 14 change): Extend supported types to include uint8, int8, uint16, and int16. + """ + + def __init__(self, A, B): + super().__init__( + "Add", + 1, + [ + { + "at::kDouble", + "at::kLong", + "at::kByte", + "at::kInt", + "at::kHalf", + "at::kShort", + "at::kFloat", + "at::kBFloat16", + }, + { + "at::kDouble", + "at::kLong", + "at::kByte", + "at::kInt", + "at::kHalf", + "at::kShort", + "at::kFloat", + "at::kBFloat16", + }, + ], + A, + B, + ) + class And(ONNXOp): - """ - Returns the tensor resulted from performing the `and` logical operation - elementwise on the input tensors `A` and `B` (with Numpy-style broadcasting support). - - This operator supports **multidirectional (i.e., Numpy-style) broadcasting**; for more details please check [the doc](Broadcasting.md). - """ - - def __init__(self, A, B): - super().__init__('And', 1, - [{'at::kBool'}, {'at::kBool'}], - A,B) + """ + Returns the tensor resulted from performing the `and` logical operation + elementwise on the input tensors `A` and `B` (with Numpy-style broadcasting support). + + This operator supports **multidirectional (i.e., Numpy-style) broadcasting**; for more details please check [the doc](Broadcasting.md). + """ + + def __init__(self, A, B): + super().__init__("And", 1, [{"at::kBool"}, {"at::kBool"}], A, B) + class ArgMax(ONNXOp): - """ - Computes the indices of the max elements of the input tensor's element along the - provided axis. The resulting tensor has the same rank as the input if keepdims equal 1. - If keepdims equal 0, then the resulting tensor have the reduced dimension pruned. - If select_last_index is True (default False), the index of the last occurrence of the max - is selected if the max appears more than once in the input. Otherwise the index of the - first occurrence is selected. - The type of the output tensor is integer. - """ - - def __init__(self, data, - axis=None, - keepdims=None, - select_last_index=None): - super().__init__('ArgMax', 1, - [{'at::kDouble', 'at::kLong', 'at::kByte', 'at::kInt', 'at::kHalf', 'at::kShort', 'at::kFloat', 'at::kBFloat16'}], - data, - axis=ONNXAttr(axis, AttrType.INT), - keepdims=ONNXAttr(keepdims, AttrType.INT), - select_last_index=ONNXAttr(select_last_index, AttrType.INT)) + """ + Computes the indices of the max elements of the input tensor's element along the + provided axis. The resulting tensor has the same rank as the input if keepdims equal 1. + If keepdims equal 0, then the resulting tensor have the reduced dimension pruned. + If select_last_index is True (default False), the index of the last occurrence of the max + is selected if the max appears more than once in the input. Otherwise the index of the + first occurrence is selected. + The type of the output tensor is integer. + """ + + def __init__(self, data, axis=None, keepdims=None, select_last_index=None): + super().__init__( + "ArgMax", + 1, + [ + { + "at::kDouble", + "at::kLong", + "at::kByte", + "at::kInt", + "at::kHalf", + "at::kShort", + "at::kFloat", + "at::kBFloat16", + } + ], + data, + axis=ONNXAttr(axis, AttrType.INT), + keepdims=ONNXAttr(keepdims, AttrType.INT), + select_last_index=ONNXAttr(select_last_index, AttrType.INT), + ) + class ArgMin(ONNXOp): - """ - Computes the indices of the min elements of the input tensor's element along the - provided axis. The resulting tensor has the same rank as the input if keepdims equal 1. - If keepdims equal 0, then the resulting tensor have the reduced dimension pruned. - If select_last_index is True (default False), the index of the last occurrence of the min - is selected if the min appears more than once in the input. Otherwise the index of the - first occurrence is selected. - The type of the output tensor is integer. - """ - - def __init__(self, data, - axis=None, - keepdims=None, - select_last_index=None): - super().__init__('ArgMin', 1, - [{'at::kDouble', 'at::kLong', 'at::kByte', 'at::kInt', 'at::kHalf', 'at::kShort', 'at::kFloat', 'at::kBFloat16'}], - data, - axis=ONNXAttr(axis, AttrType.INT), - keepdims=ONNXAttr(keepdims, AttrType.INT), - select_last_index=ONNXAttr(select_last_index, AttrType.INT)) + """ + Computes the indices of the min elements of the input tensor's element along the + provided axis. The resulting tensor has the same rank as the input if keepdims equal 1. + If keepdims equal 0, then the resulting tensor have the reduced dimension pruned. + If select_last_index is True (default False), the index of the last occurrence of the min + is selected if the min appears more than once in the input. Otherwise the index of the + first occurrence is selected. + The type of the output tensor is integer. + """ + + def __init__(self, data, axis=None, keepdims=None, select_last_index=None): + super().__init__( + "ArgMin", + 1, + [ + { + "at::kDouble", + "at::kLong", + "at::kByte", + "at::kInt", + "at::kHalf", + "at::kShort", + "at::kFloat", + "at::kBFloat16", + } + ], + data, + axis=ONNXAttr(axis, AttrType.INT), + keepdims=ONNXAttr(keepdims, AttrType.INT), + select_last_index=ONNXAttr(select_last_index, AttrType.INT), + ) + class ArrayFeatureExtractor(ONNXOp): - """ - Select elements of the input tensor based on the indices passed.
- The indices are applied to the last axes of the tensor. - """ + """ + Select elements of the input tensor based on the indices passed.
+ The indices are applied to the last axes of the tensor. + """ + + def __init__(self, X, Y): + super().__init__( + "ArrayFeatureExtractor", 1, [{"at::kDouble", "at::kLong", "at::kInt", "at::kFloat"}, {"at::kLong"}], X, Y + ) - def __init__(self, X, Y): - super().__init__('ArrayFeatureExtractor', 1, - [{'at::kDouble', 'at::kLong', 'at::kInt', 'at::kFloat'}, {'at::kLong'}], - X,Y) class Asin(ONNXOp): - """ - Calculates the arcsine (inverse of sine) of the given input tensor, element-wise. - """ + """ + Calculates the arcsine (inverse of sine) of the given input tensor, element-wise. + """ + + def __init__(self, input): + super().__init__("Asin", 1, [{"at::kDouble", "at::kHalf", "at::kFloat"}], input) - def __init__(self, input): - super().__init__('Asin', 1, - [{'at::kDouble', 'at::kHalf', 'at::kFloat'}], - input) class Asinh(ONNXOp): - """ - Calculates the hyperbolic arcsine of the given input tensor element-wise. - """ + """ + Calculates the hyperbolic arcsine of the given input tensor element-wise. + """ + + def __init__(self, input): + super().__init__("Asinh", 1, [{"at::kDouble", "at::kHalf", "at::kFloat"}], input) - def __init__(self, input): - super().__init__('Asinh', 1, - [{'at::kDouble', 'at::kHalf', 'at::kFloat'}], - input) class Atan(ONNXOp): - """ - Calculates the arctangent (inverse of tangent) of the given input tensor, element-wise. - """ + """ + Calculates the arctangent (inverse of tangent) of the given input tensor, element-wise. + """ + + def __init__(self, input): + super().__init__("Atan", 1, [{"at::kDouble", "at::kHalf", "at::kFloat"}], input) - def __init__(self, input): - super().__init__('Atan', 1, - [{'at::kDouble', 'at::kHalf', 'at::kFloat'}], - input) class Atanh(ONNXOp): - """ - Calculates the hyperbolic arctangent of the given input tensor element-wise. - """ + """ + Calculates the hyperbolic arctangent of the given input tensor element-wise. + """ + + def __init__(self, input): + super().__init__("Atanh", 1, [{"at::kDouble", "at::kHalf", "at::kFloat"}], input) - def __init__(self, input): - super().__init__('Atanh', 1, - [{'at::kDouble', 'at::kHalf', 'at::kFloat'}], - input) class AveragePool(ONNXOp): - """ - AveragePool consumes an input tensor X and applies average pooling across - the tensor according to kernel sizes, stride sizes, and pad lengths. - average pooling consisting of computing the average on all values of a - subset of the input tensor according to the kernel size and downsampling the - data into the output tensor Y for further processing. The output spatial shape will be following: - ``` - output_spatial_shape[i] = floor((input_spatial_shape[i] + pad_shape[i] - kernel_spatial_shape[i]) / strides_spatial_shape[i] + 1) - ``` - or - ``` - output_spatial_shape[i] = ceil((input_spatial_shape[i] + pad_shape[i] - kernel_spatial_shape[i]) / strides_spatial_shape[i] + 1) - ``` - if ceil_mode is enabled - - ``` - * pad_shape[i] is sum of pads along axis i - ``` - - `auto_pad` is a DEPRECATED attribute. If you are using them currently, the output spatial shape will be following: - ``` - VALID: output_spatial_shape[i] = ceil((input_spatial_shape[i] - kernel_spatial_shape[i] + 1) / strides_spatial_shape[i]) - SAME_UPPER or SAME_LOWER: output_spatial_shape[i] = ceil(input_spatial_shape[i] / strides_spatial_shape[i]) - ``` - And pad shape will be following if `SAME_UPPER` or `SAME_LOWER`: - ``` - pad_shape[i] = (output_spatial_shape[i] - 1) * strides_spatial_shape[i] + kernel_spatial_shape[i] - input_spatial_shape[i] - ``` - The output of each pooling window is divided by the number of elements (exclude pad when attribute count_include_pad is zero). - - """ - - def __init__(self, X, - auto_pad=None, - ceil_mode=None, - count_include_pad=None, - kernel_shape=None, - pads=None, - strides=None): - super().__init__('AveragePool', 1, - [{'at::kDouble', 'at::kHalf', 'at::kFloat'}], - X, - auto_pad=ONNXAttr(auto_pad, AttrType.STRING), - ceil_mode=ONNXAttr(ceil_mode, AttrType.INT), - count_include_pad=ONNXAttr(count_include_pad, AttrType.INT), - kernel_shape=ONNXAttr(kernel_shape, AttrType.INTS), - pads=ONNXAttr(pads, AttrType.INTS), - strides=ONNXAttr(strides, AttrType.INTS)) + """ + AveragePool consumes an input tensor X and applies average pooling across + the tensor according to kernel sizes, stride sizes, and pad lengths. + average pooling consisting of computing the average on all values of a + subset of the input tensor according to the kernel size and downsampling the + data into the output tensor Y for further processing. The output spatial shape will be following: + ``` + output_spatial_shape[i] = floor((input_spatial_shape[i] + pad_shape[i] - kernel_spatial_shape[i]) / strides_spatial_shape[i] + 1) + ``` + or + ``` + output_spatial_shape[i] = ceil((input_spatial_shape[i] + pad_shape[i] - kernel_spatial_shape[i]) / strides_spatial_shape[i] + 1) + ``` + if ceil_mode is enabled + + ``` + * pad_shape[i] is sum of pads along axis i + ``` + + `auto_pad` is a DEPRECATED attribute. If you are using them currently, the output spatial shape will be following: + ``` + VALID: output_spatial_shape[i] = ceil((input_spatial_shape[i] - kernel_spatial_shape[i] + 1) / strides_spatial_shape[i]) + SAME_UPPER or SAME_LOWER: output_spatial_shape[i] = ceil(input_spatial_shape[i] / strides_spatial_shape[i]) + ``` + And pad shape will be following if `SAME_UPPER` or `SAME_LOWER`: + ``` + pad_shape[i] = (output_spatial_shape[i] - 1) * strides_spatial_shape[i] + kernel_spatial_shape[i] - input_spatial_shape[i] + ``` + The output of each pooling window is divided by the number of elements (exclude pad when attribute count_include_pad is zero). + + """ + + def __init__( + self, X, auto_pad=None, ceil_mode=None, count_include_pad=None, kernel_shape=None, pads=None, strides=None + ): + super().__init__( + "AveragePool", + 1, + [{"at::kDouble", "at::kHalf", "at::kFloat"}], + X, + auto_pad=ONNXAttr(auto_pad, AttrType.STRING), + ceil_mode=ONNXAttr(ceil_mode, AttrType.INT), + count_include_pad=ONNXAttr(count_include_pad, AttrType.INT), + kernel_shape=ONNXAttr(kernel_shape, AttrType.INTS), + pads=ONNXAttr(pads, AttrType.INTS), + strides=ONNXAttr(strides, AttrType.INTS), + ) + class BatchNormalization(ONNXOp): - """ - Carries out batch normalization as described in the paper - https://arxiv.org/abs/1502.03167. Depending on the mode it is being run, - There are five required inputs 'X', 'scale', 'B', 'input_mean' and - 'input_var'. - Note that 'input_mean' and 'input_var' are expected to be the estimated - statistics in inference mode (training_mode=False, default), - and the running statistics in training mode (training_mode=True). - There are multiple cases for the number of outputs, which we list below: - - Output case #1: Y, running_mean, running_var (training_mode=True) - Output case #2: Y (training_mode=False) - - When training_mode=False, extra outputs are invalid. - The outputs are updated as follows when training_mode=True: - ``` - running_mean = input_mean * momentum + current_mean * (1 - momentum) - running_var = input_var * momentum + current_var * (1 - momentum) - - Y = (X - current_mean) / sqrt(current_var + epsilon) * scale + B - - where: - - current_mean = ReduceMean(X, axis=all_except_channel_index) - current_var = ReduceVar(X, axis=all_except_channel_index) - - Notice that ReduceVar refers to the population variance, and it equals to - sum(sqrd(x_i - x_avg)) / N - where N is the population size (this formula does not use sample size N - 1). - - ``` - - When training_mode=False: - ``` - Y = (X - input_mean) / sqrt(input_var + epsilon) * scale + B - ``` - - For previous (depreciated) non-spatial cases, implementors are suggested - to flatten the input shape to (N x C * D1 * D2 * ... * Dn) before a BatchNormalization Op. - This operator has **optional** inputs/outputs. See [the doc](IR.md) for more details about the representation of optional arguments. An empty string may be used in the place of an actual argument's name to indicate a missing argument. Trailing optional arguments (those not followed by an argument that is present) may also be simply omitted. - """ - - def __init__(self, X, scale, B, input_mean, input_var, - epsilon=None, - momentum=None, - training_mode=None): - super().__init__('BatchNormalization', 3, - [{'at::kDouble', 'at::kBFloat16', 'at::kHalf', 'at::kFloat'}, {'at::kDouble', 'at::kBFloat16', 'at::kHalf', 'at::kFloat'}, {'at::kDouble', 'at::kBFloat16', 'at::kHalf', 'at::kFloat'}, {'at::kDouble', 'at::kBFloat16', 'at::kHalf', 'at::kFloat'}, {'at::kDouble', 'at::kBFloat16', 'at::kHalf', 'at::kFloat'}], - X,scale,B,input_mean,input_var, - epsilon=ONNXAttr(epsilon, AttrType.FLOAT), - momentum=ONNXAttr(momentum, AttrType.FLOAT), - training_mode=ONNXAttr(training_mode, AttrType.INT)) + """ + Carries out batch normalization as described in the paper + https://arxiv.org/abs/1502.03167. Depending on the mode it is being run, + There are five required inputs 'X', 'scale', 'B', 'input_mean' and + 'input_var'. + Note that 'input_mean' and 'input_var' are expected to be the estimated + statistics in inference mode (training_mode=False, default), + and the running statistics in training mode (training_mode=True). + There are multiple cases for the number of outputs, which we list below: + + Output case #1: Y, running_mean, running_var (training_mode=True) + Output case #2: Y (training_mode=False) + + When training_mode=False, extra outputs are invalid. + The outputs are updated as follows when training_mode=True: + ``` + running_mean = input_mean * momentum + current_mean * (1 - momentum) + running_var = input_var * momentum + current_var * (1 - momentum) + + Y = (X - current_mean) / sqrt(current_var + epsilon) * scale + B + + where: + + current_mean = ReduceMean(X, axis=all_except_channel_index) + current_var = ReduceVar(X, axis=all_except_channel_index) + + Notice that ReduceVar refers to the population variance, and it equals to + sum(sqrd(x_i - x_avg)) / N + where N is the population size (this formula does not use sample size N - 1). + + ``` + + When training_mode=False: + ``` + Y = (X - input_mean) / sqrt(input_var + epsilon) * scale + B + ``` + + For previous (depreciated) non-spatial cases, implementors are suggested + to flatten the input shape to (N x C * D1 * D2 * ... * Dn) before a BatchNormalization Op. + This operator has **optional** inputs/outputs. See [the doc](IR.md) for more details about the representation of optional arguments. An empty string may be used in the place of an actual argument's name to indicate a missing argument. Trailing optional arguments (those not followed by an argument that is present) may also be simply omitted. + """ + + def __init__(self, X, scale, B, input_mean, input_var, epsilon=None, momentum=None, training_mode=None): + super().__init__( + "BatchNormalization", + 3, + [ + {"at::kDouble", "at::kBFloat16", "at::kHalf", "at::kFloat"}, + {"at::kDouble", "at::kBFloat16", "at::kHalf", "at::kFloat"}, + {"at::kDouble", "at::kBFloat16", "at::kHalf", "at::kFloat"}, + {"at::kDouble", "at::kBFloat16", "at::kHalf", "at::kFloat"}, + {"at::kDouble", "at::kBFloat16", "at::kHalf", "at::kFloat"}, + ], + X, + scale, + B, + input_mean, + input_var, + epsilon=ONNXAttr(epsilon, AttrType.FLOAT), + momentum=ONNXAttr(momentum, AttrType.FLOAT), + training_mode=ONNXAttr(training_mode, AttrType.INT), + ) + class Binarizer(ONNXOp): - """ - Maps the values of the input tensor to either 0 or 1, element-wise, based on the outcome of a comparison against a threshold value. - """ + """ + Maps the values of the input tensor to either 0 or 1, element-wise, based on the outcome of a comparison against a threshold value. + """ + + def __init__(self, X, threshold=None): + super().__init__( + "Binarizer", + 1, + [{"at::kDouble", "at::kLong", "at::kInt", "at::kFloat"}], + X, + threshold=ONNXAttr(threshold, AttrType.FLOAT), + ) - def __init__(self, X, - threshold=None): - super().__init__('Binarizer', 1, - [{'at::kDouble', 'at::kLong', 'at::kInt', 'at::kFloat'}], - X, - threshold=ONNXAttr(threshold, AttrType.FLOAT)) class BitShift(ONNXOp): - """ - Bitwise shift operator performs element-wise operation. For each input element, if the - attribute "direction" is "RIGHT", this operator moves its binary representation toward - the right side so that the input value is effectively decreased. If the attribute "direction" - is "LEFT", bits of binary representation moves toward the left side, which results the - increase of its actual value. The input X is the tensor to be shifted and another input - Y specifies the amounts of shifting. For example, if "direction" is "Right", X is [1, 4], - and S is [1, 1], the corresponding output Z would be [0, 2]. If "direction" is "LEFT" with - X=[1, 2] and S=[1, 2], the corresponding output Y would be [2, 8]. - - Because this operator supports Numpy-style broadcasting, X's and Y's shapes are - not necessarily identical. - This operator supports **multidirectional (i.e., Numpy-style) broadcasting**; for more details please check [the doc](Broadcasting.md). - """ - - def __init__(self, X, Y, - direction=None): - super().__init__('BitShift', 1, - [set(), set()], - X,Y, - direction=ONNXAttr(direction, AttrType.STRING)) + """ + Bitwise shift operator performs element-wise operation. For each input element, if the + attribute "direction" is "RIGHT", this operator moves its binary representation toward + the right side so that the input value is effectively decreased. If the attribute "direction" + is "LEFT", bits of binary representation moves toward the left side, which results the + increase of its actual value. The input X is the tensor to be shifted and another input + Y specifies the amounts of shifting. For example, if "direction" is "Right", X is [1, 4], + and S is [1, 1], the corresponding output Z would be [0, 2]. If "direction" is "LEFT" with + X=[1, 2] and S=[1, 2], the corresponding output Y would be [2, 8]. + + Because this operator supports Numpy-style broadcasting, X's and Y's shapes are + not necessarily identical. + This operator supports **multidirectional (i.e., Numpy-style) broadcasting**; for more details please check [the doc](Broadcasting.md). + """ + + def __init__(self, X, Y, direction=None): + super().__init__("BitShift", 1, [set(), set()], X, Y, direction=ONNXAttr(direction, AttrType.STRING)) + class Cast(ONNXOp): - """ - The operator casts the elements of a given input tensor to a data type - specified by the 'to' argument and returns an output tensor of the same size in - the converted type. The 'to' argument must be one of the data types specified - in the 'DataType' enum field in the TensorProto message. - - Casting from string tensor in plain (e.g., "3.14" and "1000") and scientific numeric representations - (e.g., "1e-5" and "1E8") to float types is supported. For example, converting string "100.5" to an integer may - result 100. There are some string literals reserved for special floating-point values; - "+INF" (and "INF"), "-INF", and "NaN" are positive infinity, negative infinity, and not-a-number, respectively. - Any string which can exactly match "+INF" in a case-insensitive way would be mapped to positive infinite. Similarly, - this case-insensitive rule is applied to "INF" and "NaN". When casting from numeric tensors - to string tensors, plain floating-point representation (such as "314.15926") would be used. - Converting non-numerical-literal string such as "Hello World!" is an undefined behavior. Cases - of converting string representing floating-point arithmetic value, such as "2.718", to INT is an undefined behavior. - - Conversion from a numerical type to any numerical type is always allowed. - User must be aware of precision loss and value change caused by range difference between two types. - For example, a 64-bit float 3.1415926459 may be round to a 32-bit float 3.141592. Similarly, converting - an integer 36 to Boolean may produce 1 because we truncate bits which can't be stored in the targeted type. - """ - - def __init__(self, input, - to=None): - super().__init__('Cast', 1, - [{'at::kDouble', 'at::kLong', 'at::kByte', 'at::kInt', 'at::kHalf', 'at::kShort', 'at::kBool', 'at::kFloat', 'at::kBFloat16'}], - input, - to=ONNXAttr(to, AttrType.INT)) + """ + The operator casts the elements of a given input tensor to a data type + specified by the 'to' argument and returns an output tensor of the same size in + the converted type. The 'to' argument must be one of the data types specified + in the 'DataType' enum field in the TensorProto message. + + Casting from string tensor in plain (e.g., "3.14" and "1000") and scientific numeric representations + (e.g., "1e-5" and "1E8") to float types is supported. For example, converting string "100.5" to an integer may + result 100. There are some string literals reserved for special floating-point values; + "+INF" (and "INF"), "-INF", and "NaN" are positive infinity, negative infinity, and not-a-number, respectively. + Any string which can exactly match "+INF" in a case-insensitive way would be mapped to positive infinite. Similarly, + this case-insensitive rule is applied to "INF" and "NaN". When casting from numeric tensors + to string tensors, plain floating-point representation (such as "314.15926") would be used. + Converting non-numerical-literal string such as "Hello World!" is an undefined behavior. Cases + of converting string representing floating-point arithmetic value, such as "2.718", to INT is an undefined behavior. + + Conversion from a numerical type to any numerical type is always allowed. + User must be aware of precision loss and value change caused by range difference between two types. + For example, a 64-bit float 3.1415926459 may be round to a 32-bit float 3.141592. Similarly, converting + an integer 36 to Boolean may produce 1 because we truncate bits which can't be stored in the targeted type. + """ + + def __init__(self, input, to=None): + super().__init__( + "Cast", + 1, + [ + { + "at::kDouble", + "at::kLong", + "at::kByte", + "at::kInt", + "at::kHalf", + "at::kShort", + "at::kBool", + "at::kFloat", + "at::kBFloat16", + } + ], + input, + to=ONNXAttr(to, AttrType.INT), + ) + class CastMap(ONNXOp): - """ - Converts a map to a tensor.
The map key must be an int64 and the values will be ordered - in ascending order based on this key.
The operator supports dense packing or sparse packing. - If using sparse packing, the key cannot exceed the max_map-1 value. - """ - - def __init__(self, X, - cast_to=None, - map_form=None, - max_map=None): - super().__init__('CastMap', 1, - [set()], - X, - cast_to=ONNXAttr(cast_to, AttrType.STRING), - map_form=ONNXAttr(map_form, AttrType.STRING), - max_map=ONNXAttr(max_map, AttrType.INT)) + """ + Converts a map to a tensor.
The map key must be an int64 and the values will be ordered + in ascending order based on this key.
The operator supports dense packing or sparse packing. + If using sparse packing, the key cannot exceed the max_map-1 value. + """ + + def __init__(self, X, cast_to=None, map_form=None, max_map=None): + super().__init__( + "CastMap", + 1, + [set()], + X, + cast_to=ONNXAttr(cast_to, AttrType.STRING), + map_form=ONNXAttr(map_form, AttrType.STRING), + max_map=ONNXAttr(max_map, AttrType.INT), + ) + class CategoryMapper(ONNXOp): - """ - Converts strings to integers and vice versa.
- Two sequences of equal length are used to map between integers and strings, - with strings and integers at the same index detailing the mapping.
- Each operator converts either integers to strings or strings to integers, depending - on which default value attribute is provided. Only one default value attribute - should be defined.
- If the string default value is set, it will convert integers to strings. - If the int default value is set, it will convert strings to integers. - """ - - def __init__(self, X, - cats_int64s=None, - cats_strings=None, - default_int64=None, - default_string=None): - super().__init__('CategoryMapper', 1, - [{'at::kLong'}], - X, - cats_int64s=ONNXAttr(cats_int64s, AttrType.INTS), - cats_strings=ONNXAttr(cats_strings, AttrType.STRINGS), - default_int64=ONNXAttr(default_int64, AttrType.INT), - default_string=ONNXAttr(default_string, AttrType.STRING)) + """ + Converts strings to integers and vice versa.
+ Two sequences of equal length are used to map between integers and strings, + with strings and integers at the same index detailing the mapping.
+ Each operator converts either integers to strings or strings to integers, depending + on which default value attribute is provided. Only one default value attribute + should be defined.
+ If the string default value is set, it will convert integers to strings. + If the int default value is set, it will convert strings to integers. + """ + + def __init__(self, X, cats_int64s=None, cats_strings=None, default_int64=None, default_string=None): + super().__init__( + "CategoryMapper", + 1, + [{"at::kLong"}], + X, + cats_int64s=ONNXAttr(cats_int64s, AttrType.INTS), + cats_strings=ONNXAttr(cats_strings, AttrType.STRINGS), + default_int64=ONNXAttr(default_int64, AttrType.INT), + default_string=ONNXAttr(default_string, AttrType.STRING), + ) + class Ceil(ONNXOp): - """ - Ceil takes one input data (Tensor) and produces one output data - (Tensor) where the ceil is, y = ceil(x), is applied to - the tensor elementwise. - """ + """ + Ceil takes one input data (Tensor) and produces one output data + (Tensor) where the ceil is, y = ceil(x), is applied to + the tensor elementwise. + """ + + def __init__(self, X): + super().__init__("Ceil", 1, [{"at::kDouble", "at::kBFloat16", "at::kHalf", "at::kFloat"}], X) - def __init__(self, X): - super().__init__('Ceil', 1, - [{'at::kDouble', 'at::kBFloat16', 'at::kHalf', 'at::kFloat'}], - X) class Celu(ONNXOp): - """ - Continuously Differentiable Exponential Linear Units: - Perform the linear unit element-wise on the input tensor X - using formula: - - ``` - max(0,x) + min(0,alpha*(exp(x/alpha)-1)) - ``` - """ - - def __init__(self, X, - alpha=None): - super().__init__('Celu', 1, - [{'at::kFloat'}], - X, - alpha=ONNXAttr(alpha, AttrType.FLOAT)) + """ + Continuously Differentiable Exponential Linear Units: + Perform the linear unit element-wise on the input tensor X + using formula: + + ``` + max(0,x) + min(0,alpha*(exp(x/alpha)-1)) + ``` + """ + + def __init__(self, X, alpha=None): + super().__init__("Celu", 1, [{"at::kFloat"}], X, alpha=ONNXAttr(alpha, AttrType.FLOAT)) + class Clip(ONNXOp): - """ - Clip operator limits the given input within an interval. The interval is - specified by the inputs 'min' and 'max'. They default to - numeric_limits::lowest() and numeric_limits::max(), respectively. - """ + """ + Clip operator limits the given input within an interval. The interval is + specified by the inputs 'min' and 'max'. They default to + numeric_limits::lowest() and numeric_limits::max(), respectively. + """ + + def __init__(self, input, min, max): + super().__init__( + "Clip", + 1, + [ + { + "at::kDouble", + "at::kLong", + "at::kByte", + "at::kInt", + "at::kHalf", + "at::kShort", + "at::kFloat", + "at::kBFloat16", + }, + { + "at::kDouble", + "at::kLong", + "at::kByte", + "at::kInt", + "at::kHalf", + "at::kShort", + "at::kFloat", + "at::kBFloat16", + }, + { + "at::kDouble", + "at::kLong", + "at::kByte", + "at::kInt", + "at::kHalf", + "at::kShort", + "at::kFloat", + "at::kBFloat16", + }, + ], + input, + min, + max, + ) - def __init__(self, input, min, max): - super().__init__('Clip', 1, - [{'at::kDouble', 'at::kLong', 'at::kByte', 'at::kInt', 'at::kHalf', 'at::kShort', 'at::kFloat', 'at::kBFloat16'}, {'at::kDouble', 'at::kLong', 'at::kByte', 'at::kInt', 'at::kHalf', 'at::kShort', 'at::kFloat', 'at::kBFloat16'}, {'at::kDouble', 'at::kLong', 'at::kByte', 'at::kInt', 'at::kHalf', 'at::kShort', 'at::kFloat', 'at::kBFloat16'}], - input,min,max) class Compress(ONNXOp): - """ - Selects slices from an input tensor along a given axis where condition evaluates to True for each axis index. - In case axis is not provided, input is flattened before elements are selected. - Compress behaves like numpy.compress: https://docs.scipy.org/doc/numpy/reference/generated/numpy.compress.html - - """ - - def __init__(self, input, condition, - axis=None): - super().__init__('Compress', 1, - [{'at::kDouble', 'at::kLong', 'at::kByte', 'at::kInt', 'at::kHalf', 'at::kShort', 'at::kBool', 'at::kFloat'}, {'at::kBool'}], - input,condition, - axis=ONNXAttr(axis, AttrType.INT)) + """ + Selects slices from an input tensor along a given axis where condition evaluates to True for each axis index. + In case axis is not provided, input is flattened before elements are selected. + Compress behaves like numpy.compress: https://docs.scipy.org/doc/numpy/reference/generated/numpy.compress.html + + """ + + def __init__(self, input, condition, axis=None): + super().__init__( + "Compress", + 1, + [ + { + "at::kDouble", + "at::kLong", + "at::kByte", + "at::kInt", + "at::kHalf", + "at::kShort", + "at::kBool", + "at::kFloat", + }, + {"at::kBool"}, + ], + input, + condition, + axis=ONNXAttr(axis, AttrType.INT), + ) + class Concat(ONNXOp): - """ - Concatenate a list of tensors into a single tensor. All input tensors must have the same shape, except for the dimension size of the axis to concatenate on. - """ + """ + Concatenate a list of tensors into a single tensor. All input tensors must have the same shape, except for the dimension size of the axis to concatenate on. + """ + + def __init__(self, inputs, axis=None): + super().__init__( + "Concat", + 1, + [ + { + "at::kDouble", + "at::kLong", + "at::kByte", + "at::kInt", + "at::kHalf", + "at::kShort", + "at::kBool", + "at::kFloat", + "at::kBFloat16", + } + ], + inputs, + axis=ONNXAttr(axis, AttrType.INT), + ) - def __init__(self, inputs, - axis=None): - super().__init__('Concat', 1, - [{'at::kDouble', 'at::kLong', 'at::kByte', 'at::kInt', 'at::kHalf', 'at::kShort', 'at::kBool', 'at::kFloat', 'at::kBFloat16'}], - inputs, - axis=ONNXAttr(axis, AttrType.INT)) class ConcatFromSequence(ONNXOp): - """ - Concatenate a sequence of tensors into a single tensor. - All input tensors must have the same shape, except for the dimension size of the axis to concatenate on. - By default 'new_axis' is 0, the behavior is similar to numpy.concatenate. - When 'new_axis' is 1, the behavior is similar to numpy.stack. - """ - - def __init__(self, input_sequence, - axis=None, - new_axis=None): - super().__init__('ConcatFromSequence', 1, - [set()], - input_sequence, - axis=ONNXAttr(axis, AttrType.INT), - new_axis=ONNXAttr(new_axis, AttrType.INT)) + """ + Concatenate a sequence of tensors into a single tensor. + All input tensors must have the same shape, except for the dimension size of the axis to concatenate on. + By default 'new_axis' is 0, the behavior is similar to numpy.concatenate. + When 'new_axis' is 1, the behavior is similar to numpy.stack. + """ + + def __init__(self, input_sequence, axis=None, new_axis=None): + super().__init__( + "ConcatFromSequence", + 1, + [set()], + input_sequence, + axis=ONNXAttr(axis, AttrType.INT), + new_axis=ONNXAttr(new_axis, AttrType.INT), + ) + class Constant(ONNXOp): - """ - This operator produces a constant tensor. Exactly one of the provided attributes, either value, sparse_value, - or value_* must be specified. - """ - - def __init__(self, - sparse_value=None, - value=None, - value_float=None, - value_floats=None, - value_int=None, - value_ints=None, - value_string=None, - value_strings=None): - super().__init__('Constant', 1, - [], - sparse_value=ONNXAttr(sparse_value, AttrType.SPARSE_TENSOR), - value=ONNXAttr(value, AttrType.TENSOR), - value_float=ONNXAttr(value_float, AttrType.FLOAT), - value_floats=ONNXAttr(value_floats, AttrType.FLOATS), - value_int=ONNXAttr(value_int, AttrType.INT), - value_ints=ONNXAttr(value_ints, AttrType.INTS), - value_string=ONNXAttr(value_string, AttrType.STRING), - value_strings=ONNXAttr(value_strings, AttrType.STRINGS)) + """ + This operator produces a constant tensor. Exactly one of the provided attributes, either value, sparse_value, + or value_* must be specified. + """ + + def __init__( + self, + sparse_value=None, + value=None, + value_float=None, + value_floats=None, + value_int=None, + value_ints=None, + value_string=None, + value_strings=None, + ): + super().__init__( + "Constant", + 1, + [], + sparse_value=ONNXAttr(sparse_value, AttrType.SPARSE_TENSOR), + value=ONNXAttr(value, AttrType.TENSOR), + value_float=ONNXAttr(value_float, AttrType.FLOAT), + value_floats=ONNXAttr(value_floats, AttrType.FLOATS), + value_int=ONNXAttr(value_int, AttrType.INT), + value_ints=ONNXAttr(value_ints, AttrType.INTS), + value_string=ONNXAttr(value_string, AttrType.STRING), + value_strings=ONNXAttr(value_strings, AttrType.STRINGS), + ) + class ConstantOfShape(ONNXOp): - """ - Generate a tensor with given value and shape. - """ + """ + Generate a tensor with given value and shape. + """ + + def __init__(self, input, value=None): + super().__init__("ConstantOfShape", 1, [{"at::kLong"}], input, value=ONNXAttr(value, AttrType.TENSOR)) - def __init__(self, input, - value=None): - super().__init__('ConstantOfShape', 1, - [{'at::kLong'}], - input, - value=ONNXAttr(value, AttrType.TENSOR)) class Conv(ONNXOp): - """ - The convolution operator consumes an input tensor and a filter, and - computes the output. - """ - - def __init__(self, X, W, B, - auto_pad=None, - dilations=None, - group=None, - kernel_shape=None, - pads=None, - strides=None): - super().__init__('Conv', 1, - [{'at::kDouble', 'at::kHalf', 'at::kFloat'}, {'at::kDouble', 'at::kHalf', 'at::kFloat'}, {'at::kDouble', 'at::kHalf', 'at::kFloat'}], - X,W,B, - auto_pad=ONNXAttr(auto_pad, AttrType.STRING), - dilations=ONNXAttr(dilations, AttrType.INTS), - group=ONNXAttr(group, AttrType.INT), - kernel_shape=ONNXAttr(kernel_shape, AttrType.INTS), - pads=ONNXAttr(pads, AttrType.INTS), - strides=ONNXAttr(strides, AttrType.INTS)) + """ + The convolution operator consumes an input tensor and a filter, and + computes the output. + """ + + def __init__(self, X, W, B, auto_pad=None, dilations=None, group=None, kernel_shape=None, pads=None, strides=None): + super().__init__( + "Conv", + 1, + [ + {"at::kDouble", "at::kHalf", "at::kFloat"}, + {"at::kDouble", "at::kHalf", "at::kFloat"}, + {"at::kDouble", "at::kHalf", "at::kFloat"}, + ], + X, + W, + B, + auto_pad=ONNXAttr(auto_pad, AttrType.STRING), + dilations=ONNXAttr(dilations, AttrType.INTS), + group=ONNXAttr(group, AttrType.INT), + kernel_shape=ONNXAttr(kernel_shape, AttrType.INTS), + pads=ONNXAttr(pads, AttrType.INTS), + strides=ONNXAttr(strides, AttrType.INTS), + ) + class ConvInteger(ONNXOp): - """ - The integer convolution operator consumes an input tensor, its zero-point, a filter, and its zero-point, - and computes the output. The production MUST never overflow. The accumulation may overflow if and only if in 32 bits. - """ - - def __init__(self, x, w, x_zero_point, w_zero_point, - auto_pad=None, - dilations=None, - group=None, - kernel_shape=None, - pads=None, - strides=None): - super().__init__('ConvInteger', 1, - [{'at::kByte'}, {'at::kByte'}, {'at::kByte'}, {'at::kByte'}], - x,w,x_zero_point,w_zero_point, - auto_pad=ONNXAttr(auto_pad, AttrType.STRING), - dilations=ONNXAttr(dilations, AttrType.INTS), - group=ONNXAttr(group, AttrType.INT), - kernel_shape=ONNXAttr(kernel_shape, AttrType.INTS), - pads=ONNXAttr(pads, AttrType.INTS), - strides=ONNXAttr(strides, AttrType.INTS)) + """ + The integer convolution operator consumes an input tensor, its zero-point, a filter, and its zero-point, + and computes the output. The production MUST never overflow. The accumulation may overflow if and only if in 32 bits. + """ + + def __init__( + self, + x, + w, + x_zero_point, + w_zero_point, + auto_pad=None, + dilations=None, + group=None, + kernel_shape=None, + pads=None, + strides=None, + ): + super().__init__( + "ConvInteger", + 1, + [{"at::kByte"}, {"at::kByte"}, {"at::kByte"}, {"at::kByte"}], + x, + w, + x_zero_point, + w_zero_point, + auto_pad=ONNXAttr(auto_pad, AttrType.STRING), + dilations=ONNXAttr(dilations, AttrType.INTS), + group=ONNXAttr(group, AttrType.INT), + kernel_shape=ONNXAttr(kernel_shape, AttrType.INTS), + pads=ONNXAttr(pads, AttrType.INTS), + strides=ONNXAttr(strides, AttrType.INTS), + ) + class ConvTranspose(ONNXOp): - """ - The convolution transpose operator consumes an input tensor and a filter, - and computes the output. - - If the pads parameter is provided the shape of the output is calculated via the following equation: - - output_shape[i] = stride[i] * (input_size[i] - 1) + output_padding[i] + ((kernel_shape[i] - 1) * dilations[i] + 1) - pads[start_i] - pads[end_i] - - output_shape can also be explicitly specified in which case pads values are auto generated using these equations: - - total_padding[i] = stride[i] * (input_size[i] - 1) + output_padding[i] + ((kernel_shape[i] - 1) * dilations[i] + 1) - output_shape[i] - If (auto_pads != SAME_UPPER): pads[start_i] = total_padding[i]/2; pads[end_i] = total_padding[i] - (total_padding[i]/2) - Else: pads[start_i] = total_padding[i] - (total_padding[i]/2); pads[end_i] = (total_padding[i]/2). - - - """ - - def __init__(self, X, W, B, - auto_pad=None, - dilations=None, - group=None, - kernel_shape=None, - output_padding=None, - output_shape=None, - pads=None, - strides=None): - super().__init__('ConvTranspose', 1, - [{'at::kDouble', 'at::kHalf', 'at::kFloat'}, {'at::kDouble', 'at::kHalf', 'at::kFloat'}, {'at::kDouble', 'at::kHalf', 'at::kFloat'}], - X,W,B, - auto_pad=ONNXAttr(auto_pad, AttrType.STRING), - dilations=ONNXAttr(dilations, AttrType.INTS), - group=ONNXAttr(group, AttrType.INT), - kernel_shape=ONNXAttr(kernel_shape, AttrType.INTS), - output_padding=ONNXAttr(output_padding, AttrType.INTS), - output_shape=ONNXAttr(output_shape, AttrType.INTS), - pads=ONNXAttr(pads, AttrType.INTS), - strides=ONNXAttr(strides, AttrType.INTS)) + """ + The convolution transpose operator consumes an input tensor and a filter, + and computes the output. + + If the pads parameter is provided the shape of the output is calculated via the following equation: + + output_shape[i] = stride[i] * (input_size[i] - 1) + output_padding[i] + ((kernel_shape[i] - 1) * dilations[i] + 1) - pads[start_i] - pads[end_i] + + output_shape can also be explicitly specified in which case pads values are auto generated using these equations: + + total_padding[i] = stride[i] * (input_size[i] - 1) + output_padding[i] + ((kernel_shape[i] - 1) * dilations[i] + 1) - output_shape[i] + If (auto_pads != SAME_UPPER): pads[start_i] = total_padding[i]/2; pads[end_i] = total_padding[i] - (total_padding[i]/2) + Else: pads[start_i] = total_padding[i] - (total_padding[i]/2); pads[end_i] = (total_padding[i]/2). + + + """ + + def __init__( + self, + X, + W, + B, + auto_pad=None, + dilations=None, + group=None, + kernel_shape=None, + output_padding=None, + output_shape=None, + pads=None, + strides=None, + ): + super().__init__( + "ConvTranspose", + 1, + [ + {"at::kDouble", "at::kHalf", "at::kFloat"}, + {"at::kDouble", "at::kHalf", "at::kFloat"}, + {"at::kDouble", "at::kHalf", "at::kFloat"}, + ], + X, + W, + B, + auto_pad=ONNXAttr(auto_pad, AttrType.STRING), + dilations=ONNXAttr(dilations, AttrType.INTS), + group=ONNXAttr(group, AttrType.INT), + kernel_shape=ONNXAttr(kernel_shape, AttrType.INTS), + output_padding=ONNXAttr(output_padding, AttrType.INTS), + output_shape=ONNXAttr(output_shape, AttrType.INTS), + pads=ONNXAttr(pads, AttrType.INTS), + strides=ONNXAttr(strides, AttrType.INTS), + ) + class Cos(ONNXOp): - """ - Calculates the cosine of the given input tensor, element-wise. - """ + """ + Calculates the cosine of the given input tensor, element-wise. + """ + + def __init__(self, input): + super().__init__("Cos", 1, [{"at::kDouble", "at::kHalf", "at::kFloat"}], input) - def __init__(self, input): - super().__init__('Cos', 1, - [{'at::kDouble', 'at::kHalf', 'at::kFloat'}], - input) class Cosh(ONNXOp): - """ - Calculates the hyperbolic cosine of the given input tensor element-wise. - """ + """ + Calculates the hyperbolic cosine of the given input tensor element-wise. + """ + + def __init__(self, input): + super().__init__("Cosh", 1, [{"at::kDouble", "at::kHalf", "at::kFloat"}], input) - def __init__(self, input): - super().__init__('Cosh', 1, - [{'at::kDouble', 'at::kHalf', 'at::kFloat'}], - input) class CumSum(ONNXOp): - """ - Performs cumulative sum of the input elements along the given axis. - By default, it will do the sum inclusively meaning the first element is copied as is. - Through an `exclusive` attribute, this behavior can change to exclude the first element. - It can also perform summation in the opposite direction of the axis. For that, set `reverse` attribute to 1. - - Example: - ``` - input_x = [1, 2, 3] - axis=0 - output = [1, 3, 6] - exclusive=1 - output = [0, 1, 3] - exclusive=0 - reverse=1 - output = [6, 5, 3] - exclusive=1 - reverse=1 - output = [5, 3, 0] - ``` - - """ - - def __init__(self, x, axis, - exclusive=None, - reverse=None): - super().__init__('CumSum', 1, - [{'at::kDouble', 'at::kLong', 'at::kInt', 'at::kHalf', 'at::kFloat', 'at::kBFloat16'}, {'at::kLong', 'at::kInt'}], - x,axis, - exclusive=ONNXAttr(exclusive, AttrType.INT), - reverse=ONNXAttr(reverse, AttrType.INT)) + """ + Performs cumulative sum of the input elements along the given axis. + By default, it will do the sum inclusively meaning the first element is copied as is. + Through an `exclusive` attribute, this behavior can change to exclude the first element. + It can also perform summation in the opposite direction of the axis. For that, set `reverse` attribute to 1. + + Example: + ``` + input_x = [1, 2, 3] + axis=0 + output = [1, 3, 6] + exclusive=1 + output = [0, 1, 3] + exclusive=0 + reverse=1 + output = [6, 5, 3] + exclusive=1 + reverse=1 + output = [5, 3, 0] + ``` + + """ + + def __init__(self, x, axis, exclusive=None, reverse=None): + super().__init__( + "CumSum", + 1, + [ + {"at::kDouble", "at::kLong", "at::kInt", "at::kHalf", "at::kFloat", "at::kBFloat16"}, + {"at::kLong", "at::kInt"}, + ], + x, + axis, + exclusive=ONNXAttr(exclusive, AttrType.INT), + reverse=ONNXAttr(reverse, AttrType.INT), + ) + class DepthToSpace(ONNXOp): - """ - DepthToSpace rearranges (permutes) data from depth into blocks of spatial data. - This is the reverse transformation of SpaceToDepth. More specifically, this op outputs a copy of - the input tensor where values from the depth dimension are moved in spatial blocks to the height - and width dimensions. By default, `mode` = `DCR`. - In the DCR mode, elements along the depth dimension from the input tensor are rearranged in the - following order: depth, column, and then row. The output y is computed from the input x as below: - - b, c, h, w = x.shape - - tmp = np.reshape(x, [b, blocksize, blocksize, c // (blocksize**2), h, w]) - - tmp = np.transpose(tmp, [0, 3, 4, 1, 5, 2]) - - y = np.reshape(tmp, [b, c // (blocksize**2), h * blocksize, w * blocksize]) - - - In the CRD mode, elements along the depth dimension from the input tensor are rearranged in the - following order: column, row, and the depth. The output y is computed from the input x as below: - - b, c, h, w = x.shape - - tmp = np.reshape(x, [b, c // (blocksize ** 2), blocksize, blocksize, h, w]) - - tmp = np.transpose(tmp, [0, 1, 4, 2, 5, 3]) - - y = np.reshape(tmp, [b, c // (blocksize ** 2), h * blocksize, w * blocksize]) - """ - - def __init__(self, input, - blocksize=None, - mode=None): - super().__init__('DepthToSpace', 1, - [{'at::kDouble', 'at::kLong', 'at::kByte', 'at::kInt', 'at::kHalf', 'at::kShort', 'at::kBool', 'at::kFloat', 'at::kBFloat16'}], - input, - blocksize=ONNXAttr(blocksize, AttrType.INT), - mode=ONNXAttr(mode, AttrType.STRING)) + """ + DepthToSpace rearranges (permutes) data from depth into blocks of spatial data. + This is the reverse transformation of SpaceToDepth. More specifically, this op outputs a copy of + the input tensor where values from the depth dimension are moved in spatial blocks to the height + and width dimensions. By default, `mode` = `DCR`. + In the DCR mode, elements along the depth dimension from the input tensor are rearranged in the + following order: depth, column, and then row. The output y is computed from the input x as below: + + b, c, h, w = x.shape + + tmp = np.reshape(x, [b, blocksize, blocksize, c // (blocksize**2), h, w]) + + tmp = np.transpose(tmp, [0, 3, 4, 1, 5, 2]) + + y = np.reshape(tmp, [b, c // (blocksize**2), h * blocksize, w * blocksize]) + + + In the CRD mode, elements along the depth dimension from the input tensor are rearranged in the + following order: column, row, and the depth. The output y is computed from the input x as below: + + b, c, h, w = x.shape + + tmp = np.reshape(x, [b, c // (blocksize ** 2), blocksize, blocksize, h, w]) + + tmp = np.transpose(tmp, [0, 1, 4, 2, 5, 3]) + + y = np.reshape(tmp, [b, c // (blocksize ** 2), h * blocksize, w * blocksize]) + """ + + def __init__(self, input, blocksize=None, mode=None): + super().__init__( + "DepthToSpace", + 1, + [ + { + "at::kDouble", + "at::kLong", + "at::kByte", + "at::kInt", + "at::kHalf", + "at::kShort", + "at::kBool", + "at::kFloat", + "at::kBFloat16", + } + ], + input, + blocksize=ONNXAttr(blocksize, AttrType.INT), + mode=ONNXAttr(mode, AttrType.STRING), + ) + class DequantizeLinear(ONNXOp): - """ - The linear dequantization operator. It consumes a quantized tensor, a scale, and a zero point to compute the full precision tensor. - The dequantization formula is y = (x - x_zero_point) * x_scale. 'x_scale' and 'x_zero_point' must have same shape, and can be either a scalar - for per-tensor / per layer quantization, or a 1-D tensor for per-axis quantizations. - 'x_zero_point' and 'x' must have same type. 'x' and 'y' must have same shape. In the case of dequantizing int32, - there's no zero point (zero point is supposed to be 0). - """ - - def __init__(self, x, x_scale, x_zero_point, - axis=None): - super().__init__('DequantizeLinear', 1, - [{'at::kByte', 'at::kInt'}, {'at::kFloat'}, {'at::kByte', 'at::kInt'}], - x,x_scale,x_zero_point, - axis=ONNXAttr(axis, AttrType.INT)) + """ + The linear dequantization operator. It consumes a quantized tensor, a scale, and a zero point to compute the full precision tensor. + The dequantization formula is y = (x - x_zero_point) * x_scale. 'x_scale' and 'x_zero_point' must have same shape, and can be either a scalar + for per-tensor / per layer quantization, or a 1-D tensor for per-axis quantizations. + 'x_zero_point' and 'x' must have same type. 'x' and 'y' must have same shape. In the case of dequantizing int32, + there's no zero point (zero point is supposed to be 0). + """ + + def __init__(self, x, x_scale, x_zero_point, axis=None): + super().__init__( + "DequantizeLinear", + 1, + [{"at::kByte", "at::kInt"}, {"at::kFloat"}, {"at::kByte", "at::kInt"}], + x, + x_scale, + x_zero_point, + axis=ONNXAttr(axis, AttrType.INT), + ) + class Det(ONNXOp): - """ - Det calculates determinant of a square matrix or batches of square matrices. - Det takes one input tensor of shape `[*, M, M]`, where `*` is zero or more batch dimensions, - and the inner-most 2 dimensions form square matrices. - The output is a tensor of shape `[*]`, containing the determinants of all input submatrices. - e.g., When the input is 2-D, the output is a scalar(shape is empty: `[]`). - """ - - def __init__(self, X): - super().__init__('Det', 1, - [{'at::kDouble', 'at::kHalf', 'at::kFloat'}], - X) + """ + Det calculates determinant of a square matrix or batches of square matrices. + Det takes one input tensor of shape `[*, M, M]`, where `*` is zero or more batch dimensions, + and the inner-most 2 dimensions form square matrices. + The output is a tensor of shape `[*]`, containing the determinants of all input submatrices. + e.g., When the input is 2-D, the output is a scalar(shape is empty: `[]`). + """ + + def __init__(self, X): + super().__init__("Det", 1, [{"at::kDouble", "at::kHalf", "at::kFloat"}], X) + class DictVectorizer(ONNXOp): - """ - Uses an index mapping to convert a dictionary to an array.
- Given a dictionary, each key is looked up in the vocabulary attribute corresponding to - the key type. The index into the vocabulary array at which the key is found is then - used to index the output 1-D tensor 'Y' and insert into it the value found in the dictionary 'X'.
- The key type of the input map must correspond to the element type of the defined vocabulary attribute. - Therefore, the output array will be equal in length to the index mapping vector parameter. - All keys in the input dictionary must be present in the index mapping vector. - For each item in the input dictionary, insert its value in the output array. - Any keys not present in the input dictionary, will be zero in the output array.
- For example: if the ``string_vocabulary`` parameter is set to ``["a", "c", "b", "z"]``, - then an input of ``{"a": 4, "c": 8}`` will produce an output of ``[4, 8, 0, 0]``. - - """ - - def __init__(self, X, - int64_vocabulary=None, - string_vocabulary=None): - super().__init__('DictVectorizer', 1, - [set()], - X, - int64_vocabulary=ONNXAttr(int64_vocabulary, AttrType.INTS), - string_vocabulary=ONNXAttr(string_vocabulary, AttrType.STRINGS)) + """ + Uses an index mapping to convert a dictionary to an array.
+ Given a dictionary, each key is looked up in the vocabulary attribute corresponding to + the key type. The index into the vocabulary array at which the key is found is then + used to index the output 1-D tensor 'Y' and insert into it the value found in the dictionary 'X'.
+ The key type of the input map must correspond to the element type of the defined vocabulary attribute. + Therefore, the output array will be equal in length to the index mapping vector parameter. + All keys in the input dictionary must be present in the index mapping vector. + For each item in the input dictionary, insert its value in the output array. + Any keys not present in the input dictionary, will be zero in the output array.
+ For example: if the ``string_vocabulary`` parameter is set to ``["a", "c", "b", "z"]``, + then an input of ``{"a": 4, "c": 8}`` will produce an output of ``[4, 8, 0, 0]``. + + """ + + def __init__(self, X, int64_vocabulary=None, string_vocabulary=None): + super().__init__( + "DictVectorizer", + 1, + [set()], + X, + int64_vocabulary=ONNXAttr(int64_vocabulary, AttrType.INTS), + string_vocabulary=ONNXAttr(string_vocabulary, AttrType.STRINGS), + ) + class Div(ONNXOp): - """ - Performs element-wise binary division (with Numpy-style broadcasting support). - - This operator supports **multidirectional (i.e., Numpy-style) broadcasting**; for more details please check [the doc](Broadcasting.md). - - (Opset 14 change): Extend supported types to include uint8, int8, uint16, and int16. - """ - - def __init__(self, A, B): - super().__init__('Div', 1, - [{'at::kDouble', 'at::kLong', 'at::kByte', 'at::kInt', 'at::kHalf', 'at::kShort', 'at::kFloat', 'at::kBFloat16'}, {'at::kDouble', 'at::kLong', 'at::kByte', 'at::kInt', 'at::kHalf', 'at::kShort', 'at::kFloat', 'at::kBFloat16'}], - A,B) + """ + Performs element-wise binary division (with Numpy-style broadcasting support). + + This operator supports **multidirectional (i.e., Numpy-style) broadcasting**; for more details please check [the doc](Broadcasting.md). + + (Opset 14 change): Extend supported types to include uint8, int8, uint16, and int16. + """ + + def __init__(self, A, B): + super().__init__( + "Div", + 1, + [ + { + "at::kDouble", + "at::kLong", + "at::kByte", + "at::kInt", + "at::kHalf", + "at::kShort", + "at::kFloat", + "at::kBFloat16", + }, + { + "at::kDouble", + "at::kLong", + "at::kByte", + "at::kInt", + "at::kHalf", + "at::kShort", + "at::kFloat", + "at::kBFloat16", + }, + ], + A, + B, + ) + class Dropout(ONNXOp): - """ - Dropout takes an input floating-point tensor, an optional input ratio (floating-point scalar) and an optional input training_mode (boolean scalar). It produces two tensor outputs, - output (floating-point tensor) and mask (optional `Tensor`). If `training_mode` is true then the output Y will be a random dropout; - Note that this Dropout scales the masked input data by the following equation, so to convert the trained model into inference mode, - the user can simply not pass `training_mode` input or set it to false. - ``` - output = scale * data * mask, - ``` - where - ``` - scale = 1. / (1. - ratio). - ``` - This operator has **optional** inputs/outputs. See [the doc](IR.md) for more details about the representation of optional arguments. An empty string may be used in the place of an actual argument's name to indicate a missing argument. Trailing optional arguments (those not followed by an argument that is present) may also be simply omitted. - """ - - def __init__(self, data, ratio, training_mode, - seed=None): - super().__init__('Dropout', 2, - [{'at::kDouble', 'at::kBFloat16', 'at::kHalf', 'at::kFloat'}, {'at::kDouble', 'at::kHalf', 'at::kFloat'}, {'at::kBool'}], - data,ratio,training_mode, - seed=ONNXAttr(seed, AttrType.INT)) + """ + Dropout takes an input floating-point tensor, an optional input ratio (floating-point scalar) and an optional input training_mode (boolean scalar). It produces two tensor outputs, + output (floating-point tensor) and mask (optional `Tensor`). If `training_mode` is true then the output Y will be a random dropout; + Note that this Dropout scales the masked input data by the following equation, so to convert the trained model into inference mode, + the user can simply not pass `training_mode` input or set it to false. + ``` + output = scale * data * mask, + ``` + where + ``` + scale = 1. / (1. - ratio). + ``` + This operator has **optional** inputs/outputs. See [the doc](IR.md) for more details about the representation of optional arguments. An empty string may be used in the place of an actual argument's name to indicate a missing argument. Trailing optional arguments (those not followed by an argument that is present) may also be simply omitted. + """ + + def __init__(self, data, ratio, training_mode, seed=None): + super().__init__( + "Dropout", + 2, + [ + {"at::kDouble", "at::kBFloat16", "at::kHalf", "at::kFloat"}, + {"at::kDouble", "at::kHalf", "at::kFloat"}, + {"at::kBool"}, + ], + data, + ratio, + training_mode, + seed=ONNXAttr(seed, AttrType.INT), + ) + class DynamicQuantizeLinear(ONNXOp): - """ - A Function to fuse calculation for Scale, Zero Point and FP32->8Bit convertion of FP32 Input data. - Outputs Scale, ZeroPoint and Quantized Input for a given FP32 Input. - Scale is calculated as: - ``` - y_scale = (max(x) - min(x))/(qmax - qmin) - * where qmax and qmin are max and min values for quantization range .i.e [0, 255] in case of uint8 - * data range is adjusted to include 0. - ``` - Zero point is calculated as: - ``` - intermediate_zero_point = qmin - min(x)/y_scale - y_zero_point = cast(round(saturate(itermediate_zero_point))) - * where qmax and qmin are max and min values for quantization range .i.e [0, 255] in case of uint8 - * for saturation, it saturates to [0, 255] if it's uint8, or [-127, 127] if it's int8. Right now only uint8 is supported. - * rounding to nearest ties to even. - ``` - Data quantization formula is: - ``` - y = saturate (round (x / y_scale) + y_zero_point) - * for saturation, it saturates to [0, 255] if it's uint8, or [-127, 127] if it's int8. Right now only uint8 is supported. - * rounding to nearest ties to even. - ``` - """ - - def __init__(self, x): - super().__init__('DynamicQuantizeLinear', 3, - [{'at::kFloat'}], - x) + """ + A Function to fuse calculation for Scale, Zero Point and FP32->8Bit convertion of FP32 Input data. + Outputs Scale, ZeroPoint and Quantized Input for a given FP32 Input. + Scale is calculated as: + ``` + y_scale = (max(x) - min(x))/(qmax - qmin) + * where qmax and qmin are max and min values for quantization range .i.e [0, 255] in case of uint8 + * data range is adjusted to include 0. + ``` + Zero point is calculated as: + ``` + intermediate_zero_point = qmin - min(x)/y_scale + y_zero_point = cast(round(saturate(itermediate_zero_point))) + * where qmax and qmin are max and min values for quantization range .i.e [0, 255] in case of uint8 + * for saturation, it saturates to [0, 255] if it's uint8, or [-127, 127] if it's int8. Right now only uint8 is supported. + * rounding to nearest ties to even. + ``` + Data quantization formula is: + ``` + y = saturate (round (x / y_scale) + y_zero_point) + * for saturation, it saturates to [0, 255] if it's uint8, or [-127, 127] if it's int8. Right now only uint8 is supported. + * rounding to nearest ties to even. + ``` + """ + + def __init__(self, x): + super().__init__("DynamicQuantizeLinear", 3, [{"at::kFloat"}], x) + class Einsum(ONNXOp): - """ - An einsum of the form ```term1, term2 -> output-term``` produces an output tensor using the following equation - - ```output[output-term] = reduce-sum( input1[term1] * input2[term] )``` - - where the reduce-sum performs a summation over all the indices occurring in the input terms (term1, term2) - that do not occur in the output-term. - - The Einsum operator evaluates algebraic tensor operations on a sequence of tensors, using the Einstein summation - convention. The equation string contains a comma-separated sequence of lower case letters. Each term corresponds to - an operand tensor, and the characters within the terms correspond to operands dimensions. - - This sequence may be followed by "->" to separate the left and right hand side of the equation. - If the equation contains "->" followed by the right-hand side, the explicit (not classical) form of the Einstein - summation is performed, and the right-hand side indices indicate output tensor dimensions. In other cases, - output indices are (implicitly) set to the alphabetically sorted sequence of indices appearing exactly once in the - equation. - - When a dimension character is repeated in the left-hand side, it represents summation along the dimension. - - The equation may contain ellipsis ("...") to enable broadcasting. Ellipsis must indicate a fixed number of dimensions. - Specifically, every occurrence of ellipsis in the equation must represent the same number of dimensions. - The right-hand side may contain exactly one ellipsis. In implicit mode, the ellipsis dimensions are set to the - beginning of the output. The equation string may contain space (U+0020) character. - """ - - def __init__(self, Inputs, - equation=None): - super().__init__('Einsum', 1, - [{'at::kDouble', 'at::kLong', 'at::kByte', 'at::kInt', 'at::kHalf', 'at::kShort', 'at::kFloat'}], - Inputs, - equation=ONNXAttr(equation, AttrType.STRING)) + """ + An einsum of the form ```term1, term2 -> output-term``` produces an output tensor using the following equation + + ```output[output-term] = reduce-sum( input1[term1] * input2[term] )``` + + where the reduce-sum performs a summation over all the indices occurring in the input terms (term1, term2) + that do not occur in the output-term. + + The Einsum operator evaluates algebraic tensor operations on a sequence of tensors, using the Einstein summation + convention. The equation string contains a comma-separated sequence of lower case letters. Each term corresponds to + an operand tensor, and the characters within the terms correspond to operands dimensions. + + This sequence may be followed by "->" to separate the left and right hand side of the equation. + If the equation contains "->" followed by the right-hand side, the explicit (not classical) form of the Einstein + summation is performed, and the right-hand side indices indicate output tensor dimensions. In other cases, + output indices are (implicitly) set to the alphabetically sorted sequence of indices appearing exactly once in the + equation. + + When a dimension character is repeated in the left-hand side, it represents summation along the dimension. + + The equation may contain ellipsis ("...") to enable broadcasting. Ellipsis must indicate a fixed number of dimensions. + Specifically, every occurrence of ellipsis in the equation must represent the same number of dimensions. + The right-hand side may contain exactly one ellipsis. In implicit mode, the ellipsis dimensions are set to the + beginning of the output. The equation string may contain space (U+0020) character. + """ + + def __init__(self, Inputs, equation=None): + super().__init__( + "Einsum", + 1, + [{"at::kDouble", "at::kLong", "at::kByte", "at::kInt", "at::kHalf", "at::kShort", "at::kFloat"}], + Inputs, + equation=ONNXAttr(equation, AttrType.STRING), + ) + class Elu(ONNXOp): - """ - Elu takes one input data (Tensor) and produces one output data - (Tensor) where the function `f(x) = alpha * (exp(x) - 1.) for x < - 0`, `f(x) = x for x >= 0`., is applied to the tensor elementwise. - """ - - def __init__(self, X, - alpha=None): - super().__init__('Elu', 1, - [{'at::kDouble', 'at::kHalf', 'at::kFloat'}], - X, - alpha=ONNXAttr(alpha, AttrType.FLOAT)) + """ + Elu takes one input data (Tensor) and produces one output data + (Tensor) where the function `f(x) = alpha * (exp(x) - 1.) for x < + 0`, `f(x) = x for x >= 0`., is applied to the tensor elementwise. + """ + + def __init__(self, X, alpha=None): + super().__init__( + "Elu", 1, [{"at::kDouble", "at::kHalf", "at::kFloat"}], X, alpha=ONNXAttr(alpha, AttrType.FLOAT) + ) + class Equal(ONNXOp): - """ - Returns the tensor resulted from performing the `equal` logical operation - elementwise on the input tensors `A` and `B` (with Numpy-style broadcasting support). - - This operator supports **multidirectional (i.e., Numpy-style) broadcasting**; for more details please check [the doc](Broadcasting.md). - """ - - def __init__(self, A, B): - super().__init__('Equal', 1, - [{'at::kDouble', 'at::kLong', 'at::kByte', 'at::kInt', 'at::kHalf', 'at::kShort', 'at::kBool', 'at::kFloat', 'at::kBFloat16'}, {'at::kDouble', 'at::kLong', 'at::kByte', 'at::kInt', 'at::kHalf', 'at::kShort', 'at::kBool', 'at::kFloat', 'at::kBFloat16'}], - A,B) + """ + Returns the tensor resulted from performing the `equal` logical operation + elementwise on the input tensors `A` and `B` (with Numpy-style broadcasting support). + + This operator supports **multidirectional (i.e., Numpy-style) broadcasting**; for more details please check [the doc](Broadcasting.md). + """ + + def __init__(self, A, B): + super().__init__( + "Equal", + 1, + [ + { + "at::kDouble", + "at::kLong", + "at::kByte", + "at::kInt", + "at::kHalf", + "at::kShort", + "at::kBool", + "at::kFloat", + "at::kBFloat16", + }, + { + "at::kDouble", + "at::kLong", + "at::kByte", + "at::kInt", + "at::kHalf", + "at::kShort", + "at::kBool", + "at::kFloat", + "at::kBFloat16", + }, + ], + A, + B, + ) + class Erf(ONNXOp): - """ - Computes the error function of the given input tensor element-wise. - """ + """ + Computes the error function of the given input tensor element-wise. + """ + + def __init__(self, input): + super().__init__( + "Erf", + 1, + [ + { + "at::kDouble", + "at::kLong", + "at::kByte", + "at::kInt", + "at::kHalf", + "at::kShort", + "at::kFloat", + "at::kBFloat16", + } + ], + input, + ) - def __init__(self, input): - super().__init__('Erf', 1, - [{'at::kDouble', 'at::kLong', 'at::kByte', 'at::kInt', 'at::kHalf', 'at::kShort', 'at::kFloat', 'at::kBFloat16'}], - input) class Exp(ONNXOp): - """ - Calculates the exponential of the given input tensor, element-wise. - """ + """ + Calculates the exponential of the given input tensor, element-wise. + """ + + def __init__(self, input): + super().__init__("Exp", 1, [{"at::kDouble", "at::kBFloat16", "at::kHalf", "at::kFloat"}], input) - def __init__(self, input): - super().__init__('Exp', 1, - [{'at::kDouble', 'at::kBFloat16', 'at::kHalf', 'at::kFloat'}], - input) class Expand(ONNXOp): - """ - Broadcast the input tensor following the given shape and the broadcast rule. - The broadcast rule is similar to numpy.array(input) * numpy.ones(shape): - Dimensions are right alignment; - Two corresponding dimension must have the same value, or one of them is equal to 1. - Also, this operator is similar to numpy.broadcast_to(input, shape), - but the major difference is numpy.broadcast_to() does not allow shape to be smaller than input.size(). - It is possible that the output.shape is not equal to shape, when some dimensions in shape is equal to 1, - or the shape.ndim < input.shape.ndim. - """ - - def __init__(self, input, shape): - super().__init__('Expand', 1, - [{'at::kDouble', 'at::kLong', 'at::kByte', 'at::kInt', 'at::kHalf', 'at::kShort', 'at::kBool', 'at::kFloat', 'at::kBFloat16'}, {'at::kLong'}], - input,shape) + """ + Broadcast the input tensor following the given shape and the broadcast rule. + The broadcast rule is similar to numpy.array(input) * numpy.ones(shape): + Dimensions are right alignment; + Two corresponding dimension must have the same value, or one of them is equal to 1. + Also, this operator is similar to numpy.broadcast_to(input, shape), + but the major difference is numpy.broadcast_to() does not allow shape to be smaller than input.size(). + It is possible that the output.shape is not equal to shape, when some dimensions in shape is equal to 1, + or the shape.ndim < input.shape.ndim. + """ + + def __init__(self, input, shape): + super().__init__( + "Expand", + 1, + [ + { + "at::kDouble", + "at::kLong", + "at::kByte", + "at::kInt", + "at::kHalf", + "at::kShort", + "at::kBool", + "at::kFloat", + "at::kBFloat16", + }, + {"at::kLong"}, + ], + input, + shape, + ) + class EyeLike(ONNXOp): - """ - Generate a 2D tensor (matrix) with ones on the diagonal and zeros everywhere else. Only 2D - tensors are supported, i.e. input T1 must be of rank 2. The shape of the output tensor is the - same as the input tensor. The data type can be specified by the 'dtype' argument. If - 'dtype' is not specified, then the type of input tensor is used. By default, the main diagonal - is populated with ones, but attribute 'k' can be used to populate upper or lower diagonals. - The 'dtype' argument must be one of the data types specified in the 'DataType' enum field in the - TensorProto message and be valid as an output type. - """ - - def __init__(self, input, - dtype=None, - k=None): - super().__init__('EyeLike', 1, - [{'at::kDouble', 'at::kLong', 'at::kByte', 'at::kInt', 'at::kHalf', 'at::kShort', 'at::kBool', 'at::kFloat'}], - input, - dtype=ONNXAttr(dtype, AttrType.INT), - k=ONNXAttr(k, AttrType.INT)) + """ + Generate a 2D tensor (matrix) with ones on the diagonal and zeros everywhere else. Only 2D + tensors are supported, i.e. input T1 must be of rank 2. The shape of the output tensor is the + same as the input tensor. The data type can be specified by the 'dtype' argument. If + 'dtype' is not specified, then the type of input tensor is used. By default, the main diagonal + is populated with ones, but attribute 'k' can be used to populate upper or lower diagonals. + The 'dtype' argument must be one of the data types specified in the 'DataType' enum field in the + TensorProto message and be valid as an output type. + """ + + def __init__(self, input, dtype=None, k=None): + super().__init__( + "EyeLike", + 1, + [ + { + "at::kDouble", + "at::kLong", + "at::kByte", + "at::kInt", + "at::kHalf", + "at::kShort", + "at::kBool", + "at::kFloat", + } + ], + input, + dtype=ONNXAttr(dtype, AttrType.INT), + k=ONNXAttr(k, AttrType.INT), + ) + class FeatureVectorizer(ONNXOp): - """ - Concatenates input tensors into one continuous output.
- All input shapes are 2-D and are concatenated along the second dimention. 1-D tensors are treated as [1,C]. - Inputs are copied to the output maintaining the order of the input arguments.
- All inputs must be integers or floats, while the output will be all floating point values. - """ - - def __init__(self, X, - inputdimensions=None): - super().__init__('FeatureVectorizer', 1, - [{'at::kDouble', 'at::kLong', 'at::kInt', 'at::kFloat'}], - X, - inputdimensions=ONNXAttr(inputdimensions, AttrType.INTS)) + """ + Concatenates input tensors into one continuous output.
+ All input shapes are 2-D and are concatenated along the second dimention. 1-D tensors are treated as [1,C]. + Inputs are copied to the output maintaining the order of the input arguments.
+ All inputs must be integers or floats, while the output will be all floating point values. + """ + + def __init__(self, X, inputdimensions=None): + super().__init__( + "FeatureVectorizer", + 1, + [{"at::kDouble", "at::kLong", "at::kInt", "at::kFloat"}], + X, + inputdimensions=ONNXAttr(inputdimensions, AttrType.INTS), + ) + class Flatten(ONNXOp): - """ - Flattens the input tensor into a 2D matrix. If input tensor has shape - (d_0, d_1, ... d_n) then the output will have shape - (d_0 X d_1 ... d_(axis-1), d_axis X d_(axis+1) ... X dn). - """ - - def __init__(self, input, - axis=None): - super().__init__('Flatten', 1, - [{'at::kDouble', 'at::kLong', 'at::kByte', 'at::kInt', 'at::kHalf', 'at::kShort', 'at::kBool', 'at::kFloat', 'at::kBFloat16'}], - input, - axis=ONNXAttr(axis, AttrType.INT)) + """ + Flattens the input tensor into a 2D matrix. If input tensor has shape + (d_0, d_1, ... d_n) then the output will have shape + (d_0 X d_1 ... d_(axis-1), d_axis X d_(axis+1) ... X dn). + """ + + def __init__(self, input, axis=None): + super().__init__( + "Flatten", + 1, + [ + { + "at::kDouble", + "at::kLong", + "at::kByte", + "at::kInt", + "at::kHalf", + "at::kShort", + "at::kBool", + "at::kFloat", + "at::kBFloat16", + } + ], + input, + axis=ONNXAttr(axis, AttrType.INT), + ) + class Floor(ONNXOp): - """ - Floor takes one input data (Tensor) and produces one output data - (Tensor) where the floor is, y = floor(x), is applied to - the tensor elementwise. - """ + """ + Floor takes one input data (Tensor) and produces one output data + (Tensor) where the floor is, y = floor(x), is applied to + the tensor elementwise. + """ + + def __init__(self, X): + super().__init__("Floor", 1, [{"at::kDouble", "at::kBFloat16", "at::kHalf", "at::kFloat"}], X) - def __init__(self, X): - super().__init__('Floor', 1, - [{'at::kDouble', 'at::kBFloat16', 'at::kHalf', 'at::kFloat'}], - X) class Gather(ONNXOp): - """ - Given `data` tensor of rank r >= 1, and `indices` tensor of rank q, gather - entries of the axis dimension of `data` (by default outer-most one as axis=0) indexed by `indices`, and concatenates - them in an output tensor of rank q + (r - 1). - - axis = 0 : - - Let - k = indices[i_{0}, ..., i_{q-1}] - Then - output[i_{0}, ..., i_{q-1}, j_{0}, ..., j_{r-2}] = input[k , j_{0}, ..., j_{r-2}] - - ``` - data = [ - [1.0, 1.2], - [2.3, 3.4], - [4.5, 5.7], - ] - indices = [ - [0, 1], - [1, 2], - ] - output = [ - [ - [1.0, 1.2], - [2.3, 3.4], - ], - [ - [2.3, 3.4], - [4.5, 5.7], - ], - ] - ``` - axis = 1 : - - Let - k = indices[i_{0}, ..., i_{q-1}] - Then - output[i_{0}, ..., i_{q-1}, j_{0}, ..., j_{r-2}] = input[j_{0}, k, j_{1}, ..., j_{r-2}] - - ``` - data = [ - [1.0, 1.2, 1.9], - [2.3, 3.4, 3.9], - [4.5, 5.7, 5.9], - ] - indices = [ - [0, 2], - ] - axis = 1, - output = [ - [[1.0, 1.9]], - [[2.3, 3.9]], - [[4.5, 5.9]], - ] - ``` - """ - - def __init__(self, data, indices, - axis=None): - super().__init__('Gather', 1, - [{'at::kDouble', 'at::kLong', 'at::kByte', 'at::kInt', 'at::kHalf', 'at::kShort', 'at::kBool', 'at::kFloat', 'at::kBFloat16'}, {'at::kLong', 'at::kInt'}], - data,indices, - axis=ONNXAttr(axis, AttrType.INT)) + """ + Given `data` tensor of rank r >= 1, and `indices` tensor of rank q, gather + entries of the axis dimension of `data` (by default outer-most one as axis=0) indexed by `indices`, and concatenates + them in an output tensor of rank q + (r - 1). + + axis = 0 : + + Let + k = indices[i_{0}, ..., i_{q-1}] + Then + output[i_{0}, ..., i_{q-1}, j_{0}, ..., j_{r-2}] = input[k , j_{0}, ..., j_{r-2}] + + ``` + data = [ + [1.0, 1.2], + [2.3, 3.4], + [4.5, 5.7], + ] + indices = [ + [0, 1], + [1, 2], + ] + output = [ + [ + [1.0, 1.2], + [2.3, 3.4], + ], + [ + [2.3, 3.4], + [4.5, 5.7], + ], + ] + ``` + axis = 1 : + + Let + k = indices[i_{0}, ..., i_{q-1}] + Then + output[i_{0}, ..., i_{q-1}, j_{0}, ..., j_{r-2}] = input[j_{0}, k, j_{1}, ..., j_{r-2}] + + ``` + data = [ + [1.0, 1.2, 1.9], + [2.3, 3.4, 3.9], + [4.5, 5.7, 5.9], + ] + indices = [ + [0, 2], + ] + axis = 1, + output = [ + [[1.0, 1.9]], + [[2.3, 3.9]], + [[4.5, 5.9]], + ] + ``` + """ + + def __init__(self, data, indices, axis=None): + super().__init__( + "Gather", + 1, + [ + { + "at::kDouble", + "at::kLong", + "at::kByte", + "at::kInt", + "at::kHalf", + "at::kShort", + "at::kBool", + "at::kFloat", + "at::kBFloat16", + }, + {"at::kLong", "at::kInt"}, + ], + data, + indices, + axis=ONNXAttr(axis, AttrType.INT), + ) + class GatherElements(ONNXOp): - """ - GatherElements takes two inputs `data` and `indices` of the same rank r >= 1 - and an optional attribute `axis` that identifies an axis of `data` - (by default, the outer-most axis, that is axis 0). It is an indexing operation - that produces its output by indexing into the input data tensor at index - positions determined by elements of the `indices` tensor. - Its output shape is the same as the shape of `indices` and consists of one value - (gathered from the `data`) for each element in `indices`. - - For instance, in the 3-D case (r = 3), the output produced is determined - by the following equations: - ``` - out[i][j][k] = input[index[i][j][k]][j][k] if axis = 0, - out[i][j][k] = input[i][index[i][j][k]][k] if axis = 1, - out[i][j][k] = input[i][j][index[i][j][k]] if axis = 2, - ``` - - This operator is also the inverse of ScatterElements. It is similar to Torch's gather operation. - - Example 1: - ``` - data = [ - [1, 2], - [3, 4], - ] - indices = [ - [0, 0], - [1, 0], - ] - axis = 1 - output = [ - [ - [1, 1], - [4, 3], - ], - ] - ``` - Example 2: - ``` - data = [ - [1, 2, 3], - [4, 5, 6], - [7, 8, 9], - ] - indices = [ - [1, 2, 0], - [2, 0, 0], - ] - axis = 0 - output = [ - [ - [4, 8, 3], - [7, 2, 3], - ], - ] - ``` - """ - - def __init__(self, data, indices, - axis=None): - super().__init__('GatherElements', 1, - [{'at::kDouble', 'at::kLong', 'at::kByte', 'at::kInt', 'at::kHalf', 'at::kShort', 'at::kBool', 'at::kFloat', 'at::kBFloat16'}, {'at::kLong', 'at::kInt'}], - data,indices, - axis=ONNXAttr(axis, AttrType.INT)) + """ + GatherElements takes two inputs `data` and `indices` of the same rank r >= 1 + and an optional attribute `axis` that identifies an axis of `data` + (by default, the outer-most axis, that is axis 0). It is an indexing operation + that produces its output by indexing into the input data tensor at index + positions determined by elements of the `indices` tensor. + Its output shape is the same as the shape of `indices` and consists of one value + (gathered from the `data`) for each element in `indices`. + + For instance, in the 3-D case (r = 3), the output produced is determined + by the following equations: + ``` + out[i][j][k] = input[index[i][j][k]][j][k] if axis = 0, + out[i][j][k] = input[i][index[i][j][k]][k] if axis = 1, + out[i][j][k] = input[i][j][index[i][j][k]] if axis = 2, + ``` + + This operator is also the inverse of ScatterElements. It is similar to Torch's gather operation. + + Example 1: + ``` + data = [ + [1, 2], + [3, 4], + ] + indices = [ + [0, 0], + [1, 0], + ] + axis = 1 + output = [ + [ + [1, 1], + [4, 3], + ], + ] + ``` + Example 2: + ``` + data = [ + [1, 2, 3], + [4, 5, 6], + [7, 8, 9], + ] + indices = [ + [1, 2, 0], + [2, 0, 0], + ] + axis = 0 + output = [ + [ + [4, 8, 3], + [7, 2, 3], + ], + ] + ``` + """ + + def __init__(self, data, indices, axis=None): + super().__init__( + "GatherElements", + 1, + [ + { + "at::kDouble", + "at::kLong", + "at::kByte", + "at::kInt", + "at::kHalf", + "at::kShort", + "at::kBool", + "at::kFloat", + "at::kBFloat16", + }, + {"at::kLong", "at::kInt"}, + ], + data, + indices, + axis=ONNXAttr(axis, AttrType.INT), + ) + class GatherND(ONNXOp): - """ - Given `data` tensor of rank `r` >= 1, `indices` tensor of rank `q` >= 1, and `batch_dims` integer `b`, this operator gathers - slices of `data` into an output tensor of rank `q + r - indices_shape[-1] - 1 - b`. - - `indices` is an q-dimensional integer tensor, best thought of as a `(q-1)`-dimensional tensor of index-tuples into `data`, - where each element defines a slice of `data` - - `batch_dims` (denoted as `b`) is an integer indicating the number of batch dimensions, i.e the leading `b` number of dimensions of - `data` tensor and `indices` are representing the batches, and the gather starts from the `b+1` dimension. - - Some salient points about the inputs' rank and shape: - - 1) r >= 1 and q >= 1 are to be honored. There is no dependency condition to be met between ranks `r` and `q` - - 2) The first `b` dimensions of the shape of `indices` tensor and `data` tensor must be equal. - - 3) b < min(q, r) is to be honored. - - 4) The `indices_shape[-1]` should have a value between 1 (inclusive) and rank `r-b` (inclusive) - - 5) All values in `indices` are expected to be within bounds [-s, s-1] along axis of size `s` (i.e.) `-data_shape[i] <= indices[...,i] <= data_shape[i] - 1`. - It is an error if any of the index values are out of bounds. - - The output is computed as follows: - - The output tensor is obtained by mapping each index-tuple in the `indices` tensor to the corresponding slice of the input `data`. - - 1) If `indices_shape[-1] > r-b` => error condition - - 2) If `indices_shape[-1] == r-b`, since the rank of `indices` is `q`, `indices` can be thought of as `N` `(q-b-1)`-dimensional tensors - containing 1-D tensors of dimension `r-b`, where `N` is an integer equals to the product of 1 and all the elements in the batch dimensions - of the indices_shape. Let us think of each such `r-b` ranked tensor as `indices_slice`. Each *scalar value* corresponding to `data[0:b-1,indices_slice]` - is filled into the corresponding location of the `(q-b-1)`-dimensional tensor to form the `output` tensor (Example 1 below) - - 3) If `indices_shape[-1] < r-b`, since the rank of `indices` is `q`, `indices` can be thought of as `N` `(q-b-1)`-dimensional tensor - containing 1-D tensors of dimension `< r-b`. Let us think of each such tensors as `indices_slice`. Each *tensor slice* corresponding - to `data[0:b-1, indices_slice , :]` is filled into the corresponding location of the `(q-b-1)`-dimensional tensor - to form the `output` tensor (Examples 2, 3, 4 and 5 below) - - This operator is the inverse of `ScatterND`. - - `Example 1` - - batch_dims = 0 - - data = [[0,1],[2,3]] # data_shape = [2, 2] - - indices = [[0,0],[1,1]] # indices_shape = [2, 2] - - output = [0,3] # output_shape = [2] - - `Example 2` - - batch_dims = 0 - - data = [[0,1],[2,3]] # data_shape = [2, 2] - - indices = [[1],[0]] # indices_shape = [2, 1] - - output = [[2,3],[0,1]] # output_shape = [2, 2] - - `Example 3` - - batch_dims = 0 - - data = [[[0,1],[2,3]],[[4,5],[6,7]]] # data_shape = [2, 2, 2] - - indices = [[0,1],[1,0]] # indices_shape = [2, 2] - - output = [[2,3],[4,5]] # output_shape = [2, 2] - - `Example 4` - - batch_dims = 0 - - data = [[[0,1],[2,3]],[[4,5],[6,7]]] # data_shape = [2, 2, 2] - - indices = [[[0,1]],[[1,0]]] # indices_shape = [2, 1, 2] - - output = [[[2,3]],[[4,5]]] # output_shape = [2, 1, 2] - - `Example 5` - - batch_dims = 1 - - data = [[[0,1],[2,3]],[[4,5],[6,7]]] # data_shape = [2, 2, 2] - - indices = [[1],[0]] # indices_shape = [2, 1] - - output = [[2,3],[4,5]] # output_shape = [2, 2] - """ - - def __init__(self, data, indices, - batch_dims=None): - super().__init__('GatherND', 1, - [{'at::kDouble', 'at::kLong', 'at::kByte', 'at::kInt', 'at::kHalf', 'at::kShort', 'at::kBool', 'at::kFloat', 'at::kBFloat16'}, {'at::kLong'}], - data,indices, - batch_dims=ONNXAttr(batch_dims, AttrType.INT)) + """ + Given `data` tensor of rank `r` >= 1, `indices` tensor of rank `q` >= 1, and `batch_dims` integer `b`, this operator gathers + slices of `data` into an output tensor of rank `q + r - indices_shape[-1] - 1 - b`. + + `indices` is an q-dimensional integer tensor, best thought of as a `(q-1)`-dimensional tensor of index-tuples into `data`, + where each element defines a slice of `data` + + `batch_dims` (denoted as `b`) is an integer indicating the number of batch dimensions, i.e the leading `b` number of dimensions of + `data` tensor and `indices` are representing the batches, and the gather starts from the `b+1` dimension. + + Some salient points about the inputs' rank and shape: + + 1) r >= 1 and q >= 1 are to be honored. There is no dependency condition to be met between ranks `r` and `q` + + 2) The first `b` dimensions of the shape of `indices` tensor and `data` tensor must be equal. + + 3) b < min(q, r) is to be honored. + + 4) The `indices_shape[-1]` should have a value between 1 (inclusive) and rank `r-b` (inclusive) + + 5) All values in `indices` are expected to be within bounds [-s, s-1] along axis of size `s` (i.e.) `-data_shape[i] <= indices[...,i] <= data_shape[i] - 1`. + It is an error if any of the index values are out of bounds. + + The output is computed as follows: + + The output tensor is obtained by mapping each index-tuple in the `indices` tensor to the corresponding slice of the input `data`. + + 1) If `indices_shape[-1] > r-b` => error condition + + 2) If `indices_shape[-1] == r-b`, since the rank of `indices` is `q`, `indices` can be thought of as `N` `(q-b-1)`-dimensional tensors + containing 1-D tensors of dimension `r-b`, where `N` is an integer equals to the product of 1 and all the elements in the batch dimensions + of the indices_shape. Let us think of each such `r-b` ranked tensor as `indices_slice`. Each *scalar value* corresponding to `data[0:b-1,indices_slice]` + is filled into the corresponding location of the `(q-b-1)`-dimensional tensor to form the `output` tensor (Example 1 below) + + 3) If `indices_shape[-1] < r-b`, since the rank of `indices` is `q`, `indices` can be thought of as `N` `(q-b-1)`-dimensional tensor + containing 1-D tensors of dimension `< r-b`. Let us think of each such tensors as `indices_slice`. Each *tensor slice* corresponding + to `data[0:b-1, indices_slice , :]` is filled into the corresponding location of the `(q-b-1)`-dimensional tensor + to form the `output` tensor (Examples 2, 3, 4 and 5 below) + + This operator is the inverse of `ScatterND`. + + `Example 1` + + batch_dims = 0 + + data = [[0,1],[2,3]] # data_shape = [2, 2] + + indices = [[0,0],[1,1]] # indices_shape = [2, 2] + + output = [0,3] # output_shape = [2] + + `Example 2` + + batch_dims = 0 + + data = [[0,1],[2,3]] # data_shape = [2, 2] + + indices = [[1],[0]] # indices_shape = [2, 1] + + output = [[2,3],[0,1]] # output_shape = [2, 2] + + `Example 3` + + batch_dims = 0 + + data = [[[0,1],[2,3]],[[4,5],[6,7]]] # data_shape = [2, 2, 2] + + indices = [[0,1],[1,0]] # indices_shape = [2, 2] + + output = [[2,3],[4,5]] # output_shape = [2, 2] + + `Example 4` + + batch_dims = 0 + + data = [[[0,1],[2,3]],[[4,5],[6,7]]] # data_shape = [2, 2, 2] + + indices = [[[0,1]],[[1,0]]] # indices_shape = [2, 1, 2] + + output = [[[2,3]],[[4,5]]] # output_shape = [2, 1, 2] + + `Example 5` + + batch_dims = 1 + + data = [[[0,1],[2,3]],[[4,5],[6,7]]] # data_shape = [2, 2, 2] + + indices = [[1],[0]] # indices_shape = [2, 1] + + output = [[2,3],[4,5]] # output_shape = [2, 2] + """ + + def __init__(self, data, indices, batch_dims=None): + super().__init__( + "GatherND", + 1, + [ + { + "at::kDouble", + "at::kLong", + "at::kByte", + "at::kInt", + "at::kHalf", + "at::kShort", + "at::kBool", + "at::kFloat", + "at::kBFloat16", + }, + {"at::kLong"}, + ], + data, + indices, + batch_dims=ONNXAttr(batch_dims, AttrType.INT), + ) + class Gemm(ONNXOp): - """ - General Matrix multiplication: - https://en.wikipedia.org/wiki/Basic_Linear_Algebra_Subprograms#Level_3 - - A' = transpose(A) if transA else A - - B' = transpose(B) if transB else B - - Compute Y = alpha * A' * B' + beta * C, where input tensor A has shape (M, K) or (K, M), - input tensor B has shape (K, N) or (N, K), input tensor C is broadcastable to shape (M, N), - and output tensor Y has shape (M, N). A will be transposed before doing the - computation if attribute transA is non-zero, same for B and transB. - This operator supports **unidirectional broadcasting** (tensor C should be unidirectional broadcastable to tensor A * B); for more details please check [the doc](Broadcasting.md). - This operator has **optional** inputs/outputs. See [the doc](IR.md) for more details about the representation of optional arguments. An empty string may be used in the place of an actual argument's name to indicate a missing argument. Trailing optional arguments (those not followed by an argument that is present) may also be simply omitted. - """ - - def __init__(self, A, B, C, - alpha=None, - beta=None, - transA=None, - transB=None): - super().__init__('Gemm', 1, - [{'at::kDouble', 'at::kLong', 'at::kInt', 'at::kHalf', 'at::kFloat', 'at::kBFloat16'}, {'at::kDouble', 'at::kLong', 'at::kInt', 'at::kHalf', 'at::kFloat', 'at::kBFloat16'}, {'at::kDouble', 'at::kLong', 'at::kInt', 'at::kHalf', 'at::kFloat', 'at::kBFloat16'}], - A,B,C, - alpha=ONNXAttr(alpha, AttrType.FLOAT), - beta=ONNXAttr(beta, AttrType.FLOAT), - transA=ONNXAttr(transA, AttrType.INT), - transB=ONNXAttr(transB, AttrType.INT)) + """ + General Matrix multiplication: + https://en.wikipedia.org/wiki/Basic_Linear_Algebra_Subprograms#Level_3 + + A' = transpose(A) if transA else A + + B' = transpose(B) if transB else B + + Compute Y = alpha * A' * B' + beta * C, where input tensor A has shape (M, K) or (K, M), + input tensor B has shape (K, N) or (N, K), input tensor C is broadcastable to shape (M, N), + and output tensor Y has shape (M, N). A will be transposed before doing the + computation if attribute transA is non-zero, same for B and transB. + This operator supports **unidirectional broadcasting** (tensor C should be unidirectional broadcastable to tensor A * B); for more details please check [the doc](Broadcasting.md). + This operator has **optional** inputs/outputs. See [the doc](IR.md) for more details about the representation of optional arguments. An empty string may be used in the place of an actual argument's name to indicate a missing argument. Trailing optional arguments (those not followed by an argument that is present) may also be simply omitted. + """ + + def __init__(self, A, B, C, alpha=None, beta=None, transA=None, transB=None): + super().__init__( + "Gemm", + 1, + [ + {"at::kDouble", "at::kLong", "at::kInt", "at::kHalf", "at::kFloat", "at::kBFloat16"}, + {"at::kDouble", "at::kLong", "at::kInt", "at::kHalf", "at::kFloat", "at::kBFloat16"}, + {"at::kDouble", "at::kLong", "at::kInt", "at::kHalf", "at::kFloat", "at::kBFloat16"}, + ], + A, + B, + C, + alpha=ONNXAttr(alpha, AttrType.FLOAT), + beta=ONNXAttr(beta, AttrType.FLOAT), + transA=ONNXAttr(transA, AttrType.INT), + transB=ONNXAttr(transB, AttrType.INT), + ) + class GlobalAveragePool(ONNXOp): - """ - GlobalAveragePool consumes an input tensor X and applies average pooling across - the values in the same channel. This is equivalent to AveragePool with kernel size - equal to the spatial dimension of input tensor. - """ + """ + GlobalAveragePool consumes an input tensor X and applies average pooling across + the values in the same channel. This is equivalent to AveragePool with kernel size + equal to the spatial dimension of input tensor. + """ + + def __init__(self, X): + super().__init__("GlobalAveragePool", 1, [{"at::kDouble", "at::kHalf", "at::kFloat"}], X) - def __init__(self, X): - super().__init__('GlobalAveragePool', 1, - [{'at::kDouble', 'at::kHalf', 'at::kFloat'}], - X) class GlobalLpPool(ONNXOp): - """ - GlobalLpPool consumes an input tensor X and applies lp pool pooling across - the values in the same channel. This is equivalent to LpPool with kernel size - equal to the spatial dimension of input tensor. - """ - - def __init__(self, X, - p=None): - super().__init__('GlobalLpPool', 1, - [{'at::kDouble', 'at::kHalf', 'at::kFloat'}], - X, - p=ONNXAttr(p, AttrType.INT)) + """ + GlobalLpPool consumes an input tensor X and applies lp pool pooling across + the values in the same channel. This is equivalent to LpPool with kernel size + equal to the spatial dimension of input tensor. + """ + + def __init__(self, X, p=None): + super().__init__( + "GlobalLpPool", 1, [{"at::kDouble", "at::kHalf", "at::kFloat"}], X, p=ONNXAttr(p, AttrType.INT) + ) + class GlobalMaxPool(ONNXOp): - """ - GlobalMaxPool consumes an input tensor X and applies max pooling across - the values in the same channel. This is equivalent to MaxPool with kernel size - equal to the spatial dimension of input tensor. - """ + """ + GlobalMaxPool consumes an input tensor X and applies max pooling across + the values in the same channel. This is equivalent to MaxPool with kernel size + equal to the spatial dimension of input tensor. + """ + + def __init__(self, X): + super().__init__("GlobalMaxPool", 1, [{"at::kDouble", "at::kHalf", "at::kFloat"}], X) - def __init__(self, X): - super().__init__('GlobalMaxPool', 1, - [{'at::kDouble', 'at::kHalf', 'at::kFloat'}], - X) class Gradient(ONNXOp): - """ - Gradient operator computes the partial derivatives of a specific tensor w.r.t. - some other tensors. This operator is widely used in gradient-based training - algorithms. To illustrate its use, let's consider a computation graph, - - ``` - X -----. - | - v - W --> Conv --> H --> Gemm --> Y - ^ - | - Z - ``` - - , where W and Z are trainable tensors. Note that operators' attributes are - omitted for the sake of simplicity. Let dY/dW (dY/dZ) be the gradient of - Y with respect to W (Z). The user can compute gradient by inserting Gradient - operator to form another graph shown below. - - ``` - W --> Conv --> H --> Gemm --> Y - | ^ ^ - | | | - | X Z - | | | - | | .----------' - | | | (W/Z/X is the 1st/2nd/3rd input of Gradient as shown in - | | | "xs" followed by "zs") - | v v - '---> Gradient(xs=["W", "Z"], zs=["X"], y="Y") - | | - | '-----------------------------------> dY/dW (1st output of Gradient) - | - '---------------------------------------> dY/dZ (2nd output of Gradient) - ``` - - By definition, the tensor "y" is a function of independent variables in "xs" - and "zs". Since we only compute the gradient of "y" w.r.t. the differentiable - variables in "xs", this Gradient only outputs dY/dW and dY/dZ. Note that "H" - cannot appear in "xs" and "zs". The reason is that "H" can be determined by - tensors "W" and "X" and therefore "H" is not an independent variable. - - All outputs are optional. If needed, for example, user can assign an empty - string to the 1st output name of that Gradient to skip the generation of dY/dW. - Note that the concept of optional outputs can also be found in ONNX's RNN, GRU, - and LSTM. - - Gradient operator can compute derivative against intermediate tensors. For - example, the gradient of Y with respect to H can be done via - - ``` - W --> Conv --> H --> Gemm --> Y - ^ | ^ - | | | - X | Z - .-------' | - | .----------' - | | (H/Z is the 1st/2nd input of Gradient as shown in "xs") - v v - Gradient(xs=["H", "Z"], y="Y") - | | - | '-----------------------------------> dY/dH (1st output of Gradient) - | - '---------------------------------------> dY/dZ (2nd output of Gradient) - ``` - - It is possible to represent high-order differentiation using Gradient operators. - For example, given the following linear model: - - ``` - W --> Gemm --> Y --> Loss --> O - ^ ^ - | | - X L - ``` - - To compute the 2nd order derivative of O with respect to W (denoted by - d^2O/dW^2), one can do - - ``` - W --> Gemm --> Y --> Loss --> O - | ^ ^ - | | | - | X .------------L - | | | | - | | | v - +------+-+> Gradient(xs=["X", "W"], zs=["L"], y="O") ---> dO/dX (1st output of Gradient) - | | | | - | | | '---> dO/dW (2nd output of Gradient) - | v v - '---> Gradient(xs=["X", "W"], zs=["L"], y="dO/dW") ---> d(dO/dW)dX (1st output of - | Gradient) - | - | - '---> d^2O/dW^2 (2nd output of Gradient) - ``` - - The tensors named in attributes "xs", "zs", and "y" define the differentiated - computation graph, and the inputs to Gradient node define the values at - which the gradient is computed. We can feed different tensors to the identified - graph. For example, one can compute the gradient of Y with respect to H at - a specific value of H, H_1, by providing that value as an input to the Gradient - node. - - ``` - W --> Conv --> H --> Gemm --> Y - ^ ^ - | | - X Z - - Z_1 (2nd input of Gradient) - | - v - H_1 --> Gradient(xs=["H", "Z"], y="Y") ---> dY/dH when H = H_1 and Y = Y_1. - | - '------------------------------> dY/dZ (2nd output of Gradient) - ``` - - When the inputs of Gradient are the tensors named in "xs" and "zs", the - computation can be optimized. More specifically, intermediate variables in - forward pass can be reused if the gradient is computed via reverse-mode - auto-differentiation. - """ - - def __init__(self, Inputs, - xs=None, - y=None, - zs=None): - super().__init__('Gradient', 1, - [{'at::kDouble', 'at::kLong', 'at::kByte', 'at::kInt', 'at::kHalf', 'at::kShort', 'at::kBool', 'at::kFloat'}], - Inputs, - xs=ONNXAttr(xs, AttrType.STRINGS), - y=ONNXAttr(y, AttrType.STRING), - zs=ONNXAttr(zs, AttrType.STRINGS)) + """ + Gradient operator computes the partial derivatives of a specific tensor w.r.t. + some other tensors. This operator is widely used in gradient-based training + algorithms. To illustrate its use, let's consider a computation graph, + + ``` + X -----. + | + v + W --> Conv --> H --> Gemm --> Y + ^ + | + Z + ``` + + , where W and Z are trainable tensors. Note that operators' attributes are + omitted for the sake of simplicity. Let dY/dW (dY/dZ) be the gradient of + Y with respect to W (Z). The user can compute gradient by inserting Gradient + operator to form another graph shown below. + + ``` + W --> Conv --> H --> Gemm --> Y + | ^ ^ + | | | + | X Z + | | | + | | .----------' + | | | (W/Z/X is the 1st/2nd/3rd input of Gradient as shown in + | | | "xs" followed by "zs") + | v v + '---> Gradient(xs=["W", "Z"], zs=["X"], y="Y") + | | + | '-----------------------------------> dY/dW (1st output of Gradient) + | + '---------------------------------------> dY/dZ (2nd output of Gradient) + ``` + + By definition, the tensor "y" is a function of independent variables in "xs" + and "zs". Since we only compute the gradient of "y" w.r.t. the differentiable + variables in "xs", this Gradient only outputs dY/dW and dY/dZ. Note that "H" + cannot appear in "xs" and "zs". The reason is that "H" can be determined by + tensors "W" and "X" and therefore "H" is not an independent variable. + + All outputs are optional. If needed, for example, user can assign an empty + string to the 1st output name of that Gradient to skip the generation of dY/dW. + Note that the concept of optional outputs can also be found in ONNX's RNN, GRU, + and LSTM. + + Gradient operator can compute derivative against intermediate tensors. For + example, the gradient of Y with respect to H can be done via + + ``` + W --> Conv --> H --> Gemm --> Y + ^ | ^ + | | | + X | Z + .-------' | + | .----------' + | | (H/Z is the 1st/2nd input of Gradient as shown in "xs") + v v + Gradient(xs=["H", "Z"], y="Y") + | | + | '-----------------------------------> dY/dH (1st output of Gradient) + | + '---------------------------------------> dY/dZ (2nd output of Gradient) + ``` + + It is possible to represent high-order differentiation using Gradient operators. + For example, given the following linear model: + + ``` + W --> Gemm --> Y --> Loss --> O + ^ ^ + | | + X L + ``` + + To compute the 2nd order derivative of O with respect to W (denoted by + d^2O/dW^2), one can do + + ``` + W --> Gemm --> Y --> Loss --> O + | ^ ^ + | | | + | X .------------L + | | | | + | | | v + +------+-+> Gradient(xs=["X", "W"], zs=["L"], y="O") ---> dO/dX (1st output of Gradient) + | | | | + | | | '---> dO/dW (2nd output of Gradient) + | v v + '---> Gradient(xs=["X", "W"], zs=["L"], y="dO/dW") ---> d(dO/dW)dX (1st output of + | Gradient) + | + | + '---> d^2O/dW^2 (2nd output of Gradient) + ``` + + The tensors named in attributes "xs", "zs", and "y" define the differentiated + computation graph, and the inputs to Gradient node define the values at + which the gradient is computed. We can feed different tensors to the identified + graph. For example, one can compute the gradient of Y with respect to H at + a specific value of H, H_1, by providing that value as an input to the Gradient + node. + + ``` + W --> Conv --> H --> Gemm --> Y + ^ ^ + | | + X Z + + Z_1 (2nd input of Gradient) + | + v + H_1 --> Gradient(xs=["H", "Z"], y="Y") ---> dY/dH when H = H_1 and Y = Y_1. + | + '------------------------------> dY/dZ (2nd output of Gradient) + ``` + + When the inputs of Gradient are the tensors named in "xs" and "zs", the + computation can be optimized. More specifically, intermediate variables in + forward pass can be reused if the gradient is computed via reverse-mode + auto-differentiation. + """ + + def __init__(self, Inputs, xs=None, y=None, zs=None): + super().__init__( + "Gradient", + 1, + [ + { + "at::kDouble", + "at::kLong", + "at::kByte", + "at::kInt", + "at::kHalf", + "at::kShort", + "at::kBool", + "at::kFloat", + } + ], + Inputs, + xs=ONNXAttr(xs, AttrType.STRINGS), + y=ONNXAttr(y, AttrType.STRING), + zs=ONNXAttr(zs, AttrType.STRINGS), + ) + class Greater(ONNXOp): - """ - Returns the tensor resulted from performing the `greater` logical operation - elementwise on the input tensors `A` and `B` (with Numpy-style broadcasting support). - - This operator supports **multidirectional (i.e., Numpy-style) broadcasting**; for more details please check [the doc](Broadcasting.md). - """ - - def __init__(self, A, B): - super().__init__('Greater', 1, - [{'at::kDouble', 'at::kLong', 'at::kByte', 'at::kInt', 'at::kHalf', 'at::kShort', 'at::kFloat', 'at::kBFloat16'}, {'at::kDouble', 'at::kLong', 'at::kByte', 'at::kInt', 'at::kHalf', 'at::kShort', 'at::kFloat', 'at::kBFloat16'}], - A,B) + """ + Returns the tensor resulted from performing the `greater` logical operation + elementwise on the input tensors `A` and `B` (with Numpy-style broadcasting support). + + This operator supports **multidirectional (i.e., Numpy-style) broadcasting**; for more details please check [the doc](Broadcasting.md). + """ + + def __init__(self, A, B): + super().__init__( + "Greater", + 1, + [ + { + "at::kDouble", + "at::kLong", + "at::kByte", + "at::kInt", + "at::kHalf", + "at::kShort", + "at::kFloat", + "at::kBFloat16", + }, + { + "at::kDouble", + "at::kLong", + "at::kByte", + "at::kInt", + "at::kHalf", + "at::kShort", + "at::kFloat", + "at::kBFloat16", + }, + ], + A, + B, + ) + class GreaterOrEqual(ONNXOp): - """ - Returns the tensor resulted from performing the `greater_equal` logical operation - elementwise on the input tensors `A` and `B` (with Numpy-style broadcasting support). - - This operator supports **multidirectional (i.e., Numpy-style) broadcasting**; for more details please check [the doc](Broadcasting.md). - """ - - def __init__(self, A, B): - super().__init__('GreaterOrEqual', 1, - [{'at::kDouble', 'at::kLong', 'at::kByte', 'at::kInt', 'at::kHalf', 'at::kShort', 'at::kFloat'}, {'at::kDouble', 'at::kLong', 'at::kByte', 'at::kInt', 'at::kHalf', 'at::kShort', 'at::kFloat'}], - A,B) + """ + Returns the tensor resulted from performing the `greater_equal` logical operation + elementwise on the input tensors `A` and `B` (with Numpy-style broadcasting support). + + This operator supports **multidirectional (i.e., Numpy-style) broadcasting**; for more details please check [the doc](Broadcasting.md). + """ + + def __init__(self, A, B): + super().__init__( + "GreaterOrEqual", + 1, + [ + {"at::kDouble", "at::kLong", "at::kByte", "at::kInt", "at::kHalf", "at::kShort", "at::kFloat"}, + {"at::kDouble", "at::kLong", "at::kByte", "at::kInt", "at::kHalf", "at::kShort", "at::kFloat"}, + ], + A, + B, + ) + class GRU(ONNXOp): - """ - Computes an one-layer GRU. This operator is usually supported via some custom - implementation such as CuDNN. - - Notations: - - `X` - input tensor - - `z` - update gate - - `r` - reset gate - - `h` - hidden gate - - `t` - time step (t-1 means previous time step) - - `W[zrh]` - W parameter weight matrix for update, reset, and hidden gates - - `R[zrh]` - R recurrence weight matrix for update, reset, and hidden gates - - `Wb[zrh]` - W bias vectors for update, reset, and hidden gates - - `Rb[zrh]` - R bias vectors for update, reset, and hidden gates - - `WB[zrh]` - W parameter weight matrix for backward update, reset, and hidden gates - - `RB[zrh]` - R recurrence weight matrix for backward update, reset, and hidden gates - - `WBb[zrh]` - W bias vectors for backward update, reset, and hidden gates - - `RBb[zrh]` - R bias vectors for backward update, reset, and hidden gates - - `H` - Hidden state - - `num_directions` - 2 if direction == bidirectional else 1 - - Activation functions: - - Relu(x) - max(0, x) - - Tanh(x) - (1 - e^{-2x})/(1 + e^{-2x}) - - Sigmoid(x) - 1/(1 + e^{-x}) - - (NOTE: Below are optional) - - Affine(x) - alpha*x + beta - - LeakyRelu(x) - x if x >= 0 else alpha * x - - ThresholdedRelu(x) - x if x >= alpha else 0 - - ScaledTanh(x) - alpha*Tanh(beta*x) - - HardSigmoid(x) - min(max(alpha*x + beta, 0), 1) - - Elu(x) - x if x >= 0 else alpha*(e^x - 1) - - Softsign(x) - x/(1 + |x|) - - Softplus(x) - log(1 + e^x) - - Equations (Default: f=Sigmoid, g=Tanh): - - - zt = f(Xt*(Wz^T) + Ht-1*(Rz^T) + Wbz + Rbz) - - - rt = f(Xt*(Wr^T) + Ht-1*(Rr^T) + Wbr + Rbr) - - - ht = g(Xt*(Wh^T) + (rt (.) Ht-1)*(Rh^T) + Rbh + Wbh) # default, when linear_before_reset = 0 - - - ht = g(Xt*(Wh^T) + (rt (.) (Ht-1*(Rh^T) + Rbh)) + Wbh) # when linear_before_reset != 0 - - - Ht = (1 - zt) (.) ht + zt (.) Ht-1 - This operator has **optional** inputs/outputs. See [the doc](IR.md) for more details about the representation of optional arguments. An empty string may be used in the place of an actual argument's name to indicate a missing argument. Trailing optional arguments (those not followed by an argument that is present) may also be simply omitted. - """ - - def __init__(self, X, W, R, B, sequence_lens, initial_h, - activation_alpha=None, - activation_beta=None, - activations=None, - clip=None, - direction=None, - hidden_size=None, - layout=None, - linear_before_reset=None): - super().__init__('GRU', 2, - [{'at::kDouble', 'at::kHalf', 'at::kFloat'}, {'at::kDouble', 'at::kHalf', 'at::kFloat'}, {'at::kDouble', 'at::kHalf', 'at::kFloat'}, {'at::kDouble', 'at::kHalf', 'at::kFloat'}, {'at::kInt'}, {'at::kDouble', 'at::kHalf', 'at::kFloat'}], - X,W,R,B,sequence_lens,initial_h, - activation_alpha=ONNXAttr(activation_alpha, AttrType.FLOATS), - activation_beta=ONNXAttr(activation_beta, AttrType.FLOATS), - activations=ONNXAttr(activations, AttrType.STRINGS), - clip=ONNXAttr(clip, AttrType.FLOAT), - direction=ONNXAttr(direction, AttrType.STRING), - hidden_size=ONNXAttr(hidden_size, AttrType.INT), - layout=ONNXAttr(layout, AttrType.INT), - linear_before_reset=ONNXAttr(linear_before_reset, AttrType.INT)) + """ + Computes an one-layer GRU. This operator is usually supported via some custom + implementation such as CuDNN. + + Notations: + + `X` - input tensor + + `z` - update gate + + `r` - reset gate + + `h` - hidden gate + + `t` - time step (t-1 means previous time step) + + `W[zrh]` - W parameter weight matrix for update, reset, and hidden gates + + `R[zrh]` - R recurrence weight matrix for update, reset, and hidden gates + + `Wb[zrh]` - W bias vectors for update, reset, and hidden gates + + `Rb[zrh]` - R bias vectors for update, reset, and hidden gates + + `WB[zrh]` - W parameter weight matrix for backward update, reset, and hidden gates + + `RB[zrh]` - R recurrence weight matrix for backward update, reset, and hidden gates + + `WBb[zrh]` - W bias vectors for backward update, reset, and hidden gates + + `RBb[zrh]` - R bias vectors for backward update, reset, and hidden gates + + `H` - Hidden state + + `num_directions` - 2 if direction == bidirectional else 1 + + Activation functions: + + Relu(x) - max(0, x) + + Tanh(x) - (1 - e^{-2x})/(1 + e^{-2x}) + + Sigmoid(x) - 1/(1 + e^{-x}) + + (NOTE: Below are optional) + + Affine(x) - alpha*x + beta + + LeakyRelu(x) - x if x >= 0 else alpha * x + + ThresholdedRelu(x) - x if x >= alpha else 0 + + ScaledTanh(x) - alpha*Tanh(beta*x) + + HardSigmoid(x) - min(max(alpha*x + beta, 0), 1) + + Elu(x) - x if x >= 0 else alpha*(e^x - 1) + + Softsign(x) - x/(1 + |x|) + + Softplus(x) - log(1 + e^x) + + Equations (Default: f=Sigmoid, g=Tanh): + + - zt = f(Xt*(Wz^T) + Ht-1*(Rz^T) + Wbz + Rbz) + + - rt = f(Xt*(Wr^T) + Ht-1*(Rr^T) + Wbr + Rbr) + + - ht = g(Xt*(Wh^T) + (rt (.) Ht-1)*(Rh^T) + Rbh + Wbh) # default, when linear_before_reset = 0 + + - ht = g(Xt*(Wh^T) + (rt (.) (Ht-1*(Rh^T) + Rbh)) + Wbh) # when linear_before_reset != 0 + + - Ht = (1 - zt) (.) ht + zt (.) Ht-1 + This operator has **optional** inputs/outputs. See [the doc](IR.md) for more details about the representation of optional arguments. An empty string may be used in the place of an actual argument's name to indicate a missing argument. Trailing optional arguments (those not followed by an argument that is present) may also be simply omitted. + """ + + def __init__( + self, + X, + W, + R, + B, + sequence_lens, + initial_h, + activation_alpha=None, + activation_beta=None, + activations=None, + clip=None, + direction=None, + hidden_size=None, + layout=None, + linear_before_reset=None, + ): + super().__init__( + "GRU", + 2, + [ + {"at::kDouble", "at::kHalf", "at::kFloat"}, + {"at::kDouble", "at::kHalf", "at::kFloat"}, + {"at::kDouble", "at::kHalf", "at::kFloat"}, + {"at::kDouble", "at::kHalf", "at::kFloat"}, + {"at::kInt"}, + {"at::kDouble", "at::kHalf", "at::kFloat"}, + ], + X, + W, + R, + B, + sequence_lens, + initial_h, + activation_alpha=ONNXAttr(activation_alpha, AttrType.FLOATS), + activation_beta=ONNXAttr(activation_beta, AttrType.FLOATS), + activations=ONNXAttr(activations, AttrType.STRINGS), + clip=ONNXAttr(clip, AttrType.FLOAT), + direction=ONNXAttr(direction, AttrType.STRING), + hidden_size=ONNXAttr(hidden_size, AttrType.INT), + layout=ONNXAttr(layout, AttrType.INT), + linear_before_reset=ONNXAttr(linear_before_reset, AttrType.INT), + ) + class Hardmax(ONNXOp): - """ - The operator computes the hardmax values for the given input: - - Hardmax(element in input, axis) = 1 if the element is the first maximum value along the specified axis, 0 otherwise - - The input does not need to explicitly be a 2D vector. The "axis" attribute - indicates the dimension along which Hardmax will be performed. - The output tensor has the same shape - and contains the Hardmax values of the corresponding input. - """ - - def __init__(self, input, - axis=None): - super().__init__('Hardmax', 1, - [{'at::kDouble', 'at::kBFloat16', 'at::kHalf', 'at::kFloat'}], - input, - axis=ONNXAttr(axis, AttrType.INT)) + """ + The operator computes the hardmax values for the given input: + + Hardmax(element in input, axis) = 1 if the element is the first maximum value along the specified axis, 0 otherwise + + The input does not need to explicitly be a 2D vector. The "axis" attribute + indicates the dimension along which Hardmax will be performed. + The output tensor has the same shape + and contains the Hardmax values of the corresponding input. + """ + + def __init__(self, input, axis=None): + super().__init__( + "Hardmax", + 1, + [{"at::kDouble", "at::kBFloat16", "at::kHalf", "at::kFloat"}], + input, + axis=ONNXAttr(axis, AttrType.INT), + ) + class HardSigmoid(ONNXOp): - """ - HardSigmoid takes one input data (Tensor) and produces one output data - (Tensor) where the HardSigmoid function, y = max(0, min(1, alpha * x + beta)), - is applied to the tensor elementwise. - """ - - def __init__(self, X, - alpha=None, - beta=None): - super().__init__('HardSigmoid', 1, - [{'at::kDouble', 'at::kHalf', 'at::kFloat'}], - X, - alpha=ONNXAttr(alpha, AttrType.FLOAT), - beta=ONNXAttr(beta, AttrType.FLOAT)) + """ + HardSigmoid takes one input data (Tensor) and produces one output data + (Tensor) where the HardSigmoid function, y = max(0, min(1, alpha * x + beta)), + is applied to the tensor elementwise. + """ + + def __init__(self, X, alpha=None, beta=None): + super().__init__( + "HardSigmoid", + 1, + [{"at::kDouble", "at::kHalf", "at::kFloat"}], + X, + alpha=ONNXAttr(alpha, AttrType.FLOAT), + beta=ONNXAttr(beta, AttrType.FLOAT), + ) + class HardSwish(ONNXOp): - """ - HardSwish takes one input data (Tensor) and produces one output data (Tensor) where - the HardSwish function, y = x * max(0, min(1, alpha * x + beta)) = x * HardSigmoid(x), - where alpha = 1/6 and beta = 0.5, is applied to the tensor elementwise. - """ + """ + HardSwish takes one input data (Tensor) and produces one output data (Tensor) where + the HardSwish function, y = x * max(0, min(1, alpha * x + beta)) = x * HardSigmoid(x), + where alpha = 1/6 and beta = 0.5, is applied to the tensor elementwise. + """ + + def __init__(self, X): + super().__init__("HardSwish", 1, [{"at::kDouble", "at::kHalf", "at::kFloat"}], X) - def __init__(self, X): - super().__init__('HardSwish', 1, - [{'at::kDouble', 'at::kHalf', 'at::kFloat'}], - X) class Identity(ONNXOp): - """ - Identity operator - """ + """ + Identity operator + """ + + def __init__(self, input): + super().__init__( + "Identity", + 1, + [ + { + "at::kDouble", + "at::kLong", + "at::kByte", + "at::kInt", + "at::kHalf", + "at::kShort", + "at::kBool", + "at::kFloat", + "at::kBFloat16", + } + ], + input, + ) - def __init__(self, input): - super().__init__('Identity', 1, - [{'at::kDouble', 'at::kLong', 'at::kByte', 'at::kInt', 'at::kHalf', 'at::kShort', 'at::kBool', 'at::kFloat', 'at::kBFloat16'}], - input) class If(ONNXOp): - """ - If conditional - """ - - def __init__(self, cond, - else_branch=None, - then_branch=None): - super().__init__('If', 1, - [{'at::kBool'}], - cond, - else_branch=ONNXAttr(else_branch, AttrType.GRAPH), - then_branch=ONNXAttr(then_branch, AttrType.GRAPH)) + """ + If conditional + """ + + def __init__(self, cond, else_branch=None, then_branch=None): + super().__init__( + "If", + 1, + [{"at::kBool"}], + cond, + else_branch=ONNXAttr(else_branch, AttrType.GRAPH), + then_branch=ONNXAttr(then_branch, AttrType.GRAPH), + ) + class Imputer(ONNXOp): - """ - Replaces inputs that equal one value with another, leaving all other elements alone.
- This operator is typically used to replace missing values in situations where they have a canonical - representation, such as -1, 0, NaN, or some extreme value.
- One and only one of imputed_value_floats or imputed_value_int64s should be defined -- floats if the input tensor - holds floats, integers if the input tensor holds integers. The imputed values must all fit within the - width of the tensor element type. One and only one of the replaced_value_float or replaced_value_int64 should be defined, - which one depends on whether floats or integers are being processed.
- The imputed_value attribute length can be 1 element, or it can have one element per input feature.
In other words, if the input tensor has the shape [*,F], then the length of the attribute array may be 1 or F. If it is 1, then it is broadcast along the last dimension and applied to each feature. - """ - - def __init__(self, X, - imputed_value_floats=None, - imputed_value_int64s=None, - replaced_value_float=None, - replaced_value_int64=None): - super().__init__('Imputer', 1, - [{'at::kDouble', 'at::kLong', 'at::kInt', 'at::kFloat'}], - X, - imputed_value_floats=ONNXAttr(imputed_value_floats, AttrType.FLOATS), - imputed_value_int64s=ONNXAttr(imputed_value_int64s, AttrType.INTS), - replaced_value_float=ONNXAttr(replaced_value_float, AttrType.FLOAT), - replaced_value_int64=ONNXAttr(replaced_value_int64, AttrType.INT)) + """ + Replaces inputs that equal one value with another, leaving all other elements alone.
+ This operator is typically used to replace missing values in situations where they have a canonical + representation, such as -1, 0, NaN, or some extreme value.
+ One and only one of imputed_value_floats or imputed_value_int64s should be defined -- floats if the input tensor + holds floats, integers if the input tensor holds integers. The imputed values must all fit within the + width of the tensor element type. One and only one of the replaced_value_float or replaced_value_int64 should be defined, + which one depends on whether floats or integers are being processed.
+ The imputed_value attribute length can be 1 element, or it can have one element per input feature.
In other words, if the input tensor has the shape [*,F], then the length of the attribute array may be 1 or F. If it is 1, then it is broadcast along the last dimension and applied to each feature. + """ + + def __init__( + self, + X, + imputed_value_floats=None, + imputed_value_int64s=None, + replaced_value_float=None, + replaced_value_int64=None, + ): + super().__init__( + "Imputer", + 1, + [{"at::kDouble", "at::kLong", "at::kInt", "at::kFloat"}], + X, + imputed_value_floats=ONNXAttr(imputed_value_floats, AttrType.FLOATS), + imputed_value_int64s=ONNXAttr(imputed_value_int64s, AttrType.INTS), + replaced_value_float=ONNXAttr(replaced_value_float, AttrType.FLOAT), + replaced_value_int64=ONNXAttr(replaced_value_int64, AttrType.INT), + ) + class InstanceNormalization(ONNXOp): - """ - Carries out instance normalization as described in the paper - https://arxiv.org/abs/1607.08022. - - y = scale * (x - mean) / sqrt(variance + epsilon) + B, - where mean and variance are computed per instance per channel. - """ - - def __init__(self, input, scale, B, - epsilon=None): - super().__init__('InstanceNormalization', 1, - [{'at::kDouble', 'at::kHalf', 'at::kFloat'}, {'at::kDouble', 'at::kHalf', 'at::kFloat'}, {'at::kDouble', 'at::kHalf', 'at::kFloat'}], - input,scale,B, - epsilon=ONNXAttr(epsilon, AttrType.FLOAT)) + """ + Carries out instance normalization as described in the paper + https://arxiv.org/abs/1607.08022. + + y = scale * (x - mean) / sqrt(variance + epsilon) + B, + where mean and variance are computed per instance per channel. + """ + + def __init__(self, input, scale, B, epsilon=None): + super().__init__( + "InstanceNormalization", + 1, + [ + {"at::kDouble", "at::kHalf", "at::kFloat"}, + {"at::kDouble", "at::kHalf", "at::kFloat"}, + {"at::kDouble", "at::kHalf", "at::kFloat"}, + ], + input, + scale, + B, + epsilon=ONNXAttr(epsilon, AttrType.FLOAT), + ) + class IsInf(ONNXOp): - """ - Map infinity to true and other values to false. - """ - - def __init__(self, X, - detect_negative=None, - detect_positive=None): - super().__init__('IsInf', 1, - [{'at::kDouble', 'at::kFloat'}], - X, - detect_negative=ONNXAttr(detect_negative, AttrType.INT), - detect_positive=ONNXAttr(detect_positive, AttrType.INT)) + """ + Map infinity to true and other values to false. + """ + + def __init__(self, X, detect_negative=None, detect_positive=None): + super().__init__( + "IsInf", + 1, + [{"at::kDouble", "at::kFloat"}], + X, + detect_negative=ONNXAttr(detect_negative, AttrType.INT), + detect_positive=ONNXAttr(detect_positive, AttrType.INT), + ) + class IsNaN(ONNXOp): - """ - Returns which elements of the input are NaN. - """ + """ + Returns which elements of the input are NaN. + """ + + def __init__(self, X): + super().__init__("IsNaN", 1, [{"at::kDouble", "at::kBFloat16", "at::kHalf", "at::kFloat"}], X) - def __init__(self, X): - super().__init__('IsNaN', 1, - [{'at::kDouble', 'at::kBFloat16', 'at::kHalf', 'at::kFloat'}], - X) class LabelEncoder(ONNXOp): - """ - Maps each element in the input tensor to another value.
- The mapping is determined by the two parallel attributes, 'keys_*' and - 'values_*' attribute. The i-th value in the specified 'keys_*' attribute - would be mapped to the i-th value in the specified 'values_*' attribute. It - implies that input's element type and the element type of the specified - 'keys_*' should be identical while the output type is identical to the - specified 'values_*' attribute. If an input element can not be found in the - specified 'keys_*' attribute, the 'default_*' that matches the specified - 'values_*' attribute may be used as its output value.
- Let's consider an example which maps a string tensor to an integer tensor. - Assume and 'keys_strings' is ["Amy", "Sally"], 'values_int64s' is [5, 6], - and 'default_int64' is '-1'. The input ["Dori", "Amy", "Amy", "Sally", - "Sally"] would be mapped to [-1, 5, 5, 6, 6].
- Since this operator is an one-to-one mapping, its input and output shapes - are the same. Notice that only one of 'keys_*'/'values_*' can be set.
- For key look-up, bit-wise comparison is used so even a float NaN can be - mapped to a value in 'values_*' attribute.
- """ - - def __init__(self, X, - default_float=None, - default_int64=None, - default_string=None, - keys_floats=None, - keys_int64s=None, - keys_strings=None, - values_floats=None, - values_int64s=None, - values_strings=None): - super().__init__('LabelEncoder', 1, - [{'at::kLong', 'at::kFloat'}], - X, - default_float=ONNXAttr(default_float, AttrType.FLOAT), - default_int64=ONNXAttr(default_int64, AttrType.INT), - default_string=ONNXAttr(default_string, AttrType.STRING), - keys_floats=ONNXAttr(keys_floats, AttrType.FLOATS), - keys_int64s=ONNXAttr(keys_int64s, AttrType.INTS), - keys_strings=ONNXAttr(keys_strings, AttrType.STRINGS), - values_floats=ONNXAttr(values_floats, AttrType.FLOATS), - values_int64s=ONNXAttr(values_int64s, AttrType.INTS), - values_strings=ONNXAttr(values_strings, AttrType.STRINGS)) + """ + Maps each element in the input tensor to another value.
+ The mapping is determined by the two parallel attributes, 'keys_*' and + 'values_*' attribute. The i-th value in the specified 'keys_*' attribute + would be mapped to the i-th value in the specified 'values_*' attribute. It + implies that input's element type and the element type of the specified + 'keys_*' should be identical while the output type is identical to the + specified 'values_*' attribute. If an input element can not be found in the + specified 'keys_*' attribute, the 'default_*' that matches the specified + 'values_*' attribute may be used as its output value.
+ Let's consider an example which maps a string tensor to an integer tensor. + Assume and 'keys_strings' is ["Amy", "Sally"], 'values_int64s' is [5, 6], + and 'default_int64' is '-1'. The input ["Dori", "Amy", "Amy", "Sally", + "Sally"] would be mapped to [-1, 5, 5, 6, 6].
+ Since this operator is an one-to-one mapping, its input and output shapes + are the same. Notice that only one of 'keys_*'/'values_*' can be set.
+ For key look-up, bit-wise comparison is used so even a float NaN can be + mapped to a value in 'values_*' attribute.
+ """ + + def __init__( + self, + X, + default_float=None, + default_int64=None, + default_string=None, + keys_floats=None, + keys_int64s=None, + keys_strings=None, + values_floats=None, + values_int64s=None, + values_strings=None, + ): + super().__init__( + "LabelEncoder", + 1, + [{"at::kLong", "at::kFloat"}], + X, + default_float=ONNXAttr(default_float, AttrType.FLOAT), + default_int64=ONNXAttr(default_int64, AttrType.INT), + default_string=ONNXAttr(default_string, AttrType.STRING), + keys_floats=ONNXAttr(keys_floats, AttrType.FLOATS), + keys_int64s=ONNXAttr(keys_int64s, AttrType.INTS), + keys_strings=ONNXAttr(keys_strings, AttrType.STRINGS), + values_floats=ONNXAttr(values_floats, AttrType.FLOATS), + values_int64s=ONNXAttr(values_int64s, AttrType.INTS), + values_strings=ONNXAttr(values_strings, AttrType.STRINGS), + ) + class LeakyRelu(ONNXOp): - """ - LeakyRelu takes input data (Tensor) and an argument alpha, and produces one - output data (Tensor) where the function `f(x) = alpha * x for x < 0`, - `f(x) = x for x >= 0`, is applied to the data tensor elementwise. - """ - - def __init__(self, X, - alpha=None): - super().__init__('LeakyRelu', 1, - [{'at::kDouble', 'at::kHalf', 'at::kFloat'}], - X, - alpha=ONNXAttr(alpha, AttrType.FLOAT)) + """ + LeakyRelu takes input data (Tensor) and an argument alpha, and produces one + output data (Tensor) where the function `f(x) = alpha * x for x < 0`, + `f(x) = x for x >= 0`, is applied to the data tensor elementwise. + """ + + def __init__(self, X, alpha=None): + super().__init__( + "LeakyRelu", 1, [{"at::kDouble", "at::kHalf", "at::kFloat"}], X, alpha=ONNXAttr(alpha, AttrType.FLOAT) + ) + class Less(ONNXOp): - """ - Returns the tensor resulted from performing the `less` logical operation - elementwise on the input tensors `A` and `B` (with Numpy-style broadcasting support). - - This operator supports **multidirectional (i.e., Numpy-style) broadcasting**; for more details please check [the doc](Broadcasting.md). - """ - - def __init__(self, A, B): - super().__init__('Less', 1, - [{'at::kDouble', 'at::kLong', 'at::kByte', 'at::kInt', 'at::kHalf', 'at::kShort', 'at::kFloat', 'at::kBFloat16'}, {'at::kDouble', 'at::kLong', 'at::kByte', 'at::kInt', 'at::kHalf', 'at::kShort', 'at::kFloat', 'at::kBFloat16'}], - A,B) + """ + Returns the tensor resulted from performing the `less` logical operation + elementwise on the input tensors `A` and `B` (with Numpy-style broadcasting support). + + This operator supports **multidirectional (i.e., Numpy-style) broadcasting**; for more details please check [the doc](Broadcasting.md). + """ + + def __init__(self, A, B): + super().__init__( + "Less", + 1, + [ + { + "at::kDouble", + "at::kLong", + "at::kByte", + "at::kInt", + "at::kHalf", + "at::kShort", + "at::kFloat", + "at::kBFloat16", + }, + { + "at::kDouble", + "at::kLong", + "at::kByte", + "at::kInt", + "at::kHalf", + "at::kShort", + "at::kFloat", + "at::kBFloat16", + }, + ], + A, + B, + ) + class LessOrEqual(ONNXOp): - """ - Returns the tensor resulted from performing the `less_equal` logical operation - elementwise on the input tensors `A` and `B` (with Numpy-style broadcasting support). - - This operator supports **multidirectional (i.e., Numpy-style) broadcasting**; for more details please check [the doc](Broadcasting.md). - """ - - def __init__(self, A, B): - super().__init__('LessOrEqual', 1, - [{'at::kDouble', 'at::kLong', 'at::kByte', 'at::kInt', 'at::kHalf', 'at::kShort', 'at::kFloat'}, {'at::kDouble', 'at::kLong', 'at::kByte', 'at::kInt', 'at::kHalf', 'at::kShort', 'at::kFloat'}], - A,B) + """ + Returns the tensor resulted from performing the `less_equal` logical operation + elementwise on the input tensors `A` and `B` (with Numpy-style broadcasting support). + + This operator supports **multidirectional (i.e., Numpy-style) broadcasting**; for more details please check [the doc](Broadcasting.md). + """ + + def __init__(self, A, B): + super().__init__( + "LessOrEqual", + 1, + [ + {"at::kDouble", "at::kLong", "at::kByte", "at::kInt", "at::kHalf", "at::kShort", "at::kFloat"}, + {"at::kDouble", "at::kLong", "at::kByte", "at::kInt", "at::kHalf", "at::kShort", "at::kFloat"}, + ], + A, + B, + ) + class LinearClassifier(ONNXOp): - """ - Linear classifier - """ - - def __init__(self, X, - classlabels_ints=None, - classlabels_strings=None, - coefficients=None, - intercepts=None, - multi_class=None, - post_transform=None): - super().__init__('LinearClassifier', 2, - [{'at::kDouble', 'at::kLong', 'at::kInt', 'at::kFloat'}], - X, - classlabels_ints=ONNXAttr(classlabels_ints, AttrType.INTS), - classlabels_strings=ONNXAttr(classlabels_strings, AttrType.STRINGS), - coefficients=ONNXAttr(coefficients, AttrType.FLOATS), - intercepts=ONNXAttr(intercepts, AttrType.FLOATS), - multi_class=ONNXAttr(multi_class, AttrType.INT), - post_transform=ONNXAttr(post_transform, AttrType.STRING)) + """ + Linear classifier + """ + + def __init__( + self, + X, + classlabels_ints=None, + classlabels_strings=None, + coefficients=None, + intercepts=None, + multi_class=None, + post_transform=None, + ): + super().__init__( + "LinearClassifier", + 2, + [{"at::kDouble", "at::kLong", "at::kInt", "at::kFloat"}], + X, + classlabels_ints=ONNXAttr(classlabels_ints, AttrType.INTS), + classlabels_strings=ONNXAttr(classlabels_strings, AttrType.STRINGS), + coefficients=ONNXAttr(coefficients, AttrType.FLOATS), + intercepts=ONNXAttr(intercepts, AttrType.FLOATS), + multi_class=ONNXAttr(multi_class, AttrType.INT), + post_transform=ONNXAttr(post_transform, AttrType.STRING), + ) + class LinearRegressor(ONNXOp): - """ - Generalized linear regression evaluation.
- If targets is set to 1 (default) then univariate regression is performed.
- If targets is set to M then M sets of coefficients must be passed in as a sequence - and M results will be output for each input n in N.
- The coefficients array is of length n, and the coefficients for each target are contiguous. - Intercepts are optional but if provided must match the number of targets. - """ - - def __init__(self, X, - coefficients=None, - intercepts=None, - post_transform=None, - targets=None): - super().__init__('LinearRegressor', 1, - [{'at::kDouble', 'at::kLong', 'at::kInt', 'at::kFloat'}], - X, - coefficients=ONNXAttr(coefficients, AttrType.FLOATS), - intercepts=ONNXAttr(intercepts, AttrType.FLOATS), - post_transform=ONNXAttr(post_transform, AttrType.STRING), - targets=ONNXAttr(targets, AttrType.INT)) + """ + Generalized linear regression evaluation.
+ If targets is set to 1 (default) then univariate regression is performed.
+ If targets is set to M then M sets of coefficients must be passed in as a sequence + and M results will be output for each input n in N.
+ The coefficients array is of length n, and the coefficients for each target are contiguous. + Intercepts are optional but if provided must match the number of targets. + """ + + def __init__(self, X, coefficients=None, intercepts=None, post_transform=None, targets=None): + super().__init__( + "LinearRegressor", + 1, + [{"at::kDouble", "at::kLong", "at::kInt", "at::kFloat"}], + X, + coefficients=ONNXAttr(coefficients, AttrType.FLOATS), + intercepts=ONNXAttr(intercepts, AttrType.FLOATS), + post_transform=ONNXAttr(post_transform, AttrType.STRING), + targets=ONNXAttr(targets, AttrType.INT), + ) + class Log(ONNXOp): - """ - Calculates the natural log of the given input tensor, element-wise. - """ + """ + Calculates the natural log of the given input tensor, element-wise. + """ + + def __init__(self, input): + super().__init__("Log", 1, [{"at::kDouble", "at::kBFloat16", "at::kHalf", "at::kFloat"}], input) - def __init__(self, input): - super().__init__('Log', 1, - [{'at::kDouble', 'at::kBFloat16', 'at::kHalf', 'at::kFloat'}], - input) class LogSoftmax(ONNXOp): - """ - The operator computes the log of softmax values for the given input: - - LogSoftmax(input, axis) = Log(Softmax(input, axis=axis)) - - The input does not need to explicitly be a 2D vector. The "axis" attribute - indicates the dimension along which LogSoftmax will be performed. - The output tensor has the same shape - and contains the LogSoftmax values of the corresponding input. - """ - - def __init__(self, input, - axis=None): - super().__init__('LogSoftmax', 1, - [{'at::kDouble', 'at::kBFloat16', 'at::kHalf', 'at::kFloat'}], - input, - axis=ONNXAttr(axis, AttrType.INT)) + """ + The operator computes the log of softmax values for the given input: + + LogSoftmax(input, axis) = Log(Softmax(input, axis=axis)) + + The input does not need to explicitly be a 2D vector. The "axis" attribute + indicates the dimension along which LogSoftmax will be performed. + The output tensor has the same shape + and contains the LogSoftmax values of the corresponding input. + """ + + def __init__(self, input, axis=None): + super().__init__( + "LogSoftmax", + 1, + [{"at::kDouble", "at::kBFloat16", "at::kHalf", "at::kFloat"}], + input, + axis=ONNXAttr(axis, AttrType.INT), + ) + class Loop(ONNXOp): - """ - Generic Looping construct. This loop has multiple termination conditions: - - 1) Trip count. Iteration count specified at runtime. Set by - specifying the input M. Optional. Set to empty string to omit. - Note that a static trip count (specified at graph construction time) can be - specified by passing in a constant node for input M. - 2) Loop termination condition. This is an input to the op that determines - whether to run the first iteration and also a loop-carried dependency for - the body graph. The body graph must yield a value for the condition variable, - whether this input is provided or not. - - This table summarizes the operating modes of this operator with equivalent - C-style code: - - Operator inputs defined as (max_trip_count, condition_var). - - input ("", ""): - for (int i=0; ; ++i) { - cond = ... // Note this value is ignored, but is required in the body - } - - input ("", cond) // Note this is analogous to a while loop - bool cond = ...; - for (int i=0; cond; ++i) { - cond = ...; - } - - input ("", 1) // Note this is analogous to a do-while loop - bool cond = true - for (int i=0; cond; ++i) { - cond = ...; - } - - input (trip_count, "") // Note this is analogous to a for loop - int trip_count = ... - for (int i=0; i < trip_count; ++i) { - cond = ...; // ignored - } - - input (trip_count, cond) - int trip_count = ...; - bool cond = ...; - for (int i=0; i < trip_count && cond; ++i) { - cond = ...; - } - - - *Sample usage - cond as well as trip count* - - graph predict-net { - %a = Constant[value = ]() - %b = Constant[value = ]() - %keepgoing = Constant[value = ]() - %max_trip_count = Constant[value = ]() - %keepgoing_out, %b_out, %user_defined_vals = Loop[body = ](%max_trip_count, %keepgoing, %b) - return - } - - graph body-net ( - %i[INT32, scalar] // iteration number - %keepgoing_in[BOOL, scalar] // incoming loop-termination-condition; not used - %b_in[INT32, scalar] // incoming value of loop-carried-dependency b - ) { - %my_local = Add(%a, %b_in) - %b_out = Sub(%a, %b_in) // outgoing value of loop-carried-dependency b - %keepgoing_out = Greater(%my_local, %b_out) // outgoing loop-termination-condition - %user_defined_val = Add(%b_in, %b_in) // scan-output value to be accumulated - return %keepgoing_out, %b_out, %user_defined_val - } - - *Sample equivalent C code* - - { - /* User-defined code (enclosing scope) */ - int a = 3, b = 6; - bool keepgoing = true; // Analogous to input cond - /* End user-defined code */ - - /* Implicitly-defined code */ - const int max_trip_count = 10; // Analogous to input M - int user_defined_vals[]; // Imagine this is resizable - /* End implicitly-defined code */ - /* initialize loop-carried variables and scan-output variables */ - bool keepgoing_out = keepgoing - int b_out = b - - for (int i=0; i < max_trip_count && keepgoing_out; ++i) { - /* Implicitly-defined code: bind actual parameter values - to formal parameter variables of loop-body */ - bool keepgoing_in = keepgoing_out; - bool b_in = b_out; - - /* User-defined code (loop body) */ - int my_local = a + b_in; // Reading value "a" from the enclosing scope is fine - b_out = a - b_in; - keepgoing_out = my_local > b_out; - user_defined_val = b_in + b_in; // b_in and b_out are different variables + """ + Generic Looping construct. This loop has multiple termination conditions: + + 1) Trip count. Iteration count specified at runtime. Set by + specifying the input M. Optional. Set to empty string to omit. + Note that a static trip count (specified at graph construction time) can be + specified by passing in a constant node for input M. + 2) Loop termination condition. This is an input to the op that determines + whether to run the first iteration and also a loop-carried dependency for + the body graph. The body graph must yield a value for the condition variable, + whether this input is provided or not. + + This table summarizes the operating modes of this operator with equivalent + C-style code: + + Operator inputs defined as (max_trip_count, condition_var). + + input ("", ""): + for (int i=0; ; ++i) { + cond = ... // Note this value is ignored, but is required in the body + } + + input ("", cond) // Note this is analogous to a while loop + bool cond = ...; + for (int i=0; cond; ++i) { + cond = ...; + } + + input ("", 1) // Note this is analogous to a do-while loop + bool cond = true + for (int i=0; cond; ++i) { + cond = ...; + } + + input (trip_count, "") // Note this is analogous to a for loop + int trip_count = ... + for (int i=0; i < trip_count; ++i) { + cond = ...; // ignored + } + + input (trip_count, cond) + int trip_count = ...; + bool cond = ...; + for (int i=0; i < trip_count && cond; ++i) { + cond = ...; + } + + + *Sample usage - cond as well as trip count* + + graph predict-net { + %a = Constant[value = ]() + %b = Constant[value = ]() + %keepgoing = Constant[value = ]() + %max_trip_count = Constant[value = ]() + %keepgoing_out, %b_out, %user_defined_vals = Loop[body = ](%max_trip_count, %keepgoing, %b) + return + } + + graph body-net ( + %i[INT32, scalar] // iteration number + %keepgoing_in[BOOL, scalar] // incoming loop-termination-condition; not used + %b_in[INT32, scalar] // incoming value of loop-carried-dependency b + ) { + %my_local = Add(%a, %b_in) + %b_out = Sub(%a, %b_in) // outgoing value of loop-carried-dependency b + %keepgoing_out = Greater(%my_local, %b_out) // outgoing loop-termination-condition + %user_defined_val = Add(%b_in, %b_in) // scan-output value to be accumulated + return %keepgoing_out, %b_out, %user_defined_val + } + + *Sample equivalent C code* + + { + /* User-defined code (enclosing scope) */ + int a = 3, b = 6; + bool keepgoing = true; // Analogous to input cond /* End user-defined code */ - - /* Implicitly defined-code */ - user_defined_vals[i] = user_defined_val // accumulate scan-output values + + /* Implicitly-defined code */ + const int max_trip_count = 10; // Analogous to input M + int user_defined_vals[]; // Imagine this is resizable + /* End implicitly-defined code */ + /* initialize loop-carried variables and scan-output variables */ + bool keepgoing_out = keepgoing + int b_out = b + + for (int i=0; i < max_trip_count && keepgoing_out; ++i) { + /* Implicitly-defined code: bind actual parameter values + to formal parameter variables of loop-body */ + bool keepgoing_in = keepgoing_out; + bool b_in = b_out; + + /* User-defined code (loop body) */ + int my_local = a + b_in; // Reading value "a" from the enclosing scope is fine + b_out = a - b_in; + keepgoing_out = my_local > b_out; + user_defined_val = b_in + b_in; // b_in and b_out are different variables + /* End user-defined code */ + + /* Implicitly defined-code */ + user_defined_vals[i] = user_defined_val // accumulate scan-output values + } + // int t = my_local; // Can't do this. my_local is not accessible here. + + // The values below are bound to the output variables of the loop and therefore accessible + // b_out; user_defined_vals; keepgoing_out; } - // int t = my_local; // Can't do this. my_local is not accessible here. - - // The values below are bound to the output variables of the loop and therefore accessible - // b_out; user_defined_vals; keepgoing_out; - } - - There are several things of note in this code snippet: - - 1) Values from the enclosing scope (i.e. variable "a" here) are in scope and can - be referenced in the inputs of the loop. - 2) Any values computed in the loop body that needs to be used in a subsequent - iteration or after the loop are modelled using a pair of variables in the loop-body, - consisting of an input variable (eg., b_in) and an output variable (eg., b_out). - These are referred to as loop-carried dependences. The loop operation node - supplies the input value of the input variable for the first iteration, and - returns the output value of the output variable produced by the final - iteration. - 3) Scan_output variables are used to implicitly concatenate values computed across - all the iterations. In the above example, the value of user_defined_val computed - over all iterations are concatenated and returned as the value of user_defined_vals - after the loop. - 4) Values created in the body cannot be accessed in the enclosing scope, - except using the mechanism described above. - - Note that the semantics of this op support "diagonal" or "wavefront" execution. - (See Step 3 here for an example: - https://devblogs.nvidia.com/optimizing-recurrent-neural-networks-cudnn-5/). - Frontends should emit multi-layer RNNs as a series of While operators (with - time being the inner looping dimension), with each successive layer consuming - the scan_outputs from the previous layer, possibly going through several - point-wise operators (e.g. dropout, residual connections, linear layer). - - The input/output of subgraph (produced by loop node) matching is based on order instead of name. The implementation will figure out the names based on this order. - """ - - def __init__(self, M, cond, v_initial, - body=None): - super().__init__('Loop', 1, - [{'at::kLong'}, {'at::kBool'}, {'at::kDouble', 'at::kLong', 'at::kByte', 'at::kInt', 'at::kHalf', 'at::kShort', 'at::kBool', 'at::kFloat'}], - M,cond,v_initial, - body=ONNXAttr(body, AttrType.GRAPH)) + + There are several things of note in this code snippet: + + 1) Values from the enclosing scope (i.e. variable "a" here) are in scope and can + be referenced in the inputs of the loop. + 2) Any values computed in the loop body that needs to be used in a subsequent + iteration or after the loop are modelled using a pair of variables in the loop-body, + consisting of an input variable (eg., b_in) and an output variable (eg., b_out). + These are referred to as loop-carried dependences. The loop operation node + supplies the input value of the input variable for the first iteration, and + returns the output value of the output variable produced by the final + iteration. + 3) Scan_output variables are used to implicitly concatenate values computed across + all the iterations. In the above example, the value of user_defined_val computed + over all iterations are concatenated and returned as the value of user_defined_vals + after the loop. + 4) Values created in the body cannot be accessed in the enclosing scope, + except using the mechanism described above. + + Note that the semantics of this op support "diagonal" or "wavefront" execution. + (See Step 3 here for an example: + https://devblogs.nvidia.com/optimizing-recurrent-neural-networks-cudnn-5/). + Frontends should emit multi-layer RNNs as a series of While operators (with + time being the inner looping dimension), with each successive layer consuming + the scan_outputs from the previous layer, possibly going through several + point-wise operators (e.g. dropout, residual connections, linear layer). + + The input/output of subgraph (produced by loop node) matching is based on order instead of name. The implementation will figure out the names based on this order. + """ + + def __init__(self, M, cond, v_initial, body=None): + super().__init__( + "Loop", + 1, + [ + {"at::kLong"}, + {"at::kBool"}, + { + "at::kDouble", + "at::kLong", + "at::kByte", + "at::kInt", + "at::kHalf", + "at::kShort", + "at::kBool", + "at::kFloat", + }, + ], + M, + cond, + v_initial, + body=ONNXAttr(body, AttrType.GRAPH), + ) + class LpNormalization(ONNXOp): - """ - Given a matrix, apply Lp-normalization along the provided axis. - """ - - def __init__(self, input, - axis=None, - p=None): - super().__init__('LpNormalization', 1, - [{'at::kDouble', 'at::kHalf', 'at::kFloat'}], - input, - axis=ONNXAttr(axis, AttrType.INT), - p=ONNXAttr(p, AttrType.INT)) + """ + Given a matrix, apply Lp-normalization along the provided axis. + """ + + def __init__(self, input, axis=None, p=None): + super().__init__( + "LpNormalization", + 1, + [{"at::kDouble", "at::kHalf", "at::kFloat"}], + input, + axis=ONNXAttr(axis, AttrType.INT), + p=ONNXAttr(p, AttrType.INT), + ) + class LpPool(ONNXOp): - """ - LpPool consumes an input tensor X and applies Lp pooling across - the tensor according to kernel sizes, stride sizes, and pad lengths. - Lp pooling consisting of computing the Lp norm on all values of a subset - of the input tensor according to the kernel size and downsampling the - data into the output tensor Y for further processing. - """ - - def __init__(self, X, - auto_pad=None, - kernel_shape=None, - p=None, - pads=None, - strides=None): - super().__init__('LpPool', 1, - [{'at::kDouble', 'at::kHalf', 'at::kFloat'}], - X, - auto_pad=ONNXAttr(auto_pad, AttrType.STRING), - kernel_shape=ONNXAttr(kernel_shape, AttrType.INTS), - p=ONNXAttr(p, AttrType.INT), - pads=ONNXAttr(pads, AttrType.INTS), - strides=ONNXAttr(strides, AttrType.INTS)) + """ + LpPool consumes an input tensor X and applies Lp pooling across + the tensor according to kernel sizes, stride sizes, and pad lengths. + Lp pooling consisting of computing the Lp norm on all values of a subset + of the input tensor according to the kernel size and downsampling the + data into the output tensor Y for further processing. + """ + + def __init__(self, X, auto_pad=None, kernel_shape=None, p=None, pads=None, strides=None): + super().__init__( + "LpPool", + 1, + [{"at::kDouble", "at::kHalf", "at::kFloat"}], + X, + auto_pad=ONNXAttr(auto_pad, AttrType.STRING), + kernel_shape=ONNXAttr(kernel_shape, AttrType.INTS), + p=ONNXAttr(p, AttrType.INT), + pads=ONNXAttr(pads, AttrType.INTS), + strides=ONNXAttr(strides, AttrType.INTS), + ) + class LRN(ONNXOp): - """ - Local Response Normalization proposed in the [AlexNet paper](https://papers.nips.cc/paper/4824-imagenet-classification-with-deep-convolutional-neural-networks.pdf). - It normalizes over local input regions. - The local region is defined across the channels. For an element X[n, c, d1, ..., dk] in a tensor - of shape (N x C x D1 x D2, ..., Dk), its region is - {X[n, i, d1, ..., dk] | max(0, c - floor((size - 1) / 2)) <= i <= min(C - 1, c + ceil((size - 1) / 2))}. - - square_sum[n, c, d1, ..., dk] = sum(X[n, i, d1, ..., dk] ^ 2), - where max(0, c - floor((size - 1) / 2)) <= i <= min(C - 1, c + ceil((size - 1) / 2)). - - Y[n, c, d1, ..., dk] = X[n, c, d1, ..., dk] / (bias + alpha / size * square_sum[n, c, d1, ..., dk] ) ^ beta - """ - - def __init__(self, X, - alpha=None, - beta=None, - bias=None, - size=None): - super().__init__('LRN', 1, - [{'at::kDouble', 'at::kBFloat16', 'at::kHalf', 'at::kFloat'}], - X, - alpha=ONNXAttr(alpha, AttrType.FLOAT), - beta=ONNXAttr(beta, AttrType.FLOAT), - bias=ONNXAttr(bias, AttrType.FLOAT), - size=ONNXAttr(size, AttrType.INT)) + """ + Local Response Normalization proposed in the [AlexNet paper](https://papers.nips.cc/paper/4824-imagenet-classification-with-deep-convolutional-neural-networks.pdf). + It normalizes over local input regions. + The local region is defined across the channels. For an element X[n, c, d1, ..., dk] in a tensor + of shape (N x C x D1 x D2, ..., Dk), its region is + {X[n, i, d1, ..., dk] | max(0, c - floor((size - 1) / 2)) <= i <= min(C - 1, c + ceil((size - 1) / 2))}. + + square_sum[n, c, d1, ..., dk] = sum(X[n, i, d1, ..., dk] ^ 2), + where max(0, c - floor((size - 1) / 2)) <= i <= min(C - 1, c + ceil((size - 1) / 2)). + + Y[n, c, d1, ..., dk] = X[n, c, d1, ..., dk] / (bias + alpha / size * square_sum[n, c, d1, ..., dk] ) ^ beta + """ + + def __init__(self, X, alpha=None, beta=None, bias=None, size=None): + super().__init__( + "LRN", + 1, + [{"at::kDouble", "at::kBFloat16", "at::kHalf", "at::kFloat"}], + X, + alpha=ONNXAttr(alpha, AttrType.FLOAT), + beta=ONNXAttr(beta, AttrType.FLOAT), + bias=ONNXAttr(bias, AttrType.FLOAT), + size=ONNXAttr(size, AttrType.INT), + ) + class LSTM(ONNXOp): - """ - Computes an one-layer LSTM. This operator is usually supported via some - custom implementation such as CuDNN. - - Notations: - - `X` - input tensor - - `i` - input gate - - `o` - output gate - - `f` - forget gate - - `c` - cell gate - - `t` - time step (t-1 means previous time step) - - `W[iofc]` - W parameter weight matrix for input, output, forget, and cell gates - - `R[iofc]` - R recurrence weight matrix for input, output, forget, and cell gates - - `Wb[iofc]` - W bias vectors for input, output, forget, and cell gates - - `Rb[iofc]` - R bias vectors for input, output, forget, and cell gates - - `P[iof]` - P peephole weight vector for input, output, and forget gates - - `WB[iofc]` - W parameter weight matrix for backward input, output, forget, and cell gates - - `RB[iofc]` - R recurrence weight matrix for backward input, output, forget, and cell gates - - `WBb[iofc]` - W bias vectors for backward input, output, forget, and cell gates - - `RBb[iofc]` - R bias vectors for backward input, output, forget, and cell gates - - `PB[iof]` - P peephole weight vector for backward input, output, and forget gates - - `H` - Hidden state - - `num_directions` - 2 if direction == bidirectional else 1 - - Activation functions: - - Relu(x) - max(0, x) - - Tanh(x) - (1 - e^{-2x})/(1 + e^{-2x}) - - Sigmoid(x) - 1/(1 + e^{-x}) - - (NOTE: Below are optional) - - Affine(x) - alpha*x + beta - - LeakyRelu(x) - x if x >= 0 else alpha * x - - ThresholdedRelu(x) - x if x >= alpha else 0 - - ScaledTanh(x) - alpha*Tanh(beta*x) - - HardSigmoid(x) - min(max(alpha*x + beta, 0), 1) - - Elu(x) - x if x >= 0 else alpha*(e^x - 1) - - Softsign(x) - x/(1 + |x|) - - Softplus(x) - log(1 + e^x) - - Equations (Default: f=Sigmoid, g=Tanh, h=Tanh): - - - it = f(Xt*(Wi^T) + Ht-1*(Ri^T) + Pi (.) Ct-1 + Wbi + Rbi) - - - ft = f(Xt*(Wf^T) + Ht-1*(Rf^T) + Pf (.) Ct-1 + Wbf + Rbf) - - - ct = g(Xt*(Wc^T) + Ht-1*(Rc^T) + Wbc + Rbc) - - - Ct = ft (.) Ct-1 + it (.) ct - - - ot = f(Xt*(Wo^T) + Ht-1*(Ro^T) + Po (.) Ct + Wbo + Rbo) - - - Ht = ot (.) h(Ct) - This operator has **optional** inputs/outputs. See [the doc](IR.md) for more details about the representation of optional arguments. An empty string may be used in the place of an actual argument's name to indicate a missing argument. Trailing optional arguments (those not followed by an argument that is present) may also be simply omitted. - """ - - def __init__(self, X, W, R, B, sequence_lens, initial_h, initial_c, P, - activation_alpha=None, - activation_beta=None, - activations=None, - clip=None, - direction=None, - hidden_size=None, - input_forget=None, - layout=None): - super().__init__('LSTM', 3, - [{'at::kDouble', 'at::kHalf', 'at::kFloat'}, {'at::kDouble', 'at::kHalf', 'at::kFloat'}, {'at::kDouble', 'at::kHalf', 'at::kFloat'}, {'at::kDouble', 'at::kHalf', 'at::kFloat'}, {'at::kInt'}, {'at::kDouble', 'at::kHalf', 'at::kFloat'}, {'at::kDouble', 'at::kHalf', 'at::kFloat'}, {'at::kDouble', 'at::kHalf', 'at::kFloat'}], - X,W,R,B,sequence_lens,initial_h,initial_c,P, - activation_alpha=ONNXAttr(activation_alpha, AttrType.FLOATS), - activation_beta=ONNXAttr(activation_beta, AttrType.FLOATS), - activations=ONNXAttr(activations, AttrType.STRINGS), - clip=ONNXAttr(clip, AttrType.FLOAT), - direction=ONNXAttr(direction, AttrType.STRING), - hidden_size=ONNXAttr(hidden_size, AttrType.INT), - input_forget=ONNXAttr(input_forget, AttrType.INT), - layout=ONNXAttr(layout, AttrType.INT)) + """ + Computes an one-layer LSTM. This operator is usually supported via some + custom implementation such as CuDNN. + + Notations: + + `X` - input tensor + + `i` - input gate + + `o` - output gate + + `f` - forget gate + + `c` - cell gate + + `t` - time step (t-1 means previous time step) + + `W[iofc]` - W parameter weight matrix for input, output, forget, and cell gates + + `R[iofc]` - R recurrence weight matrix for input, output, forget, and cell gates + + `Wb[iofc]` - W bias vectors for input, output, forget, and cell gates + + `Rb[iofc]` - R bias vectors for input, output, forget, and cell gates + + `P[iof]` - P peephole weight vector for input, output, and forget gates + + `WB[iofc]` - W parameter weight matrix for backward input, output, forget, and cell gates + + `RB[iofc]` - R recurrence weight matrix for backward input, output, forget, and cell gates + + `WBb[iofc]` - W bias vectors for backward input, output, forget, and cell gates + + `RBb[iofc]` - R bias vectors for backward input, output, forget, and cell gates + + `PB[iof]` - P peephole weight vector for backward input, output, and forget gates + + `H` - Hidden state + + `num_directions` - 2 if direction == bidirectional else 1 + + Activation functions: + + Relu(x) - max(0, x) + + Tanh(x) - (1 - e^{-2x})/(1 + e^{-2x}) + + Sigmoid(x) - 1/(1 + e^{-x}) + + (NOTE: Below are optional) + + Affine(x) - alpha*x + beta + + LeakyRelu(x) - x if x >= 0 else alpha * x + + ThresholdedRelu(x) - x if x >= alpha else 0 + + ScaledTanh(x) - alpha*Tanh(beta*x) + + HardSigmoid(x) - min(max(alpha*x + beta, 0), 1) + + Elu(x) - x if x >= 0 else alpha*(e^x - 1) + + Softsign(x) - x/(1 + |x|) + + Softplus(x) - log(1 + e^x) + + Equations (Default: f=Sigmoid, g=Tanh, h=Tanh): + + - it = f(Xt*(Wi^T) + Ht-1*(Ri^T) + Pi (.) Ct-1 + Wbi + Rbi) + + - ft = f(Xt*(Wf^T) + Ht-1*(Rf^T) + Pf (.) Ct-1 + Wbf + Rbf) + + - ct = g(Xt*(Wc^T) + Ht-1*(Rc^T) + Wbc + Rbc) + + - Ct = ft (.) Ct-1 + it (.) ct + + - ot = f(Xt*(Wo^T) + Ht-1*(Ro^T) + Po (.) Ct + Wbo + Rbo) + + - Ht = ot (.) h(Ct) + This operator has **optional** inputs/outputs. See [the doc](IR.md) for more details about the representation of optional arguments. An empty string may be used in the place of an actual argument's name to indicate a missing argument. Trailing optional arguments (those not followed by an argument that is present) may also be simply omitted. + """ + + def __init__( + self, + X, + W, + R, + B, + sequence_lens, + initial_h, + initial_c, + P, + activation_alpha=None, + activation_beta=None, + activations=None, + clip=None, + direction=None, + hidden_size=None, + input_forget=None, + layout=None, + ): + super().__init__( + "LSTM", + 3, + [ + {"at::kDouble", "at::kHalf", "at::kFloat"}, + {"at::kDouble", "at::kHalf", "at::kFloat"}, + {"at::kDouble", "at::kHalf", "at::kFloat"}, + {"at::kDouble", "at::kHalf", "at::kFloat"}, + {"at::kInt"}, + {"at::kDouble", "at::kHalf", "at::kFloat"}, + {"at::kDouble", "at::kHalf", "at::kFloat"}, + {"at::kDouble", "at::kHalf", "at::kFloat"}, + ], + X, + W, + R, + B, + sequence_lens, + initial_h, + initial_c, + P, + activation_alpha=ONNXAttr(activation_alpha, AttrType.FLOATS), + activation_beta=ONNXAttr(activation_beta, AttrType.FLOATS), + activations=ONNXAttr(activations, AttrType.STRINGS), + clip=ONNXAttr(clip, AttrType.FLOAT), + direction=ONNXAttr(direction, AttrType.STRING), + hidden_size=ONNXAttr(hidden_size, AttrType.INT), + input_forget=ONNXAttr(input_forget, AttrType.INT), + layout=ONNXAttr(layout, AttrType.INT), + ) + class MatMul(ONNXOp): - """ - Matrix product that behaves like numpy.matmul: https://docs.scipy.org/doc/numpy-1.13.0/reference/generated/numpy.matmul.html - """ + """ + Matrix product that behaves like numpy.matmul: https://docs.scipy.org/doc/numpy-1.13.0/reference/generated/numpy.matmul.html + """ + + def __init__(self, A, B): + super().__init__( + "MatMul", + 1, + [ + {"at::kDouble", "at::kLong", "at::kInt", "at::kHalf", "at::kFloat", "at::kBFloat16"}, + {"at::kDouble", "at::kLong", "at::kInt", "at::kHalf", "at::kFloat", "at::kBFloat16"}, + ], + A, + B, + ) - def __init__(self, A, B): - super().__init__('MatMul', 1, - [{'at::kDouble', 'at::kLong', 'at::kInt', 'at::kHalf', 'at::kFloat', 'at::kBFloat16'}, {'at::kDouble', 'at::kLong', 'at::kInt', 'at::kHalf', 'at::kFloat', 'at::kBFloat16'}], - A,B) class MatMulInteger(ONNXOp): - """ - Matrix product that behaves like numpy.matmul: https://docs.scipy.org/doc/numpy-1.13.0/reference/generated/numpy.matmul.html. - The production MUST never overflow. The accumulation may overflow if and only if in 32 bits. - """ + """ + Matrix product that behaves like numpy.matmul: https://docs.scipy.org/doc/numpy-1.13.0/reference/generated/numpy.matmul.html. + The production MUST never overflow. The accumulation may overflow if and only if in 32 bits. + """ + + def __init__(self, A, B, a_zero_point, b_zero_point): + super().__init__( + "MatMulInteger", + 1, + [{"at::kByte"}, {"at::kByte"}, {"at::kByte"}, {"at::kByte"}], + A, + B, + a_zero_point, + b_zero_point, + ) - def __init__(self, A, B, a_zero_point, b_zero_point): - super().__init__('MatMulInteger', 1, - [{'at::kByte'}, {'at::kByte'}, {'at::kByte'}, {'at::kByte'}], - A,B,a_zero_point,b_zero_point) class Max(ONNXOp): - """ - Element-wise max of each of the input tensors (with Numpy-style broadcasting support). - All inputs and outputs must have the same data type. - This operator supports **multidirectional (i.e., Numpy-style) broadcasting**; for more details please check [the doc](Broadcasting.md). - """ + """ + Element-wise max of each of the input tensors (with Numpy-style broadcasting support). + All inputs and outputs must have the same data type. + This operator supports **multidirectional (i.e., Numpy-style) broadcasting**; for more details please check [the doc](Broadcasting.md). + """ + + def __init__(self, data_0): + super().__init__( + "Max", + 1, + [ + { + "at::kDouble", + "at::kLong", + "at::kByte", + "at::kInt", + "at::kHalf", + "at::kShort", + "at::kFloat", + "at::kBFloat16", + } + ], + data_0, + ) - def __init__(self, data_0): - super().__init__('Max', 1, - [{'at::kDouble', 'at::kLong', 'at::kByte', 'at::kInt', 'at::kHalf', 'at::kShort', 'at::kFloat', 'at::kBFloat16'}], - data_0) class MaxPool(ONNXOp): - """ - MaxPool consumes an input tensor X and applies max pooling across - the tensor according to kernel sizes, stride sizes, and pad lengths. - max pooling consisting of computing the max on all values of a - subset of the input tensor according to the kernel size and downsampling the - data into the output tensor Y for further processing. The output spatial shape will be following: - ``` - output_spatial_shape[i] = floor((input_spatial_shape[i] + pad_shape[i] - ((kernel_spatial_shape[i] - 1) * dilations[i] + 1)) / strides_spatial_shape[i] + 1) - ``` - or - ``` - output_spatial_shape[i] = ceil((input_spatial_shape[i] + pad_shape[i] - ((kernel_spatial_shape[i] - 1) * dilations[i] + 1)) / strides_spatial_shape[i] + 1) - ``` - if ceil_mode is enabled - - ``` - * pad_shape[i] is sum of pads along axis i - ``` - - `auto_pad` is a DEPRECATED attribute. If you are using them currently, the output spatial shape will be following: - ``` - VALID: output_spatial_shape[i] = ceil((input_spatial_shape[i] - ((kernel_spatial_shape[i] - 1) * dilations[i] + 1) + 1) / strides_spatial_shape[i]) - SAME_UPPER or SAME_LOWER: output_spatial_shape[i] = ceil(input_spatial_shape[i] / strides_spatial_shape[i]) - ``` - And pad shape will be following if `SAME_UPPER` or `SAME_LOWER`: - ``` - pad_shape[i] = (output_spatial_shape[i] - 1) * strides_spatial_shape[i] + ((kernel_spatial_shape[i] - 1) * dilations[i] + 1) - input_spatial_shape[i] - ``` - The output of each pooling window is maximum number of elements exclude pad. - - """ - - def __init__(self, X, - auto_pad=None, - ceil_mode=None, - dilations=None, - kernel_shape=None, - pads=None, - storage_order=None, - strides=None): - super().__init__('MaxPool', 2, - [{'at::kDouble', 'at::kByte', 'at::kHalf', 'at::kFloat'}], - X, - auto_pad=ONNXAttr(auto_pad, AttrType.STRING), - ceil_mode=ONNXAttr(ceil_mode, AttrType.INT), - dilations=ONNXAttr(dilations, AttrType.INTS), - kernel_shape=ONNXAttr(kernel_shape, AttrType.INTS), - pads=ONNXAttr(pads, AttrType.INTS), - storage_order=ONNXAttr(storage_order, AttrType.INT), - strides=ONNXAttr(strides, AttrType.INTS)) + """ + MaxPool consumes an input tensor X and applies max pooling across + the tensor according to kernel sizes, stride sizes, and pad lengths. + max pooling consisting of computing the max on all values of a + subset of the input tensor according to the kernel size and downsampling the + data into the output tensor Y for further processing. The output spatial shape will be following: + ``` + output_spatial_shape[i] = floor((input_spatial_shape[i] + pad_shape[i] - ((kernel_spatial_shape[i] - 1) * dilations[i] + 1)) / strides_spatial_shape[i] + 1) + ``` + or + ``` + output_spatial_shape[i] = ceil((input_spatial_shape[i] + pad_shape[i] - ((kernel_spatial_shape[i] - 1) * dilations[i] + 1)) / strides_spatial_shape[i] + 1) + ``` + if ceil_mode is enabled + + ``` + * pad_shape[i] is sum of pads along axis i + ``` + + `auto_pad` is a DEPRECATED attribute. If you are using them currently, the output spatial shape will be following: + ``` + VALID: output_spatial_shape[i] = ceil((input_spatial_shape[i] - ((kernel_spatial_shape[i] - 1) * dilations[i] + 1) + 1) / strides_spatial_shape[i]) + SAME_UPPER or SAME_LOWER: output_spatial_shape[i] = ceil(input_spatial_shape[i] / strides_spatial_shape[i]) + ``` + And pad shape will be following if `SAME_UPPER` or `SAME_LOWER`: + ``` + pad_shape[i] = (output_spatial_shape[i] - 1) * strides_spatial_shape[i] + ((kernel_spatial_shape[i] - 1) * dilations[i] + 1) - input_spatial_shape[i] + ``` + The output of each pooling window is maximum number of elements exclude pad. + + """ + + def __init__( + self, + X, + auto_pad=None, + ceil_mode=None, + dilations=None, + kernel_shape=None, + pads=None, + storage_order=None, + strides=None, + ): + super().__init__( + "MaxPool", + 2, + [{"at::kDouble", "at::kByte", "at::kHalf", "at::kFloat"}], + X, + auto_pad=ONNXAttr(auto_pad, AttrType.STRING), + ceil_mode=ONNXAttr(ceil_mode, AttrType.INT), + dilations=ONNXAttr(dilations, AttrType.INTS), + kernel_shape=ONNXAttr(kernel_shape, AttrType.INTS), + pads=ONNXAttr(pads, AttrType.INTS), + storage_order=ONNXAttr(storage_order, AttrType.INT), + strides=ONNXAttr(strides, AttrType.INTS), + ) + class MaxRoiPool(ONNXOp): - """ - ROI max pool consumes an input tensor X and region of interests (RoIs) to - apply max pooling across each RoI, to produce output 4-D tensor of shape - (num_rois, channels, pooled_shape[0], pooled_shape[1]). - """ - - def __init__(self, X, rois, - pooled_shape=None, - spatial_scale=None): - super().__init__('MaxRoiPool', 1, - [{'at::kDouble', 'at::kHalf', 'at::kFloat'}, {'at::kDouble', 'at::kHalf', 'at::kFloat'}], - X,rois, - pooled_shape=ONNXAttr(pooled_shape, AttrType.INTS), - spatial_scale=ONNXAttr(spatial_scale, AttrType.FLOAT)) + """ + ROI max pool consumes an input tensor X and region of interests (RoIs) to + apply max pooling across each RoI, to produce output 4-D tensor of shape + (num_rois, channels, pooled_shape[0], pooled_shape[1]). + """ + + def __init__(self, X, rois, pooled_shape=None, spatial_scale=None): + super().__init__( + "MaxRoiPool", + 1, + [{"at::kDouble", "at::kHalf", "at::kFloat"}, {"at::kDouble", "at::kHalf", "at::kFloat"}], + X, + rois, + pooled_shape=ONNXAttr(pooled_shape, AttrType.INTS), + spatial_scale=ONNXAttr(spatial_scale, AttrType.FLOAT), + ) + class MaxUnpool(ONNXOp): - """ - MaxUnpool essentially computes the partial inverse of the MaxPool op. - The input information to this op is typically the the output information from a MaxPool op. The first - input tensor X is the tensor that needs to be unpooled, which is typically the pooled tensor (first output) - from MaxPool. The second input tensor, I, contains the indices to the (locally maximal) elements corrsponding - to the elements in the first input tensor X. Input tensor I is typically the second output of the MaxPool op. - The third (optional) input is a tensor that specifies the output size of the unpooling operation. - - MaxUnpool is intended to do 'partial' inverse of the MaxPool op. 'Partial' because all the non-maximal - values from the original input to MaxPool are set to zero in the output of the MaxUnpool op. Pooling - the result of an unpooling operation should give back the original input to the unpooling op. - - MaxUnpool can produce the same output size for several input sizes, which makes unpooling op ambiguous. - The third input argument, output_size, is meant to disambiguate the op and produce output tensor of - known/predictable size. - - In addition to the inputs, MaxUnpool takes three attributes, namely kernel_shape, strides, and pads, - which define the exact unpooling op. The attributes typically have the same values as the corrsponding - pooling op that the unpooling op is trying to invert. - """ - - def __init__(self, X, I, output_shape, - kernel_shape=None, - pads=None, - strides=None): - super().__init__('MaxUnpool', 1, - [{'at::kDouble', 'at::kHalf', 'at::kFloat'}, {'at::kLong'}, {'at::kLong'}], - X,I,output_shape, - kernel_shape=ONNXAttr(kernel_shape, AttrType.INTS), - pads=ONNXAttr(pads, AttrType.INTS), - strides=ONNXAttr(strides, AttrType.INTS)) + """ + MaxUnpool essentially computes the partial inverse of the MaxPool op. + The input information to this op is typically the the output information from a MaxPool op. The first + input tensor X is the tensor that needs to be unpooled, which is typically the pooled tensor (first output) + from MaxPool. The second input tensor, I, contains the indices to the (locally maximal) elements corrsponding + to the elements in the first input tensor X. Input tensor I is typically the second output of the MaxPool op. + The third (optional) input is a tensor that specifies the output size of the unpooling operation. + + MaxUnpool is intended to do 'partial' inverse of the MaxPool op. 'Partial' because all the non-maximal + values from the original input to MaxPool are set to zero in the output of the MaxUnpool op. Pooling + the result of an unpooling operation should give back the original input to the unpooling op. + + MaxUnpool can produce the same output size for several input sizes, which makes unpooling op ambiguous. + The third input argument, output_size, is meant to disambiguate the op and produce output tensor of + known/predictable size. + + In addition to the inputs, MaxUnpool takes three attributes, namely kernel_shape, strides, and pads, + which define the exact unpooling op. The attributes typically have the same values as the corrsponding + pooling op that the unpooling op is trying to invert. + """ + + def __init__(self, X, I, output_shape, kernel_shape=None, pads=None, strides=None): + super().__init__( + "MaxUnpool", + 1, + [{"at::kDouble", "at::kHalf", "at::kFloat"}, {"at::kLong"}, {"at::kLong"}], + X, + I, + output_shape, + kernel_shape=ONNXAttr(kernel_shape, AttrType.INTS), + pads=ONNXAttr(pads, AttrType.INTS), + strides=ONNXAttr(strides, AttrType.INTS), + ) + class Mean(ONNXOp): - """ - Element-wise mean of each of the input tensors (with Numpy-style broadcasting support). - All inputs and outputs must have the same data type. - This operator supports **multidirectional (i.e., Numpy-style) broadcasting**; for more details please check [the doc](Broadcasting.md). - """ + """ + Element-wise mean of each of the input tensors (with Numpy-style broadcasting support). + All inputs and outputs must have the same data type. + This operator supports **multidirectional (i.e., Numpy-style) broadcasting**; for more details please check [the doc](Broadcasting.md). + """ + + def __init__(self, data_0): + super().__init__("Mean", 1, [{"at::kDouble", "at::kBFloat16", "at::kHalf", "at::kFloat"}], data_0) - def __init__(self, data_0): - super().__init__('Mean', 1, - [{'at::kDouble', 'at::kBFloat16', 'at::kHalf', 'at::kFloat'}], - data_0) class MeanVarianceNormalization(ONNXOp): - """ - A MeanVarianceNormalization Function: Perform mean variance normalization - on the input tensor X using formula:
``` (X-EX)/sqrt(E(X-EX)^2) ``` - """ - - def __init__(self, X, - axes=None): - super().__init__('MeanVarianceNormalization', 1, - [{'at::kDouble', 'at::kBFloat16', 'at::kHalf', 'at::kFloat'}], - X, - axes=ONNXAttr(axes, AttrType.INTS)) + """ + A MeanVarianceNormalization Function: Perform mean variance normalization + on the input tensor X using formula:
``` (X-EX)/sqrt(E(X-EX)^2) ``` + """ + + def __init__(self, X, axes=None): + super().__init__( + "MeanVarianceNormalization", + 1, + [{"at::kDouble", "at::kBFloat16", "at::kHalf", "at::kFloat"}], + X, + axes=ONNXAttr(axes, AttrType.INTS), + ) + class Min(ONNXOp): - """ - Element-wise min of each of the input tensors (with Numpy-style broadcasting support). - All inputs and outputs must have the same data type. - This operator supports **multidirectional (i.e., Numpy-style) broadcasting**; for more details please check [the doc](Broadcasting.md). - """ + """ + Element-wise min of each of the input tensors (with Numpy-style broadcasting support). + All inputs and outputs must have the same data type. + This operator supports **multidirectional (i.e., Numpy-style) broadcasting**; for more details please check [the doc](Broadcasting.md). + """ + + def __init__(self, data_0): + super().__init__( + "Min", + 1, + [ + { + "at::kDouble", + "at::kLong", + "at::kByte", + "at::kInt", + "at::kHalf", + "at::kShort", + "at::kFloat", + "at::kBFloat16", + } + ], + data_0, + ) - def __init__(self, data_0): - super().__init__('Min', 1, - [{'at::kDouble', 'at::kLong', 'at::kByte', 'at::kInt', 'at::kHalf', 'at::kShort', 'at::kFloat', 'at::kBFloat16'}], - data_0) class Mod(ONNXOp): - """ + """ Performs element-wise binary modulus (with Numpy-style broadcasting support). The sign of the remainder is the same as that of the Divisor. - + Mod operator can also behave like C fmod() or numpy.fmod. In this case, the sign of the remainder however, will be the same as the Dividend (in contrast to integer mod). To force a behavior like numpy.fmod() an 'fmod' Attribute is provided. This attribute is set to 0 by default causing the behavior to be like integer mod. Setting this attribute to 1 causes the remainder to be calculated similar to that of numpy.fmod(). - + If the input type is floating point, then `fmod` attribute must be set to 1. - + In case of dividend being zero, the results will be platform dependent. - + This operator supports **multidirectional (i.e., Numpy-style) broadcasting**; for more details please check [the doc](Broadcasting.md). - """ + """ + + def __init__(self, A, B, fmod=None): + super().__init__( + "Mod", + 1, + [ + { + "at::kDouble", + "at::kLong", + "at::kByte", + "at::kInt", + "at::kHalf", + "at::kShort", + "at::kFloat", + "at::kBFloat16", + }, + { + "at::kDouble", + "at::kLong", + "at::kByte", + "at::kInt", + "at::kHalf", + "at::kShort", + "at::kFloat", + "at::kBFloat16", + }, + ], + A, + B, + fmod=ONNXAttr(fmod, AttrType.INT), + ) - def __init__(self, A, B, - fmod=None): - super().__init__('Mod', 1, - [{'at::kDouble', 'at::kLong', 'at::kByte', 'at::kInt', 'at::kHalf', 'at::kShort', 'at::kFloat', 'at::kBFloat16'}, {'at::kDouble', 'at::kLong', 'at::kByte', 'at::kInt', 'at::kHalf', 'at::kShort', 'at::kFloat', 'at::kBFloat16'}], - A,B, - fmod=ONNXAttr(fmod, AttrType.INT)) class Momentum(ONNXOp): - """ - Compute one iteration of stochastic gradient update with momentum. - This operator can conduct the optimization of multiple tensor variables. - - Let's define the behavior of this operator. As you can imagine, SG with momentum requires - several parameters: - - - The learning-rate "R". - - The update count "T". That is, the number of conducted training iterations. It should - be zero in the first training iteration. - - A L2-norm regularization coefficient "norm_coefficient". - - A decay coefficient of previous accumulated gradient (i.e., momentum) "alpha". - - The scaling coefficient of current gradient "beta". - - An attribute to choose either standard momentum or Nesterov's momentum "mode" should - be used. - - For the sake of simplicity, assume that there is only one tensor (called "X") to be optimized. - Other necessary inputs are "X"'s gradient (called "G") and "X"'s momentum (called "V"). This - Momentum operator maps all these inputs to the new value of "X" (called "X_new") and its new - momentum (called "V_new"). - - This operator supports two different momentum algorithms. Set the attribute "mode" to - "nesterov" if Nesterov's momentum is desired. Otherwise, set the attribute "model" to - "standard" to use standard momentum. Computation details are described subsequently. - - Let "+", "-", "*", and "/" are all element-wise operations with numpy-style broadcasting. - - Pseudo code for SG with standard momentum: - - // Add gradient of 0.5 * norm_coefficient * ||X||^2, where ||X|| is the sum of squared - // values of all elements in X. - G_regularized = norm_coefficient * X + G - - // In the first training iteration, beta should always be 1. - beta_adjusted = T > 0 ? beta : 1 - - // Compute the current momentum based on previous momentum and the current gradient. - V_new = alpha * V + beta_adjusted * G_regularized - - // Update X. - X_new = X - R * V_new - - Pseudo code for SG with Nesterov's momentum: - - // Add gradient of 0.5 * norm_coefficient * ||X||^2, where ||X|| is the sum of squared - // values of all elements in X. - G_regularized = norm_coefficient * X + G; - - // In the first training iteration, beta should always be 1. - beta_adjusted = T > 0 ? beta : 1 - - // Compute the current momentum based on previous momentum and the current gradient. - V_new = alpha * V + beta_adjusted * G_regularized; - - // Compute final update direction and then update X. - X_new = X - R * (G_regularized + alpha * V_new) - - If one assign this operators to optimize multiple inputs, for example, "X_1" and "X_2". The same - pseudo code would be extended to handle all tensors jointly. More specifically, we can view "X" as a - concatenation of "X_1" and "X_2" (of course, their gradient and accumulate gradient should - be concatenated too) and then our pseudo code becomes applicable. - """ - - def __init__(self, R, T, inputs, - alpha=None, - beta=None, - mode=None, - norm_coefficient=None): - super().__init__('Momentum', 1, - [{'at::kDouble', 'at::kFloat'}, {'at::kLong'}, {'at::kDouble', 'at::kFloat'}], - R,T,inputs, - alpha=ONNXAttr(alpha, AttrType.FLOAT), - beta=ONNXAttr(beta, AttrType.FLOAT), - mode=ONNXAttr(mode, AttrType.STRING), - norm_coefficient=ONNXAttr(norm_coefficient, AttrType.FLOAT)) + """ + Compute one iteration of stochastic gradient update with momentum. + This operator can conduct the optimization of multiple tensor variables. + + Let's define the behavior of this operator. As you can imagine, SG with momentum requires + several parameters: + + - The learning-rate "R". + - The update count "T". That is, the number of conducted training iterations. It should + be zero in the first training iteration. + - A L2-norm regularization coefficient "norm_coefficient". + - A decay coefficient of previous accumulated gradient (i.e., momentum) "alpha". + - The scaling coefficient of current gradient "beta". + - An attribute to choose either standard momentum or Nesterov's momentum "mode" should + be used. + + For the sake of simplicity, assume that there is only one tensor (called "X") to be optimized. + Other necessary inputs are "X"'s gradient (called "G") and "X"'s momentum (called "V"). This + Momentum operator maps all these inputs to the new value of "X" (called "X_new") and its new + momentum (called "V_new"). + + This operator supports two different momentum algorithms. Set the attribute "mode" to + "nesterov" if Nesterov's momentum is desired. Otherwise, set the attribute "model" to + "standard" to use standard momentum. Computation details are described subsequently. + + Let "+", "-", "*", and "/" are all element-wise operations with numpy-style broadcasting. + + Pseudo code for SG with standard momentum: + + // Add gradient of 0.5 * norm_coefficient * ||X||^2, where ||X|| is the sum of squared + // values of all elements in X. + G_regularized = norm_coefficient * X + G + + // In the first training iteration, beta should always be 1. + beta_adjusted = T > 0 ? beta : 1 + + // Compute the current momentum based on previous momentum and the current gradient. + V_new = alpha * V + beta_adjusted * G_regularized + + // Update X. + X_new = X - R * V_new + + Pseudo code for SG with Nesterov's momentum: + + // Add gradient of 0.5 * norm_coefficient * ||X||^2, where ||X|| is the sum of squared + // values of all elements in X. + G_regularized = norm_coefficient * X + G; + + // In the first training iteration, beta should always be 1. + beta_adjusted = T > 0 ? beta : 1 + + // Compute the current momentum based on previous momentum and the current gradient. + V_new = alpha * V + beta_adjusted * G_regularized; + + // Compute final update direction and then update X. + X_new = X - R * (G_regularized + alpha * V_new) + + If one assign this operators to optimize multiple inputs, for example, "X_1" and "X_2". The same + pseudo code would be extended to handle all tensors jointly. More specifically, we can view "X" as a + concatenation of "X_1" and "X_2" (of course, their gradient and accumulate gradient should + be concatenated too) and then our pseudo code becomes applicable. + """ + + def __init__(self, R, T, inputs, alpha=None, beta=None, mode=None, norm_coefficient=None): + super().__init__( + "Momentum", + 1, + [{"at::kDouble", "at::kFloat"}, {"at::kLong"}, {"at::kDouble", "at::kFloat"}], + R, + T, + inputs, + alpha=ONNXAttr(alpha, AttrType.FLOAT), + beta=ONNXAttr(beta, AttrType.FLOAT), + mode=ONNXAttr(mode, AttrType.STRING), + norm_coefficient=ONNXAttr(norm_coefficient, AttrType.FLOAT), + ) + class Mul(ONNXOp): - """ - Performs element-wise binary multiplication (with Numpy-style broadcasting support). - - This operator supports **multidirectional (i.e., Numpy-style) broadcasting**; for more details please check [the doc](Broadcasting.md). - - (Opset 14 change): Extend supported types to include uint8, int8, uint16, and int16. - """ - - def __init__(self, A, B): - super().__init__('Mul', 1, - [{'at::kDouble', 'at::kLong', 'at::kByte', 'at::kInt', 'at::kHalf', 'at::kShort', 'at::kFloat', 'at::kBFloat16'}, {'at::kDouble', 'at::kLong', 'at::kByte', 'at::kInt', 'at::kHalf', 'at::kShort', 'at::kFloat', 'at::kBFloat16'}], - A,B) + """ + Performs element-wise binary multiplication (with Numpy-style broadcasting support). + + This operator supports **multidirectional (i.e., Numpy-style) broadcasting**; for more details please check [the doc](Broadcasting.md). + + (Opset 14 change): Extend supported types to include uint8, int8, uint16, and int16. + """ + + def __init__(self, A, B): + super().__init__( + "Mul", + 1, + [ + { + "at::kDouble", + "at::kLong", + "at::kByte", + "at::kInt", + "at::kHalf", + "at::kShort", + "at::kFloat", + "at::kBFloat16", + }, + { + "at::kDouble", + "at::kLong", + "at::kByte", + "at::kInt", + "at::kHalf", + "at::kShort", + "at::kFloat", + "at::kBFloat16", + }, + ], + A, + B, + ) + class Multinomial(ONNXOp): - """ - Generate a tensor of samples from a multinomial distribution according to the probabilities - of each of the possible outcomes. - """ - - def __init__(self, input, - dtype=None, - sample_size=None, - seed=None): - super().__init__('Multinomial', 1, - [{'at::kDouble', 'at::kHalf', 'at::kFloat'}], - input, - dtype=ONNXAttr(dtype, AttrType.INT), - sample_size=ONNXAttr(sample_size, AttrType.INT), - seed=ONNXAttr(seed, AttrType.FLOAT)) + """ + Generate a tensor of samples from a multinomial distribution according to the probabilities + of each of the possible outcomes. + """ + + def __init__(self, input, dtype=None, sample_size=None, seed=None): + super().__init__( + "Multinomial", + 1, + [{"at::kDouble", "at::kHalf", "at::kFloat"}], + input, + dtype=ONNXAttr(dtype, AttrType.INT), + sample_size=ONNXAttr(sample_size, AttrType.INT), + seed=ONNXAttr(seed, AttrType.FLOAT), + ) + class Neg(ONNXOp): - """ - Neg takes one input data (Tensor) and produces one output data - (Tensor) where each element flipped sign, y = -x, is applied to - the tensor elementwise. - """ + """ + Neg takes one input data (Tensor) and produces one output data + (Tensor) where each element flipped sign, y = -x, is applied to + the tensor elementwise. + """ + + def __init__(self, X): + super().__init__( + "Neg", + 1, + [ + { + "at::kDouble", + "at::kLong", + "at::kByte", + "at::kInt", + "at::kHalf", + "at::kShort", + "at::kFloat", + "at::kBFloat16", + } + ], + X, + ) - def __init__(self, X): - super().__init__('Neg', 1, - [{'at::kDouble', 'at::kLong', 'at::kByte', 'at::kInt', 'at::kHalf', 'at::kShort', 'at::kFloat', 'at::kBFloat16'}], - X) class NegativeLogLikelihoodLoss(ONNXOp): - """ - A NegativeLogLikelihoodLoss operator computes (weighted) negative log likelihood loss. - Its "input" tensor has the shape of (N, C, d1, d2, ..., dk) where k >= 0. - The "input" tensor contains log-probabilities for input[n, :, d_1, d_2,..., d_k] being in a class of [0, C). - The operator's "target" input tensor has the shape of (N, d1, d2, ..., dk). It encodes class labels (one of C classes) - or it may contain a special value (indicated by an attribute ignore_index) for N x d1 x d2 x ... x dk samples. - The loss value for input[n, :, d_1, d_2,...d_k] being classified as class c = target[n][d_1][d_2]...[d_k] is computed as: - - loss[n][d_1][d_2]...[d_k] = -input[n][c][d_1][d_2]...[d_k]. - - When an optional "weight" is provided, the sample loss is calculated as: - - loss[n][d_1][d_2]...[d_k] = -input[n][c][d_1][d_2]...[d_k] * weight[c]. - - loss is zero for the case when target-value equals ignore_index. - - loss[n][d_1][d_2]...[d_k] = 0, when target[n][d_1][d_2]...[d_k] = ignore_index - - If "reduction" attribute is set to "none", the operator's output will be the above loss with shape (N, d1, d2, ..., dk). - If "reduction" attribute is set to "mean" (the default attribute value), the output loss is (weight) averaged: - - mean(loss), if "weight" is not provided, - - or if weight is provided, - - sum(loss) / sum(weight[target[n][d_1][d_2]...[d_k]]]), for all samples. - - If "reduction" attribute is set to "sum", the output is a scalar: - sum(loss). - - See also https://pytorch.org/docs/stable/nn.html#torch.nn.NLLLoss. - - Example 1: - - // negative log likelihood loss, "none" reduction - N, C, d1 = 2, 3, 2 - input = [[[1.0, 2.0], [2.0, 2.0], [3.0, 2.0]], - [[0.0, 1.0], [2.0, 2.0], [1.0, 2]]] - target = [[2, 1], [0, 2]] - - loss = np.zeros((N, d1)) - for n in range(N): - for d_1 in range(d1): - c = target[n][d_1] - loss[n][d_1] = -input[n][c][d_1] - - // print(loss) - // [[-3. -2.] - // [-0. -2.]] - - Example 2: - - // weighted negative log likelihood loss, sum reduction - N, C, d1 = 2, 3, 2 - input = [[[1.0, 2.0], [2.0, 2.0], [3.0, 2.0]], - [[0.0, 1.0], [2.0, 2.0], [1.0, 2]]] - target = [[2, 1], [0, 2]] - weight = [0.2, 0.3, 0.1] - loss = np.zeros((N, d1)) - for n in range(N): - for d_1 in range(d1): - c = target[n][d_1] - loss[n][d_1] = -input[n][c][d_1] * weight[c] - - loss = np.sum(loss) - // print(loss) - // -1.1 - - Example 3: - - // weighted negative log likelihood loss, mean reduction - N, C, d1 = 2, 3, 2 - input = [[[1.0, 2.0], [2.0, 2.0], [3.0, 2.0]], - [[0.0, 1.0], [2.0, 2.0], [1.0, 2]]] - target = [[2, 1], [0, 2]] - weight = [0.2, 0.3, 0.1] - loss = np.zeros((N, d1)) - weight_total = 0 - for n in range(N): - for d_1 in range(d1): - c = target[n][d_1] - loss[n][d_1] = -input[n][c][d_1] * weight[c] - weight_total = weight_total + weight[c] - - loss = np.sum(loss) / weight_total - // print(loss) - // -1.57 - """ - - def __init__(self, input, target, weight, - ignore_index=None, - reduction=None): - super().__init__('NegativeLogLikelihoodLoss', 1, - [{'at::kDouble', 'at::kHalf', 'at::kFloat'}, {'at::kLong', 'at::kInt'}, {'at::kDouble', 'at::kHalf', 'at::kFloat'}], - input,target,weight, - ignore_index=ONNXAttr(ignore_index, AttrType.INT), - reduction=ONNXAttr(reduction, AttrType.STRING)) + """ + A NegativeLogLikelihoodLoss operator computes (weighted) negative log likelihood loss. + Its "input" tensor has the shape of (N, C, d1, d2, ..., dk) where k >= 0. + The "input" tensor contains log-probabilities for input[n, :, d_1, d_2,..., d_k] being in a class of [0, C). + The operator's "target" input tensor has the shape of (N, d1, d2, ..., dk). It encodes class labels (one of C classes) + or it may contain a special value (indicated by an attribute ignore_index) for N x d1 x d2 x ... x dk samples. + The loss value for input[n, :, d_1, d_2,...d_k] being classified as class c = target[n][d_1][d_2]...[d_k] is computed as: + + loss[n][d_1][d_2]...[d_k] = -input[n][c][d_1][d_2]...[d_k]. + + When an optional "weight" is provided, the sample loss is calculated as: + + loss[n][d_1][d_2]...[d_k] = -input[n][c][d_1][d_2]...[d_k] * weight[c]. + + loss is zero for the case when target-value equals ignore_index. + + loss[n][d_1][d_2]...[d_k] = 0, when target[n][d_1][d_2]...[d_k] = ignore_index + + If "reduction" attribute is set to "none", the operator's output will be the above loss with shape (N, d1, d2, ..., dk). + If "reduction" attribute is set to "mean" (the default attribute value), the output loss is (weight) averaged: + + mean(loss), if "weight" is not provided, + + or if weight is provided, + + sum(loss) / sum(weight[target[n][d_1][d_2]...[d_k]]]), for all samples. + + If "reduction" attribute is set to "sum", the output is a scalar: + sum(loss). + + See also https://pytorch.org/docs/stable/nn.html#torch.nn.NLLLoss. + + Example 1: + + // negative log likelihood loss, "none" reduction + N, C, d1 = 2, 3, 2 + input = [[[1.0, 2.0], [2.0, 2.0], [3.0, 2.0]], + [[0.0, 1.0], [2.0, 2.0], [1.0, 2]]] + target = [[2, 1], [0, 2]] + + loss = np.zeros((N, d1)) + for n in range(N): + for d_1 in range(d1): + c = target[n][d_1] + loss[n][d_1] = -input[n][c][d_1] + + // print(loss) + // [[-3. -2.] + // [-0. -2.]] + + Example 2: + + // weighted negative log likelihood loss, sum reduction + N, C, d1 = 2, 3, 2 + input = [[[1.0, 2.0], [2.0, 2.0], [3.0, 2.0]], + [[0.0, 1.0], [2.0, 2.0], [1.0, 2]]] + target = [[2, 1], [0, 2]] + weight = [0.2, 0.3, 0.1] + loss = np.zeros((N, d1)) + for n in range(N): + for d_1 in range(d1): + c = target[n][d_1] + loss[n][d_1] = -input[n][c][d_1] * weight[c] + + loss = np.sum(loss) + // print(loss) + // -1.1 + + Example 3: + + // weighted negative log likelihood loss, mean reduction + N, C, d1 = 2, 3, 2 + input = [[[1.0, 2.0], [2.0, 2.0], [3.0, 2.0]], + [[0.0, 1.0], [2.0, 2.0], [1.0, 2]]] + target = [[2, 1], [0, 2]] + weight = [0.2, 0.3, 0.1] + loss = np.zeros((N, d1)) + weight_total = 0 + for n in range(N): + for d_1 in range(d1): + c = target[n][d_1] + loss[n][d_1] = -input[n][c][d_1] * weight[c] + weight_total = weight_total + weight[c] + + loss = np.sum(loss) / weight_total + // print(loss) + // -1.57 + """ + + def __init__(self, input, target, weight, ignore_index=None, reduction=None): + super().__init__( + "NegativeLogLikelihoodLoss", + 1, + [ + {"at::kDouble", "at::kHalf", "at::kFloat"}, + {"at::kLong", "at::kInt"}, + {"at::kDouble", "at::kHalf", "at::kFloat"}, + ], + input, + target, + weight, + ignore_index=ONNXAttr(ignore_index, AttrType.INT), + reduction=ONNXAttr(reduction, AttrType.STRING), + ) + class NonMaxSuppression(ONNXOp): - """ - Filter out boxes that have high intersection-over-union (IOU) overlap with previously selected boxes. - Bounding boxes with score less than score_threshold are removed. Bounding box format is indicated by attribute center_point_box. - Note that this algorithm is agnostic to where the origin is in the coordinate system and more generally is invariant to - orthogonal transformations and translations of the coordinate system; thus translating or reflections of the coordinate system - result in the same boxes being selected by the algorithm. - The selected_indices output is a set of integers indexing into the input collection of bounding boxes representing the selected boxes. - The bounding box coordinates corresponding to the selected indices can then be obtained using the Gather or GatherND operation. - """ - - def __init__(self, boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold, - center_point_box=None): - super().__init__('NonMaxSuppression', 1, - [{'at::kFloat'}, {'at::kFloat'}, {'at::kLong'}, {'at::kFloat'}, {'at::kFloat'}], - boxes,scores,max_output_boxes_per_class,iou_threshold,score_threshold, - center_point_box=ONNXAttr(center_point_box, AttrType.INT)) + """ + Filter out boxes that have high intersection-over-union (IOU) overlap with previously selected boxes. + Bounding boxes with score less than score_threshold are removed. Bounding box format is indicated by attribute center_point_box. + Note that this algorithm is agnostic to where the origin is in the coordinate system and more generally is invariant to + orthogonal transformations and translations of the coordinate system; thus translating or reflections of the coordinate system + result in the same boxes being selected by the algorithm. + The selected_indices output is a set of integers indexing into the input collection of bounding boxes representing the selected boxes. + The bounding box coordinates corresponding to the selected indices can then be obtained using the Gather or GatherND operation. + """ + + def __init__( + self, boxes, scores, max_output_boxes_per_class, iou_threshold, score_threshold, center_point_box=None + ): + super().__init__( + "NonMaxSuppression", + 1, + [{"at::kFloat"}, {"at::kFloat"}, {"at::kLong"}, {"at::kFloat"}, {"at::kFloat"}], + boxes, + scores, + max_output_boxes_per_class, + iou_threshold, + score_threshold, + center_point_box=ONNXAttr(center_point_box, AttrType.INT), + ) + class NonZero(ONNXOp): - """ - Returns the indices of the elements that are non-zero - (in row-major order - by dimension). - NonZero behaves similar to numpy.nonzero: - https://docs.scipy.org/doc/numpy/reference/generated/numpy.nonzero.html - """ - - def __init__(self, X): - super().__init__('NonZero', 1, - [{'at::kDouble', 'at::kLong', 'at::kByte', 'at::kInt', 'at::kHalf', 'at::kShort', 'at::kBool', 'at::kFloat', 'at::kBFloat16'}], - X) + """ + Returns the indices of the elements that are non-zero + (in row-major order - by dimension). + NonZero behaves similar to numpy.nonzero: + https://docs.scipy.org/doc/numpy/reference/generated/numpy.nonzero.html + """ + + def __init__(self, X): + super().__init__( + "NonZero", + 1, + [ + { + "at::kDouble", + "at::kLong", + "at::kByte", + "at::kInt", + "at::kHalf", + "at::kShort", + "at::kBool", + "at::kFloat", + "at::kBFloat16", + } + ], + X, + ) + class Normalizer(ONNXOp): - """ - Normalize the input. There are three normalization modes, which have the corresponding formulas, - defined using element-wise infix operators '/' and '^' and tensor-wide functions 'max' and 'sum':
-
- Max: Y = X / max(X)
- L1: Y = X / sum(X)
- L2: Y = sqrt(X^2 / sum(X^2)}
- In all modes, if the divisor is zero, Y == X. -
- For batches, that is, [N,C] tensors, normalization is done along the C axis. In other words, each row - of the batch is normalized independently. - """ - - def __init__(self, X, - norm=None): - super().__init__('Normalizer', 1, - [{'at::kDouble', 'at::kLong', 'at::kInt', 'at::kFloat'}], - X, - norm=ONNXAttr(norm, AttrType.STRING)) + """ + Normalize the input. There are three normalization modes, which have the corresponding formulas, + defined using element-wise infix operators '/' and '^' and tensor-wide functions 'max' and 'sum':
+
+ Max: Y = X / max(X)
+ L1: Y = X / sum(X)
+ L2: Y = sqrt(X^2 / sum(X^2)}
+ In all modes, if the divisor is zero, Y == X. +
+ For batches, that is, [N,C] tensors, normalization is done along the C axis. In other words, each row + of the batch is normalized independently. + """ + + def __init__(self, X, norm=None): + super().__init__( + "Normalizer", + 1, + [{"at::kDouble", "at::kLong", "at::kInt", "at::kFloat"}], + X, + norm=ONNXAttr(norm, AttrType.STRING), + ) + class Not(ONNXOp): - """ - Returns the negation of the input tensor element-wise. - """ + """ + Returns the negation of the input tensor element-wise. + """ + + def __init__(self, X): + super().__init__("Not", 1, [{"at::kBool"}], X) - def __init__(self, X): - super().__init__('Not', 1, - [{'at::kBool'}], - X) class OneHot(ONNXOp): - """ - Produces a one-hot tensor based on inputs. - The locations represented by the index values in the 'indices' input tensor will have 'on_value' - and the other locations will have 'off_value' in the output tensor, where 'on_value' and 'off_value' - are specified as part of required input argument 'values', which is a two-element tensor of format - [off_value, on_value]. The rank of the output tensor will be one greater than the rank of the - input tensor. The additional dimension is for one-hot representation. The additional dimension will - be inserted at the position specified by 'axis'. If 'axis' is not specified then then additional - dimension will be inserted as the innermost dimension, i.e. axis=-1. The size of the additional - dimension is specified by required scalar input 'depth'. The type of the output tensor is the same - as the type of the 'values' input. Any entries in the 'indices' input tensor with values outside - the range [-depth, depth-1] will result in one-hot representation with all 'off_value' values in the - output tensor. - - when axis = 0: - output[input[i, j, k], i, j, k] = 1 for all i, j, k and 0 otherwise. - - when axis = -1: - output[i, j, k, input[i, j, k]] = 1 for all i, j, k and 0 otherwise. - """ - - def __init__(self, indices, depth, values, - axis=None): - super().__init__('OneHot', 1, - [{'at::kDouble', 'at::kLong', 'at::kByte', 'at::kInt', 'at::kHalf', 'at::kShort', 'at::kFloat'}, {'at::kDouble', 'at::kLong', 'at::kByte', 'at::kInt', 'at::kHalf', 'at::kShort', 'at::kFloat'}, {'at::kDouble', 'at::kLong', 'at::kByte', 'at::kInt', 'at::kHalf', 'at::kShort', 'at::kBool', 'at::kFloat'}], - indices,depth,values, - axis=ONNXAttr(axis, AttrType.INT)) + """ + Produces a one-hot tensor based on inputs. + The locations represented by the index values in the 'indices' input tensor will have 'on_value' + and the other locations will have 'off_value' in the output tensor, where 'on_value' and 'off_value' + are specified as part of required input argument 'values', which is a two-element tensor of format + [off_value, on_value]. The rank of the output tensor will be one greater than the rank of the + input tensor. The additional dimension is for one-hot representation. The additional dimension will + be inserted at the position specified by 'axis'. If 'axis' is not specified then then additional + dimension will be inserted as the innermost dimension, i.e. axis=-1. The size of the additional + dimension is specified by required scalar input 'depth'. The type of the output tensor is the same + as the type of the 'values' input. Any entries in the 'indices' input tensor with values outside + the range [-depth, depth-1] will result in one-hot representation with all 'off_value' values in the + output tensor. + + when axis = 0: + output[input[i, j, k], i, j, k] = 1 for all i, j, k and 0 otherwise. + + when axis = -1: + output[i, j, k, input[i, j, k]] = 1 for all i, j, k and 0 otherwise. + """ + + def __init__(self, indices, depth, values, axis=None): + super().__init__( + "OneHot", + 1, + [ + {"at::kDouble", "at::kLong", "at::kByte", "at::kInt", "at::kHalf", "at::kShort", "at::kFloat"}, + {"at::kDouble", "at::kLong", "at::kByte", "at::kInt", "at::kHalf", "at::kShort", "at::kFloat"}, + { + "at::kDouble", + "at::kLong", + "at::kByte", + "at::kInt", + "at::kHalf", + "at::kShort", + "at::kBool", + "at::kFloat", + }, + ], + indices, + depth, + values, + axis=ONNXAttr(axis, AttrType.INT), + ) + class OneHotEncoder(ONNXOp): - """ - Replace each input element with an array of ones and zeros, where a single - one is placed at the index of the category that was passed in. The total category count - will determine the size of the extra dimension of the output array Y.
- For example, if we pass a tensor with a single value of 4, and a category count of 8, - the output will be a tensor with ``[0,0,0,0,1,0,0,0]``.
- This operator assumes every input feature is from the same set of categories.
- If the input is a tensor of float, int32, or double, the data will be cast - to integers and the cats_int64s category list will be used for the lookups. - """ - - def __init__(self, X, - cats_int64s=None, - cats_strings=None, - zeros=None): - super().__init__('OneHotEncoder', 1, - [{'at::kDouble', 'at::kLong', 'at::kInt', 'at::kFloat'}], - X, - cats_int64s=ONNXAttr(cats_int64s, AttrType.INTS), - cats_strings=ONNXAttr(cats_strings, AttrType.STRINGS), - zeros=ONNXAttr(zeros, AttrType.INT)) + """ + Replace each input element with an array of ones and zeros, where a single + one is placed at the index of the category that was passed in. The total category count + will determine the size of the extra dimension of the output array Y.
+ For example, if we pass a tensor with a single value of 4, and a category count of 8, + the output will be a tensor with ``[0,0,0,0,1,0,0,0]``.
+ This operator assumes every input feature is from the same set of categories.
+ If the input is a tensor of float, int32, or double, the data will be cast + to integers and the cats_int64s category list will be used for the lookups. + """ + + def __init__(self, X, cats_int64s=None, cats_strings=None, zeros=None): + super().__init__( + "OneHotEncoder", + 1, + [{"at::kDouble", "at::kLong", "at::kInt", "at::kFloat"}], + X, + cats_int64s=ONNXAttr(cats_int64s, AttrType.INTS), + cats_strings=ONNXAttr(cats_strings, AttrType.STRINGS), + zeros=ONNXAttr(zeros, AttrType.INT), + ) + class Or(ONNXOp): - """ - Returns the tensor resulted from performing the `or` logical operation - elementwise on the input tensors `A` and `B` (with Numpy-style broadcasting support). - - This operator supports **multidirectional (i.e., Numpy-style) broadcasting**; for more details please check [the doc](Broadcasting.md). - """ - - def __init__(self, A, B): - super().__init__('Or', 1, - [{'at::kBool'}, {'at::kBool'}], - A,B) + """ + Returns the tensor resulted from performing the `or` logical operation + elementwise on the input tensors `A` and `B` (with Numpy-style broadcasting support). + + This operator supports **multidirectional (i.e., Numpy-style) broadcasting**; for more details please check [the doc](Broadcasting.md). + """ + + def __init__(self, A, B): + super().__init__("Or", 1, [{"at::kBool"}, {"at::kBool"}], A, B) + class Pad(ONNXOp): - """ - Given a tensor containing the data to be padded (`data`), a tensor containing the number of start and end pad values for axis (`pads`), (optionally) a `mode`, and (optionally) `constant_value`, - a padded tensor (`output`) is generated. - - The three supported `modes` are (similar to corresponding modes supported by `numpy.pad`): - - 1) `constant`(default) - pads with a given constant value as specified by `constant_value` (which defaults to 0, empty string, or False) - - 2) `reflect` - pads with the reflection of the vector mirrored on the first and last values of the vector along each axis - - 3) `edge` - pads with the edge values of array - - - Example 1 (`constant` mode): - Insert 0 pads to the beginning of the second dimension. - - data = - [ - [1.0, 1.2], - [2.3, 3.4], - [4.5, 5.7], - ] - - pads = [0, 2, 0, 0] - - mode = 'constant' - - constant_value = 0.0 - - output = - [ - [0.0, 0.0, 1.0, 1.2], - [0.0, 0.0, 2.3, 3.4], - [0.0, 0.0, 4.5, 5.7], - ] - - - Example 2 (`reflect` mode): - data = - [ - [1.0, 1.2], - [2.3, 3.4], - [4.5, 5.7], - ] - - pads = [0, 2, 0, 0] - - mode = 'reflect' - - output = - [ - [1.0, 1.2, 1.0, 1.2], - [2.3, 3.4, 2.3, 3.4], - [4.5, 5.7, 4.5, 5.7], - ] - - - Example 3 (`edge` mode): - data = - [ - [1.0, 1.2], - [2.3, 3.4], - [4.5, 5.7], - ] - - pads = [0, 2, 0, 0] - - mode = 'edge' - - output = - [ - [1.0, 1.0, 1.0, 1.2], - [2.3, 2.3, 2.3, 3.4], - [4.5, 4.5, 4.5, 5.7], - ] - """ - - def __init__(self, data, pads, constant_value, - mode=None): - super().__init__('Pad', 1, - [{'at::kDouble', 'at::kLong', 'at::kByte', 'at::kInt', 'at::kHalf', 'at::kShort', 'at::kBool', 'at::kFloat', 'at::kBFloat16'}, {'at::kLong'}, {'at::kDouble', 'at::kLong', 'at::kByte', 'at::kInt', 'at::kHalf', 'at::kShort', 'at::kBool', 'at::kFloat', 'at::kBFloat16'}], - data,pads,constant_value, - mode=ONNXAttr(mode, AttrType.STRING)) + """ + Given a tensor containing the data to be padded (`data`), a tensor containing the number of start and end pad values for axis (`pads`), (optionally) a `mode`, and (optionally) `constant_value`, + a padded tensor (`output`) is generated. + + The three supported `modes` are (similar to corresponding modes supported by `numpy.pad`): + + 1) `constant`(default) - pads with a given constant value as specified by `constant_value` (which defaults to 0, empty string, or False) + + 2) `reflect` - pads with the reflection of the vector mirrored on the first and last values of the vector along each axis + + 3) `edge` - pads with the edge values of array + + + Example 1 (`constant` mode): + Insert 0 pads to the beginning of the second dimension. + + data = + [ + [1.0, 1.2], + [2.3, 3.4], + [4.5, 5.7], + ] + + pads = [0, 2, 0, 0] + + mode = 'constant' + + constant_value = 0.0 + + output = + [ + [0.0, 0.0, 1.0, 1.2], + [0.0, 0.0, 2.3, 3.4], + [0.0, 0.0, 4.5, 5.7], + ] + + + Example 2 (`reflect` mode): + data = + [ + [1.0, 1.2], + [2.3, 3.4], + [4.5, 5.7], + ] + + pads = [0, 2, 0, 0] + + mode = 'reflect' + + output = + [ + [1.0, 1.2, 1.0, 1.2], + [2.3, 3.4, 2.3, 3.4], + [4.5, 5.7, 4.5, 5.7], + ] + + + Example 3 (`edge` mode): + data = + [ + [1.0, 1.2], + [2.3, 3.4], + [4.5, 5.7], + ] + + pads = [0, 2, 0, 0] + + mode = 'edge' + + output = + [ + [1.0, 1.0, 1.0, 1.2], + [2.3, 2.3, 2.3, 3.4], + [4.5, 4.5, 4.5, 5.7], + ] + """ + + def __init__(self, data, pads, constant_value, mode=None): + super().__init__( + "Pad", + 1, + [ + { + "at::kDouble", + "at::kLong", + "at::kByte", + "at::kInt", + "at::kHalf", + "at::kShort", + "at::kBool", + "at::kFloat", + "at::kBFloat16", + }, + {"at::kLong"}, + { + "at::kDouble", + "at::kLong", + "at::kByte", + "at::kInt", + "at::kHalf", + "at::kShort", + "at::kBool", + "at::kFloat", + "at::kBFloat16", + }, + ], + data, + pads, + constant_value, + mode=ONNXAttr(mode, AttrType.STRING), + ) + class Pow(ONNXOp): - """ - Pow takes input data (Tensor) and exponent Tensor, and - produces one output data (Tensor) where the function `f(x) = x^exponent`, - is applied to the data tensor elementwise. - This operator supports **multidirectional (i.e., Numpy-style) broadcasting**; for more details please check [the doc](Broadcasting.md). - """ - - def __init__(self, X, Y): - super().__init__('Pow', 1, - [{'at::kDouble', 'at::kLong', 'at::kInt', 'at::kHalf', 'at::kFloat', 'at::kBFloat16'}, {'at::kDouble', 'at::kLong', 'at::kByte', 'at::kInt', 'at::kHalf', 'at::kShort', 'at::kFloat'}], - X,Y) + """ + Pow takes input data (Tensor) and exponent Tensor, and + produces one output data (Tensor) where the function `f(x) = x^exponent`, + is applied to the data tensor elementwise. + This operator supports **multidirectional (i.e., Numpy-style) broadcasting**; for more details please check [the doc](Broadcasting.md). + """ + + def __init__(self, X, Y): + super().__init__( + "Pow", + 1, + [ + {"at::kDouble", "at::kLong", "at::kInt", "at::kHalf", "at::kFloat", "at::kBFloat16"}, + {"at::kDouble", "at::kLong", "at::kByte", "at::kInt", "at::kHalf", "at::kShort", "at::kFloat"}, + ], + X, + Y, + ) + class PRelu(ONNXOp): - """ - PRelu takes input data (Tensor) and slope tensor as input, and produces one - output data (Tensor) where the function `f(x) = slope * x for x < 0`, - `f(x) = x for x >= 0`., is applied to the data tensor elementwise. - This operator supports **unidirectional broadcasting** (tensor slope should be unidirectional broadcastable to input tensor X); for more details please check [the doc](Broadcasting.md). - """ - - def __init__(self, X, slope): - super().__init__('PRelu', 1, - [{'at::kDouble', 'at::kLong', 'at::kInt', 'at::kHalf', 'at::kFloat'}, {'at::kDouble', 'at::kLong', 'at::kInt', 'at::kHalf', 'at::kFloat'}], - X,slope) + """ + PRelu takes input data (Tensor) and slope tensor as input, and produces one + output data (Tensor) where the function `f(x) = slope * x for x < 0`, + `f(x) = x for x >= 0`., is applied to the data tensor elementwise. + This operator supports **unidirectional broadcasting** (tensor slope should be unidirectional broadcastable to input tensor X); for more details please check [the doc](Broadcasting.md). + """ + + def __init__(self, X, slope): + super().__init__( + "PRelu", + 1, + [ + {"at::kDouble", "at::kLong", "at::kInt", "at::kHalf", "at::kFloat"}, + {"at::kDouble", "at::kLong", "at::kInt", "at::kHalf", "at::kFloat"}, + ], + X, + slope, + ) + class QLinearConv(ONNXOp): - """ - The convolution operator consumes a quantized input tensor, its scale and zero point, - a quantized filter, its scale and zero point, and output's scale and zero point, - and computes the quantized output. Each scale and zero-point pair must have same shape. - It means they must be either scalars (per tensor) or 1-D tensors (per output channel). - Each input or output and its related zero point must have same type. - When bias is present it must be quantized using scale = input scale * weight scale and - zero point as 0. - """ - - def __init__(self, x, x_scale, x_zero_point, w, w_scale, w_zero_point, y_scale, y_zero_point, B, - auto_pad=None, - dilations=None, - group=None, - kernel_shape=None, - pads=None, - strides=None): - super().__init__('QLinearConv', 1, - [{'at::kByte'}, {'at::kFloat'}, {'at::kByte'}, {'at::kByte'}, {'at::kFloat'}, {'at::kByte'}, {'at::kFloat'}, {'at::kByte'}, {'at::kInt'}], - x,x_scale,x_zero_point,w,w_scale,w_zero_point,y_scale,y_zero_point,B, - auto_pad=ONNXAttr(auto_pad, AttrType.STRING), - dilations=ONNXAttr(dilations, AttrType.INTS), - group=ONNXAttr(group, AttrType.INT), - kernel_shape=ONNXAttr(kernel_shape, AttrType.INTS), - pads=ONNXAttr(pads, AttrType.INTS), - strides=ONNXAttr(strides, AttrType.INTS)) + """ + The convolution operator consumes a quantized input tensor, its scale and zero point, + a quantized filter, its scale and zero point, and output's scale and zero point, + and computes the quantized output. Each scale and zero-point pair must have same shape. + It means they must be either scalars (per tensor) or 1-D tensors (per output channel). + Each input or output and its related zero point must have same type. + When bias is present it must be quantized using scale = input scale * weight scale and + zero point as 0. + """ + + def __init__( + self, + x, + x_scale, + x_zero_point, + w, + w_scale, + w_zero_point, + y_scale, + y_zero_point, + B, + auto_pad=None, + dilations=None, + group=None, + kernel_shape=None, + pads=None, + strides=None, + ): + super().__init__( + "QLinearConv", + 1, + [ + {"at::kByte"}, + {"at::kFloat"}, + {"at::kByte"}, + {"at::kByte"}, + {"at::kFloat"}, + {"at::kByte"}, + {"at::kFloat"}, + {"at::kByte"}, + {"at::kInt"}, + ], + x, + x_scale, + x_zero_point, + w, + w_scale, + w_zero_point, + y_scale, + y_zero_point, + B, + auto_pad=ONNXAttr(auto_pad, AttrType.STRING), + dilations=ONNXAttr(dilations, AttrType.INTS), + group=ONNXAttr(group, AttrType.INT), + kernel_shape=ONNXAttr(kernel_shape, AttrType.INTS), + pads=ONNXAttr(pads, AttrType.INTS), + strides=ONNXAttr(strides, AttrType.INTS), + ) + class QLinearMatMul(ONNXOp): - """ - Matrix product that behaves like numpy.matmul: https://docs.scipy.org/doc/numpy-1.13.0/reference/generated/numpy.matmul.html. - It consumes two quantized input tensors, their scales and zero points, scale and zero point of output, and computes the quantized output. - The quantization formula is y = saturate((x / y_scale) + y_zero_point). For (x / y_scale), it is rounding to nearest ties to even. - Refer to https://en.wikipedia.org/wiki/Rounding for details. Scale and zero point must have same shape. - They must be either scalar (per tensor) or 1-D tensor (per row for 'a' and per column for 'b'). If scale and zero point are 1-D tensor, - the number of elements of scale and zero point tensor of input 'a' and output 'y' should be equal to the number of rows of input 'a', - and the number of elements of scale and zero point tensor of input 'b' should be equal to the number of columns of input 'b'. - Production must never overflow, and accumulation may overflow if and only if in 32 bits. - """ - - def __init__(self, a, a_scale, a_zero_point, b, b_scale, b_zero_point, y_scale, y_zero_point): - super().__init__('QLinearMatMul', 1, - [{'at::kByte'}, {'at::kFloat'}, {'at::kByte'}, {'at::kByte'}, {'at::kFloat'}, {'at::kByte'}, {'at::kFloat'}, {'at::kByte'}], - a,a_scale,a_zero_point,b,b_scale,b_zero_point,y_scale,y_zero_point) + """ + Matrix product that behaves like numpy.matmul: https://docs.scipy.org/doc/numpy-1.13.0/reference/generated/numpy.matmul.html. + It consumes two quantized input tensors, their scales and zero points, scale and zero point of output, and computes the quantized output. + The quantization formula is y = saturate((x / y_scale) + y_zero_point). For (x / y_scale), it is rounding to nearest ties to even. + Refer to https://en.wikipedia.org/wiki/Rounding for details. Scale and zero point must have same shape. + They must be either scalar (per tensor) or 1-D tensor (per row for 'a' and per column for 'b'). If scale and zero point are 1-D tensor, + the number of elements of scale and zero point tensor of input 'a' and output 'y' should be equal to the number of rows of input 'a', + and the number of elements of scale and zero point tensor of input 'b' should be equal to the number of columns of input 'b'. + Production must never overflow, and accumulation may overflow if and only if in 32 bits. + """ + + def __init__(self, a, a_scale, a_zero_point, b, b_scale, b_zero_point, y_scale, y_zero_point): + super().__init__( + "QLinearMatMul", + 1, + [ + {"at::kByte"}, + {"at::kFloat"}, + {"at::kByte"}, + {"at::kByte"}, + {"at::kFloat"}, + {"at::kByte"}, + {"at::kFloat"}, + {"at::kByte"}, + ], + a, + a_scale, + a_zero_point, + b, + b_scale, + b_zero_point, + y_scale, + y_zero_point, + ) + class QuantizeLinear(ONNXOp): - """ - The linear quantization operator. It consumes a high precision tensor, a scale, and a zero point to compute the low precision / quantized tensor. The scale factor can be a scalar - (per-tensor/layer quantization), or a 1-D tensor for per-axis quantization. The quantization formula is y = saturate ((x / y_scale) + y_zero_point). - For saturation, it saturates to [0, 255] if it's uint8, or [-128, 127] if it's int8. - For (x / y_scale), it's rounding to nearest ties to even. Refer to https://en.wikipedia.org/wiki/Rounding for details. 'y_zero_point' and 'y' must have same type. - """ - - def __init__(self, x, y_scale, y_zero_point, - axis=None): - super().__init__('QuantizeLinear', 1, - [{'at::kInt', 'at::kFloat'}, {'at::kFloat'}, {'at::kByte'}], - x,y_scale,y_zero_point, - axis=ONNXAttr(axis, AttrType.INT)) + """ + The linear quantization operator. It consumes a high precision tensor, a scale, and a zero point to compute the low precision / quantized tensor. The scale factor can be a scalar + (per-tensor/layer quantization), or a 1-D tensor for per-axis quantization. The quantization formula is y = saturate ((x / y_scale) + y_zero_point). + For saturation, it saturates to [0, 255] if it's uint8, or [-128, 127] if it's int8. + For (x / y_scale), it's rounding to nearest ties to even. Refer to https://en.wikipedia.org/wiki/Rounding for details. 'y_zero_point' and 'y' must have same type. + """ + + def __init__(self, x, y_scale, y_zero_point, axis=None): + super().__init__( + "QuantizeLinear", + 1, + [{"at::kInt", "at::kFloat"}, {"at::kFloat"}, {"at::kByte"}], + x, + y_scale, + y_zero_point, + axis=ONNXAttr(axis, AttrType.INT), + ) + class RandomNormal(ONNXOp): - """ - Generate a tensor with random values drawn from a normal distribution. The shape - of the tensor is specified by the `shape` argument and the parameter of the normal distribution - specified by `mean` and `scale`. - - The data type is specified by the 'dtype' argument. The 'dtype' argument must - be one of the data types specified in the 'DataType' enum field in the - TensorProto message. - """ - - def __init__(self, - dtype=None, - mean=None, - scale=None, - seed=None, - shape=None): - super().__init__('RandomNormal', 1, - [], - dtype=ONNXAttr(dtype, AttrType.INT), - mean=ONNXAttr(mean, AttrType.FLOAT), - scale=ONNXAttr(scale, AttrType.FLOAT), - seed=ONNXAttr(seed, AttrType.FLOAT), - shape=ONNXAttr(shape, AttrType.INTS)) + """ + Generate a tensor with random values drawn from a normal distribution. The shape + of the tensor is specified by the `shape` argument and the parameter of the normal distribution + specified by `mean` and `scale`. + + The data type is specified by the 'dtype' argument. The 'dtype' argument must + be one of the data types specified in the 'DataType' enum field in the + TensorProto message. + """ + + def __init__(self, dtype=None, mean=None, scale=None, seed=None, shape=None): + super().__init__( + "RandomNormal", + 1, + [], + dtype=ONNXAttr(dtype, AttrType.INT), + mean=ONNXAttr(mean, AttrType.FLOAT), + scale=ONNXAttr(scale, AttrType.FLOAT), + seed=ONNXAttr(seed, AttrType.FLOAT), + shape=ONNXAttr(shape, AttrType.INTS), + ) + class RandomNormalLike(ONNXOp): - """ - Generate a tensor with random values drawn from a normal distribution. - The shape of the output tensor is copied from the shape of the input tensor, - and the parameters of the normal distribution are specified by `mean` and `scale`. - - The data type is specified by the 'dtype' argument, or copied from the input tensor if not provided. - The 'dtype' argument must be one of the data types specified in the 'DataType' enum field in the - TensorProto message, and be valid as an output type. - """ - - def __init__(self, input, - dtype=None, - mean=None, - scale=None, - seed=None): - super().__init__('RandomNormalLike', 1, - [{'at::kDouble', 'at::kLong', 'at::kByte', 'at::kInt', 'at::kHalf', 'at::kShort', 'at::kBool', 'at::kFloat'}], - input, - dtype=ONNXAttr(dtype, AttrType.INT), - mean=ONNXAttr(mean, AttrType.FLOAT), - scale=ONNXAttr(scale, AttrType.FLOAT), - seed=ONNXAttr(seed, AttrType.FLOAT)) + """ + Generate a tensor with random values drawn from a normal distribution. + The shape of the output tensor is copied from the shape of the input tensor, + and the parameters of the normal distribution are specified by `mean` and `scale`. + + The data type is specified by the 'dtype' argument, or copied from the input tensor if not provided. + The 'dtype' argument must be one of the data types specified in the 'DataType' enum field in the + TensorProto message, and be valid as an output type. + """ + + def __init__(self, input, dtype=None, mean=None, scale=None, seed=None): + super().__init__( + "RandomNormalLike", + 1, + [ + { + "at::kDouble", + "at::kLong", + "at::kByte", + "at::kInt", + "at::kHalf", + "at::kShort", + "at::kBool", + "at::kFloat", + } + ], + input, + dtype=ONNXAttr(dtype, AttrType.INT), + mean=ONNXAttr(mean, AttrType.FLOAT), + scale=ONNXAttr(scale, AttrType.FLOAT), + seed=ONNXAttr(seed, AttrType.FLOAT), + ) + class RandomUniform(ONNXOp): - """ - Generate a tensor with random values drawn from a uniform distribution. The shape - of the tensor is specified by the `shape` argument and the range by `low` and `high`. - - The data type is specified by the 'dtype' argument. The 'dtype' argument must - be one of the data types specified in the 'DataType' enum field in the - TensorProto message. - """ - - def __init__(self, - dtype=None, - high=None, - low=None, - seed=None, - shape=None): - super().__init__('RandomUniform', 1, - [], - dtype=ONNXAttr(dtype, AttrType.INT), - high=ONNXAttr(high, AttrType.FLOAT), - low=ONNXAttr(low, AttrType.FLOAT), - seed=ONNXAttr(seed, AttrType.FLOAT), - shape=ONNXAttr(shape, AttrType.INTS)) + """ + Generate a tensor with random values drawn from a uniform distribution. The shape + of the tensor is specified by the `shape` argument and the range by `low` and `high`. + + The data type is specified by the 'dtype' argument. The 'dtype' argument must + be one of the data types specified in the 'DataType' enum field in the + TensorProto message. + """ + + def __init__(self, dtype=None, high=None, low=None, seed=None, shape=None): + super().__init__( + "RandomUniform", + 1, + [], + dtype=ONNXAttr(dtype, AttrType.INT), + high=ONNXAttr(high, AttrType.FLOAT), + low=ONNXAttr(low, AttrType.FLOAT), + seed=ONNXAttr(seed, AttrType.FLOAT), + shape=ONNXAttr(shape, AttrType.INTS), + ) + class RandomUniformLike(ONNXOp): - """ - Generate a tensor with random values drawn from a uniform distribution. - The shape of the output tensor is copied from the shape of the input tensor, - and the parameters of the uniform distribution are specified by `low` and `high`. - - The data type is specified by the 'dtype' argument, or copied from the input tensor if not provided. - The 'dtype' argument must be one of the data types specified in the 'DataType' enum field in the - TensorProto message and be valid as an output type. - """ - - def __init__(self, input, - dtype=None, - high=None, - low=None, - seed=None): - super().__init__('RandomUniformLike', 1, - [{'at::kDouble', 'at::kLong', 'at::kByte', 'at::kInt', 'at::kHalf', 'at::kShort', 'at::kBool', 'at::kFloat'}], - input, - dtype=ONNXAttr(dtype, AttrType.INT), - high=ONNXAttr(high, AttrType.FLOAT), - low=ONNXAttr(low, AttrType.FLOAT), - seed=ONNXAttr(seed, AttrType.FLOAT)) + """ + Generate a tensor with random values drawn from a uniform distribution. + The shape of the output tensor is copied from the shape of the input tensor, + and the parameters of the uniform distribution are specified by `low` and `high`. + + The data type is specified by the 'dtype' argument, or copied from the input tensor if not provided. + The 'dtype' argument must be one of the data types specified in the 'DataType' enum field in the + TensorProto message and be valid as an output type. + """ + + def __init__(self, input, dtype=None, high=None, low=None, seed=None): + super().__init__( + "RandomUniformLike", + 1, + [ + { + "at::kDouble", + "at::kLong", + "at::kByte", + "at::kInt", + "at::kHalf", + "at::kShort", + "at::kBool", + "at::kFloat", + } + ], + input, + dtype=ONNXAttr(dtype, AttrType.INT), + high=ONNXAttr(high, AttrType.FLOAT), + low=ONNXAttr(low, AttrType.FLOAT), + seed=ONNXAttr(seed, AttrType.FLOAT), + ) + class Range(ONNXOp): - """ - Generate a tensor containing a sequence of numbers that begin at `start` and extends by increments of `delta` - up to `limit` (exclusive). - - The number of elements in the output of range is computed as below- - - `number_of_elements = max( ceil( (limit - start) / delta ) , 0 )` - - The pseudocode determining the contents of the output is shown below- - - `for(int i=0; i) and produces one output data - (Tensor) where the reciprocal is, y = 1/x, is applied to - the tensor elementwise. - """ + """ + Reciprocal takes one input data (Tensor) and produces one output data + (Tensor) where the reciprocal is, y = 1/x, is applied to + the tensor elementwise. + """ + + def __init__(self, X): + super().__init__("Reciprocal", 1, [{"at::kDouble", "at::kBFloat16", "at::kHalf", "at::kFloat"}], X) - def __init__(self, X): - super().__init__('Reciprocal', 1, - [{'at::kDouble', 'at::kBFloat16', 'at::kHalf', 'at::kFloat'}], - X) class ReduceL1(ONNXOp): - """ - Computes the L1 norm of the input tensor's element along the provided axes. The resulted - tensor has the same rank as the input if keepdims equal 1. If keepdims equal 0, then - the resulted tensor have the reduced dimension pruned. - - The above behavior is similar to numpy, with the exception that numpy default keepdims to - False instead of True. - """ - - def __init__(self, data, - axes=None, - keepdims=None): - super().__init__('ReduceL1', 1, - [{'at::kDouble', 'at::kLong', 'at::kInt', 'at::kHalf', 'at::kFloat', 'at::kBFloat16'}], - data, - axes=ONNXAttr(axes, AttrType.INTS), - keepdims=ONNXAttr(keepdims, AttrType.INT)) + """ + Computes the L1 norm of the input tensor's element along the provided axes. The resulted + tensor has the same rank as the input if keepdims equal 1. If keepdims equal 0, then + the resulted tensor have the reduced dimension pruned. + + The above behavior is similar to numpy, with the exception that numpy default keepdims to + False instead of True. + """ + + def __init__(self, data, axes=None, keepdims=None): + super().__init__( + "ReduceL1", + 1, + [{"at::kDouble", "at::kLong", "at::kInt", "at::kHalf", "at::kFloat", "at::kBFloat16"}], + data, + axes=ONNXAttr(axes, AttrType.INTS), + keepdims=ONNXAttr(keepdims, AttrType.INT), + ) + class ReduceL2(ONNXOp): - """ - Computes the L2 norm of the input tensor's element along the provided axes. The resulted - tensor has the same rank as the input if keepdims equal 1. If keepdims equal 0, then - the resulted tensor have the reduced dimension pruned. - - The above behavior is similar to numpy, with the exception that numpy default keepdims to - False instead of True. - """ - - def __init__(self, data, - axes=None, - keepdims=None): - super().__init__('ReduceL2', 1, - [{'at::kDouble', 'at::kLong', 'at::kInt', 'at::kHalf', 'at::kFloat', 'at::kBFloat16'}], - data, - axes=ONNXAttr(axes, AttrType.INTS), - keepdims=ONNXAttr(keepdims, AttrType.INT)) + """ + Computes the L2 norm of the input tensor's element along the provided axes. The resulted + tensor has the same rank as the input if keepdims equal 1. If keepdims equal 0, then + the resulted tensor have the reduced dimension pruned. + + The above behavior is similar to numpy, with the exception that numpy default keepdims to + False instead of True. + """ + + def __init__(self, data, axes=None, keepdims=None): + super().__init__( + "ReduceL2", + 1, + [{"at::kDouble", "at::kLong", "at::kInt", "at::kHalf", "at::kFloat", "at::kBFloat16"}], + data, + axes=ONNXAttr(axes, AttrType.INTS), + keepdims=ONNXAttr(keepdims, AttrType.INT), + ) + class ReduceLogSum(ONNXOp): - """ - Computes the log sum of the input tensor's element along the provided axes. The resulted - tensor has the same rank as the input if keepdims equal 1. If keepdims equal 0, then - the resulted tensor have the reduced dimension pruned. - - The above behavior is similar to numpy, with the exception that numpy default keepdims to - False instead of True. - """ - - def __init__(self, data, - axes=None, - keepdims=None): - super().__init__('ReduceLogSum', 1, - [{'at::kDouble', 'at::kLong', 'at::kInt', 'at::kHalf', 'at::kFloat', 'at::kBFloat16'}], - data, - axes=ONNXAttr(axes, AttrType.INTS), - keepdims=ONNXAttr(keepdims, AttrType.INT)) + """ + Computes the log sum of the input tensor's element along the provided axes. The resulted + tensor has the same rank as the input if keepdims equal 1. If keepdims equal 0, then + the resulted tensor have the reduced dimension pruned. + + The above behavior is similar to numpy, with the exception that numpy default keepdims to + False instead of True. + """ + + def __init__(self, data, axes=None, keepdims=None): + super().__init__( + "ReduceLogSum", + 1, + [{"at::kDouble", "at::kLong", "at::kInt", "at::kHalf", "at::kFloat", "at::kBFloat16"}], + data, + axes=ONNXAttr(axes, AttrType.INTS), + keepdims=ONNXAttr(keepdims, AttrType.INT), + ) + class ReduceLogSumExp(ONNXOp): - """ - Computes the log sum exponent of the input tensor's element along the provided axes. The resulted - tensor has the same rank as the input if keepdims equal 1. If keepdims equal 0, then - the resulted tensor have the reduced dimension pruned. - - The above behavior is similar to numpy, with the exception that numpy default keepdims to - False instead of True. - """ - - def __init__(self, data, - axes=None, - keepdims=None): - super().__init__('ReduceLogSumExp', 1, - [{'at::kDouble', 'at::kLong', 'at::kInt', 'at::kHalf', 'at::kFloat', 'at::kBFloat16'}], - data, - axes=ONNXAttr(axes, AttrType.INTS), - keepdims=ONNXAttr(keepdims, AttrType.INT)) + """ + Computes the log sum exponent of the input tensor's element along the provided axes. The resulted + tensor has the same rank as the input if keepdims equal 1. If keepdims equal 0, then + the resulted tensor have the reduced dimension pruned. + + The above behavior is similar to numpy, with the exception that numpy default keepdims to + False instead of True. + """ + + def __init__(self, data, axes=None, keepdims=None): + super().__init__( + "ReduceLogSumExp", + 1, + [{"at::kDouble", "at::kLong", "at::kInt", "at::kHalf", "at::kFloat", "at::kBFloat16"}], + data, + axes=ONNXAttr(axes, AttrType.INTS), + keepdims=ONNXAttr(keepdims, AttrType.INT), + ) + class ReduceMax(ONNXOp): - """ - Computes the max of the input tensor's element along the provided axes. The resulted - tensor has the same rank as the input if keepdims equal 1. If keepdims equal 0, then - the resulted tensor have the reduced dimension pruned. - - The above behavior is similar to numpy, with the exception that numpy default keepdims to - False instead of True. - """ - - def __init__(self, data, - axes=None, - keepdims=None): - super().__init__('ReduceMax', 1, - [{'at::kDouble', 'at::kLong', 'at::kByte', 'at::kInt', 'at::kHalf', 'at::kFloat', 'at::kBFloat16'}], - data, - axes=ONNXAttr(axes, AttrType.INTS), - keepdims=ONNXAttr(keepdims, AttrType.INT)) + """ + Computes the max of the input tensor's element along the provided axes. The resulted + tensor has the same rank as the input if keepdims equal 1. If keepdims equal 0, then + the resulted tensor have the reduced dimension pruned. + + The above behavior is similar to numpy, with the exception that numpy default keepdims to + False instead of True. + """ + + def __init__(self, data, axes=None, keepdims=None): + super().__init__( + "ReduceMax", + 1, + [{"at::kDouble", "at::kLong", "at::kByte", "at::kInt", "at::kHalf", "at::kFloat", "at::kBFloat16"}], + data, + axes=ONNXAttr(axes, AttrType.INTS), + keepdims=ONNXAttr(keepdims, AttrType.INT), + ) + class ReduceMean(ONNXOp): - """ - Computes the mean of the input tensor's element along the provided axes. The resulted - tensor has the same rank as the input if keepdims equal 1. If keepdims equal 0, then - the resulted tensor have the reduced dimension pruned. - - The above behavior is similar to numpy, with the exception that numpy default keepdims to - False instead of True. - """ - - def __init__(self, data, - axes=None, - keepdims=None): - super().__init__('ReduceMean', 1, - [{'at::kDouble', 'at::kLong', 'at::kInt', 'at::kHalf', 'at::kFloat', 'at::kBFloat16'}], - data, - axes=ONNXAttr(axes, AttrType.INTS), - keepdims=ONNXAttr(keepdims, AttrType.INT)) + """ + Computes the mean of the input tensor's element along the provided axes. The resulted + tensor has the same rank as the input if keepdims equal 1. If keepdims equal 0, then + the resulted tensor have the reduced dimension pruned. + + The above behavior is similar to numpy, with the exception that numpy default keepdims to + False instead of True. + """ + + def __init__(self, data, axes=None, keepdims=None): + super().__init__( + "ReduceMean", + 1, + [{"at::kDouble", "at::kLong", "at::kInt", "at::kHalf", "at::kFloat", "at::kBFloat16"}], + data, + axes=ONNXAttr(axes, AttrType.INTS), + keepdims=ONNXAttr(keepdims, AttrType.INT), + ) + class ReduceMin(ONNXOp): - """ - Computes the min of the input tensor's element along the provided axes. The resulted - tensor has the same rank as the input if keepdims equal 1. If keepdims equal 0, then - the resulted tensor have the reduced dimension pruned. - - The above behavior is similar to numpy, with the exception that numpy default keepdims to - False instead of True. - """ - - def __init__(self, data, - axes=None, - keepdims=None): - super().__init__('ReduceMin', 1, - [{'at::kDouble', 'at::kLong', 'at::kByte', 'at::kInt', 'at::kHalf', 'at::kFloat', 'at::kBFloat16'}], - data, - axes=ONNXAttr(axes, AttrType.INTS), - keepdims=ONNXAttr(keepdims, AttrType.INT)) + """ + Computes the min of the input tensor's element along the provided axes. The resulted + tensor has the same rank as the input if keepdims equal 1. If keepdims equal 0, then + the resulted tensor have the reduced dimension pruned. + + The above behavior is similar to numpy, with the exception that numpy default keepdims to + False instead of True. + """ + + def __init__(self, data, axes=None, keepdims=None): + super().__init__( + "ReduceMin", + 1, + [{"at::kDouble", "at::kLong", "at::kByte", "at::kInt", "at::kHalf", "at::kFloat", "at::kBFloat16"}], + data, + axes=ONNXAttr(axes, AttrType.INTS), + keepdims=ONNXAttr(keepdims, AttrType.INT), + ) + class ReduceProd(ONNXOp): - """ - Computes the product of the input tensor's element along the provided axes. The resulted - tensor has the same rank as the input if keepdims equal 1. If keepdims equal 0, then - the resulted tensor have the reduced dimension pruned. - - The above behavior is similar to numpy, with the exception that numpy default keepdims to - False instead of True. - """ - - def __init__(self, data, - axes=None, - keepdims=None): - super().__init__('ReduceProd', 1, - [{'at::kDouble', 'at::kLong', 'at::kInt', 'at::kHalf', 'at::kFloat', 'at::kBFloat16'}], - data, - axes=ONNXAttr(axes, AttrType.INTS), - keepdims=ONNXAttr(keepdims, AttrType.INT)) + """ + Computes the product of the input tensor's element along the provided axes. The resulted + tensor has the same rank as the input if keepdims equal 1. If keepdims equal 0, then + the resulted tensor have the reduced dimension pruned. + + The above behavior is similar to numpy, with the exception that numpy default keepdims to + False instead of True. + """ + + def __init__(self, data, axes=None, keepdims=None): + super().__init__( + "ReduceProd", + 1, + [{"at::kDouble", "at::kLong", "at::kInt", "at::kHalf", "at::kFloat", "at::kBFloat16"}], + data, + axes=ONNXAttr(axes, AttrType.INTS), + keepdims=ONNXAttr(keepdims, AttrType.INT), + ) + class ReduceSum(ONNXOp): - """ - Computes the sum of the input tensor's element along the provided axes. The resulted - tensor has the same rank as the input if keepdims equal 1. If keepdims equal 0, then - the resulted tensor have the reduced dimension pruned. - - The above behavior is similar to numpy, with the exception that numpy default keepdims to - False instead of True. - """ - - def __init__(self, data, axes, - keepdims=None, - noop_with_empty_axes=None): - super().__init__('ReduceSum', 1, - [{'at::kDouble', 'at::kLong', 'at::kInt', 'at::kHalf', 'at::kFloat', 'at::kBFloat16'}, {'at::kLong'}], - data,axes, - keepdims=ONNXAttr(keepdims, AttrType.INT), - noop_with_empty_axes=ONNXAttr(noop_with_empty_axes, AttrType.INT)) + """ + Computes the sum of the input tensor's element along the provided axes. The resulted + tensor has the same rank as the input if keepdims equal 1. If keepdims equal 0, then + the resulted tensor have the reduced dimension pruned. + + The above behavior is similar to numpy, with the exception that numpy default keepdims to + False instead of True. + """ + + def __init__(self, data, axes, keepdims=None, noop_with_empty_axes=None): + super().__init__( + "ReduceSum", + 1, + [{"at::kDouble", "at::kLong", "at::kInt", "at::kHalf", "at::kFloat", "at::kBFloat16"}, {"at::kLong"}], + data, + axes, + keepdims=ONNXAttr(keepdims, AttrType.INT), + noop_with_empty_axes=ONNXAttr(noop_with_empty_axes, AttrType.INT), + ) + class ReduceSumSquare(ONNXOp): - """ - Computes the sum square of the input tensor's element along the provided axes. The resulted - tensor has the same rank as the input if keepdims equal 1. If keepdims equal 0, then - the resulted tensor have the reduced dimension pruned. - - The above behavior is similar to numpy, with the exception that numpy default keepdims to - False instead of True. - """ - - def __init__(self, data, - axes=None, - keepdims=None): - super().__init__('ReduceSumSquare', 1, - [{'at::kDouble', 'at::kLong', 'at::kInt', 'at::kHalf', 'at::kFloat', 'at::kBFloat16'}], - data, - axes=ONNXAttr(axes, AttrType.INTS), - keepdims=ONNXAttr(keepdims, AttrType.INT)) + """ + Computes the sum square of the input tensor's element along the provided axes. The resulted + tensor has the same rank as the input if keepdims equal 1. If keepdims equal 0, then + the resulted tensor have the reduced dimension pruned. + + The above behavior is similar to numpy, with the exception that numpy default keepdims to + False instead of True. + """ + + def __init__(self, data, axes=None, keepdims=None): + super().__init__( + "ReduceSumSquare", + 1, + [{"at::kDouble", "at::kLong", "at::kInt", "at::kHalf", "at::kFloat", "at::kBFloat16"}], + data, + axes=ONNXAttr(axes, AttrType.INTS), + keepdims=ONNXAttr(keepdims, AttrType.INT), + ) + class Relu(ONNXOp): - """ - Relu takes one input data (Tensor) and produces one output data - (Tensor) where the rectified linear function, y = max(0, x), is applied to - the tensor elementwise. - """ + """ + Relu takes one input data (Tensor) and produces one output data + (Tensor) where the rectified linear function, y = max(0, x), is applied to + the tensor elementwise. + """ + + def __init__(self, X): + super().__init__( + "Relu", + 1, + [ + { + "at::kDouble", + "at::kLong", + "at::kByte", + "at::kInt", + "at::kHalf", + "at::kShort", + "at::kFloat", + "at::kBFloat16", + } + ], + X, + ) - def __init__(self, X): - super().__init__('Relu', 1, - [{'at::kDouble', 'at::kLong', 'at::kByte', 'at::kInt', 'at::kHalf', 'at::kShort', 'at::kFloat', 'at::kBFloat16'}], - X) class Reshape(ONNXOp): - """ - Reshape the input tensor similar to numpy.reshape. - First input is the data tensor, second input is a shape tensor which specifies the output shape. It outputs the reshaped tensor. - At most one dimension of the new shape can be -1. In this case, the value is - inferred from the size of the tensor and the remaining dimensions. A dimension - could also be 0, in which case the actual dimension value is unchanged (i.e. taken - from the input tensor). If 'allowzero' is set, and the new shape includes 0, the - dimension will be set explicitly to zero (i.e. not taken from input tensor) - """ - - def __init__(self, data, shape, - allowzero=None): - super().__init__('Reshape', 1, - [{'at::kDouble', 'at::kLong', 'at::kByte', 'at::kInt', 'at::kHalf', 'at::kShort', 'at::kBool', 'at::kFloat', 'at::kBFloat16'}, {'at::kLong'}], - data,shape, - allowzero=ONNXAttr(allowzero, AttrType.INT)) + """ + Reshape the input tensor similar to numpy.reshape. + First input is the data tensor, second input is a shape tensor which specifies the output shape. It outputs the reshaped tensor. + At most one dimension of the new shape can be -1. In this case, the value is + inferred from the size of the tensor and the remaining dimensions. A dimension + could also be 0, in which case the actual dimension value is unchanged (i.e. taken + from the input tensor). If 'allowzero' is set, and the new shape includes 0, the + dimension will be set explicitly to zero (i.e. not taken from input tensor) + """ + + def __init__(self, data, shape, allowzero=None): + super().__init__( + "Reshape", + 1, + [ + { + "at::kDouble", + "at::kLong", + "at::kByte", + "at::kInt", + "at::kHalf", + "at::kShort", + "at::kBool", + "at::kFloat", + "at::kBFloat16", + }, + {"at::kLong"}, + ], + data, + shape, + allowzero=ONNXAttr(allowzero, AttrType.INT), + ) + class Resize(ONNXOp): - """ - Resize the input tensor. In general, it calculates every value in the output tensor as a weighted average of neighborhood (a.k.a. sampling locations) in the input tensor. - Each dimension value of the output tensor is: - output_dimension = floor(input_dimension * (roi_end - roi_start) * scale) if input \"sizes\" is not specified. - """ - - def __init__(self, X, roi, scales, sizes, - coordinate_transformation_mode=None, - cubic_coeff_a=None, - exclude_outside=None, - extrapolation_value=None, - mode=None, - nearest_mode=None): - super().__init__('Resize', 1, - [{'at::kDouble', 'at::kLong', 'at::kByte', 'at::kInt', 'at::kHalf', 'at::kShort', 'at::kBool', 'at::kFloat', 'at::kBFloat16'}, {'at::kDouble', 'at::kHalf', 'at::kFloat'}, {'at::kFloat'}, {'at::kLong'}], - X,roi,scales,sizes, - coordinate_transformation_mode=ONNXAttr(coordinate_transformation_mode, AttrType.STRING), - cubic_coeff_a=ONNXAttr(cubic_coeff_a, AttrType.FLOAT), - exclude_outside=ONNXAttr(exclude_outside, AttrType.INT), - extrapolation_value=ONNXAttr(extrapolation_value, AttrType.FLOAT), - mode=ONNXAttr(mode, AttrType.STRING), - nearest_mode=ONNXAttr(nearest_mode, AttrType.STRING)) + """ + Resize the input tensor. In general, it calculates every value in the output tensor as a weighted average of neighborhood (a.k.a. sampling locations) in the input tensor. + Each dimension value of the output tensor is: + output_dimension = floor(input_dimension * (roi_end - roi_start) * scale) if input \"sizes\" is not specified. + """ + + def __init__( + self, + X, + roi, + scales, + sizes, + coordinate_transformation_mode=None, + cubic_coeff_a=None, + exclude_outside=None, + extrapolation_value=None, + mode=None, + nearest_mode=None, + ): + super().__init__( + "Resize", + 1, + [ + { + "at::kDouble", + "at::kLong", + "at::kByte", + "at::kInt", + "at::kHalf", + "at::kShort", + "at::kBool", + "at::kFloat", + "at::kBFloat16", + }, + {"at::kDouble", "at::kHalf", "at::kFloat"}, + {"at::kFloat"}, + {"at::kLong"}, + ], + X, + roi, + scales, + sizes, + coordinate_transformation_mode=ONNXAttr(coordinate_transformation_mode, AttrType.STRING), + cubic_coeff_a=ONNXAttr(cubic_coeff_a, AttrType.FLOAT), + exclude_outside=ONNXAttr(exclude_outside, AttrType.INT), + extrapolation_value=ONNXAttr(extrapolation_value, AttrType.FLOAT), + mode=ONNXAttr(mode, AttrType.STRING), + nearest_mode=ONNXAttr(nearest_mode, AttrType.STRING), + ) + class ReverseSequence(ONNXOp): - """ - Reverse batch of sequences having different lengths specified by `sequence_lens`. - - For each slice i iterating on batch axis, the operator reverses the first sequence_lens[i] elements on time axis, - and copies elements whose index's beyond sequence_lens[i] to the output. So the output slice i contains reversed - sequences on the first sequence_lens[i] elements, then have original values copied for the other elements. - - Example 1: - input = [[0.0, 4.0, 8.0, 12.0], - [1.0, 5.0, 9.0, 13.0], - [2.0, 6.0, 10.0, 14.0], - [3.0, 7.0, 11.0, 15.0]] - sequence_lens = [4, 3, 2, 1] - time_axis = 0 - batch_axis = 1 - - output = [[3.0, 6.0, 9.0, 12.0], - [2.0, 5.0, 8.0, 13.0], - [1.0, 4.0, 10.0, 14.0], - [0.0, 7.0, 11.0, 15.0]] - - Example 2: - input = [[0.0, 1.0, 2.0, 3.0 ], - [4.0, 5.0, 6.0, 7.0 ], - [8.0, 9.0, 10.0, 11.0], - [12.0, 13.0, 14.0, 15.0]] - sequence_lens = [1, 2, 3, 4] - time_axis = 1 - batch_axis = 0 - - output = [[0.0, 1.0, 2.0, 3.0 ], - [5.0, 4.0, 6.0, 7.0 ], - [10.0, 9.0, 8.0, 11.0], - [15.0, 14.0, 13.0, 12.0]] - """ - - def __init__(self, input, sequence_lens, - batch_axis=None, - time_axis=None): - super().__init__('ReverseSequence', 1, - [{'at::kDouble', 'at::kLong', 'at::kByte', 'at::kInt', 'at::kHalf', 'at::kShort', 'at::kBool', 'at::kFloat'}, {'at::kLong'}], - input,sequence_lens, - batch_axis=ONNXAttr(batch_axis, AttrType.INT), - time_axis=ONNXAttr(time_axis, AttrType.INT)) + """ + Reverse batch of sequences having different lengths specified by `sequence_lens`. + + For each slice i iterating on batch axis, the operator reverses the first sequence_lens[i] elements on time axis, + and copies elements whose index's beyond sequence_lens[i] to the output. So the output slice i contains reversed + sequences on the first sequence_lens[i] elements, then have original values copied for the other elements. + + Example 1: + input = [[0.0, 4.0, 8.0, 12.0], + [1.0, 5.0, 9.0, 13.0], + [2.0, 6.0, 10.0, 14.0], + [3.0, 7.0, 11.0, 15.0]] + sequence_lens = [4, 3, 2, 1] + time_axis = 0 + batch_axis = 1 + + output = [[3.0, 6.0, 9.0, 12.0], + [2.0, 5.0, 8.0, 13.0], + [1.0, 4.0, 10.0, 14.0], + [0.0, 7.0, 11.0, 15.0]] + + Example 2: + input = [[0.0, 1.0, 2.0, 3.0 ], + [4.0, 5.0, 6.0, 7.0 ], + [8.0, 9.0, 10.0, 11.0], + [12.0, 13.0, 14.0, 15.0]] + sequence_lens = [1, 2, 3, 4] + time_axis = 1 + batch_axis = 0 + + output = [[0.0, 1.0, 2.0, 3.0 ], + [5.0, 4.0, 6.0, 7.0 ], + [10.0, 9.0, 8.0, 11.0], + [15.0, 14.0, 13.0, 12.0]] + """ + + def __init__(self, input, sequence_lens, batch_axis=None, time_axis=None): + super().__init__( + "ReverseSequence", + 1, + [ + { + "at::kDouble", + "at::kLong", + "at::kByte", + "at::kInt", + "at::kHalf", + "at::kShort", + "at::kBool", + "at::kFloat", + }, + {"at::kLong"}, + ], + input, + sequence_lens, + batch_axis=ONNXAttr(batch_axis, AttrType.INT), + time_axis=ONNXAttr(time_axis, AttrType.INT), + ) + class RNN(ONNXOp): - """ - Computes an one-layer simple RNN. This operator is usually supported - via some custom implementation such as CuDNN. - - Notations: - - `X` - input tensor - - `i` - input gate - - `t` - time step (t-1 means previous time step) - - `Wi` - W parameter weight matrix for input gate - - `Ri` - R recurrence weight matrix for input gate - - `Wbi` - W parameter bias vector for input gate - - `Rbi` - R parameter bias vector for input gate - - `WBi` - W parameter weight matrix for backward input gate - - `RBi` - R recurrence weight matrix for backward input gate - - `WBbi` - WR bias vectors for backward input gate - - `RBbi` - RR bias vectors for backward input gate - - `H` - Hidden state - - `num_directions` - 2 if direction == bidirectional else 1 - - Activation functions: - - Relu(x) - max(0, x) - - Tanh(x) - (1 - e^{-2x})/(1 + e^{-2x}) - - Sigmoid(x) - 1/(1 + e^{-x}) - - (NOTE: Below are optional) - - Affine(x) - alpha*x + beta - - LeakyRelu(x) - x if x >= 0 else alpha * x - - ThresholdedRelu(x) - x if x >= alpha else 0 - - ScaledTanh(x) - alpha*Tanh(beta*x) - - HardSigmoid(x) - min(max(alpha*x + beta, 0), 1) - - Elu(x) - x if x >= 0 else alpha*(e^x - 1) - - Softsign(x) - x/(1 + |x|) - - Softplus(x) - log(1 + e^x) - - Equations (Default: f=Tanh): - - - Ht = f(Xt*(Wi^T) + Ht-1*(Ri^T) + Wbi + Rbi) - This operator has **optional** inputs/outputs. See [the doc](IR.md) for more details about the representation of optional arguments. An empty string may be used in the place of an actual argument's name to indicate a missing argument. Trailing optional arguments (those not followed by an argument that is present) may also be simply omitted. - """ - - def __init__(self, X, W, R, B, sequence_lens, initial_h, - activation_alpha=None, - activation_beta=None, - activations=None, - clip=None, - direction=None, - hidden_size=None, - layout=None): - super().__init__('RNN', 2, - [{'at::kDouble', 'at::kHalf', 'at::kFloat'}, {'at::kDouble', 'at::kHalf', 'at::kFloat'}, {'at::kDouble', 'at::kHalf', 'at::kFloat'}, {'at::kDouble', 'at::kHalf', 'at::kFloat'}, {'at::kInt'}, {'at::kDouble', 'at::kHalf', 'at::kFloat'}], - X,W,R,B,sequence_lens,initial_h, - activation_alpha=ONNXAttr(activation_alpha, AttrType.FLOATS), - activation_beta=ONNXAttr(activation_beta, AttrType.FLOATS), - activations=ONNXAttr(activations, AttrType.STRINGS), - clip=ONNXAttr(clip, AttrType.FLOAT), - direction=ONNXAttr(direction, AttrType.STRING), - hidden_size=ONNXAttr(hidden_size, AttrType.INT), - layout=ONNXAttr(layout, AttrType.INT)) + """ + Computes an one-layer simple RNN. This operator is usually supported + via some custom implementation such as CuDNN. + + Notations: + + `X` - input tensor + + `i` - input gate + + `t` - time step (t-1 means previous time step) + + `Wi` - W parameter weight matrix for input gate + + `Ri` - R recurrence weight matrix for input gate + + `Wbi` - W parameter bias vector for input gate + + `Rbi` - R parameter bias vector for input gate + + `WBi` - W parameter weight matrix for backward input gate + + `RBi` - R recurrence weight matrix for backward input gate + + `WBbi` - WR bias vectors for backward input gate + + `RBbi` - RR bias vectors for backward input gate + + `H` - Hidden state + + `num_directions` - 2 if direction == bidirectional else 1 + + Activation functions: + + Relu(x) - max(0, x) + + Tanh(x) - (1 - e^{-2x})/(1 + e^{-2x}) + + Sigmoid(x) - 1/(1 + e^{-x}) + + (NOTE: Below are optional) + + Affine(x) - alpha*x + beta + + LeakyRelu(x) - x if x >= 0 else alpha * x + + ThresholdedRelu(x) - x if x >= alpha else 0 + + ScaledTanh(x) - alpha*Tanh(beta*x) + + HardSigmoid(x) - min(max(alpha*x + beta, 0), 1) + + Elu(x) - x if x >= 0 else alpha*(e^x - 1) + + Softsign(x) - x/(1 + |x|) + + Softplus(x) - log(1 + e^x) + + Equations (Default: f=Tanh): + + - Ht = f(Xt*(Wi^T) + Ht-1*(Ri^T) + Wbi + Rbi) + This operator has **optional** inputs/outputs. See [the doc](IR.md) for more details about the representation of optional arguments. An empty string may be used in the place of an actual argument's name to indicate a missing argument. Trailing optional arguments (those not followed by an argument that is present) may also be simply omitted. + """ + + def __init__( + self, + X, + W, + R, + B, + sequence_lens, + initial_h, + activation_alpha=None, + activation_beta=None, + activations=None, + clip=None, + direction=None, + hidden_size=None, + layout=None, + ): + super().__init__( + "RNN", + 2, + [ + {"at::kDouble", "at::kHalf", "at::kFloat"}, + {"at::kDouble", "at::kHalf", "at::kFloat"}, + {"at::kDouble", "at::kHalf", "at::kFloat"}, + {"at::kDouble", "at::kHalf", "at::kFloat"}, + {"at::kInt"}, + {"at::kDouble", "at::kHalf", "at::kFloat"}, + ], + X, + W, + R, + B, + sequence_lens, + initial_h, + activation_alpha=ONNXAttr(activation_alpha, AttrType.FLOATS), + activation_beta=ONNXAttr(activation_beta, AttrType.FLOATS), + activations=ONNXAttr(activations, AttrType.STRINGS), + clip=ONNXAttr(clip, AttrType.FLOAT), + direction=ONNXAttr(direction, AttrType.STRING), + hidden_size=ONNXAttr(hidden_size, AttrType.INT), + layout=ONNXAttr(layout, AttrType.INT), + ) + class RoiAlign(ONNXOp): - """ - Region of Interest (RoI) align operation described in the - [Mask R-CNN paper](https://arxiv.org/abs/1703.06870). - RoiAlign consumes an input tensor X and region of interests (rois) - to apply pooling across each RoI; it produces a 4-D tensor of shape - (num_rois, C, output_height, output_width). - - RoiAlign is proposed to avoid the misalignment by removing - quantizations while converting from original image into feature - map and from feature map into RoI feature; in each ROI bin, - the value of the sampled locations are computed directly - through bilinear interpolation. - """ - - def __init__(self, X, rois, batch_indices, - mode=None, - output_height=None, - output_width=None, - sampling_ratio=None, - spatial_scale=None): - super().__init__('RoiAlign', 1, - [{'at::kDouble', 'at::kHalf', 'at::kFloat'}, {'at::kDouble', 'at::kHalf', 'at::kFloat'}, {'at::kLong'}], - X,rois,batch_indices, - mode=ONNXAttr(mode, AttrType.STRING), - output_height=ONNXAttr(output_height, AttrType.INT), - output_width=ONNXAttr(output_width, AttrType.INT), - sampling_ratio=ONNXAttr(sampling_ratio, AttrType.INT), - spatial_scale=ONNXAttr(spatial_scale, AttrType.FLOAT)) + """ + Region of Interest (RoI) align operation described in the + [Mask R-CNN paper](https://arxiv.org/abs/1703.06870). + RoiAlign consumes an input tensor X and region of interests (rois) + to apply pooling across each RoI; it produces a 4-D tensor of shape + (num_rois, C, output_height, output_width). + + RoiAlign is proposed to avoid the misalignment by removing + quantizations while converting from original image into feature + map and from feature map into RoI feature; in each ROI bin, + the value of the sampled locations are computed directly + through bilinear interpolation. + """ + + def __init__( + self, + X, + rois, + batch_indices, + mode=None, + output_height=None, + output_width=None, + sampling_ratio=None, + spatial_scale=None, + ): + super().__init__( + "RoiAlign", + 1, + [{"at::kDouble", "at::kHalf", "at::kFloat"}, {"at::kDouble", "at::kHalf", "at::kFloat"}, {"at::kLong"}], + X, + rois, + batch_indices, + mode=ONNXAttr(mode, AttrType.STRING), + output_height=ONNXAttr(output_height, AttrType.INT), + output_width=ONNXAttr(output_width, AttrType.INT), + sampling_ratio=ONNXAttr(sampling_ratio, AttrType.INT), + spatial_scale=ONNXAttr(spatial_scale, AttrType.FLOAT), + ) + class Round(ONNXOp): - """ - Round takes one input Tensor and rounds the values, element-wise, meaning - it finds the nearest integer for each value. - In case of halfs, the rule is to round them to the nearest even integer. - The output tensor has the same shape and type as the input. - - Examples: - ``` - round([0.9]) = [1.0] - round([2.5]) = [2.0] - round([2.3]) = [2.0] - round([1.5]) = [2.0] - round([-4.5]) = [-4.0] - ``` - """ - - def __init__(self, X): - super().__init__('Round', 1, - [{'at::kDouble', 'at::kHalf', 'at::kFloat'}], - X) + """ + Round takes one input Tensor and rounds the values, element-wise, meaning + it finds the nearest integer for each value. + In case of halfs, the rule is to round them to the nearest even integer. + The output tensor has the same shape and type as the input. + + Examples: + ``` + round([0.9]) = [1.0] + round([2.5]) = [2.0] + round([2.3]) = [2.0] + round([1.5]) = [2.0] + round([-4.5]) = [-4.0] + ``` + """ + + def __init__(self, X): + super().__init__("Round", 1, [{"at::kDouble", "at::kHalf", "at::kFloat"}], X) + class Scaler(ONNXOp): - """ - Rescale input data, for example to standardize features by removing the mean and scaling to unit variance. - """ - - def __init__(self, X, - offset=None, - scale=None): - super().__init__('Scaler', 1, - [{'at::kDouble', 'at::kLong', 'at::kInt', 'at::kFloat'}], - X, - offset=ONNXAttr(offset, AttrType.FLOATS), - scale=ONNXAttr(scale, AttrType.FLOATS)) + """ + Rescale input data, for example to standardize features by removing the mean and scaling to unit variance. + """ + + def __init__(self, X, offset=None, scale=None): + super().__init__( + "Scaler", + 1, + [{"at::kDouble", "at::kLong", "at::kInt", "at::kFloat"}], + X, + offset=ONNXAttr(offset, AttrType.FLOATS), + scale=ONNXAttr(scale, AttrType.FLOATS), + ) + class Scan(ONNXOp): - """ - Scan can be used to iterate over one or more scan_input tensors, - constructing zero or more scan_output tensors. It combines ideas from general recurrences, - functional programming constructs such as scan, fold, map, and zip and is intended to enable - generalizations of RNN-like constructs for sequence-to-sequence processing. - Other tensors (referred to as state_variables here) can be used to carry a state - when iterating from one element to another (similar to hidden-state in RNNs, also referred - to as loop-carried dependences in the context of loops). - Many common usages involve a single scan_input tensor (where functionality - similar to scan, fold and map can be obtained). When more than one scan_input is used, - a behavior similar to zip is obtained. - - The attribute body must be a graph, specifying the computation to be performed in - every iteration. It takes as input the current values of the state_variables and - the current iterated element of the scan_inputs. It must return the (updated) values - of the state_variables and zero or more scan_output_element tensors. The values of the - scan_output_element tensors are concatenated over all the iterations to produce the - scan_output values of the scan construct (similar to the concatenated intermediate - hidden-state values of RNN-like constructs). All the output tensors (state_variables as - well as scan_output_element tensors) are required to have the same shape in each iteration - of the loop (a restriction imposed to enable efficient memory allocation). - - Note that the iterated element passed to the body subgraph does not have a sequence - axis. It will have a rank one less than the rank of the corresponding scan_input. - - The scan operation returns the final values of the state_variables as well as the - scan_outputs. - - The optional attribute scan_input_directions specifies the direction (forward or backward) - for each scan input. If this attribute is omitted, all sequences are scanned in the forward - direction. A bidirectional scan may be performed by specifying the same tensor input twice - in the scan_inputs, once with a forward direction, and once with a backward direction. - - The scan_output of the operation is produced by concatenating the scan_output_element - values produced by the body in each iteration. The optional attribute scan_output_directions - specifies the direction in which scan_output is constructed (by appending or prepending the - scan_output_element to scan_output in each iteration) for each scan_output. If this attribute - is omitted, the scan_output_element is appended to the scan_output in each iteration. - - The optional attribute scan_input_axes specifies the axis to be scanned for each scan_input. - If omitted, every scan_input will be scanned in axis 0. For example, if axis 0 is the - batch axis and axis 1 is the time axis (to be scanned), specify an axis value of 1. - Note that scanning a non-zero axis may be less efficient than scanning axis zero. - - The optional attribute scan_output_axes specifies the axis along which the scan_outputs - are accumulated for each scan_output. For example, if axis 1 is the time axis (to be - scanned) for both inputs and outputs, specify a scan_input axis and scan_output axis - value of 1. - - Note that because of the ONNX restriction that only the last parameter of an operator can - be variadic, the initial-states and scan-inputs are listed together as one input parameter. - Similarly, the final-states and scan-outputs are listed together as one output parameter. - The attribute num_scan_inputs indicates the number M of scan-inputs. - - The behavior of - - Scan < - num_scan_inputs = m, - body = loop-body, - scan_input_axes = [axis_1, ..., axis_m] - > (init_1, ..., init_n, scan_1, ..., scan_m) - - is equivalent to the following pseudo-code: - - // scan_i.shape[axis_i] denotes the (max) sequence-length of scan_i - // scan_i.shape[axis_i] is required to be equal to scan_j.shape[axis_j] for all i,j. - sequence_length = scan_1.shape[axis_1]; - - // initialize state-variables - st_1 = init_1; ... st_n = init_n; - // initialize scan-output variables: [] denotes an empty tensor - scan_out_1 = []; ...; scan_out_k = []; - // identify number of iterations: - - // execute loop - for (int t = 0; t < sequence_length; ++t) { - // generate the scan-input elements: the notation T[t] indicates the sub-tensor - // of rank one less than T obtained by indexing T at position t along axis k. - si_1 = scan_1[t]; - ... ; - si_m = scan_m[t]; - // execute loop-body - st_1, ..., st_n, so_1, ..., so_k = loop-body(st_1, ..., st_n, si_1, ..., si_m) - // accumulate the scan-output elements - scan_out_1 = Concat(scan_out_1, so_1); ... ; scan_out_k = Concat(scan_out_k, so_k); - } - - return st_1, ..., st_n, scan_out_1, ..., scan_out_k; - - *Sample usage: Encoding RNN using a Scan* - - The following example shows how a simple RNN over an input tensor %X, with weight tensor %Wi, - recurrence weight tensor %Ri, bias tensors %Wbi and %Rbi, and initial hidden-state %H_0 can - be encoded as a ScanLoop. Note that the loop-body is a nested graph, and it directly computes - %Wi, %Ri, %Wbi, and %Rbi (typically constants or initializers in the body graph). If these - values are computed in the outer graph, they need to be passed in as extra state_variables. - - graph rnn-encoding { - %H_0 = ... - %X = ... - %Y_h, %Y = Scan[body = , num_scan_inputs=1](%H_0, %X) - return %Y, %Y_h - } - - graph rnn-cell-1 ( - %H_tminus1[FLOAT, tensor] - %X_t[FLOAT, tensor] - ) { - %Wi = ... - %Ri = ... - %Wbi = ... - %Rbi = ... - %t1 = X_t * (Wi^T) - %t2 = H_tminus1*(Ri^T) - %t3 = Add(%t1, %t2) - %t4 = Add(%t3, %Wbi) - %t5 = Add(%t4, %Rbi) - %Ht = Tanh(%t5) - %Accumulate = Identity(%Ht) - return %Ht, %Accumulate - } - """ - - def __init__(self, initial_state_and_scan_inputs, - body=None, - num_scan_inputs=None, - scan_input_axes=None, - scan_input_directions=None, - scan_output_axes=None, - scan_output_directions=None): - super().__init__('Scan', 1, - [{'at::kDouble', 'at::kLong', 'at::kByte', 'at::kInt', 'at::kHalf', 'at::kShort', 'at::kBool', 'at::kFloat'}], - initial_state_and_scan_inputs, - body=ONNXAttr(body, AttrType.GRAPH), - num_scan_inputs=ONNXAttr(num_scan_inputs, AttrType.INT), - scan_input_axes=ONNXAttr(scan_input_axes, AttrType.INTS), - scan_input_directions=ONNXAttr(scan_input_directions, AttrType.INTS), - scan_output_axes=ONNXAttr(scan_output_axes, AttrType.INTS), - scan_output_directions=ONNXAttr(scan_output_directions, AttrType.INTS)) + """ + Scan can be used to iterate over one or more scan_input tensors, + constructing zero or more scan_output tensors. It combines ideas from general recurrences, + functional programming constructs such as scan, fold, map, and zip and is intended to enable + generalizations of RNN-like constructs for sequence-to-sequence processing. + Other tensors (referred to as state_variables here) can be used to carry a state + when iterating from one element to another (similar to hidden-state in RNNs, also referred + to as loop-carried dependences in the context of loops). + Many common usages involve a single scan_input tensor (where functionality + similar to scan, fold and map can be obtained). When more than one scan_input is used, + a behavior similar to zip is obtained. + + The attribute body must be a graph, specifying the computation to be performed in + every iteration. It takes as input the current values of the state_variables and + the current iterated element of the scan_inputs. It must return the (updated) values + of the state_variables and zero or more scan_output_element tensors. The values of the + scan_output_element tensors are concatenated over all the iterations to produce the + scan_output values of the scan construct (similar to the concatenated intermediate + hidden-state values of RNN-like constructs). All the output tensors (state_variables as + well as scan_output_element tensors) are required to have the same shape in each iteration + of the loop (a restriction imposed to enable efficient memory allocation). + + Note that the iterated element passed to the body subgraph does not have a sequence + axis. It will have a rank one less than the rank of the corresponding scan_input. + + The scan operation returns the final values of the state_variables as well as the + scan_outputs. + + The optional attribute scan_input_directions specifies the direction (forward or backward) + for each scan input. If this attribute is omitted, all sequences are scanned in the forward + direction. A bidirectional scan may be performed by specifying the same tensor input twice + in the scan_inputs, once with a forward direction, and once with a backward direction. + + The scan_output of the operation is produced by concatenating the scan_output_element + values produced by the body in each iteration. The optional attribute scan_output_directions + specifies the direction in which scan_output is constructed (by appending or prepending the + scan_output_element to scan_output in each iteration) for each scan_output. If this attribute + is omitted, the scan_output_element is appended to the scan_output in each iteration. + + The optional attribute scan_input_axes specifies the axis to be scanned for each scan_input. + If omitted, every scan_input will be scanned in axis 0. For example, if axis 0 is the + batch axis and axis 1 is the time axis (to be scanned), specify an axis value of 1. + Note that scanning a non-zero axis may be less efficient than scanning axis zero. + + The optional attribute scan_output_axes specifies the axis along which the scan_outputs + are accumulated for each scan_output. For example, if axis 1 is the time axis (to be + scanned) for both inputs and outputs, specify a scan_input axis and scan_output axis + value of 1. + + Note that because of the ONNX restriction that only the last parameter of an operator can + be variadic, the initial-states and scan-inputs are listed together as one input parameter. + Similarly, the final-states and scan-outputs are listed together as one output parameter. + The attribute num_scan_inputs indicates the number M of scan-inputs. + + The behavior of + + Scan < + num_scan_inputs = m, + body = loop-body, + scan_input_axes = [axis_1, ..., axis_m] + > (init_1, ..., init_n, scan_1, ..., scan_m) + + is equivalent to the following pseudo-code: + + // scan_i.shape[axis_i] denotes the (max) sequence-length of scan_i + // scan_i.shape[axis_i] is required to be equal to scan_j.shape[axis_j] for all i,j. + sequence_length = scan_1.shape[axis_1]; + + // initialize state-variables + st_1 = init_1; ... st_n = init_n; + // initialize scan-output variables: [] denotes an empty tensor + scan_out_1 = []; ...; scan_out_k = []; + // identify number of iterations: + + // execute loop + for (int t = 0; t < sequence_length; ++t) { + // generate the scan-input elements: the notation T[t] indicates the sub-tensor + // of rank one less than T obtained by indexing T at position t along axis k. + si_1 = scan_1[t]; + ... ; + si_m = scan_m[t]; + // execute loop-body + st_1, ..., st_n, so_1, ..., so_k = loop-body(st_1, ..., st_n, si_1, ..., si_m) + // accumulate the scan-output elements + scan_out_1 = Concat(scan_out_1, so_1); ... ; scan_out_k = Concat(scan_out_k, so_k); + } + + return st_1, ..., st_n, scan_out_1, ..., scan_out_k; + + *Sample usage: Encoding RNN using a Scan* + + The following example shows how a simple RNN over an input tensor %X, with weight tensor %Wi, + recurrence weight tensor %Ri, bias tensors %Wbi and %Rbi, and initial hidden-state %H_0 can + be encoded as a ScanLoop. Note that the loop-body is a nested graph, and it directly computes + %Wi, %Ri, %Wbi, and %Rbi (typically constants or initializers in the body graph). If these + values are computed in the outer graph, they need to be passed in as extra state_variables. + + graph rnn-encoding { + %H_0 = ... + %X = ... + %Y_h, %Y = Scan[body = , num_scan_inputs=1](%H_0, %X) + return %Y, %Y_h + } + + graph rnn-cell-1 ( + %H_tminus1[FLOAT, tensor] + %X_t[FLOAT, tensor] + ) { + %Wi = ... + %Ri = ... + %Wbi = ... + %Rbi = ... + %t1 = X_t * (Wi^T) + %t2 = H_tminus1*(Ri^T) + %t3 = Add(%t1, %t2) + %t4 = Add(%t3, %Wbi) + %t5 = Add(%t4, %Rbi) + %Ht = Tanh(%t5) + %Accumulate = Identity(%Ht) + return %Ht, %Accumulate + } + """ + + def __init__( + self, + initial_state_and_scan_inputs, + body=None, + num_scan_inputs=None, + scan_input_axes=None, + scan_input_directions=None, + scan_output_axes=None, + scan_output_directions=None, + ): + super().__init__( + "Scan", + 1, + [ + { + "at::kDouble", + "at::kLong", + "at::kByte", + "at::kInt", + "at::kHalf", + "at::kShort", + "at::kBool", + "at::kFloat", + } + ], + initial_state_and_scan_inputs, + body=ONNXAttr(body, AttrType.GRAPH), + num_scan_inputs=ONNXAttr(num_scan_inputs, AttrType.INT), + scan_input_axes=ONNXAttr(scan_input_axes, AttrType.INTS), + scan_input_directions=ONNXAttr(scan_input_directions, AttrType.INTS), + scan_output_axes=ONNXAttr(scan_output_axes, AttrType.INTS), + scan_output_directions=ONNXAttr(scan_output_directions, AttrType.INTS), + ) + class Scatter(ONNXOp): - """ - Given `data`, `updates` and `indices` input tensors of rank r >= 1, write the values provided by `updates` - into the first input, `data`, along `axis` dimension of `data` (by default outer-most one as axis=0) at corresponding `indices`. - For each entry in `updates`, the target index in `data` is specified by corresponding entry in `indices` - for dimension = axis, and index in source for dimension != axis. For instance, in a 2-D tensor case, - data[indices[i][j]][j] = updates[i][j] if axis = 0, or data[i][indices[i][j]] = updates[i][j] if axis = 1, - where i and j are loop counters from 0 up to the respective size in `updates` - 1. - Example 1: - data = [ - [0.0, 0.0, 0.0], - [0.0, 0.0, 0.0], - [0.0, 0.0, 0.0], - ] - indices = [ - [1, 0, 2], - [0, 2, 1], - ] - updates = [ - [1.0, 1.1, 1.2], - [2.0, 2.1, 2.2], - ] - output = [ - [2.0, 1.1, 0.0] - [1.0, 0.0, 2.2] - [0.0, 2.1, 1.2] - ] - Example 2: - data = [[1.0, 2.0, 3.0, 4.0, 5.0]] - indices = [[1, 3]] - updates = [[1.1, 2.1]] - axis = 1 - output = [[1.0, 1.1, 3.0, 2.1, 5.0]] - """ - - def __init__(self, data, indices, updates, - axis=None): - super().__init__('Scatter', 1, - [{'at::kDouble', 'at::kLong', 'at::kByte', 'at::kInt', 'at::kHalf', 'at::kShort', 'at::kBool', 'at::kFloat'}, {'at::kLong', 'at::kInt'}, {'at::kDouble', 'at::kLong', 'at::kByte', 'at::kInt', 'at::kHalf', 'at::kShort', 'at::kBool', 'at::kFloat'}], - data,indices,updates, - axis=ONNXAttr(axis, AttrType.INT)) + """ + Given `data`, `updates` and `indices` input tensors of rank r >= 1, write the values provided by `updates` + into the first input, `data`, along `axis` dimension of `data` (by default outer-most one as axis=0) at corresponding `indices`. + For each entry in `updates`, the target index in `data` is specified by corresponding entry in `indices` + for dimension = axis, and index in source for dimension != axis. For instance, in a 2-D tensor case, + data[indices[i][j]][j] = updates[i][j] if axis = 0, or data[i][indices[i][j]] = updates[i][j] if axis = 1, + where i and j are loop counters from 0 up to the respective size in `updates` - 1. + Example 1: + data = [ + [0.0, 0.0, 0.0], + [0.0, 0.0, 0.0], + [0.0, 0.0, 0.0], + ] + indices = [ + [1, 0, 2], + [0, 2, 1], + ] + updates = [ + [1.0, 1.1, 1.2], + [2.0, 2.1, 2.2], + ] + output = [ + [2.0, 1.1, 0.0] + [1.0, 0.0, 2.2] + [0.0, 2.1, 1.2] + ] + Example 2: + data = [[1.0, 2.0, 3.0, 4.0, 5.0]] + indices = [[1, 3]] + updates = [[1.1, 2.1]] + axis = 1 + output = [[1.0, 1.1, 3.0, 2.1, 5.0]] + """ + + def __init__(self, data, indices, updates, axis=None): + super().__init__( + "Scatter", + 1, + [ + { + "at::kDouble", + "at::kLong", + "at::kByte", + "at::kInt", + "at::kHalf", + "at::kShort", + "at::kBool", + "at::kFloat", + }, + {"at::kLong", "at::kInt"}, + { + "at::kDouble", + "at::kLong", + "at::kByte", + "at::kInt", + "at::kHalf", + "at::kShort", + "at::kBool", + "at::kFloat", + }, + ], + data, + indices, + updates, + axis=ONNXAttr(axis, AttrType.INT), + ) + class ScatterElements(ONNXOp): - """ - ScatterElements takes three inputs `data`, `updates`, and `indices` of the same - rank r >= 1 and an optional attribute axis that identifies an axis of `data` - (by default, the outer-most axis, that is axis 0). The output of the operation - is produced by creating a copy of the input `data`, and then updating its value - to values specified by `updates` at specific index positions specified by - `indices`. Its output shape is the same as the shape of `data`. - - For each entry in `updates`, the target index in `data` is obtained by combining - the corresponding entry in `indices` with the index of the entry itself: the - index-value for dimension = axis is obtained from the value of the corresponding - entry in `indices` and the index-value for dimension != axis is obtained from the - index of the entry itself. - - For instance, in a 2-D tensor case, the update corresponding to the [i][j] entry - is performed as below: - ``` - output[indices[i][j]][j] = updates[i][j] if axis = 0, - output[i][indices[i][j]] = updates[i][j] if axis = 1, - ``` - - This operator is the inverse of GatherElements. It is similar to Torch's Scatter operation. - - Example 1: - ``` - data = [ - [0.0, 0.0, 0.0], - [0.0, 0.0, 0.0], - [0.0, 0.0, 0.0], - ] - indices = [ - [1, 0, 2], - [0, 2, 1], - ] - updates = [ - [1.0, 1.1, 1.2], - [2.0, 2.1, 2.2], - ] - output = [ - [2.0, 1.1, 0.0] - [1.0, 0.0, 2.2] - [0.0, 2.1, 1.2] - ] - ``` - Example 2: - ``` - data = [[1.0, 2.0, 3.0, 4.0, 5.0]] - indices = [[1, 3]] - updates = [[1.1, 2.1]] - axis = 1 - output = [[1.0, 1.1, 3.0, 2.1, 5.0]] - ``` - """ - - def __init__(self, data, indices, updates, - axis=None): - super().__init__('ScatterElements', 1, - [{'at::kDouble', 'at::kLong', 'at::kByte', 'at::kInt', 'at::kHalf', 'at::kShort', 'at::kBool', 'at::kFloat', 'at::kBFloat16'}, {'at::kLong', 'at::kInt'}, {'at::kDouble', 'at::kLong', 'at::kByte', 'at::kInt', 'at::kHalf', 'at::kShort', 'at::kBool', 'at::kFloat', 'at::kBFloat16'}], - data,indices,updates, - axis=ONNXAttr(axis, AttrType.INT)) + """ + ScatterElements takes three inputs `data`, `updates`, and `indices` of the same + rank r >= 1 and an optional attribute axis that identifies an axis of `data` + (by default, the outer-most axis, that is axis 0). The output of the operation + is produced by creating a copy of the input `data`, and then updating its value + to values specified by `updates` at specific index positions specified by + `indices`. Its output shape is the same as the shape of `data`. + + For each entry in `updates`, the target index in `data` is obtained by combining + the corresponding entry in `indices` with the index of the entry itself: the + index-value for dimension = axis is obtained from the value of the corresponding + entry in `indices` and the index-value for dimension != axis is obtained from the + index of the entry itself. + + For instance, in a 2-D tensor case, the update corresponding to the [i][j] entry + is performed as below: + ``` + output[indices[i][j]][j] = updates[i][j] if axis = 0, + output[i][indices[i][j]] = updates[i][j] if axis = 1, + ``` + + This operator is the inverse of GatherElements. It is similar to Torch's Scatter operation. + + Example 1: + ``` + data = [ + [0.0, 0.0, 0.0], + [0.0, 0.0, 0.0], + [0.0, 0.0, 0.0], + ] + indices = [ + [1, 0, 2], + [0, 2, 1], + ] + updates = [ + [1.0, 1.1, 1.2], + [2.0, 2.1, 2.2], + ] + output = [ + [2.0, 1.1, 0.0] + [1.0, 0.0, 2.2] + [0.0, 2.1, 1.2] + ] + ``` + Example 2: + ``` + data = [[1.0, 2.0, 3.0, 4.0, 5.0]] + indices = [[1, 3]] + updates = [[1.1, 2.1]] + axis = 1 + output = [[1.0, 1.1, 3.0, 2.1, 5.0]] + ``` + """ + + def __init__(self, data, indices, updates, axis=None): + super().__init__( + "ScatterElements", + 1, + [ + { + "at::kDouble", + "at::kLong", + "at::kByte", + "at::kInt", + "at::kHalf", + "at::kShort", + "at::kBool", + "at::kFloat", + "at::kBFloat16", + }, + {"at::kLong", "at::kInt"}, + { + "at::kDouble", + "at::kLong", + "at::kByte", + "at::kInt", + "at::kHalf", + "at::kShort", + "at::kBool", + "at::kFloat", + "at::kBFloat16", + }, + ], + data, + indices, + updates, + axis=ONNXAttr(axis, AttrType.INT), + ) + class ScatterND(ONNXOp): - """ - ScatterND takes three inputs `data` tensor of rank r >= 1, `indices` tensor of rank q >= 1, - and `updates` tensor of rank q + r - indices.shape[-1] - 1. The output of the operation - is produced by creating a copy of the input `data`, and then updating its value to values - specified by `updates` at specific index positions specified by `indices`. Its output shape - is the same as the shape of `data`. Note that `indices` should not have duplicate entries. - That is, two or more `updates` for the same index-location is not supported. - - `indices` is an integer tensor. Let k denote indices.shape[-1], the last dimension in the shape of `indices`. - `indices` is treated as a (q-1)-dimensional tensor of k-tuples, where each k-tuple is a partial-index into `data`. - Hence, k can be a value at most the rank of `data`. When k equals rank(data), each update entry specifies an - update to a single element of the tensor. When k is less than rank(data) each update entry specifies an - update to a slice of the tensor. - - `updates` is treated as a (q-1)-dimensional tensor of replacement-slice-values. Thus, the - first (q-1) dimensions of updates.shape must match the first (q-1) dimensions of indices.shape. - The remaining dimensions of `updates` correspond to the dimensions of the - replacement-slice-values. Each replacement-slice-value is a (r-k) dimensional tensor, - corresponding to the trailing (r-k) dimensions of `data`. Thus, the shape of `updates` - must equal indices.shape[0:q-1] ++ data.shape[k:r-1], where ++ denotes the concatenation - of shapes. - - The `output` is calculated via the following equation: - - output = np.copy(data) - update_indices = indices.shape[:-1] - for idx in np.ndindex(update_indices): - output[indices[idx]] = updates[idx] - - The order of iteration in the above loop is not specified. - In particular, indices should not have duplicate entries: that is, if idx1 != idx2, then indices[idx1] != indices[idx2]. - This ensures that the output value does not depend on the iteration order. - - This operator is the inverse of GatherND. - - Example 1: - ``` - data = [1, 2, 3, 4, 5, 6, 7, 8] - indices = [[4], [3], [1], [7]] - updates = [9, 10, 11, 12] - output = [1, 11, 3, 10, 9, 6, 7, 12] - ``` - - Example 2: - ``` - data = [[[1, 2, 3, 4], [5, 6, 7, 8], [8, 7, 6, 5], [4, 3, 2, 1]], - [[1, 2, 3, 4], [5, 6, 7, 8], [8, 7, 6, 5], [4, 3, 2, 1]], - [[8, 7, 6, 5], [4, 3, 2, 1], [1, 2, 3, 4], [5, 6, 7, 8]], - [[8, 7, 6, 5], [4, 3, 2, 1], [1, 2, 3, 4], [5, 6, 7, 8]]] - indices = [[0], [2]] - updates = [[[5, 5, 5, 5], [6, 6, 6, 6], [7, 7, 7, 7], [8, 8, 8, 8]], - [[1, 1, 1, 1], [2, 2, 2, 2], [3, 3, 3, 3], [4, 4, 4, 4]]] - output = [[[5, 5, 5, 5], [6, 6, 6, 6], [7, 7, 7, 7], [8, 8, 8, 8]], - [[1, 2, 3, 4], [5, 6, 7, 8], [8, 7, 6, 5], [4, 3, 2, 1]], - [[1, 1, 1, 1], [2, 2, 2, 2], [3, 3, 3, 3], [4, 4, 4, 4]], - [[8, 7, 6, 5], [4, 3, 2, 1], [1, 2, 3, 4], [5, 6, 7, 8]]] - ``` - """ - - def __init__(self, data, indices, updates): - super().__init__('ScatterND', 1, - [{'at::kDouble', 'at::kLong', 'at::kByte', 'at::kInt', 'at::kHalf', 'at::kShort', 'at::kBool', 'at::kFloat', 'at::kBFloat16'}, {'at::kLong'}, {'at::kDouble', 'at::kLong', 'at::kByte', 'at::kInt', 'at::kHalf', 'at::kShort', 'at::kBool', 'at::kFloat', 'at::kBFloat16'}], - data,indices,updates) + """ + ScatterND takes three inputs `data` tensor of rank r >= 1, `indices` tensor of rank q >= 1, + and `updates` tensor of rank q + r - indices.shape[-1] - 1. The output of the operation + is produced by creating a copy of the input `data`, and then updating its value to values + specified by `updates` at specific index positions specified by `indices`. Its output shape + is the same as the shape of `data`. Note that `indices` should not have duplicate entries. + That is, two or more `updates` for the same index-location is not supported. + + `indices` is an integer tensor. Let k denote indices.shape[-1], the last dimension in the shape of `indices`. + `indices` is treated as a (q-1)-dimensional tensor of k-tuples, where each k-tuple is a partial-index into `data`. + Hence, k can be a value at most the rank of `data`. When k equals rank(data), each update entry specifies an + update to a single element of the tensor. When k is less than rank(data) each update entry specifies an + update to a slice of the tensor. + + `updates` is treated as a (q-1)-dimensional tensor of replacement-slice-values. Thus, the + first (q-1) dimensions of updates.shape must match the first (q-1) dimensions of indices.shape. + The remaining dimensions of `updates` correspond to the dimensions of the + replacement-slice-values. Each replacement-slice-value is a (r-k) dimensional tensor, + corresponding to the trailing (r-k) dimensions of `data`. Thus, the shape of `updates` + must equal indices.shape[0:q-1] ++ data.shape[k:r-1], where ++ denotes the concatenation + of shapes. + + The `output` is calculated via the following equation: + + output = np.copy(data) + update_indices = indices.shape[:-1] + for idx in np.ndindex(update_indices): + output[indices[idx]] = updates[idx] + + The order of iteration in the above loop is not specified. + In particular, indices should not have duplicate entries: that is, if idx1 != idx2, then indices[idx1] != indices[idx2]. + This ensures that the output value does not depend on the iteration order. + + This operator is the inverse of GatherND. + + Example 1: + ``` + data = [1, 2, 3, 4, 5, 6, 7, 8] + indices = [[4], [3], [1], [7]] + updates = [9, 10, 11, 12] + output = [1, 11, 3, 10, 9, 6, 7, 12] + ``` + + Example 2: + ``` + data = [[[1, 2, 3, 4], [5, 6, 7, 8], [8, 7, 6, 5], [4, 3, 2, 1]], + [[1, 2, 3, 4], [5, 6, 7, 8], [8, 7, 6, 5], [4, 3, 2, 1]], + [[8, 7, 6, 5], [4, 3, 2, 1], [1, 2, 3, 4], [5, 6, 7, 8]], + [[8, 7, 6, 5], [4, 3, 2, 1], [1, 2, 3, 4], [5, 6, 7, 8]]] + indices = [[0], [2]] + updates = [[[5, 5, 5, 5], [6, 6, 6, 6], [7, 7, 7, 7], [8, 8, 8, 8]], + [[1, 1, 1, 1], [2, 2, 2, 2], [3, 3, 3, 3], [4, 4, 4, 4]]] + output = [[[5, 5, 5, 5], [6, 6, 6, 6], [7, 7, 7, 7], [8, 8, 8, 8]], + [[1, 2, 3, 4], [5, 6, 7, 8], [8, 7, 6, 5], [4, 3, 2, 1]], + [[1, 1, 1, 1], [2, 2, 2, 2], [3, 3, 3, 3], [4, 4, 4, 4]], + [[8, 7, 6, 5], [4, 3, 2, 1], [1, 2, 3, 4], [5, 6, 7, 8]]] + ``` + """ + + def __init__(self, data, indices, updates): + super().__init__( + "ScatterND", + 1, + [ + { + "at::kDouble", + "at::kLong", + "at::kByte", + "at::kInt", + "at::kHalf", + "at::kShort", + "at::kBool", + "at::kFloat", + "at::kBFloat16", + }, + {"at::kLong"}, + { + "at::kDouble", + "at::kLong", + "at::kByte", + "at::kInt", + "at::kHalf", + "at::kShort", + "at::kBool", + "at::kFloat", + "at::kBFloat16", + }, + ], + data, + indices, + updates, + ) + class Selu(ONNXOp): - """ - Selu takes one input data (Tensor) and produces one output data - (Tensor) where the scaled exponential linear unit function, - `y = gamma * (alpha * e^x - alpha) for x <= 0`, `y = gamma * x for x > 0`, - is applied to the tensor elementwise. - """ - - def __init__(self, X, - alpha=None, - gamma=None): - super().__init__('Selu', 1, - [{'at::kDouble', 'at::kHalf', 'at::kFloat'}], - X, - alpha=ONNXAttr(alpha, AttrType.FLOAT), - gamma=ONNXAttr(gamma, AttrType.FLOAT)) + """ + Selu takes one input data (Tensor) and produces one output data + (Tensor) where the scaled exponential linear unit function, + `y = gamma * (alpha * e^x - alpha) for x <= 0`, `y = gamma * x for x > 0`, + is applied to the tensor elementwise. + """ + + def __init__(self, X, alpha=None, gamma=None): + super().__init__( + "Selu", + 1, + [{"at::kDouble", "at::kHalf", "at::kFloat"}], + X, + alpha=ONNXAttr(alpha, AttrType.FLOAT), + gamma=ONNXAttr(gamma, AttrType.FLOAT), + ) + class SequenceAt(ONNXOp): - """ - Outputs a tensor copy from the tensor at 'position' in 'input_sequence'. - Accepted range for 'position' is in `[-n, n - 1]`, where `n` is the number of tensors in 'input_sequence'. - Negative value means counting positions from the back. - """ + """ + Outputs a tensor copy from the tensor at 'position' in 'input_sequence'. + Accepted range for 'position' is in `[-n, n - 1]`, where `n` is the number of tensors in 'input_sequence'. + Negative value means counting positions from the back. + """ + + def __init__(self, input_sequence, position): + super().__init__("SequenceAt", 1, [set(), {"at::kLong", "at::kInt"}], input_sequence, position) - def __init__(self, input_sequence, position): - super().__init__('SequenceAt', 1, - [set(), {'at::kLong', 'at::kInt'}], - input_sequence,position) class SequenceConstruct(ONNXOp): - """ - Construct a tensor sequence containing 'inputs' tensors. - All tensors in 'inputs' must have the same data type. - """ + """ + Construct a tensor sequence containing 'inputs' tensors. + All tensors in 'inputs' must have the same data type. + """ + + def __init__(self, inputs): + super().__init__( + "SequenceConstruct", + 1, + [ + { + "at::kDouble", + "at::kLong", + "at::kByte", + "at::kInt", + "at::kHalf", + "at::kShort", + "at::kBool", + "at::kFloat", + } + ], + inputs, + ) - def __init__(self, inputs): - super().__init__('SequenceConstruct', 1, - [{'at::kDouble', 'at::kLong', 'at::kByte', 'at::kInt', 'at::kHalf', 'at::kShort', 'at::kBool', 'at::kFloat'}], - inputs) class SequenceEmpty(ONNXOp): - """ - Construct an empty tensor sequence, with given data type. - """ + """ + Construct an empty tensor sequence, with given data type. + """ + + def __init__(self, dtype=None): + super().__init__("SequenceEmpty", 1, [], dtype=ONNXAttr(dtype, AttrType.INT)) - def __init__(self, - dtype=None): - super().__init__('SequenceEmpty', 1, - [], - dtype=ONNXAttr(dtype, AttrType.INT)) class SequenceErase(ONNXOp): - """ - Outputs a tensor sequence that removes the tensor at 'position' from 'input_sequence'. - Accepted range for 'position' is in `[-n, n - 1]`, where `n` is the number of tensors in 'input_sequence'. - Negative value means counting positions from the back. - 'position' is optional, by default it erases the last tensor from 'input_sequence'. - """ - - def __init__(self, input_sequence, position): - super().__init__('SequenceErase', 1, - [set(), {'at::kLong', 'at::kInt'}], - input_sequence,position) + """ + Outputs a tensor sequence that removes the tensor at 'position' from 'input_sequence'. + Accepted range for 'position' is in `[-n, n - 1]`, where `n` is the number of tensors in 'input_sequence'. + Negative value means counting positions from the back. + 'position' is optional, by default it erases the last tensor from 'input_sequence'. + """ + + def __init__(self, input_sequence, position): + super().__init__("SequenceErase", 1, [set(), {"at::kLong", "at::kInt"}], input_sequence, position) + class SequenceInsert(ONNXOp): - """ - Outputs a tensor sequence that inserts 'tensor' into 'input_sequence' at 'position'. - 'tensor' must have the same data type as 'input_sequence'. - Accepted range for 'position' is in `[-n, n]`, where `n` is the number of tensors in 'input_sequence'. - Negative value means counting positions from the back. - 'position' is optional, by default it inserts 'tensor' to the back of 'input_sequence'. - """ - - def __init__(self, input_sequence, tensor, position): - super().__init__('SequenceInsert', 1, - [set(), {'at::kDouble', 'at::kLong', 'at::kByte', 'at::kInt', 'at::kHalf', 'at::kShort', 'at::kBool', 'at::kFloat'}, {'at::kLong', 'at::kInt'}], - input_sequence,tensor,position) + """ + Outputs a tensor sequence that inserts 'tensor' into 'input_sequence' at 'position'. + 'tensor' must have the same data type as 'input_sequence'. + Accepted range for 'position' is in `[-n, n]`, where `n` is the number of tensors in 'input_sequence'. + Negative value means counting positions from the back. + 'position' is optional, by default it inserts 'tensor' to the back of 'input_sequence'. + """ + + def __init__(self, input_sequence, tensor, position): + super().__init__( + "SequenceInsert", + 1, + [ + set(), + { + "at::kDouble", + "at::kLong", + "at::kByte", + "at::kInt", + "at::kHalf", + "at::kShort", + "at::kBool", + "at::kFloat", + }, + {"at::kLong", "at::kInt"}, + ], + input_sequence, + tensor, + position, + ) + class SequenceLength(ONNXOp): - """ - Produces a scalar(tensor of empty shape) containing the number of tensors in 'input_sequence'. - """ + """ + Produces a scalar(tensor of empty shape) containing the number of tensors in 'input_sequence'. + """ + + def __init__(self, input_sequence): + super().__init__("SequenceLength", 1, [set()], input_sequence) - def __init__(self, input_sequence): - super().__init__('SequenceLength', 1, - [set()], - input_sequence) class Shape(ONNXOp): - """ - Takes a tensor as input and outputs an 1D int64 tensor containing the shape of the input tensor. - """ + """ + Takes a tensor as input and outputs an 1D int64 tensor containing the shape of the input tensor. + """ + + def __init__(self, data): + super().__init__( + "Shape", + 1, + [ + { + "at::kDouble", + "at::kLong", + "at::kByte", + "at::kInt", + "at::kHalf", + "at::kShort", + "at::kBool", + "at::kFloat", + "at::kBFloat16", + } + ], + data, + ) - def __init__(self, data): - super().__init__('Shape', 1, - [{'at::kDouble', 'at::kLong', 'at::kByte', 'at::kInt', 'at::kHalf', 'at::kShort', 'at::kBool', 'at::kFloat', 'at::kBFloat16'}], - data) class Shrink(ONNXOp): - """ - Shrink takes one input data (Tensor) and produces one Tensor output, - having same datatype and shape with input. It has two attributes, lambd and - bias. The formula of this operator is: If x < -lambd, y = x + bias; - If x > lambd, y = x - bias; Otherwise, y = 0. - """ - - def __init__(self, input, - bias=None, - lambd=None): - super().__init__('Shrink', 1, - [{'at::kDouble', 'at::kLong', 'at::kByte', 'at::kInt', 'at::kHalf', 'at::kShort', 'at::kFloat'}], - input, - bias=ONNXAttr(bias, AttrType.FLOAT), - lambd=ONNXAttr(lambd, AttrType.FLOAT)) + """ + Shrink takes one input data (Tensor) and produces one Tensor output, + having same datatype and shape with input. It has two attributes, lambd and + bias. The formula of this operator is: If x < -lambd, y = x + bias; + If x > lambd, y = x - bias; Otherwise, y = 0. + """ + + def __init__(self, input, bias=None, lambd=None): + super().__init__( + "Shrink", + 1, + [{"at::kDouble", "at::kLong", "at::kByte", "at::kInt", "at::kHalf", "at::kShort", "at::kFloat"}], + input, + bias=ONNXAttr(bias, AttrType.FLOAT), + lambd=ONNXAttr(lambd, AttrType.FLOAT), + ) + class Sigmoid(ONNXOp): - """ - Sigmoid takes one input data (Tensor) and produces one output data - (Tensor) where the sigmoid function, y = 1 / (1 + exp(-x)), is applied to the - tensor elementwise. - """ + """ + Sigmoid takes one input data (Tensor) and produces one output data + (Tensor) where the sigmoid function, y = 1 / (1 + exp(-x)), is applied to the + tensor elementwise. + """ + + def __init__(self, X): + super().__init__("Sigmoid", 1, [{"at::kDouble", "at::kBFloat16", "at::kHalf", "at::kFloat"}], X) - def __init__(self, X): - super().__init__('Sigmoid', 1, - [{'at::kDouble', 'at::kBFloat16', 'at::kHalf', 'at::kFloat'}], - X) class Sign(ONNXOp): - """ - Calculate the sign of the given input tensor element-wise. - If input > 0, output 1. if input < 0, output -1. if input == 0, output 0. - """ + """ + Calculate the sign of the given input tensor element-wise. + If input > 0, output 1. if input < 0, output -1. if input == 0, output 0. + """ + + def __init__(self, input): + super().__init__( + "Sign", + 1, + [ + { + "at::kDouble", + "at::kLong", + "at::kByte", + "at::kInt", + "at::kHalf", + "at::kShort", + "at::kFloat", + "at::kBFloat16", + } + ], + input, + ) - def __init__(self, input): - super().__init__('Sign', 1, - [{'at::kDouble', 'at::kLong', 'at::kByte', 'at::kInt', 'at::kHalf', 'at::kShort', 'at::kFloat', 'at::kBFloat16'}], - input) class Sin(ONNXOp): - """ - Calculates the sine of the given input tensor, element-wise. - """ + """ + Calculates the sine of the given input tensor, element-wise. + """ + + def __init__(self, input): + super().__init__("Sin", 1, [{"at::kDouble", "at::kHalf", "at::kFloat"}], input) - def __init__(self, input): - super().__init__('Sin', 1, - [{'at::kDouble', 'at::kHalf', 'at::kFloat'}], - input) class Sinh(ONNXOp): - """ - Calculates the hyperbolic sine of the given input tensor element-wise. - """ + """ + Calculates the hyperbolic sine of the given input tensor element-wise. + """ + + def __init__(self, input): + super().__init__("Sinh", 1, [{"at::kDouble", "at::kHalf", "at::kFloat"}], input) - def __init__(self, input): - super().__init__('Sinh', 1, - [{'at::kDouble', 'at::kHalf', 'at::kFloat'}], - input) class Size(ONNXOp): - """ - Takes a tensor as input and outputs a int64 scalar that equals to the total number of elements of the input tensor. - """ + """ + Takes a tensor as input and outputs a int64 scalar that equals to the total number of elements of the input tensor. + """ + + def __init__(self, data): + super().__init__( + "Size", + 1, + [ + { + "at::kDouble", + "at::kLong", + "at::kByte", + "at::kInt", + "at::kHalf", + "at::kShort", + "at::kBool", + "at::kFloat", + "at::kBFloat16", + } + ], + data, + ) - def __init__(self, data): - super().__init__('Size', 1, - [{'at::kDouble', 'at::kLong', 'at::kByte', 'at::kInt', 'at::kHalf', 'at::kShort', 'at::kBool', 'at::kFloat', 'at::kBFloat16'}], - data) class Slice(ONNXOp): - """ - Produces a slice of the input tensor along multiple axes. Similar to numpy: - https://docs.scipy.org/doc/numpy/reference/arrays.indexing.html - Slices uses `starts`, `ends`, `axes` and `steps` inputs to specify the start and end - dimension and step for each axis in the list of axes, it uses this information to - slice the input `data` tensor. If a negative value is passed for any of the - start or end indices, it represents number of elements before the end of that - dimension. If the value passed to start or end is larger than the `n` (the - number of elements in this dimension), it represents `n`. For slicing to the - end of a dimension with unknown size, it is recommended to pass in `INT_MAX` - when sclicing forward and 'INT_MIN' when slicing backward. - If a negative value is passed for step, it represents slicing backward. - However step value cannot be 0. - If `axes` are omitted, they are set to `[0, ..., ndim-1]`. - If `steps` are omitted, they are set to `[1, ..., 1]` of length `len(starts)` - Example 1: - data = [ - [1, 2, 3, 4], - [5, 6, 7, 8], - ] - axes = [0, 1] - starts = [1, 0] - ends = [2, 3] - steps = [1, 2] - result = [ - [5, 7], - ] - Example 2: - data = [ - [1, 2, 3, 4], - [5, 6, 7, 8], - ] - starts = [0, 1] - ends = [-1, 1000] - result = [ - [2, 3, 4], - ] - """ - - def __init__(self, data, starts, ends, axes, steps): - super().__init__('Slice', 1, - [{'at::kDouble', 'at::kLong', 'at::kByte', 'at::kInt', 'at::kHalf', 'at::kShort', 'at::kBool', 'at::kFloat', 'at::kBFloat16'}, {'at::kLong', 'at::kInt'}, {'at::kLong', 'at::kInt'}, {'at::kLong', 'at::kInt'}, {'at::kLong', 'at::kInt'}], - data,starts,ends,axes,steps) + """ + Produces a slice of the input tensor along multiple axes. Similar to numpy: + https://docs.scipy.org/doc/numpy/reference/arrays.indexing.html + Slices uses `starts`, `ends`, `axes` and `steps` inputs to specify the start and end + dimension and step for each axis in the list of axes, it uses this information to + slice the input `data` tensor. If a negative value is passed for any of the + start or end indices, it represents number of elements before the end of that + dimension. If the value passed to start or end is larger than the `n` (the + number of elements in this dimension), it represents `n`. For slicing to the + end of a dimension with unknown size, it is recommended to pass in `INT_MAX` + when sclicing forward and 'INT_MIN' when slicing backward. + If a negative value is passed for step, it represents slicing backward. + However step value cannot be 0. + If `axes` are omitted, they are set to `[0, ..., ndim-1]`. + If `steps` are omitted, they are set to `[1, ..., 1]` of length `len(starts)` + Example 1: + data = [ + [1, 2, 3, 4], + [5, 6, 7, 8], + ] + axes = [0, 1] + starts = [1, 0] + ends = [2, 3] + steps = [1, 2] + result = [ + [5, 7], + ] + Example 2: + data = [ + [1, 2, 3, 4], + [5, 6, 7, 8], + ] + starts = [0, 1] + ends = [-1, 1000] + result = [ + [2, 3, 4], + ] + """ + + def __init__(self, data, starts, ends, axes, steps): + super().__init__( + "Slice", + 1, + [ + { + "at::kDouble", + "at::kLong", + "at::kByte", + "at::kInt", + "at::kHalf", + "at::kShort", + "at::kBool", + "at::kFloat", + "at::kBFloat16", + }, + {"at::kLong", "at::kInt"}, + {"at::kLong", "at::kInt"}, + {"at::kLong", "at::kInt"}, + {"at::kLong", "at::kInt"}, + ], + data, + starts, + ends, + axes, + steps, + ) + class Softmax(ONNXOp): - """ - The operator computes the normalized exponential values for the given input: - - Softmax(input, axis) = Exp(input) / ReduceSum(Exp(input), axis=axis, keepdims=1) - - The input does not need to explicitly be a 2D vector. The "axis" attribute - indicates the dimension along which Softmax will be performed. - The output tensor has the same shape - and contains the Softmax values of the corresponding input. - """ - - def __init__(self, input, - axis=None): - super().__init__('Softmax', 1, - [{'at::kDouble', 'at::kBFloat16', 'at::kHalf', 'at::kFloat'}], - input, - axis=ONNXAttr(axis, AttrType.INT)) + """ + The operator computes the normalized exponential values for the given input: + + Softmax(input, axis) = Exp(input) / ReduceSum(Exp(input), axis=axis, keepdims=1) + + The input does not need to explicitly be a 2D vector. The "axis" attribute + indicates the dimension along which Softmax will be performed. + The output tensor has the same shape + and contains the Softmax values of the corresponding input. + """ + + def __init__(self, input, axis=None): + super().__init__( + "Softmax", + 1, + [{"at::kDouble", "at::kBFloat16", "at::kHalf", "at::kFloat"}], + input, + axis=ONNXAttr(axis, AttrType.INT), + ) + class SoftmaxCrossEntropyLoss(ONNXOp): - """ - Loss function that measures the softmax cross entropy - between 'scores' and 'labels'. - This operator first computes a loss tensor whose shape is identical to the labels input. - If the input is 2-D with shape (N, C), the loss tensor may be a N-element vector L = (l_1, l_2, ..., l_N). - If the input is N-D tensor with shape (N, C, D1, D2, ..., Dk), - the loss tensor L may have (N, D1, D2, ..., Dk) as its shape and L[i,][j_1][j_2]...[j_k] denotes a scalar element in L. - After L is available, this operator can optionally do a reduction operator. - - shape(scores): (N, C) where C is the number of classes, or (N, C, D1, D2,..., Dk), - with K >= 1 in case of K-dimensional loss. - shape(labels): (N) where each value is 0 <= labels[i] <= C-1, or (N, D1, D2,..., Dk), - with K >= 1 in case of K-dimensional loss. - - The loss for one sample, l_i, can caculated as follows: - l[i][d1][d2]...[dk] = -y[i][c][d1][d2]..[dk], where i is the index of classes. - or - l[i][d1][d2]...[dk] = -y[i][c][d1][d2]..[dk] * weights[c], if 'weights' is provided. - - loss is zero for the case when label-value equals ignore_index. - l[i][d1][d2]...[dk] = 0, when labels[n][d1][d2]...[dk] = ignore_index - - where: - p = Softmax(scores) - y = Log(p) - c = labels[i][d1][d2]...[dk] - - Finally, L is optionally reduced: - If reduction = 'none', the output is L with shape (N, D1, D2, ..., Dk). - If reduction = 'sum', the output is scalar: Sum(L). - If reduction = 'mean', the output is scalar: ReduceMean(L), or if weight is provided: ReduceSum(L) / ReduceSum(W), - where tensor W is of shape (N, D1, D2, ..., Dk) and W[n][d1][d2]...[dk] = weights[labels[i][d1][d2]...[dk]]. - """ - - def __init__(self, scores, labels, weights, - ignore_index=None, - reduction=None): - super().__init__('SoftmaxCrossEntropyLoss', 2, - [{'at::kDouble', 'at::kBFloat16', 'at::kHalf', 'at::kFloat'}, {'at::kLong', 'at::kInt'}, {'at::kDouble', 'at::kBFloat16', 'at::kHalf', 'at::kFloat'}], - scores,labels,weights, - ignore_index=ONNXAttr(ignore_index, AttrType.INT), - reduction=ONNXAttr(reduction, AttrType.STRING)) + """ + Loss function that measures the softmax cross entropy + between 'scores' and 'labels'. + This operator first computes a loss tensor whose shape is identical to the labels input. + If the input is 2-D with shape (N, C), the loss tensor may be a N-element vector L = (l_1, l_2, ..., l_N). + If the input is N-D tensor with shape (N, C, D1, D2, ..., Dk), + the loss tensor L may have (N, D1, D2, ..., Dk) as its shape and L[i,][j_1][j_2]...[j_k] denotes a scalar element in L. + After L is available, this operator can optionally do a reduction operator. + + shape(scores): (N, C) where C is the number of classes, or (N, C, D1, D2,..., Dk), + with K >= 1 in case of K-dimensional loss. + shape(labels): (N) where each value is 0 <= labels[i] <= C-1, or (N, D1, D2,..., Dk), + with K >= 1 in case of K-dimensional loss. + + The loss for one sample, l_i, can caculated as follows: + l[i][d1][d2]...[dk] = -y[i][c][d1][d2]..[dk], where i is the index of classes. + or + l[i][d1][d2]...[dk] = -y[i][c][d1][d2]..[dk] * weights[c], if 'weights' is provided. + + loss is zero for the case when label-value equals ignore_index. + l[i][d1][d2]...[dk] = 0, when labels[n][d1][d2]...[dk] = ignore_index + + where: + p = Softmax(scores) + y = Log(p) + c = labels[i][d1][d2]...[dk] + + Finally, L is optionally reduced: + If reduction = 'none', the output is L with shape (N, D1, D2, ..., Dk). + If reduction = 'sum', the output is scalar: Sum(L). + If reduction = 'mean', the output is scalar: ReduceMean(L), or if weight is provided: ReduceSum(L) / ReduceSum(W), + where tensor W is of shape (N, D1, D2, ..., Dk) and W[n][d1][d2]...[dk] = weights[labels[i][d1][d2]...[dk]]. + """ + + def __init__(self, scores, labels, weights, ignore_index=None, reduction=None): + super().__init__( + "SoftmaxCrossEntropyLoss", + 2, + [ + {"at::kDouble", "at::kBFloat16", "at::kHalf", "at::kFloat"}, + {"at::kLong", "at::kInt"}, + {"at::kDouble", "at::kBFloat16", "at::kHalf", "at::kFloat"}, + ], + scores, + labels, + weights, + ignore_index=ONNXAttr(ignore_index, AttrType.INT), + reduction=ONNXAttr(reduction, AttrType.STRING), + ) + class Softplus(ONNXOp): - """ - Softplus takes one input data (Tensor) and produces one output data - (Tensor) where the softplus function, y = ln(exp(x) + 1), is applied to - the tensor elementwise. - """ + """ + Softplus takes one input data (Tensor) and produces one output data + (Tensor) where the softplus function, y = ln(exp(x) + 1), is applied to + the tensor elementwise. + """ + + def __init__(self, X): + super().__init__("Softplus", 1, [{"at::kDouble", "at::kHalf", "at::kFloat"}], X) - def __init__(self, X): - super().__init__('Softplus', 1, - [{'at::kDouble', 'at::kHalf', 'at::kFloat'}], - X) class Softsign(ONNXOp): - """ - Calculates the softsign (x/(1+|x|)) of the given input tensor element-wise. - """ + """ + Calculates the softsign (x/(1+|x|)) of the given input tensor element-wise. + """ + + def __init__(self, input): + super().__init__("Softsign", 1, [{"at::kDouble", "at::kHalf", "at::kFloat"}], input) - def __init__(self, input): - super().__init__('Softsign', 1, - [{'at::kDouble', 'at::kHalf', 'at::kFloat'}], - input) class SpaceToDepth(ONNXOp): - """ - SpaceToDepth rearranges blocks of spatial data into depth. More specifically, - this op outputs a copy of the input tensor where values from the height and width dimensions - are moved to the depth dimension. - """ - - def __init__(self, input, - blocksize=None): - super().__init__('SpaceToDepth', 1, - [{'at::kDouble', 'at::kLong', 'at::kByte', 'at::kInt', 'at::kHalf', 'at::kShort', 'at::kBool', 'at::kFloat', 'at::kBFloat16'}], - input, - blocksize=ONNXAttr(blocksize, AttrType.INT)) + """ + SpaceToDepth rearranges blocks of spatial data into depth. More specifically, + this op outputs a copy of the input tensor where values from the height and width dimensions + are moved to the depth dimension. + """ + + def __init__(self, input, blocksize=None): + super().__init__( + "SpaceToDepth", + 1, + [ + { + "at::kDouble", + "at::kLong", + "at::kByte", + "at::kInt", + "at::kHalf", + "at::kShort", + "at::kBool", + "at::kFloat", + "at::kBFloat16", + } + ], + input, + blocksize=ONNXAttr(blocksize, AttrType.INT), + ) + class Split(ONNXOp): - """ - Split a tensor into a list of tensors, along the specified - 'axis'. Lengths of the parts can be specified using input 'split'. - Otherwise, the tensor is split to equal sized parts. - """ - - def __init__(self, input, split, - axis=None): - super().__init__('Split', 1, - [{'at::kDouble', 'at::kLong', 'at::kByte', 'at::kInt', 'at::kHalf', 'at::kShort', 'at::kBool', 'at::kFloat', 'at::kBFloat16'}, {'at::kLong'}], - input,split, - axis=ONNXAttr(axis, AttrType.INT)) + """ + Split a tensor into a list of tensors, along the specified + 'axis'. Lengths of the parts can be specified using input 'split'. + Otherwise, the tensor is split to equal sized parts. + """ + + def __init__(self, input, split, axis=None): + super().__init__( + "Split", + 1, + [ + { + "at::kDouble", + "at::kLong", + "at::kByte", + "at::kInt", + "at::kHalf", + "at::kShort", + "at::kBool", + "at::kFloat", + "at::kBFloat16", + }, + {"at::kLong"}, + ], + input, + split, + axis=ONNXAttr(axis, AttrType.INT), + ) + class SplitToSequence(ONNXOp): - """ - Split a tensor into a sequence of tensors, along the specified - 'axis'. Lengths of the parts can be specified using argument 'split'. - 'split' must contain only positive numbers. - 'split' is either a scalar (tensor of empty shape), or a 1-D tensor. - If 'split' is a scalar, then 'input' will be split into equally sized chunks(if possible). - Last chunk will be smaller if the 'input' size along the given axis 'axis' is not divisible - by 'split'. - Otherwise, the tensor is split into 'size(split)' chunks, with lengths of the parts on 'axis' - specified in 'split'. In this scenario, the sum of entries in 'split' must be equal to the - dimension size of input tensor on 'axis'. - """ - - def __init__(self, input, split, - axis=None, - keepdims=None): - super().__init__('SplitToSequence', 1, - [{'at::kDouble', 'at::kLong', 'at::kByte', 'at::kInt', 'at::kHalf', 'at::kShort', 'at::kBool', 'at::kFloat'}, {'at::kLong', 'at::kInt'}], - input,split, - axis=ONNXAttr(axis, AttrType.INT), - keepdims=ONNXAttr(keepdims, AttrType.INT)) + """ + Split a tensor into a sequence of tensors, along the specified + 'axis'. Lengths of the parts can be specified using argument 'split'. + 'split' must contain only positive numbers. + 'split' is either a scalar (tensor of empty shape), or a 1-D tensor. + If 'split' is a scalar, then 'input' will be split into equally sized chunks(if possible). + Last chunk will be smaller if the 'input' size along the given axis 'axis' is not divisible + by 'split'. + Otherwise, the tensor is split into 'size(split)' chunks, with lengths of the parts on 'axis' + specified in 'split'. In this scenario, the sum of entries in 'split' must be equal to the + dimension size of input tensor on 'axis'. + """ + + def __init__(self, input, split, axis=None, keepdims=None): + super().__init__( + "SplitToSequence", + 1, + [ + { + "at::kDouble", + "at::kLong", + "at::kByte", + "at::kInt", + "at::kHalf", + "at::kShort", + "at::kBool", + "at::kFloat", + }, + {"at::kLong", "at::kInt"}, + ], + input, + split, + axis=ONNXAttr(axis, AttrType.INT), + keepdims=ONNXAttr(keepdims, AttrType.INT), + ) + class Sqrt(ONNXOp): - """ - Square root takes one input data (Tensor) and produces one output data - (Tensor) where the square root is, y = x^0.5, is applied to - the tensor elementwise. If x is negative, then it will return NaN. - """ + """ + Square root takes one input data (Tensor) and produces one output data + (Tensor) where the square root is, y = x^0.5, is applied to + the tensor elementwise. If x is negative, then it will return NaN. + """ + + def __init__(self, X): + super().__init__("Sqrt", 1, [{"at::kDouble", "at::kBFloat16", "at::kHalf", "at::kFloat"}], X) - def __init__(self, X): - super().__init__('Sqrt', 1, - [{'at::kDouble', 'at::kBFloat16', 'at::kHalf', 'at::kFloat'}], - X) class Squeeze(ONNXOp): - """ - Remove single-dimensional entries from the shape of a tensor. - Takes an input `axes` with a list of axes to squeeze. - If `axes` is not provided, all the single dimensions will be removed from - the shape. If an axis is selected with shape entry not equal to one, an error is raised. - """ - - def __init__(self, data, axes): - super().__init__('Squeeze', 1, - [{'at::kDouble', 'at::kLong', 'at::kByte', 'at::kInt', 'at::kHalf', 'at::kShort', 'at::kBool', 'at::kFloat', 'at::kBFloat16'}, {'at::kLong'}], - data,axes) + """ + Remove single-dimensional entries from the shape of a tensor. + Takes an input `axes` with a list of axes to squeeze. + If `axes` is not provided, all the single dimensions will be removed from + the shape. If an axis is selected with shape entry not equal to one, an error is raised. + """ + + def __init__(self, data, axes): + super().__init__( + "Squeeze", + 1, + [ + { + "at::kDouble", + "at::kLong", + "at::kByte", + "at::kInt", + "at::kHalf", + "at::kShort", + "at::kBool", + "at::kFloat", + "at::kBFloat16", + }, + {"at::kLong"}, + ], + data, + axes, + ) + class StringNormalizer(ONNXOp): - """ - StringNormalization performs string operations for basic cleaning. - This operator has only one input (denoted by X) and only one output - (denoted by Y). This operator first examines the elements in the X, - and removes elements specified in "stopwords" attribute. - After removing stop words, the intermediate result can be further lowercased, - uppercased, or just returned depending the "case_change_action" attribute. - This operator only accepts [C]- and [1, C]-tensor. - If all elements in X are dropped, the output will be the empty value of string tensor with shape [1] - if input shape is [C] and shape [1, 1] if input shape is [1, C]. - """ - - def __init__(self, X, - case_change_action=None, - is_case_sensitive=None, - locale=None, - stopwords=None): - super().__init__('StringNormalizer', 1, - [set()], - X, - case_change_action=ONNXAttr(case_change_action, AttrType.STRING), - is_case_sensitive=ONNXAttr(is_case_sensitive, AttrType.INT), - locale=ONNXAttr(locale, AttrType.STRING), - stopwords=ONNXAttr(stopwords, AttrType.STRINGS)) + """ + StringNormalization performs string operations for basic cleaning. + This operator has only one input (denoted by X) and only one output + (denoted by Y). This operator first examines the elements in the X, + and removes elements specified in "stopwords" attribute. + After removing stop words, the intermediate result can be further lowercased, + uppercased, or just returned depending the "case_change_action" attribute. + This operator only accepts [C]- and [1, C]-tensor. + If all elements in X are dropped, the output will be the empty value of string tensor with shape [1] + if input shape is [C] and shape [1, 1] if input shape is [1, C]. + """ + + def __init__(self, X, case_change_action=None, is_case_sensitive=None, locale=None, stopwords=None): + super().__init__( + "StringNormalizer", + 1, + [set()], + X, + case_change_action=ONNXAttr(case_change_action, AttrType.STRING), + is_case_sensitive=ONNXAttr(is_case_sensitive, AttrType.INT), + locale=ONNXAttr(locale, AttrType.STRING), + stopwords=ONNXAttr(stopwords, AttrType.STRINGS), + ) + class Sub(ONNXOp): - """ - Performs element-wise binary subtraction (with Numpy-style broadcasting support). - - This operator supports **multidirectional (i.e., Numpy-style) broadcasting**; for more details please check [the doc](Broadcasting.md). - - (Opset 14 change): Extend supported types to include uint8, int8, uint16, and int16. - """ - - def __init__(self, A, B): - super().__init__('Sub', 1, - [{'at::kDouble', 'at::kLong', 'at::kByte', 'at::kInt', 'at::kHalf', 'at::kShort', 'at::kFloat', 'at::kBFloat16'}, {'at::kDouble', 'at::kLong', 'at::kByte', 'at::kInt', 'at::kHalf', 'at::kShort', 'at::kFloat', 'at::kBFloat16'}], - A,B) + """ + Performs element-wise binary subtraction (with Numpy-style broadcasting support). + + This operator supports **multidirectional (i.e., Numpy-style) broadcasting**; for more details please check [the doc](Broadcasting.md). + + (Opset 14 change): Extend supported types to include uint8, int8, uint16, and int16. + """ + + def __init__(self, A, B): + super().__init__( + "Sub", + 1, + [ + { + "at::kDouble", + "at::kLong", + "at::kByte", + "at::kInt", + "at::kHalf", + "at::kShort", + "at::kFloat", + "at::kBFloat16", + }, + { + "at::kDouble", + "at::kLong", + "at::kByte", + "at::kInt", + "at::kHalf", + "at::kShort", + "at::kFloat", + "at::kBFloat16", + }, + ], + A, + B, + ) + class Sum(ONNXOp): - """ - Element-wise sum of each of the input tensors (with Numpy-style broadcasting support). - All inputs and outputs must have the same data type. - This operator supports **multidirectional (i.e., Numpy-style) broadcasting**; for more details please check [the doc](Broadcasting.md). - """ + """ + Element-wise sum of each of the input tensors (with Numpy-style broadcasting support). + All inputs and outputs must have the same data type. + This operator supports **multidirectional (i.e., Numpy-style) broadcasting**; for more details please check [the doc](Broadcasting.md). + """ + + def __init__(self, data_0): + super().__init__("Sum", 1, [{"at::kDouble", "at::kBFloat16", "at::kHalf", "at::kFloat"}], data_0) - def __init__(self, data_0): - super().__init__('Sum', 1, - [{'at::kDouble', 'at::kBFloat16', 'at::kHalf', 'at::kFloat'}], - data_0) class SVMClassifier(ONNXOp): - """ - Support Vector Machine classifier - """ - - def __init__(self, X, - classlabels_ints=None, - classlabels_strings=None, - coefficients=None, - kernel_params=None, - kernel_type=None, - post_transform=None, - prob_a=None, - prob_b=None, - rho=None, - support_vectors=None, - vectors_per_class=None): - super().__init__('SVMClassifier', 2, - [{'at::kDouble', 'at::kLong', 'at::kInt', 'at::kFloat'}], - X, - classlabels_ints=ONNXAttr(classlabels_ints, AttrType.INTS), - classlabels_strings=ONNXAttr(classlabels_strings, AttrType.STRINGS), - coefficients=ONNXAttr(coefficients, AttrType.FLOATS), - kernel_params=ONNXAttr(kernel_params, AttrType.FLOATS), - kernel_type=ONNXAttr(kernel_type, AttrType.STRING), - post_transform=ONNXAttr(post_transform, AttrType.STRING), - prob_a=ONNXAttr(prob_a, AttrType.FLOATS), - prob_b=ONNXAttr(prob_b, AttrType.FLOATS), - rho=ONNXAttr(rho, AttrType.FLOATS), - support_vectors=ONNXAttr(support_vectors, AttrType.FLOATS), - vectors_per_class=ONNXAttr(vectors_per_class, AttrType.INTS)) + """ + Support Vector Machine classifier + """ + + def __init__( + self, + X, + classlabels_ints=None, + classlabels_strings=None, + coefficients=None, + kernel_params=None, + kernel_type=None, + post_transform=None, + prob_a=None, + prob_b=None, + rho=None, + support_vectors=None, + vectors_per_class=None, + ): + super().__init__( + "SVMClassifier", + 2, + [{"at::kDouble", "at::kLong", "at::kInt", "at::kFloat"}], + X, + classlabels_ints=ONNXAttr(classlabels_ints, AttrType.INTS), + classlabels_strings=ONNXAttr(classlabels_strings, AttrType.STRINGS), + coefficients=ONNXAttr(coefficients, AttrType.FLOATS), + kernel_params=ONNXAttr(kernel_params, AttrType.FLOATS), + kernel_type=ONNXAttr(kernel_type, AttrType.STRING), + post_transform=ONNXAttr(post_transform, AttrType.STRING), + prob_a=ONNXAttr(prob_a, AttrType.FLOATS), + prob_b=ONNXAttr(prob_b, AttrType.FLOATS), + rho=ONNXAttr(rho, AttrType.FLOATS), + support_vectors=ONNXAttr(support_vectors, AttrType.FLOATS), + vectors_per_class=ONNXAttr(vectors_per_class, AttrType.INTS), + ) + class SVMRegressor(ONNXOp): - """ - Support Vector Machine regression prediction and one-class SVM anomaly detection. - """ - - def __init__(self, X, - coefficients=None, - kernel_params=None, - kernel_type=None, - n_supports=None, - one_class=None, - post_transform=None, - rho=None, - support_vectors=None): - super().__init__('SVMRegressor', 1, - [{'at::kDouble', 'at::kLong', 'at::kInt', 'at::kFloat'}], - X, - coefficients=ONNXAttr(coefficients, AttrType.FLOATS), - kernel_params=ONNXAttr(kernel_params, AttrType.FLOATS), - kernel_type=ONNXAttr(kernel_type, AttrType.STRING), - n_supports=ONNXAttr(n_supports, AttrType.INT), - one_class=ONNXAttr(one_class, AttrType.INT), - post_transform=ONNXAttr(post_transform, AttrType.STRING), - rho=ONNXAttr(rho, AttrType.FLOATS), - support_vectors=ONNXAttr(support_vectors, AttrType.FLOATS)) + """ + Support Vector Machine regression prediction and one-class SVM anomaly detection. + """ + + def __init__( + self, + X, + coefficients=None, + kernel_params=None, + kernel_type=None, + n_supports=None, + one_class=None, + post_transform=None, + rho=None, + support_vectors=None, + ): + super().__init__( + "SVMRegressor", + 1, + [{"at::kDouble", "at::kLong", "at::kInt", "at::kFloat"}], + X, + coefficients=ONNXAttr(coefficients, AttrType.FLOATS), + kernel_params=ONNXAttr(kernel_params, AttrType.FLOATS), + kernel_type=ONNXAttr(kernel_type, AttrType.STRING), + n_supports=ONNXAttr(n_supports, AttrType.INT), + one_class=ONNXAttr(one_class, AttrType.INT), + post_transform=ONNXAttr(post_transform, AttrType.STRING), + rho=ONNXAttr(rho, AttrType.FLOATS), + support_vectors=ONNXAttr(support_vectors, AttrType.FLOATS), + ) + class Tan(ONNXOp): - """ - Calculates the tangent of the given input tensor, element-wise. - """ + """ + Calculates the tangent of the given input tensor, element-wise. + """ + + def __init__(self, input): + super().__init__("Tan", 1, [{"at::kDouble", "at::kHalf", "at::kFloat"}], input) - def __init__(self, input): - super().__init__('Tan', 1, - [{'at::kDouble', 'at::kHalf', 'at::kFloat'}], - input) class Tanh(ONNXOp): - """ - Calculates the hyperbolic tangent of the given input tensor element-wise. - """ + """ + Calculates the hyperbolic tangent of the given input tensor element-wise. + """ + + def __init__(self, input): + super().__init__("Tanh", 1, [{"at::kDouble", "at::kBFloat16", "at::kHalf", "at::kFloat"}], input) - def __init__(self, input): - super().__init__('Tanh', 1, - [{'at::kDouble', 'at::kBFloat16', 'at::kHalf', 'at::kFloat'}], - input) class TfIdfVectorizer(ONNXOp): - """ - This transform extracts n-grams from the input sequence and save them as a vector. Input can - be either a 1-D or 2-D tensor. For 1-D input, output is the n-gram representation of that input. - For 2-D input, the output is also a 2-D tensor whose i-th row is the n-gram representation of the i-th input row. - More specifically, if input shape is [C], the corresponding output shape would be [max(ngram_indexes) + 1]. - If input shape is [N, C], this operator produces a [N, max(ngram_indexes) + 1]-tensor. - - In contrast to standard n-gram extraction, here, the indexes of extracting an n-gram from the original - sequence are not necessarily consecutive numbers. The discontinuity between indexes are controlled by the number of skips. - If the number of skips is 2, we should skip two tokens when scanning through the original sequence. - Let's consider an example. Assume that input sequence is [94, 17, 36, 12, 28] and the number of skips is 2. - The associated 2-grams are [94, 12] and [17, 28] respectively indexed by [0, 3] and [1, 4]. - If the number of skips becomes 0, the 2-grams generated are [94, 17], [17, 36], [36, 12], [12, 28] - indexed by [0, 1], [1, 2], [2, 3], [3, 4], respectively. - - The output vector (denoted by Y) stores the count of each n-gram; - Y[ngram_indexes[i]] indicates the times that the i-th n-gram is found. The attribute ngram_indexes is used to determine the mapping - between index i and the corresponding n-gram's output coordinate. If pool_int64s is [94, 17, 17, 36], ngram_indexes is [1, 0], - ngram_counts=[0, 0], then the Y[0] (first element in Y) and Y[1] (second element in Y) are the counts of [17, 36] and [94, 17], - respectively. An n-gram which cannot be found in pool_strings/pool_int64s should be ignored and has no effect on the output. - Note that we may consider all skips up to S when generating the n-grams. - - The examples used above are true if mode is "TF". If mode is "IDF", all the counts larger than 1 would be truncated to 1 and - the i-th element in weights would be used to scale (by multiplication) the count of the i-th n-gram in pool. If mode is "TFIDF", - this operator first computes the counts of all n-grams and then scale them by the associated values in the weights attribute. - - Only one of pool_strings and pool_int64s can be set. If pool_int64s is set, the input should be an integer tensor. - If pool_strings is set, the input must be a string tensor. - """ - - def __init__(self, X, - max_gram_length=None, - max_skip_count=None, - min_gram_length=None, - mode=None, - ngram_counts=None, - ngram_indexes=None, - pool_int64s=None, - pool_strings=None, - weights=None): - super().__init__('TfIdfVectorizer', 1, - [{'at::kLong', 'at::kInt'}], - X, - max_gram_length=ONNXAttr(max_gram_length, AttrType.INT), - max_skip_count=ONNXAttr(max_skip_count, AttrType.INT), - min_gram_length=ONNXAttr(min_gram_length, AttrType.INT), - mode=ONNXAttr(mode, AttrType.STRING), - ngram_counts=ONNXAttr(ngram_counts, AttrType.INTS), - ngram_indexes=ONNXAttr(ngram_indexes, AttrType.INTS), - pool_int64s=ONNXAttr(pool_int64s, AttrType.INTS), - pool_strings=ONNXAttr(pool_strings, AttrType.STRINGS), - weights=ONNXAttr(weights, AttrType.FLOATS)) + """ + This transform extracts n-grams from the input sequence and save them as a vector. Input can + be either a 1-D or 2-D tensor. For 1-D input, output is the n-gram representation of that input. + For 2-D input, the output is also a 2-D tensor whose i-th row is the n-gram representation of the i-th input row. + More specifically, if input shape is [C], the corresponding output shape would be [max(ngram_indexes) + 1]. + If input shape is [N, C], this operator produces a [N, max(ngram_indexes) + 1]-tensor. + + In contrast to standard n-gram extraction, here, the indexes of extracting an n-gram from the original + sequence are not necessarily consecutive numbers. The discontinuity between indexes are controlled by the number of skips. + If the number of skips is 2, we should skip two tokens when scanning through the original sequence. + Let's consider an example. Assume that input sequence is [94, 17, 36, 12, 28] and the number of skips is 2. + The associated 2-grams are [94, 12] and [17, 28] respectively indexed by [0, 3] and [1, 4]. + If the number of skips becomes 0, the 2-grams generated are [94, 17], [17, 36], [36, 12], [12, 28] + indexed by [0, 1], [1, 2], [2, 3], [3, 4], respectively. + + The output vector (denoted by Y) stores the count of each n-gram; + Y[ngram_indexes[i]] indicates the times that the i-th n-gram is found. The attribute ngram_indexes is used to determine the mapping + between index i and the corresponding n-gram's output coordinate. If pool_int64s is [94, 17, 17, 36], ngram_indexes is [1, 0], + ngram_counts=[0, 0], then the Y[0] (first element in Y) and Y[1] (second element in Y) are the counts of [17, 36] and [94, 17], + respectively. An n-gram which cannot be found in pool_strings/pool_int64s should be ignored and has no effect on the output. + Note that we may consider all skips up to S when generating the n-grams. + + The examples used above are true if mode is "TF". If mode is "IDF", all the counts larger than 1 would be truncated to 1 and + the i-th element in weights would be used to scale (by multiplication) the count of the i-th n-gram in pool. If mode is "TFIDF", + this operator first computes the counts of all n-grams and then scale them by the associated values in the weights attribute. + + Only one of pool_strings and pool_int64s can be set. If pool_int64s is set, the input should be an integer tensor. + If pool_strings is set, the input must be a string tensor. + """ + + def __init__( + self, + X, + max_gram_length=None, + max_skip_count=None, + min_gram_length=None, + mode=None, + ngram_counts=None, + ngram_indexes=None, + pool_int64s=None, + pool_strings=None, + weights=None, + ): + super().__init__( + "TfIdfVectorizer", + 1, + [{"at::kLong", "at::kInt"}], + X, + max_gram_length=ONNXAttr(max_gram_length, AttrType.INT), + max_skip_count=ONNXAttr(max_skip_count, AttrType.INT), + min_gram_length=ONNXAttr(min_gram_length, AttrType.INT), + mode=ONNXAttr(mode, AttrType.STRING), + ngram_counts=ONNXAttr(ngram_counts, AttrType.INTS), + ngram_indexes=ONNXAttr(ngram_indexes, AttrType.INTS), + pool_int64s=ONNXAttr(pool_int64s, AttrType.INTS), + pool_strings=ONNXAttr(pool_strings, AttrType.STRINGS), + weights=ONNXAttr(weights, AttrType.FLOATS), + ) + class ThresholdedRelu(ONNXOp): - """ - ThresholdedRelu takes one input data (Tensor) and produces one output data - (Tensor) where the rectified linear function, y = x for x > alpha, y = 0 otherwise, - is applied to the tensor elementwise. - """ - - def __init__(self, X, - alpha=None): - super().__init__('ThresholdedRelu', 1, - [{'at::kDouble', 'at::kHalf', 'at::kFloat'}], - X, - alpha=ONNXAttr(alpha, AttrType.FLOAT)) + """ + ThresholdedRelu takes one input data (Tensor) and produces one output data + (Tensor) where the rectified linear function, y = x for x > alpha, y = 0 otherwise, + is applied to the tensor elementwise. + """ + + def __init__(self, X, alpha=None): + super().__init__( + "ThresholdedRelu", 1, [{"at::kDouble", "at::kHalf", "at::kFloat"}], X, alpha=ONNXAttr(alpha, AttrType.FLOAT) + ) + class Tile(ONNXOp): - """ - Constructs a tensor by tiling a given tensor. - This is the same as function `tile` in Numpy, but no broadcast. - For example A = [[1, 2], [3, 4]], B = [1, 2], tile(A, B) = [[1, 2, 1, 2], [3, 4, 3, 4]] - """ + """ + Constructs a tensor by tiling a given tensor. + This is the same as function `tile` in Numpy, but no broadcast. + For example A = [[1, 2], [3, 4]], B = [1, 2], tile(A, B) = [[1, 2, 1, 2], [3, 4, 3, 4]] + """ + + def __init__(self, input, repeats): + super().__init__( + "Tile", + 1, + [ + { + "at::kDouble", + "at::kLong", + "at::kByte", + "at::kInt", + "at::kHalf", + "at::kShort", + "at::kBool", + "at::kFloat", + "at::kBFloat16", + }, + {"at::kLong"}, + ], + input, + repeats, + ) - def __init__(self, input, repeats): - super().__init__('Tile', 1, - [{'at::kDouble', 'at::kLong', 'at::kByte', 'at::kInt', 'at::kHalf', 'at::kShort', 'at::kBool', 'at::kFloat', 'at::kBFloat16'}, {'at::kLong'}], - input,repeats) class TopK(ONNXOp): - """ - Retrieve the top-K largest or smallest elements along a specified axis. Given an input tensor of - shape [a_1, a_2, ..., a_n, r] and integer argument k, return two outputs: - -Value tensor of shape [a_1, a_2, ..., a_{axis-1}, k, a_{axis+1}, ... a_n] - which contains the values of the top k elements along the specified axis - -Index tensor of shape [a_1, a_2, ..., a_{axis-1}, k, a_{axis+1}, ... a_n] which - contains the indices of the top k elements (original indices from the input - tensor). - - If "largest" is 1 (the default value) then the k largest elements are returned. - If "sorted" is 1 (the default value) then the resulting k elements will be sorted. - If "sorted" is 0, order of returned 'Values' and 'Indices' are undefined. - - Given two equivalent values, this operator uses the indices along the axis as - a tiebreaker. That is, the element with the lower index will appear first. - """ - - def __init__(self, X, K, - axis=None, - largest=None, - sorted=None): - super().__init__('TopK', 2, - [{'at::kDouble', 'at::kLong', 'at::kByte', 'at::kInt', 'at::kHalf', 'at::kShort', 'at::kFloat'}, {'at::kLong'}], - X,K, - axis=ONNXAttr(axis, AttrType.INT), - largest=ONNXAttr(largest, AttrType.INT), - sorted=ONNXAttr(sorted, AttrType.INT)) + """ + Retrieve the top-K largest or smallest elements along a specified axis. Given an input tensor of + shape [a_1, a_2, ..., a_n, r] and integer argument k, return two outputs: + -Value tensor of shape [a_1, a_2, ..., a_{axis-1}, k, a_{axis+1}, ... a_n] + which contains the values of the top k elements along the specified axis + -Index tensor of shape [a_1, a_2, ..., a_{axis-1}, k, a_{axis+1}, ... a_n] which + contains the indices of the top k elements (original indices from the input + tensor). + + If "largest" is 1 (the default value) then the k largest elements are returned. + If "sorted" is 1 (the default value) then the resulting k elements will be sorted. + If "sorted" is 0, order of returned 'Values' and 'Indices' are undefined. + + Given two equivalent values, this operator uses the indices along the axis as + a tiebreaker. That is, the element with the lower index will appear first. + """ + + def __init__(self, X, K, axis=None, largest=None, sorted=None): + super().__init__( + "TopK", + 2, + [ + {"at::kDouble", "at::kLong", "at::kByte", "at::kInt", "at::kHalf", "at::kShort", "at::kFloat"}, + {"at::kLong"}, + ], + X, + K, + axis=ONNXAttr(axis, AttrType.INT), + largest=ONNXAttr(largest, AttrType.INT), + sorted=ONNXAttr(sorted, AttrType.INT), + ) + class Transpose(ONNXOp): - """ - Transpose the input tensor similar to numpy.transpose. For example, when - perm=(1, 0, 2), given an input tensor of shape (1, 2, 3), the output shape - will be (2, 1, 3). - """ - - def __init__(self, data, - perm=None): - super().__init__('Transpose', 1, - [{'at::kDouble', 'at::kLong', 'at::kByte', 'at::kInt', 'at::kHalf', 'at::kShort', 'at::kBool', 'at::kFloat', 'at::kBFloat16'}], - data, - perm=ONNXAttr(perm, AttrType.INTS)) + """ + Transpose the input tensor similar to numpy.transpose. For example, when + perm=(1, 0, 2), given an input tensor of shape (1, 2, 3), the output shape + will be (2, 1, 3). + """ + + def __init__(self, data, perm=None): + super().__init__( + "Transpose", + 1, + [ + { + "at::kDouble", + "at::kLong", + "at::kByte", + "at::kInt", + "at::kHalf", + "at::kShort", + "at::kBool", + "at::kFloat", + "at::kBFloat16", + } + ], + data, + perm=ONNXAttr(perm, AttrType.INTS), + ) + class TreeEnsembleClassifier(ONNXOp): - """ - Tree Ensemble classifier. Returns the top class for each of N inputs.
- The attributes named 'nodes_X' form a sequence of tuples, associated by - index into the sequences, which must all be of equal length. These tuples - define the nodes.
- Similarly, all fields prefixed with 'class_' are tuples of votes at the leaves. - A leaf may have multiple votes, where each vote is weighted by - the associated class_weights index.
- One and only one of classlabels_strings or classlabels_int64s - will be defined. The class_ids are indices into this list. - """ - - def __init__(self, X, - base_values=None, - class_ids=None, - class_nodeids=None, - class_treeids=None, - class_weights=None, - classlabels_int64s=None, - classlabels_strings=None, - nodes_falsenodeids=None, - nodes_featureids=None, - nodes_hitrates=None, - nodes_missing_value_tracks_true=None, - nodes_modes=None, - nodes_nodeids=None, - nodes_treeids=None, - nodes_truenodeids=None, - nodes_values=None, - post_transform=None): - super().__init__('TreeEnsembleClassifier', 2, - [{'at::kDouble', 'at::kLong', 'at::kInt', 'at::kFloat'}], - X, - base_values=ONNXAttr(base_values, AttrType.FLOATS), - class_ids=ONNXAttr(class_ids, AttrType.INTS), - class_nodeids=ONNXAttr(class_nodeids, AttrType.INTS), - class_treeids=ONNXAttr(class_treeids, AttrType.INTS), - class_weights=ONNXAttr(class_weights, AttrType.FLOATS), - classlabels_int64s=ONNXAttr(classlabels_int64s, AttrType.INTS), - classlabels_strings=ONNXAttr(classlabels_strings, AttrType.STRINGS), - nodes_falsenodeids=ONNXAttr(nodes_falsenodeids, AttrType.INTS), - nodes_featureids=ONNXAttr(nodes_featureids, AttrType.INTS), - nodes_hitrates=ONNXAttr(nodes_hitrates, AttrType.FLOATS), - nodes_missing_value_tracks_true=ONNXAttr(nodes_missing_value_tracks_true, AttrType.INTS), - nodes_modes=ONNXAttr(nodes_modes, AttrType.STRINGS), - nodes_nodeids=ONNXAttr(nodes_nodeids, AttrType.INTS), - nodes_treeids=ONNXAttr(nodes_treeids, AttrType.INTS), - nodes_truenodeids=ONNXAttr(nodes_truenodeids, AttrType.INTS), - nodes_values=ONNXAttr(nodes_values, AttrType.FLOATS), - post_transform=ONNXAttr(post_transform, AttrType.STRING)) + """ + Tree Ensemble classifier. Returns the top class for each of N inputs.
+ The attributes named 'nodes_X' form a sequence of tuples, associated by + index into the sequences, which must all be of equal length. These tuples + define the nodes.
+ Similarly, all fields prefixed with 'class_' are tuples of votes at the leaves. + A leaf may have multiple votes, where each vote is weighted by + the associated class_weights index.
+ One and only one of classlabels_strings or classlabels_int64s + will be defined. The class_ids are indices into this list. + """ + + def __init__( + self, + X, + base_values=None, + class_ids=None, + class_nodeids=None, + class_treeids=None, + class_weights=None, + classlabels_int64s=None, + classlabels_strings=None, + nodes_falsenodeids=None, + nodes_featureids=None, + nodes_hitrates=None, + nodes_missing_value_tracks_true=None, + nodes_modes=None, + nodes_nodeids=None, + nodes_treeids=None, + nodes_truenodeids=None, + nodes_values=None, + post_transform=None, + ): + super().__init__( + "TreeEnsembleClassifier", + 2, + [{"at::kDouble", "at::kLong", "at::kInt", "at::kFloat"}], + X, + base_values=ONNXAttr(base_values, AttrType.FLOATS), + class_ids=ONNXAttr(class_ids, AttrType.INTS), + class_nodeids=ONNXAttr(class_nodeids, AttrType.INTS), + class_treeids=ONNXAttr(class_treeids, AttrType.INTS), + class_weights=ONNXAttr(class_weights, AttrType.FLOATS), + classlabels_int64s=ONNXAttr(classlabels_int64s, AttrType.INTS), + classlabels_strings=ONNXAttr(classlabels_strings, AttrType.STRINGS), + nodes_falsenodeids=ONNXAttr(nodes_falsenodeids, AttrType.INTS), + nodes_featureids=ONNXAttr(nodes_featureids, AttrType.INTS), + nodes_hitrates=ONNXAttr(nodes_hitrates, AttrType.FLOATS), + nodes_missing_value_tracks_true=ONNXAttr(nodes_missing_value_tracks_true, AttrType.INTS), + nodes_modes=ONNXAttr(nodes_modes, AttrType.STRINGS), + nodes_nodeids=ONNXAttr(nodes_nodeids, AttrType.INTS), + nodes_treeids=ONNXAttr(nodes_treeids, AttrType.INTS), + nodes_truenodeids=ONNXAttr(nodes_truenodeids, AttrType.INTS), + nodes_values=ONNXAttr(nodes_values, AttrType.FLOATS), + post_transform=ONNXAttr(post_transform, AttrType.STRING), + ) + class TreeEnsembleRegressor(ONNXOp): - """ - Tree Ensemble regressor. Returns the regressed values for each input in N.
- All args with nodes_ are fields of a tuple of tree nodes, and - it is assumed they are the same length, and an index i will decode the - tuple across these inputs. Each node id can appear only once - for each tree id.
- All fields prefixed with target_ are tuples of votes at the leaves.
- A leaf may have multiple votes, where each vote is weighted by - the associated target_weights index.
- All trees must have their node ids start at 0 and increment by 1.
- Mode enum is BRANCH_LEQ, BRANCH_LT, BRANCH_GTE, BRANCH_GT, BRANCH_EQ, BRANCH_NEQ, LEAF - """ - - def __init__(self, X, - aggregate_function=None, - base_values=None, - n_targets=None, - nodes_falsenodeids=None, - nodes_featureids=None, - nodes_hitrates=None, - nodes_missing_value_tracks_true=None, - nodes_modes=None, - nodes_nodeids=None, - nodes_treeids=None, - nodes_truenodeids=None, - nodes_values=None, - post_transform=None, - target_ids=None, - target_nodeids=None, - target_treeids=None, - target_weights=None): - super().__init__('TreeEnsembleRegressor', 1, - [{'at::kDouble', 'at::kLong', 'at::kInt', 'at::kFloat'}], - X, - aggregate_function=ONNXAttr(aggregate_function, AttrType.STRING), - base_values=ONNXAttr(base_values, AttrType.FLOATS), - n_targets=ONNXAttr(n_targets, AttrType.INT), - nodes_falsenodeids=ONNXAttr(nodes_falsenodeids, AttrType.INTS), - nodes_featureids=ONNXAttr(nodes_featureids, AttrType.INTS), - nodes_hitrates=ONNXAttr(nodes_hitrates, AttrType.FLOATS), - nodes_missing_value_tracks_true=ONNXAttr(nodes_missing_value_tracks_true, AttrType.INTS), - nodes_modes=ONNXAttr(nodes_modes, AttrType.STRINGS), - nodes_nodeids=ONNXAttr(nodes_nodeids, AttrType.INTS), - nodes_treeids=ONNXAttr(nodes_treeids, AttrType.INTS), - nodes_truenodeids=ONNXAttr(nodes_truenodeids, AttrType.INTS), - nodes_values=ONNXAttr(nodes_values, AttrType.FLOATS), - post_transform=ONNXAttr(post_transform, AttrType.STRING), - target_ids=ONNXAttr(target_ids, AttrType.INTS), - target_nodeids=ONNXAttr(target_nodeids, AttrType.INTS), - target_treeids=ONNXAttr(target_treeids, AttrType.INTS), - target_weights=ONNXAttr(target_weights, AttrType.FLOATS)) + """ + Tree Ensemble regressor. Returns the regressed values for each input in N.
+ All args with nodes_ are fields of a tuple of tree nodes, and + it is assumed they are the same length, and an index i will decode the + tuple across these inputs. Each node id can appear only once + for each tree id.
+ All fields prefixed with target_ are tuples of votes at the leaves.
+ A leaf may have multiple votes, where each vote is weighted by + the associated target_weights index.
+ All trees must have their node ids start at 0 and increment by 1.
+ Mode enum is BRANCH_LEQ, BRANCH_LT, BRANCH_GTE, BRANCH_GT, BRANCH_EQ, BRANCH_NEQ, LEAF + """ + + def __init__( + self, + X, + aggregate_function=None, + base_values=None, + n_targets=None, + nodes_falsenodeids=None, + nodes_featureids=None, + nodes_hitrates=None, + nodes_missing_value_tracks_true=None, + nodes_modes=None, + nodes_nodeids=None, + nodes_treeids=None, + nodes_truenodeids=None, + nodes_values=None, + post_transform=None, + target_ids=None, + target_nodeids=None, + target_treeids=None, + target_weights=None, + ): + super().__init__( + "TreeEnsembleRegressor", + 1, + [{"at::kDouble", "at::kLong", "at::kInt", "at::kFloat"}], + X, + aggregate_function=ONNXAttr(aggregate_function, AttrType.STRING), + base_values=ONNXAttr(base_values, AttrType.FLOATS), + n_targets=ONNXAttr(n_targets, AttrType.INT), + nodes_falsenodeids=ONNXAttr(nodes_falsenodeids, AttrType.INTS), + nodes_featureids=ONNXAttr(nodes_featureids, AttrType.INTS), + nodes_hitrates=ONNXAttr(nodes_hitrates, AttrType.FLOATS), + nodes_missing_value_tracks_true=ONNXAttr(nodes_missing_value_tracks_true, AttrType.INTS), + nodes_modes=ONNXAttr(nodes_modes, AttrType.STRINGS), + nodes_nodeids=ONNXAttr(nodes_nodeids, AttrType.INTS), + nodes_treeids=ONNXAttr(nodes_treeids, AttrType.INTS), + nodes_truenodeids=ONNXAttr(nodes_truenodeids, AttrType.INTS), + nodes_values=ONNXAttr(nodes_values, AttrType.FLOATS), + post_transform=ONNXAttr(post_transform, AttrType.STRING), + target_ids=ONNXAttr(target_ids, AttrType.INTS), + target_nodeids=ONNXAttr(target_nodeids, AttrType.INTS), + target_treeids=ONNXAttr(target_treeids, AttrType.INTS), + target_weights=ONNXAttr(target_weights, AttrType.FLOATS), + ) + class Trilu(ONNXOp): - """ - Given a 2-D matrix or batches of 2-D matrices, returns the upper or lower triangular part of the tensor(s). - The attribute "upper" determines whether the upper or lower part is retained. If set to true, - the upper triangular matrix is retained. Lower triangular matrix is retained otherwise. - Default value for the "upper" attribute is true. - Trilu takes one input tensor of shape [*, N, M], where * is zero or more batch dimensions. The upper triangular part consists - of the elements on and above the given diagonal (k). The lower triangular part consists of elements on and below the diagonal. - All other elements in the matrix are set to zero. - If k = 0, the triangular part on and above/below the main diagonal is retained. - If upper is set to true, a positive k retains the upper triangular matrix excluding the main diagonal and (k-1) diagonals above it. - A negative k value retains the main diagonal and |k| diagonals below it. - If upper is set to false, a positive k retains the lower triangular matrix including the main diagonal and k diagonals above it. - A negative k value excludes the main diagonal and (|k|-1) diagonals below it. - """ - - def __init__(self, input, k, - upper=None): - super().__init__('Trilu', 1, - [{'at::kDouble', 'at::kLong', 'at::kByte', 'at::kInt', 'at::kHalf', 'at::kShort', 'at::kBool', 'at::kFloat', 'at::kBFloat16'}, {'at::kLong'}], - input,k, - upper=ONNXAttr(upper, AttrType.INT)) + """ + Given a 2-D matrix or batches of 2-D matrices, returns the upper or lower triangular part of the tensor(s). + The attribute "upper" determines whether the upper or lower part is retained. If set to true, + the upper triangular matrix is retained. Lower triangular matrix is retained otherwise. + Default value for the "upper" attribute is true. + Trilu takes one input tensor of shape [*, N, M], where * is zero or more batch dimensions. The upper triangular part consists + of the elements on and above the given diagonal (k). The lower triangular part consists of elements on and below the diagonal. + All other elements in the matrix are set to zero. + If k = 0, the triangular part on and above/below the main diagonal is retained. + If upper is set to true, a positive k retains the upper triangular matrix excluding the main diagonal and (k-1) diagonals above it. + A negative k value retains the main diagonal and |k| diagonals below it. + If upper is set to false, a positive k retains the lower triangular matrix including the main diagonal and k diagonals above it. + A negative k value excludes the main diagonal and (|k|-1) diagonals below it. + """ + + def __init__(self, input, k, upper=None): + super().__init__( + "Trilu", + 1, + [ + { + "at::kDouble", + "at::kLong", + "at::kByte", + "at::kInt", + "at::kHalf", + "at::kShort", + "at::kBool", + "at::kFloat", + "at::kBFloat16", + }, + {"at::kLong"}, + ], + input, + k, + upper=ONNXAttr(upper, AttrType.INT), + ) + class Unique(ONNXOp): - """ - Find the unique elements of a tensor. When an optional attribute 'axis' is provided, unique subtensors sliced along the 'axis' are returned. - Otherwise the input tensor is flattened and unique values of the flattened tensor are returned. - - This operator returns the unique values or sliced unique subtensors of the input tensor and three optional outputs. - The first output tensor 'Y' contains all unique values or subtensors of the input. - The second optional output tensor 'indices' contains indices of 'Y' elements' first occurance in 'X'.. - The third optional output tensor 'inverse_indices' contains, for elements of 'X', its corresponding indices in 'Y'. ". - The fourth optional output tensor 'counts' contains the count of each element of 'Y' in the input. - - Outputs are either sorted in ascending order or optionally in the order of the first occurrence of the values in the input. - - https://docs.scipy.org/doc/numpy/reference/generated/numpy.unique.html - - Example 1: - input_X = [2, 1, 1, 3, 4, 3] - attribute_sorted = 0 - attribute_axis = None - output_Y = [2, 1, 3, 4] - output_indices = [0, 1, 3, 4] - output_inverse_indices = [0, 1, 1, 2, 3, 2] - output_counts = [1, 2, 2, 1] - - Example 2: - input_X = [[1, 3], [2, 3]] - attribute_sorted = 1 - attribute_axis = None - output_Y = [1, 2, 3] - output_indices = [0, 2, 1] - output_inverse_indices = [0, 2, 1, 2] - output_counts = [1, 1, 2] - - Example 3: - input_X = [[1, 0, 0], [1, 0, 0], [2, 3, 4]] - attribute_sorted = 1 - attribute_axis = 0 - output_Y = [[1, 0, 0], [2, 3, 4]] - output_indices = [0, 2] - output_inverse_indices = [0, 0, 1] - output_counts = [2, 1] - - Example 4: - input_x = [[[1., 1.], [0., 1.], [2., 1.], [0., 1.]], - [[1., 1.], [0., 1.], [2., 1.], [0., 1.]]] - attribute_sorted = 1 - attribute_axis = 1 - - intermediate data are presented below for better understanding: - - there are 4 subtensors sliced along axis 1 of input_x (shape = (2, 4, 2)): - A: [[1, 1], [1, 1]], - [[0, 1], [0, 1]], - [[2, 1], [2, 1]], - [[0, 1], [0, 1]]. - - there are 3 unique subtensors: - [[1, 1], [1, 1]], - [[0, 1], [0, 1]], - [[2, 1], [2, 1]]. - - sorted unique subtensors: - B: [[0, 1], [0, 1]], - [[1, 1], [1, 1]], - [[2, 1], [2, 1]]. - - output_Y is constructed from B: - [[[0. 1.], [1. 1.], [2. 1.]], - [[0. 1.], [1. 1.], [2. 1.]]] - - output_indices is to map from B to A: - [1, 0, 2] - - output_inverse_indices is to map from A to B: - [1, 0, 2, 0] - - output_counts = [2 1 1] - """ - - def __init__(self, X, - axis=None, - sorted=None): - super().__init__('Unique', 4, - [{'at::kDouble', 'at::kLong', 'at::kByte', 'at::kInt', 'at::kHalf', 'at::kShort', 'at::kBool', 'at::kFloat'}], - X, - axis=ONNXAttr(axis, AttrType.INT), - sorted=ONNXAttr(sorted, AttrType.INT)) + """ + Find the unique elements of a tensor. When an optional attribute 'axis' is provided, unique subtensors sliced along the 'axis' are returned. + Otherwise the input tensor is flattened and unique values of the flattened tensor are returned. + + This operator returns the unique values or sliced unique subtensors of the input tensor and three optional outputs. + The first output tensor 'Y' contains all unique values or subtensors of the input. + The second optional output tensor 'indices' contains indices of 'Y' elements' first occurance in 'X'.. + The third optional output tensor 'inverse_indices' contains, for elements of 'X', its corresponding indices in 'Y'. ". + The fourth optional output tensor 'counts' contains the count of each element of 'Y' in the input. + + Outputs are either sorted in ascending order or optionally in the order of the first occurrence of the values in the input. + + https://docs.scipy.org/doc/numpy/reference/generated/numpy.unique.html + + Example 1: + input_X = [2, 1, 1, 3, 4, 3] + attribute_sorted = 0 + attribute_axis = None + output_Y = [2, 1, 3, 4] + output_indices = [0, 1, 3, 4] + output_inverse_indices = [0, 1, 1, 2, 3, 2] + output_counts = [1, 2, 2, 1] + + Example 2: + input_X = [[1, 3], [2, 3]] + attribute_sorted = 1 + attribute_axis = None + output_Y = [1, 2, 3] + output_indices = [0, 2, 1] + output_inverse_indices = [0, 2, 1, 2] + output_counts = [1, 1, 2] + + Example 3: + input_X = [[1, 0, 0], [1, 0, 0], [2, 3, 4]] + attribute_sorted = 1 + attribute_axis = 0 + output_Y = [[1, 0, 0], [2, 3, 4]] + output_indices = [0, 2] + output_inverse_indices = [0, 0, 1] + output_counts = [2, 1] + + Example 4: + input_x = [[[1., 1.], [0., 1.], [2., 1.], [0., 1.]], + [[1., 1.], [0., 1.], [2., 1.], [0., 1.]]] + attribute_sorted = 1 + attribute_axis = 1 + + intermediate data are presented below for better understanding: + + there are 4 subtensors sliced along axis 1 of input_x (shape = (2, 4, 2)): + A: [[1, 1], [1, 1]], + [[0, 1], [0, 1]], + [[2, 1], [2, 1]], + [[0, 1], [0, 1]]. + + there are 3 unique subtensors: + [[1, 1], [1, 1]], + [[0, 1], [0, 1]], + [[2, 1], [2, 1]]. + + sorted unique subtensors: + B: [[0, 1], [0, 1]], + [[1, 1], [1, 1]], + [[2, 1], [2, 1]]. + + output_Y is constructed from B: + [[[0. 1.], [1. 1.], [2. 1.]], + [[0. 1.], [1. 1.], [2. 1.]]] + + output_indices is to map from B to A: + [1, 0, 2] + + output_inverse_indices is to map from A to B: + [1, 0, 2, 0] + + output_counts = [2 1 1] + """ + + def __init__(self, X, axis=None, sorted=None): + super().__init__( + "Unique", + 4, + [ + { + "at::kDouble", + "at::kLong", + "at::kByte", + "at::kInt", + "at::kHalf", + "at::kShort", + "at::kBool", + "at::kFloat", + } + ], + X, + axis=ONNXAttr(axis, AttrType.INT), + sorted=ONNXAttr(sorted, AttrType.INT), + ) + class Unsqueeze(ONNXOp): - """ - Insert single-dimensional entries to the shape of an input tensor (`data`). - Takes one required input `axes` - which contains a list of dimension indices and this operator will insert a dimension of value `1` into the corresponding index of the output tensor (`expanded`). - - For example: - Given an input tensor (`data`) of shape [3, 4, 5], then - Unsqueeze(data, axes=[0, 4]) outputs a tensor (`expanded`) containing same data as `data` but with shape [1, 3, 4, 5, 1]. - - The input `axes` should not contain any duplicate entries. It is an error if it contains duplicates. - The rank of the output tensor (`output_rank`) is the rank of the input tensor (`data`) plus the number of values in `axes`. - Each value in `axes` should be within the (inclusive) range [-output_rank , output_rank - 1]. - The order of values in `axes` does not matter and can come in any order. - """ - - def __init__(self, data, axes): - super().__init__('Unsqueeze', 1, - [{'at::kDouble', 'at::kLong', 'at::kByte', 'at::kInt', 'at::kHalf', 'at::kShort', 'at::kBool', 'at::kFloat', 'at::kBFloat16'}, {'at::kLong'}], - data,axes) + """ + Insert single-dimensional entries to the shape of an input tensor (`data`). + Takes one required input `axes` - which contains a list of dimension indices and this operator will insert a dimension of value `1` into the corresponding index of the output tensor (`expanded`). + + For example: + Given an input tensor (`data`) of shape [3, 4, 5], then + Unsqueeze(data, axes=[0, 4]) outputs a tensor (`expanded`) containing same data as `data` but with shape [1, 3, 4, 5, 1]. + + The input `axes` should not contain any duplicate entries. It is an error if it contains duplicates. + The rank of the output tensor (`output_rank`) is the rank of the input tensor (`data`) plus the number of values in `axes`. + Each value in `axes` should be within the (inclusive) range [-output_rank , output_rank - 1]. + The order of values in `axes` does not matter and can come in any order. + """ + + def __init__(self, data, axes): + super().__init__( + "Unsqueeze", + 1, + [ + { + "at::kDouble", + "at::kLong", + "at::kByte", + "at::kInt", + "at::kHalf", + "at::kShort", + "at::kBool", + "at::kFloat", + "at::kBFloat16", + }, + {"at::kLong"}, + ], + data, + axes, + ) + class Upsample(ONNXOp): - """ - Upsample the input tensor. - Each dimension value of the output tensor is: - output_dimension = floor(input_dimension * scale). - """ - - def __init__(self, X, scales, - mode=None): - super().__init__('Upsample', 1, - [{'at::kDouble', 'at::kLong', 'at::kByte', 'at::kInt', 'at::kHalf', 'at::kShort', 'at::kBool', 'at::kFloat'}, {'at::kFloat'}], - X,scales, - mode=ONNXAttr(mode, AttrType.STRING)) + """ + Upsample the input tensor. + Each dimension value of the output tensor is: + output_dimension = floor(input_dimension * scale). + """ + + def __init__(self, X, scales, mode=None): + super().__init__( + "Upsample", + 1, + [ + { + "at::kDouble", + "at::kLong", + "at::kByte", + "at::kInt", + "at::kHalf", + "at::kShort", + "at::kBool", + "at::kFloat", + }, + {"at::kFloat"}, + ], + X, + scales, + mode=ONNXAttr(mode, AttrType.STRING), + ) + class Where(ONNXOp): - """ - Return elements, either from X or Y, depending on condition - (with Numpy-style broadcasting support). - Where behaves like numpy.where with three parameters: - https://docs.scipy.org/doc/numpy/reference/generated/numpy.where.html - """ - - def __init__(self, condition, X, Y): - super().__init__('Where', 1, - [{'at::kBool'}, {'at::kDouble', 'at::kLong', 'at::kByte', 'at::kInt', 'at::kHalf', 'at::kShort', 'at::kBool', 'at::kFloat'}, {'at::kDouble', 'at::kLong', 'at::kByte', 'at::kInt', 'at::kHalf', 'at::kShort', 'at::kBool', 'at::kFloat'}], - condition,X,Y) + """ + Return elements, either from X or Y, depending on condition + (with Numpy-style broadcasting support). + Where behaves like numpy.where with three parameters: + https://docs.scipy.org/doc/numpy/reference/generated/numpy.where.html + """ + + def __init__(self, condition, X, Y): + super().__init__( + "Where", + 1, + [ + {"at::kBool"}, + { + "at::kDouble", + "at::kLong", + "at::kByte", + "at::kInt", + "at::kHalf", + "at::kShort", + "at::kBool", + "at::kFloat", + }, + { + "at::kDouble", + "at::kLong", + "at::kByte", + "at::kInt", + "at::kHalf", + "at::kShort", + "at::kBool", + "at::kFloat", + }, + ], + condition, + X, + Y, + ) + class Xor(ONNXOp): - """ - Returns the tensor resulted from performing the `xor` logical operation - elementwise on the input tensors `A` and `B` (with Numpy-style broadcasting support). - - This operator supports **multidirectional (i.e., Numpy-style) broadcasting**; for more details please check [the doc](Broadcasting.md). - """ - - def __init__(self, A, B): - super().__init__('Xor', 1, - [{'at::kBool'}, {'at::kBool'}], - A,B) + """ + Returns the tensor resulted from performing the `xor` logical operation + elementwise on the input tensors `A` and `B` (with Numpy-style broadcasting support). + + This operator supports **multidirectional (i.e., Numpy-style) broadcasting**; for more details please check [the doc](Broadcasting.md). + """ + + def __init__(self, A, B): + super().__init__("Xor", 1, [{"at::kBool"}, {"at::kBool"}], A, B) + class ZipMap(ONNXOp): - """ - Creates a map from the input and the attributes.
- The values are provided by the input tensor, while the keys are specified by the attributes. - Must provide keys in either classlabels_strings or classlabels_int64s (but not both).
- The columns of the tensor correspond one-by-one to the keys specified by the attributes. There must be as many columns as keys.
- """ - - def __init__(self, X, - classlabels_int64s=None, - classlabels_strings=None): - super().__init__('ZipMap', 1, - [{'at::kFloat'}], - X, - classlabels_int64s=ONNXAttr(classlabels_int64s, AttrType.INTS), - classlabels_strings=ONNXAttr(classlabels_strings, AttrType.STRINGS)) + """ + Creates a map from the input and the attributes.
+ The values are provided by the input tensor, while the keys are specified by the attributes. + Must provide keys in either classlabels_strings or classlabels_int64s (but not both).
+ The columns of the tensor correspond one-by-one to the keys specified by the attributes. There must be as many columns as keys.
+ """ + + def __init__(self, X, classlabels_int64s=None, classlabels_strings=None): + super().__init__( + "ZipMap", + 1, + [{"at::kFloat"}], + X, + classlabels_int64s=ONNXAttr(classlabels_int64s, AttrType.INTS), + classlabels_strings=ONNXAttr(classlabels_strings, AttrType.STRINGS), + ) + onnx_ops = { - 'adam': Adam, - 'adagrad': Adagrad, - 'momentum': Momentum, - 'gradient': Gradient, - 'zipmap': ZipMap, - 'onehotencoder': OneHotEncoder, - 'normalizer': Normalizer, - 'linearclassifier': LinearClassifier, - 'labelencoder': LabelEncoder, - 'imputer': Imputer, - 'featurevectorizer': FeatureVectorizer, - 'treeensembleregressor': TreeEnsembleRegressor, - 'dictvectorizer': DictVectorizer, - 'castmap': CastMap, - 'shape': Shape, - 'reshape': Reshape, - 'binarizer': Binarizer, - 'reciprocal': Reciprocal, - 'leakyrelu': LeakyRelu, - 'hardsigmoid': HardSigmoid, - 'treeensembleclassifier': TreeEnsembleClassifier, - 'reducemin': ReduceMin, - 'div': Div, - 'randomnormallike': RandomNormalLike, - 'randomnormal': RandomNormal, - 'greaterorequal': GreaterOrEqual, - 'pow': Pow, - 'or': Or, - 'mul': Mul, - 'min': Min, - 'floor': Floor, - 'mean': Mean, - 'lrn': LRN, - 'scaler': Scaler, - 'max': Max, - 'round': Round, - 'lppool': LpPool, - 'sigmoid': Sigmoid, - 'relu': Relu, - 'quantizelinear': QuantizeLinear, - 'logsoftmax': LogSoftmax, - 'randomuniform': RandomUniform, - 'depthtospace': DepthToSpace, - 'concat': Concat, - 'bitshift': BitShift, - 'ceil': Ceil, - 'gather': Gather, - 'log': Log, - 'reducesumsquare': ReduceSumSquare, - 'dropout': Dropout, - 'greater': Greater, - 'reducesum': ReduceSum, - 'sequenceempty': SequenceEmpty, - 'neg': Neg, - 'constant': Constant, - 'maxpool': MaxPool, - 'sub': Sub, - 'reducelogsumexp': ReduceLogSumExp, - 'xor': Xor, - 'globallppool': GlobalLpPool, - 'upsample': Upsample, - 'prelu': PRelu, - 'loop': Loop, - 'lpnormalization': LpNormalization, - 'dynamicquantizelinear': DynamicQuantizeLinear, - 'splittosequence': SplitToSequence, - 'linearregressor': LinearRegressor, - 'add': Add, - 'selu': Selu, - 'reducemax': ReduceMax, - 'and': And, - 'abs': Abs, - 'qlinearmatmul': QLinearMatMul, - 'lessorequal': LessOrEqual, - 'clip': Clip, - 'argmax': ArgMax, - 'einsum': Einsum, - 'hardmax': Hardmax, - 'conv': Conv, - 'globalmaxpool': GlobalMaxPool, - 'maxunpool': MaxUnpool, - 'argmin': ArgMin, - 'averagepool': AveragePool, - 'sqrt': Sqrt, - 'size': Size, - 'instancenormalization': InstanceNormalization, - 'gemm': Gemm, - 'reducelogsum': ReduceLogSum, - 'cos': Cos, - 'not': Not, - 'eyelike': EyeLike, - 'equal': Equal, - 'cast': Cast, - 'exp': Exp, - 'flatten': Flatten, - 'svmclassifier': SVMClassifier, - 'roialign': RoiAlign, - 'reducemean': ReduceMean, - 'scatter': Scatter, - 'split': Split, - 'identity': Identity, - 'reducel2': ReduceL2, - 'globalaveragepool': GlobalAveragePool, - 'tan': Tan, - 'reducel1': ReduceL1, - 'lstm': LSTM, - 'slice': Slice, - 'softmax': Softmax, - 'softmaxcrossentropyloss': SoftmaxCrossEntropyLoss, - 'categorymapper': CategoryMapper, - 'maxroipool': MaxRoiPool, - 'softsign': Softsign, - 'gathernd': GatherND, - 'batchnormalization': BatchNormalization, - 'spacetodepth': SpaceToDepth, - 'squeeze': Squeeze, - 'unique': Unique, - 'sum': Sum, - 'sinh': Sinh, - 'less': Less, - 'tanh': Tanh, - 'isnan': IsNaN, - 'tile': Tile, - 'multinomial': Multinomial, - 'topk': TopK, - 'reversesequence': ReverseSequence, - 'transpose': Transpose, - 'stringnormalizer': StringNormalizer, - 'acos': Acos, - 'asin': Asin, - 'gru': GRU, - 'atan': Atan, - 'sign': Sign, - 'trilu': Trilu, - 'where': Where, - 'sin': Sin, - 'shrink': Shrink, - 'matmul': MatMul, - 'expand': Expand, - 'scan': Scan, - 'compress': Compress, - 'elu': Elu, - 'unsqueeze': Unsqueeze, - 'constantofshape': ConstantOfShape, - 'onehot': OneHot, - 'sequenceat': SequenceAt, - 'cosh': Cosh, - 'asinh': Asinh, - 'rnn': RNN, - 'acosh': Acosh, - 'atanh': Atanh, - 'erf': Erf, - 'nonzero': NonZero, - 'meanvariancenormalization': MeanVarianceNormalization, - 'scatternd': ScatterND, - 'randomuniformlike': RandomUniformLike, - 'resize': Resize, - 'mod': Mod, - 'thresholdedrelu': ThresholdedRelu, - 'matmulinteger': MatMulInteger, - 'pad': Pad, - 'convinteger': ConvInteger, - 'qlinearconv': QLinearConv, - 'celu': Celu, - 'convtranspose': ConvTranspose, - 'dequantizelinear': DequantizeLinear, - 'sequencelength': SequenceLength, - 'nonmaxsuppression': NonMaxSuppression, - 'isinf': IsInf, - 'cumsum': CumSum, - 'softplus': Softplus, - 'gatherelements': GatherElements, - 'scatterelements': ScatterElements, - 'range': Range, - 'svmregressor': SVMRegressor, - 'negativeloglikelihoodloss': NegativeLogLikelihoodLoss, - 'det': Det, - 'sequenceconstruct': SequenceConstruct, - 'if': If, - 'sequenceinsert': SequenceInsert, - 'tfidfvectorizer': TfIdfVectorizer, - 'sequenceerase': SequenceErase, - 'concatfromsequence': ConcatFromSequence, - 'hardswish': HardSwish, - 'reduceprod': ReduceProd, - 'arrayfeatureextractor': ArrayFeatureExtractor, -} \ No newline at end of file + "adam": Adam, + "adagrad": Adagrad, + "momentum": Momentum, + "gradient": Gradient, + "zipmap": ZipMap, + "onehotencoder": OneHotEncoder, + "normalizer": Normalizer, + "linearclassifier": LinearClassifier, + "labelencoder": LabelEncoder, + "imputer": Imputer, + "featurevectorizer": FeatureVectorizer, + "treeensembleregressor": TreeEnsembleRegressor, + "dictvectorizer": DictVectorizer, + "castmap": CastMap, + "shape": Shape, + "reshape": Reshape, + "binarizer": Binarizer, + "reciprocal": Reciprocal, + "leakyrelu": LeakyRelu, + "hardsigmoid": HardSigmoid, + "treeensembleclassifier": TreeEnsembleClassifier, + "reducemin": ReduceMin, + "div": Div, + "randomnormallike": RandomNormalLike, + "randomnormal": RandomNormal, + "greaterorequal": GreaterOrEqual, + "pow": Pow, + "or": Or, + "mul": Mul, + "min": Min, + "floor": Floor, + "mean": Mean, + "lrn": LRN, + "scaler": Scaler, + "max": Max, + "round": Round, + "lppool": LpPool, + "sigmoid": Sigmoid, + "relu": Relu, + "quantizelinear": QuantizeLinear, + "logsoftmax": LogSoftmax, + "randomuniform": RandomUniform, + "depthtospace": DepthToSpace, + "concat": Concat, + "bitshift": BitShift, + "ceil": Ceil, + "gather": Gather, + "log": Log, + "reducesumsquare": ReduceSumSquare, + "dropout": Dropout, + "greater": Greater, + "reducesum": ReduceSum, + "sequenceempty": SequenceEmpty, + "neg": Neg, + "constant": Constant, + "maxpool": MaxPool, + "sub": Sub, + "reducelogsumexp": ReduceLogSumExp, + "xor": Xor, + "globallppool": GlobalLpPool, + "upsample": Upsample, + "prelu": PRelu, + "loop": Loop, + "lpnormalization": LpNormalization, + "dynamicquantizelinear": DynamicQuantizeLinear, + "splittosequence": SplitToSequence, + "linearregressor": LinearRegressor, + "add": Add, + "selu": Selu, + "reducemax": ReduceMax, + "and": And, + "abs": Abs, + "qlinearmatmul": QLinearMatMul, + "lessorequal": LessOrEqual, + "clip": Clip, + "argmax": ArgMax, + "einsum": Einsum, + "hardmax": Hardmax, + "conv": Conv, + "globalmaxpool": GlobalMaxPool, + "maxunpool": MaxUnpool, + "argmin": ArgMin, + "averagepool": AveragePool, + "sqrt": Sqrt, + "size": Size, + "instancenormalization": InstanceNormalization, + "gemm": Gemm, + "reducelogsum": ReduceLogSum, + "cos": Cos, + "not": Not, + "eyelike": EyeLike, + "equal": Equal, + "cast": Cast, + "exp": Exp, + "flatten": Flatten, + "svmclassifier": SVMClassifier, + "roialign": RoiAlign, + "reducemean": ReduceMean, + "scatter": Scatter, + "split": Split, + "identity": Identity, + "reducel2": ReduceL2, + "globalaveragepool": GlobalAveragePool, + "tan": Tan, + "reducel1": ReduceL1, + "lstm": LSTM, + "slice": Slice, + "softmax": Softmax, + "softmaxcrossentropyloss": SoftmaxCrossEntropyLoss, + "categorymapper": CategoryMapper, + "maxroipool": MaxRoiPool, + "softsign": Softsign, + "gathernd": GatherND, + "batchnormalization": BatchNormalization, + "spacetodepth": SpaceToDepth, + "squeeze": Squeeze, + "unique": Unique, + "sum": Sum, + "sinh": Sinh, + "less": Less, + "tanh": Tanh, + "isnan": IsNaN, + "tile": Tile, + "multinomial": Multinomial, + "topk": TopK, + "reversesequence": ReverseSequence, + "transpose": Transpose, + "stringnormalizer": StringNormalizer, + "acos": Acos, + "asin": Asin, + "gru": GRU, + "atan": Atan, + "sign": Sign, + "trilu": Trilu, + "where": Where, + "sin": Sin, + "shrink": Shrink, + "matmul": MatMul, + "expand": Expand, + "scan": Scan, + "compress": Compress, + "elu": Elu, + "unsqueeze": Unsqueeze, + "constantofshape": ConstantOfShape, + "onehot": OneHot, + "sequenceat": SequenceAt, + "cosh": Cosh, + "asinh": Asinh, + "rnn": RNN, + "acosh": Acosh, + "atanh": Atanh, + "erf": Erf, + "nonzero": NonZero, + "meanvariancenormalization": MeanVarianceNormalization, + "scatternd": ScatterND, + "randomuniformlike": RandomUniformLike, + "resize": Resize, + "mod": Mod, + "thresholdedrelu": ThresholdedRelu, + "matmulinteger": MatMulInteger, + "pad": Pad, + "convinteger": ConvInteger, + "qlinearconv": QLinearConv, + "celu": Celu, + "convtranspose": ConvTranspose, + "dequantizelinear": DequantizeLinear, + "sequencelength": SequenceLength, + "nonmaxsuppression": NonMaxSuppression, + "isinf": IsInf, + "cumsum": CumSum, + "softplus": Softplus, + "gatherelements": GatherElements, + "scatterelements": ScatterElements, + "range": Range, + "svmregressor": SVMRegressor, + "negativeloglikelihoodloss": NegativeLogLikelihoodLoss, + "det": Det, + "sequenceconstruct": SequenceConstruct, + "if": If, + "sequenceinsert": SequenceInsert, + "tfidfvectorizer": TfIdfVectorizer, + "sequenceerase": SequenceErase, + "concatfromsequence": ConcatFromSequence, + "hardswish": HardSwish, + "reduceprod": ReduceProd, + "arrayfeatureextractor": ArrayFeatureExtractor, +} diff --git a/orttraining/orttraining/eager/opgen/opgen/parser.py b/orttraining/orttraining/eager/opgen/opgen/parser.py index ba6ef9379566b..1e6a03b04a37e 100644 --- a/orttraining/orttraining/eager/opgen/opgen/parser.py +++ b/orttraining/orttraining/eager/opgen/opgen/parser.py @@ -5,363 +5,358 @@ from opgen.ast import * from typing import List, Tuple, Union, Optional + class UnexpectedTokenError(RuntimeError): - def __init__(self, expected: TokenKind, actual: Token): - self.expected = expected - self.actual = actual - super().__init__(f"unexpected token {actual}; expected {expected}") + def __init__(self, expected: TokenKind, actual: Token): + self.expected = expected + self.actual = actual + super().__init__(f"unexpected token {actual}; expected {expected}") + class ExpectedSyntaxError(RuntimeError): - def __init__(self, expected: str, actual: Token = None): - super().__init__(f"expected {expected}; actual {actual}") + def __init__(self, expected: str, actual: Token = None): + super().__init__(f"expected {expected}; actual {actual}") + class ParserBase(object): - _peek_queue: List[Token] - - def __init__(self, lexer: Union[Lexer, Reader]): - self._own_lexer = False - if isinstance(lexer, Reader): - self._own_lexer = True - lexer = Lexer(lexer) - elif not isinstance(lexer, Lexer): - raise TypeError("lexer must be a Lexer or Reader") - self._lexer = lexer - self._peek_queue = [] - - def __enter__(self): - if self._own_lexer: - self._lexer.__enter__() - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - if self._own_lexer: - self._lexer.__exit__(exc_type, exc_val, exc_tb) - - def set_source_location(self, origin: SourceLocation): - self._lexer.set_source_location(origin) - - def _peek_token( - self, - kinds: Union[TokenKind, List[TokenKind]] = None, - value: str = None, - look_ahead: int = 1) -> Optional[Token]: - if look_ahead < 1: - raise IndexError("look_ahead must be at least 1") - if look_ahead >= len(self._peek_queue): - for _ in range(look_ahead - len(self._peek_queue)): - self._peek_queue = [self._lexer.lex()] + self._peek_queue - peek = self._peek_queue[-look_ahead] - if not kinds: - return peek - if not isinstance(kinds, list): - kinds = [kinds] - for kind in kinds: - if peek.kind == kind: - if value: - return peek if peek.value == value else None - return peek - return None - - def _read_token(self) -> Token: - return self._peek_queue.pop() if self._peek_queue else self._lexer.lex() - - def _expect_token(self, kind: TokenKind) -> Token: - token = self._read_token() - if token.kind != kind: - raise UnexpectedTokenError(kind, token) - return token - - def _parse_list( - self, - open_token_kind: TokenKind, - separator_token_kind: TokenKind, - close_token_kind: TokenKind, - member_parser: callable) -> SyntaxList: - syntax_list = SyntaxList() - if open_token_kind: - syntax_list.open_token = self._expect_token(open_token_kind) - while True: - if close_token_kind and self._peek_token(close_token_kind): - break - member = member_parser() - if not self._peek_token(separator_token_kind): - syntax_list.append(member, None) - break - syntax_list.append(member, self._read_token()) - if close_token_kind: - syntax_list.close_token = self._expect_token(close_token_kind) - return syntax_list - - def parse_translation_unit(self) -> TranslationUnitDecl: - decls = [] - while not self._peek_token(TokenKind.EOF): - decls.append(self.parse_function()) - return TranslationUnitDecl(decls) - - def parse_function_parameter_default_value_expression(self) -> Expression: - return self.parse_expression() - - def parse_function_parameter(self) -> ParameterDecl: - parameter_type = self.parse_type() - - if not self._peek_token(TokenKind.IDENTIFIER): - return ParameterDecl(parameter_type) - - parameter_name = self._read_token() - - if not self._peek_token(TokenKind.EQUALS): - return ParameterDecl(parameter_type, parameter_name) - - return ParameterDecl( - parameter_type, - parameter_name, - self._read_token(), - self.parse_function_parameter_default_value_expression()) - - def parse_function_parameters(self) -> SyntaxList: - return self._parse_list( - TokenKind.OPEN_PAREN, - TokenKind.COMMA, - TokenKind.CLOSE_PAREN, - self.parse_function_parameter) - - def parse_function(self) -> FunctionDecl: - raise NotImplementedError() - - def parse_expression(self) -> Expression: - raise NotImplementedError() - - def parse_type(self) -> Type: - raise NotImplementedError() + _peek_queue: List[Token] + + def __init__(self, lexer: Union[Lexer, Reader]): + self._own_lexer = False + if isinstance(lexer, Reader): + self._own_lexer = True + lexer = Lexer(lexer) + elif not isinstance(lexer, Lexer): + raise TypeError("lexer must be a Lexer or Reader") + self._lexer = lexer + self._peek_queue = [] + + def __enter__(self): + if self._own_lexer: + self._lexer.__enter__() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + if self._own_lexer: + self._lexer.__exit__(exc_type, exc_val, exc_tb) + + def set_source_location(self, origin: SourceLocation): + self._lexer.set_source_location(origin) + + def _peek_token( + self, kinds: Union[TokenKind, List[TokenKind]] = None, value: str = None, look_ahead: int = 1 + ) -> Optional[Token]: + if look_ahead < 1: + raise IndexError("look_ahead must be at least 1") + if look_ahead >= len(self._peek_queue): + for _ in range(look_ahead - len(self._peek_queue)): + self._peek_queue = [self._lexer.lex()] + self._peek_queue + peek = self._peek_queue[-look_ahead] + if not kinds: + return peek + if not isinstance(kinds, list): + kinds = [kinds] + for kind in kinds: + if peek.kind == kind: + if value: + return peek if peek.value == value else None + return peek + return None + + def _read_token(self) -> Token: + return self._peek_queue.pop() if self._peek_queue else self._lexer.lex() + + def _expect_token(self, kind: TokenKind) -> Token: + token = self._read_token() + if token.kind != kind: + raise UnexpectedTokenError(kind, token) + return token + + def _parse_list( + self, + open_token_kind: TokenKind, + separator_token_kind: TokenKind, + close_token_kind: TokenKind, + member_parser: callable, + ) -> SyntaxList: + syntax_list = SyntaxList() + if open_token_kind: + syntax_list.open_token = self._expect_token(open_token_kind) + while True: + if close_token_kind and self._peek_token(close_token_kind): + break + member = member_parser() + if not self._peek_token(separator_token_kind): + syntax_list.append(member, None) + break + syntax_list.append(member, self._read_token()) + if close_token_kind: + syntax_list.close_token = self._expect_token(close_token_kind) + return syntax_list + + def parse_translation_unit(self) -> TranslationUnitDecl: + decls = [] + while not self._peek_token(TokenKind.EOF): + decls.append(self.parse_function()) + return TranslationUnitDecl(decls) + + def parse_function_parameter_default_value_expression(self) -> Expression: + return self.parse_expression() + + def parse_function_parameter(self) -> ParameterDecl: + parameter_type = self.parse_type() + + if not self._peek_token(TokenKind.IDENTIFIER): + return ParameterDecl(parameter_type) + + parameter_name = self._read_token() + + if not self._peek_token(TokenKind.EQUALS): + return ParameterDecl(parameter_type, parameter_name) + + return ParameterDecl( + parameter_type, parameter_name, self._read_token(), self.parse_function_parameter_default_value_expression() + ) + + def parse_function_parameters(self) -> SyntaxList: + return self._parse_list( + TokenKind.OPEN_PAREN, TokenKind.COMMA, TokenKind.CLOSE_PAREN, self.parse_function_parameter + ) + + def parse_function(self) -> FunctionDecl: + raise NotImplementedError() + + def parse_expression(self) -> Expression: + raise NotImplementedError() + + def parse_type(self) -> Type: + raise NotImplementedError() + class CPPParser(ParserBase): - def parse_function(self) -> FunctionDecl: - return_type = self.parse_type() - return FunctionDecl( - self._expect_token(TokenKind.IDENTIFIER), - self.parse_function_parameters(), - return_type, - semicolon = self._expect_token(TokenKind.SEMICOLON)) - - def parse_expression(self) -> Expression: - if self._peek_token(TokenKind.IDENTIFIER) or \ - self._peek_token(TokenKind.NUMBER) or \ - self._peek_token(TokenKind.STRING): - return LiteralExpression(self._read_token()) - else: - raise UnexpectedTokenError("expression", self._peek_token()) - - def parse_type(self) -> Type: - if self._peek_token(TokenKind.IDENTIFIER, "const"): - parsed_type = ConstType( - self._read_token(), - self.parse_type()) - elif self._peek_token([TokenKind.IDENTIFIER, TokenKind.DOUBLECOLON]): - identifiers = [] - while True: - token = self._peek_token([TokenKind.IDENTIFIER, TokenKind.DOUBLECOLON]) - if not token: - break - identifiers.append(self._read_token()) - if token.has_trailing_trivia(TokenKind.WHITESPACE): - break - if self._peek_token(TokenKind.LESS_THAN): - parsed_type = TemplateType( - identifiers, - self._parse_list( - TokenKind.LESS_THAN, - TokenKind.COMMA, - TokenKind.GREATER_THAN, - self._parse_template_type_argument)) - elif identifiers[-1].value == "TensorOptions": - parsed_type = TensorOptionsType(identifiers) - else: - parsed_type = ConcreteType(identifiers) - else: - raise ExpectedSyntaxError("type", self._peek_token()) - - while True: - if self._peek_token(TokenKind.AND): - parsed_type = ReferenceType( - parsed_type, - self._read_token()) - else: - return parsed_type + def parse_function(self) -> FunctionDecl: + return_type = self.parse_type() + return FunctionDecl( + self._expect_token(TokenKind.IDENTIFIER), + self.parse_function_parameters(), + return_type, + semicolon=self._expect_token(TokenKind.SEMICOLON), + ) + + def parse_expression(self) -> Expression: + if ( + self._peek_token(TokenKind.IDENTIFIER) + or self._peek_token(TokenKind.NUMBER) + or self._peek_token(TokenKind.STRING) + ): + return LiteralExpression(self._read_token()) + else: + raise UnexpectedTokenError("expression", self._peek_token()) + + def parse_type(self) -> Type: + if self._peek_token(TokenKind.IDENTIFIER, "const"): + parsed_type = ConstType(self._read_token(), self.parse_type()) + elif self._peek_token([TokenKind.IDENTIFIER, TokenKind.DOUBLECOLON]): + identifiers = [] + while True: + token = self._peek_token([TokenKind.IDENTIFIER, TokenKind.DOUBLECOLON]) + if not token: + break + identifiers.append(self._read_token()) + if token.has_trailing_trivia(TokenKind.WHITESPACE): + break + if self._peek_token(TokenKind.LESS_THAN): + parsed_type = TemplateType( + identifiers, + self._parse_list( + TokenKind.LESS_THAN, TokenKind.COMMA, TokenKind.GREATER_THAN, self._parse_template_type_argument + ), + ) + elif identifiers[-1].value == "TensorOptions": + parsed_type = TensorOptionsType(identifiers) + else: + parsed_type = ConcreteType(identifiers) + else: + raise ExpectedSyntaxError("type", self._peek_token()) + + while True: + if self._peek_token(TokenKind.AND): + parsed_type = ReferenceType(parsed_type, self._read_token()) + else: + return parsed_type + + def _parse_template_type_argument(self) -> Type: + if self._peek_token(TokenKind.NUMBER): + return ExpressionType(self.parse_expression()) + return self.parse_type() - def _parse_template_type_argument(self) -> Type: - if self._peek_token(TokenKind.NUMBER): - return ExpressionType(self.parse_expression()) - return self.parse_type() class TorchParser(ParserBase): - def __init__(self, lexer: Union[Lexer, Reader]): - super().__init__(lexer) - self._next_anonymous_alias_id = 0 - - def parse_function(self) -> FunctionDecl: - return FunctionDecl( - self._expect_token(TokenKind.IDENTIFIER), - self.parse_function_parameters(), - arrow = self._expect_token(TokenKind.ARROW), - return_type = self.parse_type()) - - def parse_expression(self) -> Expression: - if self._peek_token(TokenKind.NUMBER) or \ - self._peek_token(TokenKind.IDENTIFIER) or \ - self._peek_token(TokenKind.STRING): - return LiteralExpression(self._read_token()) - elif self._peek_token(TokenKind.OPEN_BRACKET): - return ArrayExpression(self._parse_list( - TokenKind.OPEN_BRACKET, - TokenKind.COMMA, - TokenKind.CLOSE_BRACKET, - self.parse_expression)) - else: - raise UnexpectedTokenError("expression", self._peek_token()) - - def _create_alias_info_type(self, parsed_type: Type, alias_info: AliasInfo) -> AliasInfoType: - if isinstance(parsed_type, ModifiedType): - parsed_type.base_type = AliasInfoType(parsed_type.base_type, alias_info) - else: - parsed_type = AliasInfoType(parsed_type, alias_info) - return parsed_type - - def parse_type(self) -> Type: - parsed_type, alias_info = self._parse_type_and_alias() - if not alias_info: - return parsed_type - return self._create_alias_info_type(parsed_type, alias_info) - - def _parse_type_and_alias(self) -> Tuple[Type, AliasInfo]: - parsed_type: Type = None - alias_info: AliasInfo = None - - if self._peek_token(TokenKind.MUL): - return (KWArgsSentinelType(self._read_token()), None) - - if self._peek_token(TokenKind.OPEN_PAREN): - def parse_tuple_element(): - element_type, element_alias_info = self._parse_type_and_alias() - if element_alias_info: - element_type = self._create_alias_info_type(element_type, element_alias_info) - - return TupleMemberType( - element_type, - self._read_token() \ - if self._peek_token(TokenKind.IDENTIFIER) \ - else None) - - parsed_type = TupleType(self._parse_list( - TokenKind.OPEN_PAREN, - TokenKind.COMMA, - TokenKind.CLOSE_PAREN, - parse_tuple_element)) - elif self._peek_token(TokenKind.IDENTIFIER, "Tensor"): - parsed_type = TensorType(self._read_token()) - alias_info = self._parse_torch_alias_info() - else: - parsed_type = self._parse_torch_base_type() - alias_info = self._parse_torch_alias_info() - - while True: - if self._peek_token(TokenKind.OPEN_BRACKET): - parsed_type = ArrayType( - parsed_type, - self._read_token(), - self._read_token() \ - if self._peek_token(TokenKind.NUMBER) \ - else None, - self._expect_token(TokenKind.CLOSE_BRACKET)) - elif self._peek_token(TokenKind.QUESTION_MARK): - parsed_type = OptionalType( - parsed_type, - self._read_token()) - else: - return (parsed_type, alias_info) - - def _parse_torch_base_type(self) -> Type: - base_type_parsers = { - "int": IntType, - "float": FloatType, - "bool": BoolType, - "str": StrType, - "Scalar": ScalarType, - "ScalarType": ScalarTypeType, - "Dimname": DimnameType, - "Layout": LayoutType, - "Device": DeviceType, - "Generator": GeneratorType, - "MemoryFormat": MemoryFormatType, - "QScheme": QSchemeType, - "Storage": StorageType, - "ConstQuantizerPtr": ConstQuantizerPtrType, - "Stream": StreamType - } - identifier = self._expect_token(TokenKind.IDENTIFIER) - base_type_parser = base_type_parsers.get(identifier.value) - if not base_type_parser: - raise ExpectedSyntaxError( - "|".join(base_type_parsers.keys()), - identifier) - base_type = base_type_parser(identifier) - return base_type - - def _parse_torch_alias_info(self) -> AliasInfo: - alias_info = AliasInfo() - - def parse_set(alias_set: List[str]): - while True: - if self._peek_token(TokenKind.MUL): - alias_info.tokens.append(self._read_token()) - alias_set.append('*') - elif '*' not in alias_set: - identifier = self._expect_token(TokenKind.IDENTIFIER) - alias_info.tokens.append(identifier) - alias_set.append(identifier.value) + def __init__(self, lexer: Union[Lexer, Reader]): + super().__init__(lexer) + self._next_anonymous_alias_id = 0 + + def parse_function(self) -> FunctionDecl: + return FunctionDecl( + self._expect_token(TokenKind.IDENTIFIER), + self.parse_function_parameters(), + arrow=self._expect_token(TokenKind.ARROW), + return_type=self.parse_type(), + ) + + def parse_expression(self) -> Expression: + if ( + self._peek_token(TokenKind.NUMBER) + or self._peek_token(TokenKind.IDENTIFIER) + or self._peek_token(TokenKind.STRING) + ): + return LiteralExpression(self._read_token()) + elif self._peek_token(TokenKind.OPEN_BRACKET): + return ArrayExpression( + self._parse_list( + TokenKind.OPEN_BRACKET, TokenKind.COMMA, TokenKind.CLOSE_BRACKET, self.parse_expression + ) + ) else: - raise ExpectedSyntaxError( - "alias wildcard * or alias identifier", - self._peek_token()) + raise UnexpectedTokenError("expression", self._peek_token()) - if self._peek_token(TokenKind.OR): - alias_info.tokens.append(self._read_token()) + def _create_alias_info_type(self, parsed_type: Type, alias_info: AliasInfo) -> AliasInfoType: + if isinstance(parsed_type, ModifiedType): + parsed_type.base_type = AliasInfoType(parsed_type.base_type, alias_info) else: - return + parsed_type = AliasInfoType(parsed_type, alias_info) + return parsed_type + + def parse_type(self) -> Type: + parsed_type, alias_info = self._parse_type_and_alias() + if not alias_info: + return parsed_type + return self._create_alias_info_type(parsed_type, alias_info) + + def _parse_type_and_alias(self) -> Tuple[Type, AliasInfo]: + parsed_type: Type = None + alias_info: AliasInfo = None + + if self._peek_token(TokenKind.MUL): + return (KWArgsSentinelType(self._read_token()), None) - if self._peek_token(TokenKind.OPEN_PAREN): - alias_info.tokens.append(self._read_token()) + if self._peek_token(TokenKind.OPEN_PAREN): - parse_set(alias_info.before_set) + def parse_tuple_element(): + element_type, element_alias_info = self._parse_type_and_alias() + if element_alias_info: + element_type = self._create_alias_info_type(element_type, element_alias_info) - if self._peek_token(TokenKind.EXCLAIMATION_MARK): - alias_info.tokens.append(self._read_token()) - alias_info.is_writable = True + return TupleMemberType( + element_type, self._read_token() if self._peek_token(TokenKind.IDENTIFIER) else None + ) - if self._peek_token(TokenKind.ARROW): - alias_info.tokens.append(self._read_token()) - parse_set(alias_info.after_set) - else: - # no '->' so assume before and after are identical - alias_info.after_set = alias_info.before_set + parsed_type = TupleType( + self._parse_list(TokenKind.OPEN_PAREN, TokenKind.COMMA, TokenKind.CLOSE_PAREN, parse_tuple_element) + ) + elif self._peek_token(TokenKind.IDENTIFIER, "Tensor"): + parsed_type = TensorType(self._read_token()) + alias_info = self._parse_torch_alias_info() + else: + parsed_type = self._parse_torch_base_type() + alias_info = self._parse_torch_alias_info() + + while True: + if self._peek_token(TokenKind.OPEN_BRACKET): + parsed_type = ArrayType( + parsed_type, + self._read_token(), + self._read_token() if self._peek_token(TokenKind.NUMBER) else None, + self._expect_token(TokenKind.CLOSE_BRACKET), + ) + elif self._peek_token(TokenKind.QUESTION_MARK): + parsed_type = OptionalType(parsed_type, self._read_token()) + else: + return (parsed_type, alias_info) + + def _parse_torch_base_type(self) -> Type: + base_type_parsers = { + "int": IntType, + "float": FloatType, + "bool": BoolType, + "str": StrType, + "Scalar": ScalarType, + "ScalarType": ScalarTypeType, + "Dimname": DimnameType, + "Layout": LayoutType, + "Device": DeviceType, + "Generator": GeneratorType, + "MemoryFormat": MemoryFormatType, + "QScheme": QSchemeType, + "Storage": StorageType, + "ConstQuantizerPtr": ConstQuantizerPtrType, + "Stream": StreamType, + } + identifier = self._expect_token(TokenKind.IDENTIFIER) + base_type_parser = base_type_parsers.get(identifier.value) + if not base_type_parser: + raise ExpectedSyntaxError("|".join(base_type_parsers.keys()), identifier) + base_type = base_type_parser(identifier) + return base_type + + def _parse_torch_alias_info(self) -> AliasInfo: + alias_info = AliasInfo() + + def parse_set(alias_set: List[str]): + while True: + if self._peek_token(TokenKind.MUL): + alias_info.tokens.append(self._read_token()) + alias_set.append("*") + elif "*" not in alias_set: + identifier = self._expect_token(TokenKind.IDENTIFIER) + alias_info.tokens.append(identifier) + alias_set.append(identifier.value) + else: + raise ExpectedSyntaxError("alias wildcard * or alias identifier", self._peek_token()) + + if self._peek_token(TokenKind.OR): + alias_info.tokens.append(self._read_token()) + else: + return + + if self._peek_token(TokenKind.OPEN_PAREN): + alias_info.tokens.append(self._read_token()) + + parse_set(alias_info.before_set) + + if self._peek_token(TokenKind.EXCLAIMATION_MARK): + alias_info.tokens.append(self._read_token()) + alias_info.is_writable = True + + if self._peek_token(TokenKind.ARROW): + alias_info.tokens.append(self._read_token()) + parse_set(alias_info.after_set) + else: + # no '->' so assume before and after are identical + alias_info.after_set = alias_info.before_set + + alias_info.tokens.append(self._expect_token(TokenKind.CLOSE_PAREN)) + elif self._peek_token(TokenKind.EXCLAIMATION_MARK): + alias_info.is_writable = True + alias_info.before_set.append(str(self._next_anonymous_alias_id)) + self._next_anonymous_alias_id += 1 + else: + return None - alias_info.tokens.append(self._expect_token(TokenKind.CLOSE_PAREN)) - elif self._peek_token(TokenKind.EXCLAIMATION_MARK): - alias_info.is_writable = True - alias_info.before_set.append(str(self._next_anonymous_alias_id)) - self._next_anonymous_alias_id += 1 - else: - return None + return alias_info - return alias_info def cpp_create_from_file(path: str) -> CPPParser: - return CPPParser(FileReader(path)) + return CPPParser(FileReader(path)) + def cpp_create_from_string(buffer: str) -> CPPParser: - return CPPParser(StringReader(buffer)) + return CPPParser(StringReader(buffer)) + def torch_create_from_file(path: str) -> TorchParser: - return TorchParser(FileReader(path)) + return TorchParser(FileReader(path)) + def torch_create_from_string(buffer: str) -> TorchParser: - return TorchParser(StringReader(buffer)) \ No newline at end of file + return TorchParser(StringReader(buffer)) diff --git a/orttraining/orttraining/eager/opgen/opgen/writer.py b/orttraining/orttraining/eager/opgen/opgen/writer.py index 3e610520fb3fb..460a29a879dfc 100644 --- a/orttraining/orttraining/eager/opgen/opgen/writer.py +++ b/orttraining/orttraining/eager/opgen/opgen/writer.py @@ -3,60 +3,61 @@ from typing import TextIO, List + class SourceWriter: - _writer: TextIO - _indent_str: str - _indent_depth: int - _needs_indent: bool - _namespaces: List[str] - - def __init__(self, base_writer: TextIO, indent_str: str = ' '): - self._writer = base_writer - self._indent_str = indent_str - self._indent_depth = 0 - self._needs_indent = False - self._namespaces = [] - - def __enter__(self): - self._writer.__enter__() - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - self._writer.__exit__(exc_type, exc_val, exc_tb) - - def write(self, str: str): - if not str or len(str) <=0: - return - - for c in str: - if self._needs_indent: + _writer: TextIO + _indent_str: str + _indent_depth: int + _needs_indent: bool + _namespaces: List[str] + + def __init__(self, base_writer: TextIO, indent_str: str = " "): + self._writer = base_writer + self._indent_str = indent_str + self._indent_depth = 0 self._needs_indent = False - if self._indent_depth > 0: - self._writer.write(self._indent_str * self._indent_depth) - - if c == '\n': - self._needs_indent = True - - self._writer.write(c) - - def writeline(self, str: str = None): - if str: - self.write(str) - self.write('\n') - - def push_indent(self): - self._indent_depth += 1 - - def pop_indent(self): - self._indent_depth -= 1 - - def push_namespace(self, namespace: str): - self._namespaces.append(namespace) - self.writeline(f"namespace {namespace} {{") - - def pop_namespace(self): - self.writeline(f"}} // namespace {self._namespaces.pop()}") - - def pop_namespaces(self): - while len(self._namespaces) > 0: - self.pop_namespace() \ No newline at end of file + self._namespaces = [] + + def __enter__(self): + self._writer.__enter__() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self._writer.__exit__(exc_type, exc_val, exc_tb) + + def write(self, str: str): + if not str or len(str) <= 0: + return + + for c in str: + if self._needs_indent: + self._needs_indent = False + if self._indent_depth > 0: + self._writer.write(self._indent_str * self._indent_depth) + + if c == "\n": + self._needs_indent = True + + self._writer.write(c) + + def writeline(self, str: str = None): + if str: + self.write(str) + self.write("\n") + + def push_indent(self): + self._indent_depth += 1 + + def pop_indent(self): + self._indent_depth -= 1 + + def push_namespace(self, namespace: str): + self._namespaces.append(namespace) + self.writeline(f"namespace {namespace} {{") + + def pop_namespace(self): + self.writeline(f"}} // namespace {self._namespaces.pop()}") + + def pop_namespaces(self): + while len(self._namespaces) > 0: + self.pop_namespace() diff --git a/orttraining/orttraining/eager/opgen/opgen_test/__init__.py b/orttraining/orttraining/eager/opgen/opgen_test/__init__.py index 6fcf0de4918d2..5b7f7a925cc05 100644 --- a/orttraining/orttraining/eager/opgen/opgen_test/__init__.py +++ b/orttraining/orttraining/eager/opgen/opgen_test/__init__.py @@ -1,2 +1,2 @@ # Copyright (c) Microsoft Corporation. All rights reserved. -# Licensed under the MIT License. \ No newline at end of file +# Licensed under the MIT License. diff --git a/orttraining/orttraining/eager/opgen/opgen_test/lexer_test.py b/orttraining/orttraining/eager/opgen/opgen_test/lexer_test.py index 230e0c4b5caec..30e78377b2445 100644 --- a/orttraining/orttraining/eager/opgen/opgen_test/lexer_test.py +++ b/orttraining/orttraining/eager/opgen/opgen_test/lexer_test.py @@ -5,83 +5,75 @@ from opgen.lexer import StringReader, Lexer, Token, TokenKind, SourceLocation + class LexerTestCase(unittest.TestCase): - def create_lexer(self, buffer: str) -> Lexer: - return Lexer(StringReader(buffer)) + def create_lexer(self, buffer: str) -> Lexer: + return Lexer(StringReader(buffer)) - def lex_single(self, buffer: str, expected_kind: TokenKind): - lexer = self.create_lexer(buffer) - token = lexer.lex() - self.assertEqual(expected_kind, token.kind) - self.assertIsNone(token.leading_trivia) - self.assertIsNone(token.trailing_trivia) - eof = lexer.lex() - self.assertEqual(TokenKind.EOF, eof.kind) - self.assertIsNone(eof.value) - self.assertIsNone(eof.leading_trivia) - self.assertIsNone(eof.trailing_trivia) - return token + def lex_single(self, buffer: str, expected_kind: TokenKind): + lexer = self.create_lexer(buffer) + token = lexer.lex() + self.assertEqual(expected_kind, token.kind) + self.assertIsNone(token.leading_trivia) + self.assertIsNone(token.trailing_trivia) + eof = lexer.lex() + self.assertEqual(TokenKind.EOF, eof.kind) + self.assertIsNone(eof.value) + self.assertIsNone(eof.leading_trivia) + self.assertIsNone(eof.trailing_trivia) + return token - def test_empty(self): - lexer = self.create_lexer("") - self.assertEqual(lexer.lex().kind, TokenKind.EOF) + def test_empty(self): + lexer = self.create_lexer("") + self.assertEqual(lexer.lex().kind, TokenKind.EOF) - def test_trivia(self): - lexer = self.create_lexer(" // hello\nid\r\n// world") - self.assertEqual( - Token( - (13, 2, 1), - TokenKind.IDENTIFIER, - "id", - leading_trivia = [ - Token((0, 1, 1), TokenKind.WHITESPACE, " "), - Token((4, 1, 5), TokenKind.SINGLE_LINE_COMMENT, "// hello"), - Token((12, 1, 13), TokenKind.WHITESPACE, "\n") - ], - trailing_trivia = [ - Token((15, 2, 3), TokenKind.WHITESPACE, "\r\n"), - Token((17, 3, 1), TokenKind.SINGLE_LINE_COMMENT, "// world") - ]), - lexer.lex()) - self.assertEqual( - Token( - (25, 3, 9), - TokenKind.EOF, - None), - lexer.lex()) + def test_trivia(self): + lexer = self.create_lexer(" // hello\nid\r\n// world") + self.assertEqual( + Token( + (13, 2, 1), + TokenKind.IDENTIFIER, + "id", + leading_trivia=[ + Token((0, 1, 1), TokenKind.WHITESPACE, " "), + Token((4, 1, 5), TokenKind.SINGLE_LINE_COMMENT, "// hello"), + Token((12, 1, 13), TokenKind.WHITESPACE, "\n"), + ], + trailing_trivia=[ + Token((15, 2, 3), TokenKind.WHITESPACE, "\r\n"), + Token((17, 3, 1), TokenKind.SINGLE_LINE_COMMENT, "// world"), + ], + ), + lexer.lex(), + ) + self.assertEqual(Token((25, 3, 9), TokenKind.EOF, None), lexer.lex()) - def test_number(self): - def assert_number(number): - self.assertEqual( - number, - self.lex_single( - number, - TokenKind.NUMBER).value) + def test_number(self): + def assert_number(number): + self.assertEqual(number, self.lex_single(number, TokenKind.NUMBER).value) - for number in [ - "0", "01", "-1", "-.5", ".5", "0.5", "0e1", "1E10", "-0.5e6", - "-0.123e-456", "1234567891011231314151617181920", "-12345.6789E-123456"]: - assert_number(number) + for number in [ + "0", + "01", + "-1", + "-.5", + ".5", + "0.5", + "0e1", + "1E10", + "-0.5e6", + "-0.123e-456", + "1234567891011231314151617181920", + "-12345.6789E-123456", + ]: + assert_number(number) - for number in [ - "1.2.3", "e1", "-e1", "123e0.5"]: - self.assertRaises( - BaseException, - lambda: assert_number(number)) + for number in ["1.2.3", "e1", "-e1", "123e0.5"]: + self.assertRaises(BaseException, lambda: assert_number(number)) - lexer = self.create_lexer("1.2.3.4e5.6") - self.assertEqual( - Token((0, 1, 1), TokenKind.NUMBER, "1.2"), - lexer.lex()) - self.assertEqual( - Token((3, 1, 4), TokenKind.NUMBER, ".3"), - lexer.lex()) - self.assertEqual( - Token((5, 1, 6), TokenKind.NUMBER, ".4e5"), - lexer.lex()) - self.assertEqual( - Token((9, 1, 10), TokenKind.NUMBER, ".6"), - lexer.lex()) - self.assertEqual( - Token((11, 1, 12), TokenKind.EOF, None), - lexer.lex()) \ No newline at end of file + lexer = self.create_lexer("1.2.3.4e5.6") + self.assertEqual(Token((0, 1, 1), TokenKind.NUMBER, "1.2"), lexer.lex()) + self.assertEqual(Token((3, 1, 4), TokenKind.NUMBER, ".3"), lexer.lex()) + self.assertEqual(Token((5, 1, 6), TokenKind.NUMBER, ".4e5"), lexer.lex()) + self.assertEqual(Token((9, 1, 10), TokenKind.NUMBER, ".6"), lexer.lex()) + self.assertEqual(Token((11, 1, 12), TokenKind.EOF, None), lexer.lex()) diff --git a/orttraining/orttraining/eager/opgen/opgen_test/parser_test.py b/orttraining/orttraining/eager/opgen/opgen_test/parser_test.py index efa62a27d1475..1a5f61941d602 100644 --- a/orttraining/orttraining/eager/opgen/opgen_test/parser_test.py +++ b/orttraining/orttraining/eager/opgen/opgen_test/parser_test.py @@ -6,63 +6,64 @@ from opgen.lexer import TokenKind from opgen.parser import create_from_string as Parser + class ParserLookaheadTests(unittest.TestCase): - def test_no_peeks(self): - parser = Parser("1 2 3 4 5") - self.assertEqual("1", parser._read_token().value) - self.assertEqual("2", parser._read_token().value) - self.assertEqual("3", parser._read_token().value) - self.assertEqual("4", parser._read_token().value) - self.assertEqual("5", parser._read_token().value) + def test_no_peeks(self): + parser = Parser("1 2 3 4 5") + self.assertEqual("1", parser._read_token().value) + self.assertEqual("2", parser._read_token().value) + self.assertEqual("3", parser._read_token().value) + self.assertEqual("4", parser._read_token().value) + self.assertEqual("5", parser._read_token().value) - def test_peek_0(self): - parser = Parser("1 2 3 4 5") - self.assertRaises(IndexError, lambda: parser._peek_token(look_ahead = 0)) + def test_peek_0(self): + parser = Parser("1 2 3 4 5") + self.assertRaises(IndexError, lambda: parser._peek_token(look_ahead=0)) - def test_backward_peek_no_reads(self): - parser = Parser("1 2 3 4 5") - self.assertEqual("1", parser._peek_token(look_ahead = 1).value) - self.assertEqual("2", parser._peek_token(look_ahead = 2).value) - self.assertEqual("3", parser._peek_token(look_ahead = 3).value) - self.assertEqual("4", parser._peek_token(look_ahead = 4).value) - self.assertEqual("5", parser._peek_token(look_ahead = 5).value) + def test_backward_peek_no_reads(self): + parser = Parser("1 2 3 4 5") + self.assertEqual("1", parser._peek_token(look_ahead=1).value) + self.assertEqual("2", parser._peek_token(look_ahead=2).value) + self.assertEqual("3", parser._peek_token(look_ahead=3).value) + self.assertEqual("4", parser._peek_token(look_ahead=4).value) + self.assertEqual("5", parser._peek_token(look_ahead=5).value) - def test_forward_peek_no_reads(self): - parser = Parser("1 2 3 4 5") - self.assertEqual("5", parser._peek_token(look_ahead = 5).value) - self.assertEqual("4", parser._peek_token(look_ahead = 4).value) - self.assertEqual("3", parser._peek_token(look_ahead = 3).value) - self.assertEqual("2", parser._peek_token(look_ahead = 2).value) - self.assertEqual("1", parser._peek_token(look_ahead = 1).value) + def test_forward_peek_no_reads(self): + parser = Parser("1 2 3 4 5") + self.assertEqual("5", parser._peek_token(look_ahead=5).value) + self.assertEqual("4", parser._peek_token(look_ahead=4).value) + self.assertEqual("3", parser._peek_token(look_ahead=3).value) + self.assertEqual("2", parser._peek_token(look_ahead=2).value) + self.assertEqual("1", parser._peek_token(look_ahead=1).value) - def test_peek_and_read(self): - parser = Parser("1 2 3 4 5 6 7 8") - self.assertEqual("1", parser._read_token().value) - self.assertEqual("2", parser._read_token().value) - self.assertEqual("4", parser._peek_token(look_ahead = 2).value) - self.assertEqual("3", parser._peek_token(look_ahead = 1).value) - self.assertEqual("6", parser._peek_token(look_ahead = 4).value) - self.assertEqual("3", parser._read_token().value) - self.assertEqual("4", parser._read_token().value) - self.assertEqual("5", parser._peek_token(look_ahead = 1).value) - self.assertEqual("8", parser._peek_token(look_ahead = 4).value) - self.assertEqual("5", parser._read_token().value) - self.assertEqual("6", parser._read_token().value) - self.assertEqual("7", parser._read_token().value) - self.assertEqual("8", parser._peek_token(look_ahead = 1).value) + def test_peek_and_read(self): + parser = Parser("1 2 3 4 5 6 7 8") + self.assertEqual("1", parser._read_token().value) + self.assertEqual("2", parser._read_token().value) + self.assertEqual("4", parser._peek_token(look_ahead=2).value) + self.assertEqual("3", parser._peek_token(look_ahead=1).value) + self.assertEqual("6", parser._peek_token(look_ahead=4).value) + self.assertEqual("3", parser._read_token().value) + self.assertEqual("4", parser._read_token().value) + self.assertEqual("5", parser._peek_token(look_ahead=1).value) + self.assertEqual("8", parser._peek_token(look_ahead=4).value) + self.assertEqual("5", parser._read_token().value) + self.assertEqual("6", parser._read_token().value) + self.assertEqual("7", parser._read_token().value) + self.assertEqual("8", parser._peek_token(look_ahead=1).value) - def test_peek_value_and_read(self): - parser = Parser("1 2 3 4 5 6 7 8") - self.assertEqual("1", parser._read_token().value) - self.assertEqual("2", parser._read_token().value) - self.assertIsNotNone(parser._peek_token(TokenKind.NUMBER, "4", look_ahead = 2)) - self.assertIsNotNone(parser._peek_token(TokenKind.NUMBER, "3", look_ahead = 1)) - self.assertIsNotNone(parser._peek_token(TokenKind.NUMBER, "6", look_ahead = 4)) - self.assertEqual("3", parser._read_token().value) - self.assertEqual("4", parser._read_token().value) - self.assertIsNotNone(parser._peek_token(TokenKind.NUMBER, "5", look_ahead = 1)) - self.assertIsNotNone(parser._peek_token(TokenKind.NUMBER, "8", look_ahead = 4)) - self.assertEqual("5", parser._read_token().value) - self.assertEqual("6", parser._read_token().value) - self.assertEqual("7", parser._read_token().value) - self.assertIsNotNone(parser._peek_token(TokenKind.NUMBER, "8", look_ahead = 1)) \ No newline at end of file + def test_peek_value_and_read(self): + parser = Parser("1 2 3 4 5 6 7 8") + self.assertEqual("1", parser._read_token().value) + self.assertEqual("2", parser._read_token().value) + self.assertIsNotNone(parser._peek_token(TokenKind.NUMBER, "4", look_ahead=2)) + self.assertIsNotNone(parser._peek_token(TokenKind.NUMBER, "3", look_ahead=1)) + self.assertIsNotNone(parser._peek_token(TokenKind.NUMBER, "6", look_ahead=4)) + self.assertEqual("3", parser._read_token().value) + self.assertEqual("4", parser._read_token().value) + self.assertIsNotNone(parser._peek_token(TokenKind.NUMBER, "5", look_ahead=1)) + self.assertIsNotNone(parser._peek_token(TokenKind.NUMBER, "8", look_ahead=4)) + self.assertEqual("5", parser._read_token().value) + self.assertEqual("6", parser._read_token().value) + self.assertEqual("7", parser._read_token().value) + self.assertIsNotNone(parser._peek_token(TokenKind.NUMBER, "8", look_ahead=1)) diff --git a/orttraining/orttraining/eager/test/__main__.py b/orttraining/orttraining/eager/test/__main__.py index 76923bed67a4b..f188f3c1fc3c3 100644 --- a/orttraining/orttraining/eager/test/__main__.py +++ b/orttraining/orttraining/eager/test/__main__.py @@ -8,9 +8,10 @@ selfdir = os.path.dirname(os.path.realpath(__file__)) -for testpath in glob.glob(os.path.join(selfdir, '*')): - if not os.path.basename(testpath).startswith('_') and \ - not (sys.platform.startswith("win") and os.path.basename(testpath).startswith('linux_only_')): - print(f'Running tests for {testpath} ...') - subprocess.check_call([sys.executable, testpath]) - print() \ No newline at end of file +for testpath in glob.glob(os.path.join(selfdir, "*")): + if not os.path.basename(testpath).startswith("_") and not ( + sys.platform.startswith("win") and os.path.basename(testpath).startswith("linux_only_") + ): + print(f"Running tests for {testpath} ...") + subprocess.check_call([sys.executable, testpath]) + print() diff --git a/orttraining/orttraining/eager/test/linux_only_ortmodule_eager_test.py b/orttraining/orttraining/eager/test/linux_only_ortmodule_eager_test.py index d010c95134e0f..96698ca6dcd5a 100644 --- a/orttraining/orttraining/eager/test/linux_only_ortmodule_eager_test.py +++ b/orttraining/orttraining/eager/test/linux_only_ortmodule_eager_test.py @@ -1,4 +1,3 @@ - import torch from onnxruntime.capi import _pybind_state as torch_ort_eager from onnxruntime.training import ORTModule @@ -9,9 +8,11 @@ import os import unittest + def my_loss(x, target): return F.nll_loss(F.log_softmax(x, dim=1), target) + class NeuralNet(nn.Module): def __init__(self, input_size, hidden_size, num_classes): super(NeuralNet, self).__init__() @@ -25,24 +26,25 @@ def forward(self, x): out = self.fc2(out) return out + class NoOpNet(torch.nn.Module): def __init__(self): super(NoOpNet, self).__init__() - self.dummy_weight = torch.nn.Parameter( - torch.ones(128, 128, dtype=torch.float16)) + self.dummy_weight = torch.nn.Parameter(torch.ones(128, 128, dtype=torch.float16)) def forward(self, input): return input + class OrtModuleEagerTest(unittest.TestCase): def test_half_type(self): model = NoOpNet() - device = torch.device('ort') + device = torch.device("ort") model.to(device) model = ORTModule(model) - input = torch.ones(2,2).to(torch.float16) + input = torch.ones(2, 2).to(torch.float16) y = model(input.to(device)) - assert(y.dtype == torch.float16) + assert y.dtype == torch.float16 def test_ort_module_and_eager_mode(self): input_size = 784 @@ -51,26 +53,26 @@ def test_ort_module_and_eager_mode(self): batch_size = 128 model = NeuralNet(input_size, hidden_size, num_classes) optimizer = optim.SGD(model.parameters(), lr=0.01) - + data = torch.rand(batch_size, input_size) target = torch.randint(0, 10, (batch_size,)) # save the initial state initial_state = model.state_dict() - # run on cpu first + # run on cpu first x = model(data) loss = my_loss(x, target) loss.backward() optimizer.step() optimizer.zero_grad() - - #record the updated parameters + + # record the updated parameters cpu_updated_state = model.state_dict() - #reload initial state + # reload initial state model.load_state_dict(initial_state) - #run on ort with ORTModule and eager mode - #use device_idx 1 to test non-zero device - torch_ort_eager.set_device(1, 'CPUExecutionProvider', {'dummy':'dummy'}) - device = torch.device('ort', index=0) + # run on ort with ORTModule and eager mode + # use device_idx 1 to test non-zero device + torch_ort_eager.set_device(1, "CPUExecutionProvider", {"dummy": "dummy"}) + device = torch.device("ort", index=0) model.to(device) model = ORTModule(model) ort_optimizer = optim.SGD(model.parameters(), lr=0.01) @@ -81,10 +83,11 @@ def test_ort_module_and_eager_mode(self): ort_optimizer.zero_grad() ort_updated_state = model.state_dict() - #compare the updated state + # compare the updated state for state_tensor in cpu_updated_state: assert state_tensor in ort_updated_state assert torch.allclose(cpu_updated_state[state_tensor], ort_updated_state[state_tensor].cpu(), atol=1e-3) -if __name__ == '__main__': + +if __name__ == "__main__": unittest.main() diff --git a/orttraining/orttraining/eager/test/ort_eps_test.py b/orttraining/orttraining/eager/test/ort_eps_test.py index 9a5c8ba32b914..8764132a1aa6b 100644 --- a/orttraining/orttraining/eager/test/ort_eps_test.py +++ b/orttraining/orttraining/eager/test/ort_eps_test.py @@ -7,8 +7,11 @@ import os import sys + def is_windows(): return sys.platform.startswith("win") + + from io import StringIO import sys import threading @@ -19,6 +22,7 @@ class OutputGrabber(object): """ Class used to grab standard output or another stream. """ + escape_char = "\b" def __init__(self, stream=None, threaded=False): @@ -53,7 +57,7 @@ def start(self): self.workerThread.start() # Make sure that the thread is running and os.read() has executed: time.sleep(0.01) - + def stop(self): """ Stop capturing the stream data and save the text in `capturedtext`. @@ -83,49 +87,51 @@ def readOutput(self): and save the text in `capturedtext`. """ while True: - char = os.read(self.pipe_out,1).decode(self.origstream.encoding) + char = os.read(self.pipe_out, 1).decode(self.origstream.encoding) if not char or self.escape_char in char: break self.capturedtext += char + class OrtEPTests(unittest.TestCase): - def get_test_execution_provider_path(self): - if is_windows(): - return os.path.join('.', 'test_execution_provider.dll') - else: - return os.path.join('.', 'libtest_execution_provider.so') - - def test_import_custom_eps(self): - torch_ort.set_device(0, 'CPUExecutionProvider', {}) - - torch_ort._register_provider_lib('TestExecutionProvider', self.get_test_execution_provider_path(), {}) - # capture std out - with OutputGrabber() as out: - torch_ort.set_device(1, 'TestExecutionProvider', {'device_id':'0', 'some_config':'val'}) - ort_device = torch_ort.device(1) - assert 'My EP provider created, with device id: 0, some_option: val' in out.capturedtext - with OutputGrabber() as out: - torch_ort.set_device(2, 'TestExecutionProvider', {'device_id':'1', 'some_config':'val'}) - ort_device = torch_ort.device(1) - assert 'My EP provider created, with device id: 1, some_option: val' in out.capturedtext - # test the reusing EP instance - with OutputGrabber() as out: - torch_ort.set_device(3, 'TestExecutionProvider', {'device_id':'0', 'some_config':'val'}) - ort_device = torch_ort.device(1) - assert 'My EP provider created, with device id: 0, some_option: val' not in out.capturedtext - # test clear training ep instance pool - torch_ort.clear_training_ep_instances() - with OutputGrabber() as out: - torch_ort.set_device(3, 'TestExecutionProvider', {'device_id':'0', 'some_config':'val'}) - ort_device = torch_ort.device(1) - assert 'My EP provider created, with device id: 0, some_option: val' in out.capturedtext - - def test_print(self): - x = torch.ones(1, 2) - ort_x = x.to('ort') - with OutputGrabber() as out: - print(ort_x) - assert "tensor([[1., 1.]], device='ort:0')" in out.capturedtext - -if __name__ == '__main__': - unittest.main() \ No newline at end of file + def get_test_execution_provider_path(self): + if is_windows(): + return os.path.join(".", "test_execution_provider.dll") + else: + return os.path.join(".", "libtest_execution_provider.so") + + def test_import_custom_eps(self): + torch_ort.set_device(0, "CPUExecutionProvider", {}) + + torch_ort._register_provider_lib("TestExecutionProvider", self.get_test_execution_provider_path(), {}) + # capture std out + with OutputGrabber() as out: + torch_ort.set_device(1, "TestExecutionProvider", {"device_id": "0", "some_config": "val"}) + ort_device = torch_ort.device(1) + assert "My EP provider created, with device id: 0, some_option: val" in out.capturedtext + with OutputGrabber() as out: + torch_ort.set_device(2, "TestExecutionProvider", {"device_id": "1", "some_config": "val"}) + ort_device = torch_ort.device(1) + assert "My EP provider created, with device id: 1, some_option: val" in out.capturedtext + # test the reusing EP instance + with OutputGrabber() as out: + torch_ort.set_device(3, "TestExecutionProvider", {"device_id": "0", "some_config": "val"}) + ort_device = torch_ort.device(1) + assert "My EP provider created, with device id: 0, some_option: val" not in out.capturedtext + # test clear training ep instance pool + torch_ort.clear_training_ep_instances() + with OutputGrabber() as out: + torch_ort.set_device(3, "TestExecutionProvider", {"device_id": "0", "some_config": "val"}) + ort_device = torch_ort.device(1) + assert "My EP provider created, with device id: 0, some_option: val" in out.capturedtext + + def test_print(self): + x = torch.ones(1, 2) + ort_x = x.to("ort") + with OutputGrabber() as out: + print(ort_x) + assert "tensor([[1., 1.]], device='ort:0')" in out.capturedtext + + +if __name__ == "__main__": + unittest.main() diff --git a/orttraining/orttraining/eager/test/ort_init.py b/orttraining/orttraining/eager/test/ort_init.py index 3a40058701ee1..43602cc6a5fdb 100644 --- a/orttraining/orttraining/eager/test/ort_init.py +++ b/orttraining/orttraining/eager/test/ort_init.py @@ -10,20 +10,23 @@ import unittest import torch + class OrtInitTests(unittest.TestCase): - def test_ort_init(self): - config_match = 'ORT is enabled' + def test_ort_init(self): + config_match = "ORT is enabled" + + def ort_alloc(): + torch.zeros(5, 5, device="ort") + + self.assertNotIn(config_match, torch._C._show_config()) + with self.assertRaises(BaseException): + ort_alloc() - def ort_alloc(): - torch.zeros(5, 5, device='ort') + import onnxruntime_pybind11_state as torch_ort - self.assertNotIn(config_match, torch._C._show_config()) - with self.assertRaises(BaseException): - ort_alloc() + ort_alloc() + self.assertIn(config_match, torch._C._show_config()) - import onnxruntime_pybind11_state as torch_ort - ort_alloc() - self.assertIn(config_match, torch._C._show_config()) -if __name__ == '__main__': - unittest.main() +if __name__ == "__main__": + unittest.main() diff --git a/orttraining/orttraining/eager/test/ort_ops.py b/orttraining/orttraining/eager/test/ort_ops.py index 668f26694c199..dd57b729c32d4 100644 --- a/orttraining/orttraining/eager/test/ort_ops.py +++ b/orttraining/orttraining/eager/test/ort_ops.py @@ -6,135 +6,135 @@ import onnxruntime_pybind11_state as torch_ort import numpy as np + class OrtOpTests(unittest.TestCase): - def get_device(self): - return torch_ort.device() - - def test_add(self): - device = self.get_device() - cpu_ones = torch.Tensor([[1, 1, 1], [1, 1, 1], [1, 1, 1]]) - ort_ones = cpu_ones.to(device) - cpu_twos = cpu_ones + cpu_ones - ort_twos = ort_ones + ort_ones - assert torch.allclose(cpu_twos, ort_twos.cpu()) - - def test_type_promotion_add(self): - device = self.get_device() - x = torch.ones(2, 5, dtype = torch.int64) - y = torch.ones(2, 5, dtype = torch.float32) - ort_x = x.to(device) - ort_y = y.to(device) - ort_z = ort_x + ort_y - assert ort_z.dtype == torch.float32 - assert torch.allclose(ort_z.cpu(), (x + y)) - - def test_add_alpha(self): - device = self.get_device() - cpu_ones = torch.Tensor([[1, 1, 1], [1, 1, 1], [1, 1, 1]]) - ort_ones = cpu_ones.to(device) - assert torch.allclose( - torch.add(cpu_ones, cpu_ones, alpha=2.5), - torch.add(ort_ones, ort_ones, alpha=2.5).cpu()) - - def test_mul_bool(self): - device = self.get_device() - cpu_ones = torch.ones(3, 3, dtype=bool) - ort_ones = cpu_ones.to(device) - assert torch.allclose( - torch.mul(cpu_ones, cpu_ones), - torch.mul(ort_ones, ort_ones).cpu()) - - # TODO: Add BFloat16 test coverage - def test_add_(self): - device = self.get_device() - cpu_ones = torch.Tensor([[1, 1, 1], [1, 1, 1], [1, 1, 1]]) - ort_ones = cpu_ones.to(device) - cpu_twos = cpu_ones - cpu_twos += cpu_ones - ort_twos = ort_ones - ort_twos += ort_ones - assert torch.allclose(cpu_twos, ort_twos.cpu()) - - def test_sin_(self): - device = self.get_device() - cpu_sin_pi_ = torch.Tensor([np.pi]) - torch.sin_(cpu_sin_pi_) - ort_sin_pi_ = torch.Tensor([np.pi]).to(device) - torch.sin_(ort_sin_pi_) - cpu_sin_pi = torch.sin(torch.Tensor([np.pi])) - ort_sin_pi = torch.sin(torch.Tensor([np.pi]).to(device)) - assert torch.allclose(cpu_sin_pi, ort_sin_pi.cpu()) - assert torch.allclose(cpu_sin_pi_, ort_sin_pi_.cpu()) - assert torch.allclose(ort_sin_pi.cpu(), ort_sin_pi_.cpu()) - - def test_sin(self): - device = self.get_device() - cpu_sin_pi = torch.sin(torch.Tensor([np.pi])) - ort_sin_pi = torch.sin(torch.Tensor([np.pi]).to(device)) - assert torch.allclose(cpu_sin_pi, ort_sin_pi.cpu()) - - def test_zero_like(self): - device = self.get_device() - ones = torch.ones((10, 10), dtype=torch.float32) - cpu_zeros = torch.zeros_like(ones) - ort_zeros = torch.zeros_like(ones.to(device)) - assert torch.allclose(cpu_zeros, ort_zeros.cpu()) - - def test_gemm(self): - device = self.get_device() - cpu_ones = torch.Tensor([[1, 1, 1], [1, 1, 1], [1, 1, 1]]) - ort_ones = cpu_ones.to(device) - cpu_ans = cpu_ones * 4 - ort_ans = torch_ort.custom_ops.gemm(ort_ones, ort_ones, ort_ones, 1.0, 1.0, 0, 0) - assert torch.allclose(cpu_ans, ort_ans.cpu()) - - def test_batchnormalization_inplace(self): - device = self.get_device() - x = torch.Tensor([[[[-1, 0, 1]], [[2., 3., 4.]]]]).to(device) - s = torch.Tensor([1.0, 1.5]).to(device) - bias = torch.Tensor([0., 1.]).to(device) - mean = torch.Tensor([0., 3.]).to(device) - var = torch.Tensor([1., 1.5]).to(device) - y, mean_out, var_out = torch_ort.custom_ops.batchnorm_inplace(x, s, bias, mean, var, 1e-5, 0.9) - assert torch.allclose(x.cpu(), y.cpu()), "x != y" - assert torch.allclose(mean.cpu(), mean_out.cpu()), "mean != mean_out" - assert torch.allclose(var.cpu(), var_out.cpu()), "var != var_out" - - def test_max(self): - cpu_tensor = torch.rand(10, 10) - ort_tensor = cpu_tensor.to('ort') - y = ort_tensor.max() - x = cpu_tensor.max() - assert torch.allclose(x, y.cpu()) - - def test_min(self): - cpu_tensor = torch.rand(10, 10) - ort_tensor = cpu_tensor.to('ort') - y = ort_tensor.min() - x = cpu_tensor.min() - assert torch.allclose(x, y.cpu()) - - def test_torch_ones(self): - device = self.get_device() - cpu_ones = torch.ones((10,10)) - ort_ones = cpu_ones.to(device) - ort_ones_device = torch.ones((10, 10), device = device) - assert torch.allclose(cpu_ones, ort_ones.cpu()) - assert torch.allclose(cpu_ones, ort_ones_device.cpu()) - - def test_narrow(self): - cpu_tensor = torch.rand(10, 10) - cpu_narrow = cpu_tensor.narrow(0, 5, 5) - ort_narrow = cpu_narrow.to('ort') - assert torch.allclose(cpu_narrow, ort_narrow.cpu()) - - def test_zero_stride(self): - print('ssssss') - device = self.get_device() - t = torch.empty_strided(size=(6, 1024, 512), stride=(0, 0, 0)) - assert(t.storage().size() == 1) # This test is trying to confirm that transferring a tensor with a storage size of 1 works - ort_t = t.to(device) - assert torch.allclose(t, ort_t.cpu()) - -if __name__ == '__main__': - unittest.main() \ No newline at end of file + def get_device(self): + return torch_ort.device() + + def test_add(self): + device = self.get_device() + cpu_ones = torch.Tensor([[1, 1, 1], [1, 1, 1], [1, 1, 1]]) + ort_ones = cpu_ones.to(device) + cpu_twos = cpu_ones + cpu_ones + ort_twos = ort_ones + ort_ones + assert torch.allclose(cpu_twos, ort_twos.cpu()) + + def test_type_promotion_add(self): + device = self.get_device() + x = torch.ones(2, 5, dtype=torch.int64) + y = torch.ones(2, 5, dtype=torch.float32) + ort_x = x.to(device) + ort_y = y.to(device) + ort_z = ort_x + ort_y + assert ort_z.dtype == torch.float32 + assert torch.allclose(ort_z.cpu(), (x + y)) + + def test_add_alpha(self): + device = self.get_device() + cpu_ones = torch.Tensor([[1, 1, 1], [1, 1, 1], [1, 1, 1]]) + ort_ones = cpu_ones.to(device) + assert torch.allclose(torch.add(cpu_ones, cpu_ones, alpha=2.5), torch.add(ort_ones, ort_ones, alpha=2.5).cpu()) + + def test_mul_bool(self): + device = self.get_device() + cpu_ones = torch.ones(3, 3, dtype=bool) + ort_ones = cpu_ones.to(device) + assert torch.allclose(torch.mul(cpu_ones, cpu_ones), torch.mul(ort_ones, ort_ones).cpu()) + + # TODO: Add BFloat16 test coverage + def test_add_(self): + device = self.get_device() + cpu_ones = torch.Tensor([[1, 1, 1], [1, 1, 1], [1, 1, 1]]) + ort_ones = cpu_ones.to(device) + cpu_twos = cpu_ones + cpu_twos += cpu_ones + ort_twos = ort_ones + ort_twos += ort_ones + assert torch.allclose(cpu_twos, ort_twos.cpu()) + + def test_sin_(self): + device = self.get_device() + cpu_sin_pi_ = torch.Tensor([np.pi]) + torch.sin_(cpu_sin_pi_) + ort_sin_pi_ = torch.Tensor([np.pi]).to(device) + torch.sin_(ort_sin_pi_) + cpu_sin_pi = torch.sin(torch.Tensor([np.pi])) + ort_sin_pi = torch.sin(torch.Tensor([np.pi]).to(device)) + assert torch.allclose(cpu_sin_pi, ort_sin_pi.cpu()) + assert torch.allclose(cpu_sin_pi_, ort_sin_pi_.cpu()) + assert torch.allclose(ort_sin_pi.cpu(), ort_sin_pi_.cpu()) + + def test_sin(self): + device = self.get_device() + cpu_sin_pi = torch.sin(torch.Tensor([np.pi])) + ort_sin_pi = torch.sin(torch.Tensor([np.pi]).to(device)) + assert torch.allclose(cpu_sin_pi, ort_sin_pi.cpu()) + + def test_zero_like(self): + device = self.get_device() + ones = torch.ones((10, 10), dtype=torch.float32) + cpu_zeros = torch.zeros_like(ones) + ort_zeros = torch.zeros_like(ones.to(device)) + assert torch.allclose(cpu_zeros, ort_zeros.cpu()) + + def test_gemm(self): + device = self.get_device() + cpu_ones = torch.Tensor([[1, 1, 1], [1, 1, 1], [1, 1, 1]]) + ort_ones = cpu_ones.to(device) + cpu_ans = cpu_ones * 4 + ort_ans = torch_ort.custom_ops.gemm(ort_ones, ort_ones, ort_ones, 1.0, 1.0, 0, 0) + assert torch.allclose(cpu_ans, ort_ans.cpu()) + + def test_batchnormalization_inplace(self): + device = self.get_device() + x = torch.Tensor([[[[-1, 0, 1]], [[2.0, 3.0, 4.0]]]]).to(device) + s = torch.Tensor([1.0, 1.5]).to(device) + bias = torch.Tensor([0.0, 1.0]).to(device) + mean = torch.Tensor([0.0, 3.0]).to(device) + var = torch.Tensor([1.0, 1.5]).to(device) + y, mean_out, var_out = torch_ort.custom_ops.batchnorm_inplace(x, s, bias, mean, var, 1e-5, 0.9) + assert torch.allclose(x.cpu(), y.cpu()), "x != y" + assert torch.allclose(mean.cpu(), mean_out.cpu()), "mean != mean_out" + assert torch.allclose(var.cpu(), var_out.cpu()), "var != var_out" + + def test_max(self): + cpu_tensor = torch.rand(10, 10) + ort_tensor = cpu_tensor.to("ort") + y = ort_tensor.max() + x = cpu_tensor.max() + assert torch.allclose(x, y.cpu()) + + def test_min(self): + cpu_tensor = torch.rand(10, 10) + ort_tensor = cpu_tensor.to("ort") + y = ort_tensor.min() + x = cpu_tensor.min() + assert torch.allclose(x, y.cpu()) + + def test_torch_ones(self): + device = self.get_device() + cpu_ones = torch.ones((10, 10)) + ort_ones = cpu_ones.to(device) + ort_ones_device = torch.ones((10, 10), device=device) + assert torch.allclose(cpu_ones, ort_ones.cpu()) + assert torch.allclose(cpu_ones, ort_ones_device.cpu()) + + def test_narrow(self): + cpu_tensor = torch.rand(10, 10) + cpu_narrow = cpu_tensor.narrow(0, 5, 5) + ort_narrow = cpu_narrow.to("ort") + assert torch.allclose(cpu_narrow, ort_narrow.cpu()) + + def test_zero_stride(self): + print("ssssss") + device = self.get_device() + t = torch.empty_strided(size=(6, 1024, 512), stride=(0, 0, 0)) + assert ( + t.storage().size() == 1 + ) # This test is trying to confirm that transferring a tensor with a storage size of 1 works + ort_t = t.to(device) + assert torch.allclose(t, ort_t.cpu()) + + +if __name__ == "__main__": + unittest.main() diff --git a/orttraining/orttraining/eager/test/ort_tensor.py b/orttraining/orttraining/eager/test/ort_tensor.py index 1f79f624b03fe..28b0d5455b1ec 100644 --- a/orttraining/orttraining/eager/test/ort_tensor.py +++ b/orttraining/orttraining/eager/test/ort_tensor.py @@ -5,63 +5,65 @@ import torch import onnxruntime_pybind11_state as torch_ort + class OrtTensorTests(unittest.TestCase): - def test_is_ort_via_alloc(self): - cpu_ones = torch.zeros(10, 10) - assert not cpu_ones.is_ort - ort_ones = torch.zeros(10, 10, device='ort') - assert ort_ones.is_ort - assert torch.allclose(cpu_ones, ort_ones.cpu()) + def test_is_ort_via_alloc(self): + cpu_ones = torch.zeros(10, 10) + assert not cpu_ones.is_ort + ort_ones = torch.zeros(10, 10, device="ort") + assert ort_ones.is_ort + assert torch.allclose(cpu_ones, ort_ones.cpu()) + + def test_is_ort_via_to(self): + cpu_ones = torch.ones(10, 10) + assert not cpu_ones.is_ort + ort_ones = cpu_ones.to("ort") + assert ort_ones.is_ort + assert torch.allclose(cpu_ones, ort_ones.cpu()) - def test_is_ort_via_to(self): - cpu_ones = torch.ones(10, 10) - assert not cpu_ones.is_ort - ort_ones = cpu_ones.to('ort') - assert ort_ones.is_ort - assert torch.allclose(cpu_ones, ort_ones.cpu()) + def test_reshape(self): + cpu_ones = torch.ones(10, 10) + ort_ones = cpu_ones.to("ort") + y = ort_ones.reshape(-1) + assert len(y.size()) == 1 + assert y.size()[0] == 100 - def test_reshape(self): - cpu_ones = torch.ones(10, 10) - ort_ones = cpu_ones.to('ort') - y = ort_ones.reshape(-1) - assert len(y.size()) == 1 - assert y.size()[0] == 100 + def test_view(self): + cpu_ones = torch.ones(2048) + ort_ones = cpu_ones.to("ort") + y = ort_ones.view(4, 512) + assert y.size() == (4, 512) - def test_view(self): - cpu_ones = torch.ones(2048) - ort_ones = cpu_ones.to('ort') - y = ort_ones.view(4, 512) - assert y.size() == (4, 512) + def test_view_neg1(self): + cpu_ones = torch.ones(784, 256) + ort_ones = cpu_ones.to("ort") + y = ort_ones.view(-1) + assert y.size()[0] == 200704 - def test_view_neg1(self): - cpu_ones = torch.ones(784, 256) - ort_ones = cpu_ones.to('ort') - y = ort_ones.view(-1) - assert y.size()[0] == 200704 + def test_stride(self): + cpu_ones = torch.ones(3, 3) + ort_ones = cpu_ones.to("ort") + y = torch.as_strided(ort_ones, (2, 2), (1, 2)) + assert y.size() == (2, 2) + assert y.is_contiguous() == False + contiguous_y = y.contiguous() + w = torch.ones((2, 3)) + ort_w = w.to("ort") + z = torch.zeros((2, 3)) + ort_z = z.to("ort") + ort_z = torch.addmm(ort_z, contiguous_y, ort_w) + cpu_z = torch.addmm(z, torch.ones(2, 2), w) + assert torch.allclose(ort_z.cpu(), cpu_z) - def test_stride(self): - cpu_ones = torch.ones(3, 3) - ort_ones = cpu_ones.to('ort') - y = torch.as_strided(ort_ones, (2, 2), (1, 2)) - assert y.size() == (2, 2) - assert y.is_contiguous() == False - contiguous_y = y.contiguous() - w = torch.ones((2,3)) - ort_w = w.to('ort') - z = torch.zeros((2, 3)) - ort_z = z.to('ort') - ort_z = torch.addmm(ort_z, contiguous_y, ort_w) - cpu_z = torch.addmm(z, torch.ones(2, 2), w) - assert torch.allclose(ort_z.cpu(), cpu_z) + def test_slice(self): + cpu_ones = torch.ones((128, 256), dtype=torch.bfloat16) + ort_ones = cpu_ones.to("ort") + y_cpu = cpu_ones[0:128, :128] + y = ort_ones[0:128, :128] + assert y.is_contiguous() == False + assert y.size() == (128, 128) + assert torch.allclose(y.cpu(), y_cpu) - def test_slice(self): - cpu_ones = torch.ones((128, 256), dtype=torch.bfloat16) - ort_ones = cpu_ones.to('ort') - y_cpu = cpu_ones[0:128, :128] - y = ort_ones[0:128, :128] - assert y.is_contiguous() == False - assert y.size() == (128, 128) - assert torch.allclose(y.cpu(), y_cpu) -if __name__ == '__main__': - unittest.main() +if __name__ == "__main__": + unittest.main() diff --git a/orttraining/orttraining/eager/test_model/mnist_fc_training.py b/orttraining/orttraining/eager/test_model/mnist_fc_training.py index 32bb94328ffbc..b96544e4f58a1 100644 --- a/orttraining/orttraining/eager/test_model/mnist_fc_training.py +++ b/orttraining/orttraining/eager/test_model/mnist_fc_training.py @@ -16,6 +16,7 @@ import numpy as np import os + class NeuralNet(nn.Module): def __init__(self, input_size, hidden_size, num_classes): super(NeuralNet, self).__init__() @@ -29,9 +30,11 @@ def forward(self, x): out = self.fc2(out) return out + def my_loss(x, target): return F.nll_loss(F.log_softmax(x, dim=1), target) + def train_with_eager(args, model, optimizer, device, train_loader, epoch): for batch_idx, (data, target) in enumerate(train_loader): data_cpu = data.reshape(data.shape[0], -1) @@ -44,54 +47,72 @@ def train_with_eager(args, model, optimizer, device, train_loader, epoch): optimizer.step() optimizer.zero_grad() - # Since the output corresponds to [loss_desc, probability_desc], the first value is taken as loss. if batch_idx % args.log_interval == 0: - print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( - epoch, batch_idx * len(data_cpu), len(train_loader.dataset), - 100. * batch_idx / len(train_loader), loss)) + print( + "Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format( + epoch, + batch_idx * len(data_cpu), + len(train_loader.dataset), + 100.0 * batch_idx / len(train_loader), + loss, + ) + ) -def main(): -#Training settings - parser = argparse.ArgumentParser(description='PyTorch MNIST Example') - parser.add_argument('--batch-size', type=int, default=64, metavar='N', - help='input batch size for training (default: 64)') - parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N', - help='input batch size for testing (default: 1000)') - parser.add_argument('--epochs', type=int, default=10, metavar='N', - help='number of epochs to train (default: 10)') - parser.add_argument('--lr', type=float, default=0.01, metavar='LR', - help='learning rate (default: 0.01)') - parser.add_argument('--no-cuda', action='store_true', default=False, - help='disables CUDA training') - parser.add_argument('--seed', type=int, default=1, metavar='S', - help='random seed (default: 1)') - parser.add_argument('--log-interval', type=int, default=10, metavar='N', - help='how many batches to wait before logging training status') +def main(): + # Training settings + parser = argparse.ArgumentParser(description="PyTorch MNIST Example") + parser.add_argument( + "--batch-size", type=int, default=64, metavar="N", help="input batch size for training (default: 64)" + ) + parser.add_argument( + "--test-batch-size", type=int, default=1000, metavar="N", help="input batch size for testing (default: 1000)" + ) + parser.add_argument("--epochs", type=int, default=10, metavar="N", help="number of epochs to train (default: 10)") + parser.add_argument("--lr", type=float, default=0.01, metavar="LR", help="learning rate (default: 0.01)") + parser.add_argument("--no-cuda", action="store_true", default=False, help="disables CUDA training") + parser.add_argument("--seed", type=int, default=1, metavar="S", help="random seed (default: 1)") + parser.add_argument( + "--log-interval", + type=int, + default=10, + metavar="N", + help="how many batches to wait before logging training status", + ) args = parser.parse_args() use_cuda = not args.no_cuda and torch.cuda.is_available() torch.manual_seed(args.seed) - kwargs = {'num_workers': 0, 'pin_memory': True} + kwargs = {"num_workers": 0, "pin_memory": True} train_loader = torch.utils.data.DataLoader( - datasets.MNIST('./data', train=True, download=True, - transform=transforms.Compose([ - transforms.ToTensor(), - transforms.Normalize((0.1307,), (0.3081,)) - ])), - batch_size=args.batch_size, shuffle=True, **kwargs) + datasets.MNIST( + "./data", + train=True, + download=True, + transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]), + ), + batch_size=args.batch_size, + shuffle=True, + **kwargs + ) test_loader = torch.utils.data.DataLoader( - datasets.MNIST('./data', train=False, transform=transforms.Compose([ - transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])), - batch_size=args.test_batch_size, shuffle=True, **kwargs) + datasets.MNIST( + "./data", + train=False, + transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]), + ), + batch_size=args.test_batch_size, + shuffle=True, + **kwargs + ) # set device - torch_ort_eager.set_device(0, 'CPUExecutionProvider', {}) + torch_ort_eager.set_device(0, "CPUExecutionProvider", {}) - device = torch.device('ort', index=0) + device = torch.device("ort", index=0) input_size = 784 hidden_size = 500 num_classes = 10 @@ -100,11 +121,11 @@ def main(): model = ORTModule(model) optimizer = optim.SGD(model.parameters(), lr=0.01) - print('\nStart Training.') + print("\nStart Training.") for epoch in range(1, args.epochs + 1): train_with_eager(args, model, optimizer, device, train_loader, epoch) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/orttraining/orttraining/eager/test_models/mnist_fc.py b/orttraining/orttraining/eager/test_models/mnist_fc.py index 1ed647c320240..0f0b3bb604149 100644 --- a/orttraining/orttraining/eager/test_models/mnist_fc.py +++ b/orttraining/orttraining/eager/test_models/mnist_fc.py @@ -7,6 +7,7 @@ import os import onnxruntime_pybind11_state as torch_ort + class NeuralNet(nn.Module): def __init__(self, input_size, hidden_size, num_classes): super(NeuralNet, self).__init__() @@ -20,6 +21,7 @@ def forward(self, x): out = self.fc2(out) return out + input_size = 784 hidden_size = 500 num_classes = 10 @@ -42,4 +44,4 @@ def forward(self, x): print("ORT inference result is:") print(ort_pred.cpu()) print("Compare result:") - print(torch.allclose(pred, ort_pred.cpu(), atol=1e-6)) \ No newline at end of file + print(torch.allclose(pred, ort_pred.cpu(), atol=1e-6)) diff --git a/orttraining/orttraining/eager/test_models/mnist_fc_training.py b/orttraining/orttraining/eager/test_models/mnist_fc_training.py index a2b0a859fd418..869d9cf8f721f 100644 --- a/orttraining/orttraining/eager/test_models/mnist_fc_training.py +++ b/orttraining/orttraining/eager/test_models/mnist_fc_training.py @@ -15,9 +15,8 @@ import numpy as np import os -dataset_root_dir = os.path.join( - os.path.dirname(os.path.realpath(__file__)), - 'data') +dataset_root_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "data") + class NeuralNet(nn.Module): def __init__(self, input_size, hidden_size, num_classes): @@ -32,9 +31,11 @@ def forward(self, x): out = self.fc2(out) return out + def my_loss(x, target): return F.nll_loss(F.log_softmax(x, dim=1), target) + def train_with_eager(args, model, optimizer, device, train_loader, epoch): for batch_idx, (data, target) in enumerate(train_loader): data_cpu = data.reshape(data.shape[0], -1) @@ -47,51 +48,69 @@ def train_with_eager(args, model, optimizer, device, train_loader, epoch): optimizer.step() optimizer.zero_grad() - # Since the output corresponds to [loss_desc, probability_desc], the first value is taken as loss. if batch_idx % args.log_interval == 0: - print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( - epoch, batch_idx * len(data_cpu), len(train_loader.dataset), - 100. * batch_idx / len(train_loader), loss)) + print( + "Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format( + epoch, + batch_idx * len(data_cpu), + len(train_loader.dataset), + 100.0 * batch_idx / len(train_loader), + loss, + ) + ) -def main(): -#Training settings - parser = argparse.ArgumentParser(description='PyTorch MNIST Example') - parser.add_argument('--batch-size', type=int, default=64, metavar='N', - help='input batch size for training (default: 64)') - parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N', - help='input batch size for testing (default: 1000)') - parser.add_argument('--epochs', type=int, default=10, metavar='N', - help='number of epochs to train (default: 10)') - parser.add_argument('--lr', type=float, default=0.01, metavar='LR', - help='learning rate (default: 0.01)') - parser.add_argument('--no-cuda', action='store_true', default=False, - help='disables CUDA training') - parser.add_argument('--seed', type=int, default=1, metavar='S', - help='random seed (default: 1)') - parser.add_argument('--log-interval', type=int, default=10, metavar='N', - help='how many batches to wait before logging training status') +def main(): + # Training settings + parser = argparse.ArgumentParser(description="PyTorch MNIST Example") + parser.add_argument( + "--batch-size", type=int, default=64, metavar="N", help="input batch size for training (default: 64)" + ) + parser.add_argument( + "--test-batch-size", type=int, default=1000, metavar="N", help="input batch size for testing (default: 1000)" + ) + parser.add_argument("--epochs", type=int, default=10, metavar="N", help="number of epochs to train (default: 10)") + parser.add_argument("--lr", type=float, default=0.01, metavar="LR", help="learning rate (default: 0.01)") + parser.add_argument("--no-cuda", action="store_true", default=False, help="disables CUDA training") + parser.add_argument("--seed", type=int, default=1, metavar="S", help="random seed (default: 1)") + parser.add_argument( + "--log-interval", + type=int, + default=10, + metavar="N", + help="how many batches to wait before logging training status", + ) args = parser.parse_args() use_cuda = not args.no_cuda and torch.cuda.is_available() torch.manual_seed(args.seed) - kwargs = {'num_workers': 0, 'pin_memory': True} + kwargs = {"num_workers": 0, "pin_memory": True} train_loader = torch.utils.data.DataLoader( - datasets.MNIST(dataset_root_dir, train=True, download=True, - transform=transforms.Compose([ - transforms.ToTensor(), - transforms.Normalize((0.1307,), (0.3081,)) - ])), - batch_size=args.batch_size, shuffle=True, **kwargs) + datasets.MNIST( + dataset_root_dir, + train=True, + download=True, + transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]), + ), + batch_size=args.batch_size, + shuffle=True, + **kwargs + ) test_loader = torch.utils.data.DataLoader( - datasets.MNIST(dataset_root_dir, train=False, transform=transforms.Compose([ - transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])), - batch_size=args.test_batch_size, shuffle=True, **kwargs) - - device = torch.device('ort') + datasets.MNIST( + dataset_root_dir, + train=False, + transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]), + ), + batch_size=args.test_batch_size, + shuffle=True, + **kwargs + ) + + device = torch.device("ort") input_size = 784 hidden_size = 500 num_classes = 10 @@ -99,11 +118,11 @@ def main(): model.to(device) optimizer = optim.SGD(model.parameters(), lr=0.01) - print('\nStart Training.') + print("\nStart Training.") for epoch in range(1, args.epochs + 1): train_with_eager(args, model, optimizer, device, train_loader, epoch) -if __name__ == '__main__': - main() \ No newline at end of file +if __name__ == "__main__": + main() diff --git a/orttraining/orttraining/eager/test_models/scratchpad.py b/orttraining/orttraining/eager/test_models/scratchpad.py index f92190879db32..049aa859c842c 100644 --- a/orttraining/orttraining/eager/test_models/scratchpad.py +++ b/orttraining/orttraining/eager/test_models/scratchpad.py @@ -18,11 +18,7 @@ fours = twos * twos print(fours.cpu()) -fenced_ten = torch.tensor( - [[-1, -1, -1], - [-1, 10, -1], - [-1, -1, -1]], - device = device, dtype=torch.float) +fenced_ten = torch.tensor([[-1, -1, -1], [-1, 10, -1], [-1, -1, -1]], device=device, dtype=torch.float) print(fenced_ten.numel()) print(fenced_ten.size()) @@ -32,13 +28,13 @@ a = torch.ones(3, 3).to(device) b = torch.ones(3, 3) c = a + b -d = torch.sin (c) -e = torch.tan (c) +d = torch.sin(c) +e = torch.tan(c) torch.sin_(c) -print ("sin-in-place:") +print("sin-in-place:") print(c.cpu()) -print ("sin explicit:") -print (d.cpu ()) +print("sin explicit:") +print(d.cpu()) a = torch.tensor([[10, 10]], dtype=torch.float).to(device) b = torch.tensor([[3.3, 3.3]]).to(device) @@ -47,7 +43,7 @@ print(c.cpu()) a = torch.tensor([[5, 3, -5]], dtype=torch.float).to(device) -b = torch.hardshrink(a, 3) #should be [5, 0, -5] -c = torch.nn.functional.softshrink(a, 3) #should be [2, 0, -2] +b = torch.hardshrink(a, 3) # should be [5, 0, -5] +c = torch.nn.functional.softshrink(a, 3) # should be [2, 0, -2] print(b.cpu()) -print(c.cpu()) \ No newline at end of file +print(c.cpu()) diff --git a/orttraining/orttraining/python/checkpointing_utils.py b/orttraining/orttraining/python/checkpointing_utils.py index 95367db2aecd7..359f6a8c53552 100644 --- a/orttraining/orttraining/python/checkpointing_utils.py +++ b/orttraining/orttraining/python/checkpointing_utils.py @@ -2,21 +2,25 @@ import torch -def list_checkpoint_files(checkpoint_dir, checkpoint_prefix, extension='.ort.pt'): +def list_checkpoint_files(checkpoint_dir, checkpoint_prefix, extension=".ort.pt"): ckpt_file_names = [f for f in os.listdir(checkpoint_dir) if f.startswith(checkpoint_prefix)] ckpt_file_names = [f for f in ckpt_file_names if f.endswith(extension)] ckpt_file_names = [os.path.join(checkpoint_dir, f) for f in ckpt_file_names] - assert len(ckpt_file_names) > 0, "No checkpoint files found with prefix \"{}\" in directory {}.".format(checkpoint_prefix, checkpoint_dir) + assert len(ckpt_file_names) > 0, 'No checkpoint files found with prefix "{}" in directory {}.'.format( + checkpoint_prefix, checkpoint_dir + ) return ckpt_file_names def get_checkpoint_name(prefix, is_partitioned, world_rank=None, world_size=None): - SINGLE_CHECKPOINT_FILENAME = '{prefix}.ort.pt' - MULTIPLE_CHECKPOINT_FILENAME = '{prefix}.ZeRO.{world_rank}.{world_size}.ort.pt' + SINGLE_CHECKPOINT_FILENAME = "{prefix}.ort.pt" + MULTIPLE_CHECKPOINT_FILENAME = "{prefix}.ZeRO.{world_rank}.{world_size}.ort.pt" if is_partitioned: - filename = MULTIPLE_CHECKPOINT_FILENAME.format(prefix=prefix, world_rank=world_rank, world_size=(world_size-1)) + filename = MULTIPLE_CHECKPOINT_FILENAME.format( + prefix=prefix, world_rank=world_rank, world_size=(world_size - 1) + ) else: filename = SINGLE_CHECKPOINT_FILENAME.format(prefix=prefix) @@ -24,16 +28,16 @@ def get_checkpoint_name(prefix, is_partitioned, world_rank=None, world_size=None def _split_state_dict(state_dict): - optimizer_keys = ['Moment_1_', 'Moment_2_', 'Update_Count_', 'Step'] - split_sd = {'optimizer': {}, 'fp32_param': {}, 'fp16_param': {}} + optimizer_keys = ["Moment_1_", "Moment_2_", "Update_Count_", "Step"] + split_sd = {"optimizer": {}, "fp32_param": {}, "fp16_param": {}} for k, v in state_dict.items(): - mode = 'fp32_param' + mode = "fp32_param" for optim_key in optimizer_keys: if k.startswith(optim_key): - mode = 'optimizer' + mode = "optimizer" break - if k.endswith('_fp16'): - mode = 'fp16_param' + if k.endswith("_fp16"): + mode = "fp16_param" split_sd[mode][k] = v return split_sd @@ -44,30 +48,30 @@ def __init__(self, checkpoint_files, clean_state_dict=None): assert len(checkpoint_files) > 0, "No checkpoint files passed" self.checkpoint_files = checkpoint_files self.clean_state_dict = clean_state_dict - self.world_size = int(self.checkpoint_files[0].split('ZeRO')[1].split('.')[2]) + 1 + self.world_size = int(self.checkpoint_files[0].split("ZeRO")[1].split(".")[2]) + 1 assert len(self.checkpoint_files) == self.world_size, f"Could not find {self.world_size} files" self.weight_shape_map = dict() self.sharded_params = set() def _split_name(self, name): - name_split = name.split('_view_') + name_split = name.split("_view_") view_num = None - if(len(name_split) > 1): + if len(name_split) > 1: view_num = int(name_split[1]) - optimizer_key = '' - mp_suffix = '' - if name_split[0].startswith('Moment_1'): - optimizer_key = 'Moment_1_' - elif name_split[0].startswith('Moment_2'): - optimizer_key = 'Moment_2_' - elif name_split[0].startswith('Update_Count'): - optimizer_key = 'Update_Count_' - elif name_split[0].endswith('_fp16'): - mp_suffix = '_fp16' + optimizer_key = "" + mp_suffix = "" + if name_split[0].startswith("Moment_1"): + optimizer_key = "Moment_1_" + elif name_split[0].startswith("Moment_2"): + optimizer_key = "Moment_2_" + elif name_split[0].startswith("Update_Count"): + optimizer_key = "Update_Count_" + elif name_split[0].endswith("_fp16"): + mp_suffix = "_fp16" param_name = name_split[0] - if optimizer_key != '': + if optimizer_key != "": param_name = param_name.split(optimizer_key)[1] - param_name = param_name.split('_fp16')[0] + param_name = param_name.split("_fp16")[0] return param_name, optimizer_key, view_num, mp_suffix def _update_weight_statistics(self, name, value): @@ -87,7 +91,7 @@ def _aggregate(self, param_dict): # parameter is sharded param_name = optimizer_key + weight_name + mp_suffix - if param_name in self.aggregate_state_dict and optimizer_key not in ['Update_Count_']: + if param_name in self.aggregate_state_dict and optimizer_key not in ["Update_Count_"]: self.sharded_params.add(param_name) # Found a previous shard of the param, concatenate shards ordered by ranks self.aggregate_state_dict[param_name] = torch.cat((self.aggregate_state_dict[param_name], v)) @@ -101,22 +105,22 @@ def _aggregate(self, param_dict): self._update_weight_statistics(weight_name, v) def aggregate_checkpoints(self): - checkpoint_prefix = self.checkpoint_files[0].split('.ZeRO')[0] + checkpoint_prefix = self.checkpoint_files[0].split(".ZeRO")[0] self.aggregate_state_dict = dict() for i in range(self.world_size): checkpoint_name = get_checkpoint_name(checkpoint_prefix, True, i, self.world_size) rank_state_dict = torch.load(checkpoint_name, map_location=torch.device("cpu")) - if 'model' in rank_state_dict: - rank_state_dict = rank_state_dict['model'] + if "model" in rank_state_dict: + rank_state_dict = rank_state_dict["model"] if self.clean_state_dict: rank_state_dict = self.clean_state_dict(rank_state_dict) rank_state_dict = _split_state_dict(rank_state_dict) - self._aggregate(rank_state_dict['fp16_param']) - self._aggregate(rank_state_dict['fp32_param']) - self._aggregate(rank_state_dict['optimizer']) + self._aggregate(rank_state_dict["fp16_param"]) + self._aggregate(rank_state_dict["fp32_param"]) + self._aggregate(rank_state_dict["optimizer"]) for k in self.sharded_params: self._reshape_tensor(k) diff --git a/orttraining/orttraining/python/deprecated/training_session.py b/orttraining/orttraining/python/deprecated/training_session.py index 51eda9b283b64..b6a63dbee35d2 100644 --- a/orttraining/orttraining/python/deprecated/training_session.py +++ b/orttraining/orttraining/python/deprecated/training_session.py @@ -1,14 +1,18 @@ -#------------------------------------------------------------------------- +# ------------------------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. -#-------------------------------------------------------------------------- +# -------------------------------------------------------------------------- import sys import os from onnxruntime.capi import _pybind_state as C -from onnxruntime.capi.onnxruntime_inference_collection import (Session, InferenceSession, IOBinding, - check_and_normalize_provider_args) +from onnxruntime.capi.onnxruntime_inference_collection import ( + Session, + InferenceSession, + IOBinding, + check_and_normalize_provider_args, +) class TrainingSession(InferenceSession): @@ -25,8 +29,9 @@ def __init__(self, path_or_bytes, parameters, sess_options=None, providers=None, if providers is None: providers = C.get_available_providers() - providers, provider_options = check_and_normalize_provider_args(providers, provider_options, - C.get_available_providers()) + providers, provider_options = check_and_normalize_provider_args( + providers, provider_options, C.get_available_providers() + ) if isinstance(path_or_bytes, str): config_result = self._sess.load_model(path_or_bytes, parameters, providers, provider_options) @@ -49,7 +54,7 @@ def get_state(self): def get_model_state(self, include_mixed_precision_weights=False): return self._sess.get_model_state(include_mixed_precision_weights) - + def get_optimizer_state(self): return self._sess.get_optimizer_state() diff --git a/orttraining/orttraining/python/ort_trainer.py b/orttraining/orttraining/python/ort_trainer.py index 6090099d5d739..80c04fc0ed751 100644 --- a/orttraining/orttraining/python/ort_trainer.py +++ b/orttraining/orttraining/python/ort_trainer.py @@ -20,20 +20,24 @@ DEFAULT_OPSET_VERSION = 14 -class IODescription(): + +class IODescription: def __init__(self, name, shape, dtype=None, num_classes=None): self.name_ = name self.shape_ = shape self.dtype_ = dtype self.num_classes_ = num_classes -class ModelDescription(): + +class ModelDescription: def __init__(self, inputs, outputs): self.inputs_ = inputs self.outputs_ = outputs + def resolve_symbolic_dimensions(inputs, input_descs, output_descs): import copy + output_descs_copy = copy.deepcopy(output_descs) resolved_dims = {} for input, input_desc in zip(inputs, input_descs): @@ -60,12 +64,14 @@ def generate_sample(desc, device=None): else: return torch.randn(size, dtype=desc.dtype_).to(device) + def get_device_index(device): if type(device) == str: # could be 'cuda:0', 'cuda:1', or 'cpu'. with cpu, set index=0 device = torch.device(device) return 0 if device.index is None else device.index + def input_get_device_index(input): if isinstance(input, (list, tuple)): device_index = get_device_index(input[0].device) @@ -74,39 +80,61 @@ def input_get_device_index(input): return device_index + def get_all_gradients_finite_arg_name(session): - all_fp16_or_fp32_gradients_finite_node_args = [x for x in session._outputs_meta if 'all_gradients_finite' in x.name] + all_fp16_or_fp32_gradients_finite_node_args = [x for x in session._outputs_meta if "all_gradients_finite" in x.name] if len(all_fp16_or_fp32_gradients_finite_node_args) < 1: - raise RuntimeError("Failed to find a group NodeArg with name that matches 'all_gradients_finite'\ - from the training session.") + raise RuntimeError( + "Failed to find a group NodeArg with name that matches 'all_gradients_finite'\ + from the training session." + ) return all_fp16_or_fp32_gradients_finite_node_args[0].name + def get_group_accumulated_gradients_output_node_arg_name(session): # TODO: get the constant string via pybind. # optimizer_graph_builder BuildGroupNode with fixed string: 'Group_Accumulated_Gradients' - accumulated_gradients_output_node_args = [x for x in session._outputs_meta if 'Group_Accumulated_Gradients' in x.name] + accumulated_gradients_output_node_args = [ + x for x in session._outputs_meta if "Group_Accumulated_Gradients" in x.name + ] if len(accumulated_gradients_output_node_args) != 1: - raise RuntimeError("Failed to find a group NodeArg with name that matches 'Group_Accumulated_Gradients'\ - from the training session.") + raise RuntimeError( + "Failed to find a group NodeArg with name that matches 'Group_Accumulated_Gradients'\ + from the training session." + ) return accumulated_gradients_output_node_args[0].name + def ort_training_session_run_helper(session, iobinding, inputs, input_descs, output_descs, device, run_options=None): for input, input_desc in zip(inputs, input_descs): device_index = input_get_device_index(input) - iobinding.bind_input(input_desc.name_, input.device.type, device_index, dtype_torch_to_numpy(input.dtype), - list(input.size()), input.data_ptr()) + iobinding.bind_input( + input_desc.name_, + input.device.type, + device_index, + dtype_torch_to_numpy(input.dtype), + list(input.size()), + input.data_ptr(), + ) output_descs_resolved = resolve_symbolic_dimensions(inputs, input_descs, output_descs) torch_outputs = {} for output_desc in output_descs_resolved: - torch_tensor = torch.zeros(output_desc.shape_, device=device, - dtype=output_desc.eval_dtype_ if hasattr(output_desc, 'eval_dtype_') - else output_desc.dtype_) - iobinding.bind_output(output_desc.name_, torch_tensor.device.type, get_device_index(device), - dtype_torch_to_numpy(torch_tensor.dtype), - list(torch_tensor.size()), torch_tensor.data_ptr()) + torch_tensor = torch.zeros( + output_desc.shape_, + device=device, + dtype=output_desc.eval_dtype_ if hasattr(output_desc, "eval_dtype_") else output_desc.dtype_, + ) + iobinding.bind_output( + output_desc.name_, + torch_tensor.device.type, + get_device_index(device), + dtype_torch_to_numpy(torch_tensor.dtype), + list(torch_tensor.size()), + torch_tensor.data_ptr(), + ) torch_outputs[output_desc.name_] = torch_tensor session.run_with_iobinding(iobinding, run_options) @@ -161,10 +189,19 @@ def FuseSofmaxNLLToSoftmaxCE(onnx_model): probability_output_name = softmax_node.output[0] node = onnx_model.graph.node.add() - inputs = [softmax_node.input[0], label_input_name, weight_input_name] if weight_input_name else [softmax_node.input[0], label_input_name] - node.CopyFrom(onnx.helper.make_node("SparseSoftmaxCrossEntropy", inputs, - [nll_loss_node.output[0], probability_output_name], - "nll_loss_node_" + str(nll_count))) + inputs = ( + [softmax_node.input[0], label_input_name, weight_input_name] + if weight_input_name + else [softmax_node.input[0], label_input_name] + ) + node.CopyFrom( + onnx.helper.make_node( + "SparseSoftmaxCrossEntropy", + inputs, + [nll_loss_node.output[0], probability_output_name], + "nll_loss_node_" + str(nll_count), + ) + ) return onnx_model @@ -201,6 +238,7 @@ def dtype_torch_to_numpy(torch_dtype): else: raise Exception("Torch type to numpy type mapping unavailable for: " + str(torch_dtype)) + class model_loss_cls(torch.nn.Module): def __init__(self, model, loss_fn): super(model_loss_cls, self).__init__() @@ -213,6 +251,7 @@ def forward(self, *inputs): preds = self.model_(*input) return self.loss_fn_(preds, label), preds + class WrapModel(torch.nn.Module): def __init__(self, model, loss_fn, input_names): super(WrapModel, self).__init__() @@ -222,6 +261,7 @@ def __init__(self, model, loss_fn, input_names): def forward(self, *inputs): import inspect + # *inputs is given by torch trace. It is in the order of input_names. # model_ takes input in a order (which can be obtained via inspect.signature(model.forward)) different than input_names. sig = inspect.signature(self.model_.forward) @@ -240,8 +280,10 @@ def forward(self, *inputs): preds = model_out return self.loss_fn_(preds, label), preds + def wrap_for_input_match(model, loss_fn, input_names): import inspect + sig = inspect.signature(model.forward) ordered_list_keys = list(sig.parameters.keys()) if loss_fn: @@ -281,6 +323,7 @@ def wrap_for_input_match(model, loss_fn, input_names): return model + def convert_model_loss_fn_to_onnx(model, loss_fn, model_desc, device, inputs, opset_version=DEFAULT_OPSET_VERSION): # example: {input0:{0:'batch'}, input1:{0:'batch'}} dynamic_axes = {} @@ -320,6 +363,7 @@ def convert_model_loss_fn_to_onnx(model, loss_fn, model_desc, device, inputs, op model.eval() with torch.no_grad(): import copy + # Deepcopy inputs, since input values may change after model run. sample_inputs_copy = copy.deepcopy(sample_inputs) try: @@ -327,8 +371,10 @@ def convert_model_loss_fn_to_onnx(model, loss_fn, model_desc, device, inputs, op model_copy = copy.deepcopy(model) except Exception: model_copy = model - warnings.warn("This model cannot be deep copied (or pickled), which is a required step for stateful models to be properly exported to ONNX." - " Compute will continue, but unexpected results may occur!") + warnings.warn( + "This model cannot be deep copied (or pickled), which is a required step for stateful models to be properly exported to ONNX." + " Compute will continue, but unexpected results may occur!" + ) sample_outputs = model_copy(*sample_inputs_copy) if isinstance(sample_outputs, torch.Tensor): @@ -341,31 +387,38 @@ def convert_model_loss_fn_to_onnx(model, loss_fn, model_desc, device, inputs, op # Other export options to use(this is for backward compatibility). other_export_options = {} - other_export_options['training'] = True + other_export_options["training"] = True # This option was added after 1.4 release. - if (LooseVersion(torch.__version__) > LooseVersion('1.4.0') and - LooseVersion(torch.__version__) < LooseVersion('1.10.0')): - other_export_options['enable_onnx_checker'] = False + if LooseVersion(torch.__version__) > LooseVersion("1.4.0") and LooseVersion(torch.__version__) < LooseVersion( + "1.10.0" + ): + other_export_options["enable_onnx_checker"] = False # This option was added after 1.6 release. - if LooseVersion(torch.__version__) >= LooseVersion('1.6.0'): - other_export_options['training'] = torch.onnx.TrainingMode.TRAINING + if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"): + other_export_options["training"] = torch.onnx.TrainingMode.TRAINING # Deepcopy inputs, since input values may change after model run. import copy + sample_inputs_copy = copy.deepcopy(sample_inputs) # Enable contrib ops export from PyTorch from onnxruntime.tools import pytorch_export_contrib_ops + pytorch_export_contrib_ops.register() - torch.onnx._export(model, tuple(sample_inputs_copy), f, - input_names=input_names, - output_names=output_names, - opset_version=opset_version, - dynamic_axes=dynamic_axes, - do_constant_folding=False, - **other_export_options) + torch.onnx._export( + model, + tuple(sample_inputs_copy), + f, + input_names=input_names, + output_names=output_names, + opset_version=opset_version, + dynamic_axes=dynamic_axes, + do_constant_folding=False, + **other_export_options, + ) onnx_model = onnx.load_model_from_string(f.getvalue()) @@ -373,8 +426,8 @@ def convert_model_loss_fn_to_onnx(model, loss_fn, model_desc, device, inputs, op if isinstance(model, WrapModel) or isinstance(model, model_loss_cls): replace_name_dict = {} for n in onnx_model.graph.initializer: - if n.name.startswith('model_.'): - replace_name_dict[n.name] = n.name[len('model_.'):] + if n.name.startswith("model_."): + replace_name_dict[n.name] = n.name[len("model_.") :] n.name = replace_name_dict[n.name] for n in onnx_model.graph.node: for i, name in enumerate(n.input): @@ -383,17 +436,28 @@ def convert_model_loss_fn_to_onnx(model, loss_fn, model_desc, device, inputs, op return onnx_model -def create_ort_training_session_with_optimizer(model, device, training_optimizer_name, lr_params_feed_name, - map_optimizer_attributes, world_rank=-1, world_size=1, - gradient_accumulation_steps=1, bind_parameters=False, - use_mixed_precision=False, allreduce_post_accumulation=False, - deepspeed_zero_stage=0, - enable_grad_norm_clip=True, - frozen_weights=[], opset_version=DEFAULT_OPSET_VERSION, - use_deterministic_compute=False, - use_memory_efficient_gradient=False, - enable_adasum=False, - optimized_model_filepath=""): + +def create_ort_training_session_with_optimizer( + model, + device, + training_optimizer_name, + lr_params_feed_name, + map_optimizer_attributes, + world_rank=-1, + world_size=1, + gradient_accumulation_steps=1, + bind_parameters=False, + use_mixed_precision=False, + allreduce_post_accumulation=False, + deepspeed_zero_stage=0, + enable_grad_norm_clip=True, + frozen_weights=[], + opset_version=DEFAULT_OPSET_VERSION, + use_deterministic_compute=False, + use_memory_efficient_gradient=False, + enable_adasum=False, + optimized_model_filepath="", +): output_name = model.graph.output[0].name ort_parameters = ort.TrainingParameters() ort_parameters.loss_output_name = output_name @@ -446,7 +510,8 @@ def create_ort_training_session_with_optimizer(model, device, training_optimizer torch_tensor = torch.nn.Parameter(torch.as_tensor(numpy_helper.to_array(initializer), device=device)) delete_input_with_name(model.graph.input, initializer.name) model.graph.input.extend( - [helper.make_tensor_value_info(initializer.name, initializer.data_type, initializer.dims)]) + [helper.make_tensor_value_info(initializer.name, initializer.data_type, initializer.dims)] + ) torch_params[initializer.name] = torch_tensor del model.graph.initializer[:] @@ -469,24 +534,39 @@ def create_ort_training_session_with_optimizer(model, device, training_optimizer for param in torch_params.keys(): torch_tensor = torch_params[param] - train_io_binding.bind_input(param, torch_tensor.device.type, get_device_index(torch_tensor.device), - dtype_torch_to_numpy(torch_params[param].dtype), list(torch_tensor.size()), - torch_tensor.data_ptr()) - eval_io_binding.bind_input(param, torch_tensor.device.type, get_device_index(torch_tensor.device), - dtype_torch_to_numpy(torch_params[param].dtype), list(torch_tensor.size()), - torch_tensor.data_ptr()) + train_io_binding.bind_input( + param, + torch_tensor.device.type, + get_device_index(torch_tensor.device), + dtype_torch_to_numpy(torch_params[param].dtype), + list(torch_tensor.size()), + torch_tensor.data_ptr(), + ) + eval_io_binding.bind_input( + param, + torch_tensor.device.type, + get_device_index(torch_tensor.device), + dtype_torch_to_numpy(torch_params[param].dtype), + list(torch_tensor.size()), + torch_tensor.data_ptr(), + ) return session, train_io_binding, eval_io_binding, output_name, torch_params, output_types -def save_checkpoint(model, checkpoint_dir, checkpoint_prefix="ORT_checkpoint", checkpoint_state_dict=None, include_optimizer_state=True): - if checkpoint_state_dict==None: - checkpoint_state_dict={'model': model.state_dict(include_optimizer_state)} + +def save_checkpoint( + model, checkpoint_dir, checkpoint_prefix="ORT_checkpoint", checkpoint_state_dict=None, include_optimizer_state=True +): + if checkpoint_state_dict == None: + checkpoint_state_dict = {"model": model.state_dict(include_optimizer_state)} else: - checkpoint_state_dict.update({'model': model.state_dict(include_optimizer_state)}) + checkpoint_state_dict.update({"model": model.state_dict(include_optimizer_state)}) assert os.path.exists(checkpoint_dir), "ERROR: Checkpoint directory doesn't exist: {}".format(checkpoint_dir) - checkpoint_name = get_checkpoint_name(checkpoint_prefix, model.deepspeed_zero_stage_, model.world_rank, model.world_size) + checkpoint_name = get_checkpoint_name( + checkpoint_prefix, model.deepspeed_zero_stage_, model.world_rank, model.world_size + ) checkpoint_file = os.path.join(checkpoint_dir, checkpoint_name) if os.path.exists(checkpoint_file): @@ -494,25 +574,29 @@ def save_checkpoint(model, checkpoint_dir, checkpoint_prefix="ORT_checkpoint", c torch.save(checkpoint_state_dict, checkpoint_file) + def _load_single_checkpoint(model, checkpoint_dir, checkpoint_prefix, is_partitioned, strict): checkpoint_name = get_checkpoint_name(checkpoint_prefix, is_partitioned, model.world_rank, model.world_size) checkpoint_file = os.path.join(checkpoint_dir, checkpoint_name) if is_partitioned: - assert_msg = ("Couldn't find checkpoint file {}." + - "Optimizer partitioning is enabled using ZeRO. Please make sure that the "+ - "checkpoint file exists for rank {} of {}.").format(checkpoint_file,model.world_rank, model.world_size) + assert_msg = ( + "Couldn't find checkpoint file {}." + + "Optimizer partitioning is enabled using ZeRO. Please make sure that the " + + "checkpoint file exists for rank {} of {}." + ).format(checkpoint_file, model.world_rank, model.world_size) else: assert_msg = "Couldn't find checkpoint file {}.".format(checkpoint_file) assert os.path.exists(checkpoint_file), assert_msg - checkpoint_state = torch.load(checkpoint_file, map_location='cpu') + checkpoint_state = torch.load(checkpoint_file, map_location="cpu") - model.load_state_dict(checkpoint_state['model'], strict=strict) - del(checkpoint_state['model']) + model.load_state_dict(checkpoint_state["model"], strict=strict) + del checkpoint_state["model"] return checkpoint_state + def _load_multi_checkpoint(model, checkpoint_dir, checkpoint_prefix, strict): checkpoint_files = list_checkpoint_files(checkpoint_dir, checkpoint_prefix) @@ -523,35 +607,59 @@ def _load_multi_checkpoint(model, checkpoint_dir, checkpoint_prefix, strict): # aggregate other keys in the state_dict. # Values will be overwritten for matching keys among workers - all_checkpoint_states=dict() + all_checkpoint_states = dict() for checkpoint_file in checkpoint_files: - checkpoint_state = torch.load(checkpoint_file, map_location='cpu') - del(checkpoint_state['model']) + checkpoint_state = torch.load(checkpoint_file, map_location="cpu") + del checkpoint_state["model"] all_checkpoint_states.update(checkpoint_state) return all_checkpoint_states + def load_checkpoint(model, checkpoint_dir, checkpoint_prefix="ORT_checkpoint", strict=False): checkpoint_files = list_checkpoint_files(checkpoint_dir, checkpoint_prefix) is_partitioned = False if len(checkpoint_files) > 1: - warnings.warn(f"Found more than one file with prefix {checkpoint_prefix} in directory {checkpoint_dir}." + - "Attempting to load ZeRO checkpoint.") + warnings.warn( + f"Found more than one file with prefix {checkpoint_prefix} in directory {checkpoint_dir}." + + "Attempting to load ZeRO checkpoint." + ) is_partitioned = True if (not model.deepspeed_zero_stage_) and is_partitioned: return _load_multi_checkpoint(model, checkpoint_dir, checkpoint_prefix, strict) else: return _load_single_checkpoint(model, checkpoint_dir, checkpoint_prefix, is_partitioned, strict) -class ORTTrainer(): - def __init__(self, model, loss_fn, model_desc, training_optimizer_name, map_optimizer_attributes, - learning_rate_description, device, gradient_accumulation_steps=1, - world_rank=0, world_size=1, use_mixed_precision=False, allreduce_post_accumulation=False, - global_step=0, get_lr_this_step=None, loss_scaler=None, deepspeed_zero_stage=0, - enable_grad_norm_clip=True, frozen_weights=[], _opset_version=DEFAULT_OPSET_VERSION, - _enable_internal_postprocess=True, _extra_postprocess=None, _use_deterministic_compute=False, - use_memory_efficient_gradient=False, run_symbolic_shape_infer=False, enable_adasum=False, - optimized_model_filepath=""): +class ORTTrainer: + def __init__( + self, + model, + loss_fn, + model_desc, + training_optimizer_name, + map_optimizer_attributes, + learning_rate_description, + device, + gradient_accumulation_steps=1, + world_rank=0, + world_size=1, + use_mixed_precision=False, + allreduce_post_accumulation=False, + global_step=0, + get_lr_this_step=None, + loss_scaler=None, + deepspeed_zero_stage=0, + enable_grad_norm_clip=True, + frozen_weights=[], + _opset_version=DEFAULT_OPSET_VERSION, + _enable_internal_postprocess=True, + _extra_postprocess=None, + _use_deterministic_compute=False, + use_memory_efficient_gradient=False, + run_symbolic_shape_infer=False, + enable_adasum=False, + optimized_model_filepath="", + ): super(ORTTrainer, self).__init__() """ Initialize ORTTrainer. @@ -624,7 +732,9 @@ def __init__(self, model, loss_fn, model_desc, training_optimizer_name, map_opti optimized_model_filepath: path to output the optimized training graph. Defaults to "" (no output). """ - warnings.warn('DISCLAIMER: This is an early version of an experimental training API and it is subject to change. DO NOT create production applications with it') + warnings.warn( + "DISCLAIMER: This is an early version of an experimental training API and it is subject to change. DO NOT create production applications with it" + ) self.is_train = True self.torch_model_ = None @@ -690,7 +800,7 @@ def __init__(self, model, loss_fn, model_desc, training_optimizer_name, map_opti # use this special string to workaround a corner case that external loss_scale is passed into train_step as kwargs. # see prepare_input_and_fetches for more details. - self.loss_scale_input_name = 'default_loss_scale_input_name' + self.loss_scale_input_name = "default_loss_scale_input_name" self._init_session() @@ -701,36 +811,57 @@ def _init_session(self): self._verify_fully_optimized_model(self.onnx_model_) if self.run_symbolic_shape_infer: - self.onnx_model_ = SymbolicShapeInference.infer_shapes(self.onnx_model_, auto_merge=True, guess_output_rank=True) + self.onnx_model_ = SymbolicShapeInference.infer_shapes( + self.onnx_model_, auto_merge=True, guess_output_rank=True + ) # old ort session may already exists and occupies GPU memory when creating new session, this may cause OOM error. # for example, load_state_dict will be called before returing the function, and it calls _init_session again del self.session - self.session, self.train_io_binding, self.eval_io_binding, self.output_name, _, self.output_types = \ - create_ort_training_session_with_optimizer( - self.onnx_model_, self.device_, - self.training_optimizer_name_, self.learning_rate_description_.name_, self.map_optimizer_attributes_, - self.world_rank, self.world_size, - self.gradient_accumulation_steps, bind_parameters=False, - use_mixed_precision=self.use_mixed_precision, allreduce_post_accumulation=self.allreduce_post_accumulation_, - deepspeed_zero_stage=self.deepspeed_zero_stage_, - enable_grad_norm_clip=self.enable_grad_norm_clip_, - frozen_weights=self.frozen_weights_, opset_version=self.opset_version_, - use_deterministic_compute=self._use_deterministic_compute, - use_memory_efficient_gradient=self.use_memory_efficient_gradient, - enable_adasum=self.enable_adasum, - optimized_model_filepath=self.optimized_model_filepath) + ( + self.session, + self.train_io_binding, + self.eval_io_binding, + self.output_name, + _, + self.output_types, + ) = create_ort_training_session_with_optimizer( + self.onnx_model_, + self.device_, + self.training_optimizer_name_, + self.learning_rate_description_.name_, + self.map_optimizer_attributes_, + self.world_rank, + self.world_size, + self.gradient_accumulation_steps, + bind_parameters=False, + use_mixed_precision=self.use_mixed_precision, + allreduce_post_accumulation=self.allreduce_post_accumulation_, + deepspeed_zero_stage=self.deepspeed_zero_stage_, + enable_grad_norm_clip=self.enable_grad_norm_clip_, + frozen_weights=self.frozen_weights_, + opset_version=self.opset_version_, + use_deterministic_compute=self._use_deterministic_compute, + use_memory_efficient_gradient=self.use_memory_efficient_gradient, + enable_adasum=self.enable_adasum, + optimized_model_filepath=self.optimized_model_filepath, + ) self.loss_scale_input_name = self.session.loss_scale_input_name if self.use_mixed_precision: self.input_desc_with_lr_and_loss_scale = [ *self.input_desc_with_lr, - IODescription(self.loss_scale_input_name, [], torch.float32)] + IODescription(self.loss_scale_input_name, [], torch.float32), + ] # ORT backend has modified model output dtype from float32 to float16. for o_desc in self.model_desc_.outputs_: - if self.use_mixed_precision and o_desc.dtype_ == torch.float32 and not self.session.is_output_fp32_node(o_desc.name_): + if ( + self.use_mixed_precision + and o_desc.dtype_ == torch.float32 + and not self.session.is_output_fp32_node(o_desc.name_) + ): o_desc.eval_dtype_ = torch.float16 else: o_desc.eval_dtype_ = o_desc.dtype_ @@ -740,14 +871,16 @@ def _init_session(self): if self.gradient_accumulation_steps > 1: self.output_desc_with_group_accumulated_gradients = [ *self.model_desc_.outputs_, - IODescription(get_group_accumulated_gradients_output_node_arg_name(self.session), [1], torch.bool)] + IODescription(get_group_accumulated_gradients_output_node_arg_name(self.session), [1], torch.bool), + ] if self.use_mixed_precision: # when ready to use accumulated gradient with mixed precision, we need to fetch all_infinite to determine # if the gradient is usable. self.output_desc_with_all_fp_16_or_fp32_gradients_finite = [ *self.model_desc_.outputs_, - IODescription(get_all_gradients_finite_arg_name(self.session), [1], torch.bool)] + IODescription(get_all_gradients_finite_arg_name(self.session), [1], torch.bool), + ] if self.state_dict_: self.load_state_dict(self.state_dict_, self.strict_) @@ -764,7 +897,13 @@ def _init_onnx_model(self, inputs): torch_buffers = list(dict(self.torch_model_.named_buffers()).keys()) self.frozen_weights_ = self.frozen_weights_ + torch_buffers self.onnx_model_ = convert_model_loss_fn_to_onnx( - self.torch_model_, self.loss_fn_, self.model_desc_, torch.device('cpu'), inputs, opset_version=self.opset_version_) + self.torch_model_, + self.loss_fn_, + self.model_desc_, + torch.device("cpu"), + inputs, + opset_version=self.opset_version_, + ) if self._enable_internal_postprocess: postprocess.run_postprocess(self.onnx_model_) @@ -795,8 +934,10 @@ def _update_onnx_model_initializers(self, state_tensors): def state_dict(self, include_optimizer_state=True): if not self.session: - warnings.warn("ONNXRuntime training session is not initialized yet. " - "Please run train_step or eval_step at least once before calling state_dict().") + warnings.warn( + "ONNXRuntime training session is not initialized yet. " + "Please run train_step or eval_step at least once before calling state_dict()." + ) return {} # extract trained weights @@ -841,13 +982,15 @@ def load_state_dict(self, state_dict, strict=False): self._init_session() # load training state - session_state = {name:state_dict[name].numpy() for name in state_dict} + session_state = {name: state_dict[name].numpy() for name in state_dict} self.session.load_state(session_state, strict) def save_as_onnx(self, path): if not self.session: - warnings.warn("ONNXRuntime training session is not initialized yet. " - "Please run train_step or eval_step at least once before calling save_as_onnx().") + warnings.warn( + "ONNXRuntime training session is not initialized yet. " + "Please run train_step or eval_step at least once before calling save_as_onnx()." + ) return state_tensors = self.session.get_state() self._update_onnx_model_initializers(state_tensors) @@ -855,7 +998,9 @@ def save_as_onnx(self, path): with open(path, "wb") as f: f.write(self.onnx_model_.SerializeToString()) - def _prepare_input_and_fetches(self, input_desc_with_, internal_learning_rate, internal_loss_scale, *args, **kwargs): + def _prepare_input_and_fetches( + self, input_desc_with_, internal_learning_rate, internal_loss_scale, *args, **kwargs + ): fetches = None if type(args) == tuple and len(args) == 1 and type(args[0]) == list: input = tuple(args[0]) @@ -876,16 +1021,15 @@ def _prepare_input_and_fetches(self, input_desc_with_, internal_learning_rate, i # However, when first time train_step is called model.loss_scale_input_name is not set. # To workaround this problem, we use the special name 'default_loss_scale_input_name' to indicate # the loss_scale. - if 'default_loss_scale_input_name' in kwargs.keys(): - input = input + (kwargs['default_loss_scale_input_name'],) + if "default_loss_scale_input_name" in kwargs.keys(): + input = input + (kwargs["default_loss_scale_input_name"],) fetches = None - if 'fetches' in kwargs: - fetches = kwargs['fetches'] + if "fetches" in kwargs: + fetches = kwargs["fetches"] return input, fetches - def train_step(self, *args, **kwargs): """ inputs: model inputs, labels, learning rate, and, if in mixed_precision mode, loss_scale. @@ -907,7 +1051,6 @@ def train_step(self, *args, **kwargs): # localized arguments (*args) contains inputs to the ONNX model. # named arguments can contain both inputs, learning_rate and loss_scale, and the fetches - learning_rate, loss_scale = None, None if self.get_lr_this_step_ is not None: # $args, **kwargs contains inputs to the pytorch model @@ -917,18 +1060,19 @@ def train_step(self, *args, **kwargs): loss_scale = torch.tensor([self.loss_scaler_.loss_scale_]) if self.onnx_model_ is None: - sample_input, _ = self._prepare_input_and_fetches(self.model_desc_.inputs_, - None, None, *args, **kwargs) + sample_input, _ = self._prepare_input_and_fetches(self.model_desc_.inputs_, None, None, *args, **kwargs) self._init_onnx_model(sample_input) if self.use_mixed_precision: - input, fetches = self._prepare_input_and_fetches(self.input_desc_with_lr_and_loss_scale, - learning_rate, loss_scale, *args, **kwargs) + input, fetches = self._prepare_input_and_fetches( + self.input_desc_with_lr_and_loss_scale, learning_rate, loss_scale, *args, **kwargs + ) assert len(self.input_desc_with_lr_and_loss_scale) == len(input) input_descs = self.input_desc_with_lr_and_loss_scale else: - input, fetches = self._prepare_input_and_fetches(self.input_desc_with_lr, - learning_rate, loss_scale, *args, **kwargs) + input, fetches = self._prepare_input_and_fetches( + self.input_desc_with_lr, learning_rate, loss_scale, *args, **kwargs + ) assert len(self.input_desc_with_lr) == len(input) input_descs = self.input_desc_with_lr @@ -952,10 +1096,9 @@ def train_step(self, *args, **kwargs): if not isinstance(input, (list, tuple)): input = (input,) - session_run_results = ort_training_session_run_helper(self.session, self.train_io_binding, input, - input_descs, output_desc, - self.device_, - run_options) + session_run_results = ort_training_session_run_helper( + self.session, self.train_io_binding, input, input_descs, output_desc, self.device_, run_options + ) if has_if_all_finite: # After session run with all_fp32_gradients_finite, we need to clear the iobinding's output state. @@ -976,7 +1119,10 @@ def train_step(self, *args, **kwargs): results = [session_run_results[fetch] for fetch in fetches] elif has_if_all_finite and self.loss_scaler_ is None: # return descripted outputs plus the all_finite flag so that the training script can handle loss scaling. - results = [session_run_results[output_desc.name_] for output_desc in self.output_desc_with_all_fp_16_or_fp32_gradients_finite] + results = [ + session_run_results[output_desc.name_] + for output_desc in self.output_desc_with_all_fp_16_or_fp32_gradients_finite + ] else: results = [session_run_results[output_desc.name_] for output_desc in self.model_desc_.outputs_] return results[0] if len(results) == 1 else results @@ -997,16 +1143,17 @@ def eval_step(self, *args, **kwargs): """ # with model_loss_cls, the last input is label, first output is loss - input, fetches = self._prepare_input_and_fetches(self.model_desc_.inputs_, - None, None, *args, **kwargs) + input, fetches = self._prepare_input_and_fetches(self.model_desc_.inputs_, None, None, *args, **kwargs) if self.onnx_model_ is None: if self.torch_model_ is not None: self._init_onnx_model(input) else: - raise RuntimeError("Model is unintialized. Please ensure a valid ONNX model or PyTorch model is provided to this Trainer.") + raise RuntimeError( + "Model is unintialized. Please ensure a valid ONNX model or PyTorch model is provided to this Trainer." + ) - input_desc = self.model_desc_.inputs_[0:len(input)] + input_desc = self.model_desc_.inputs_[0 : len(input)] if fetches is None: output_desc = self.model_desc_.outputs_ else: @@ -1019,11 +1166,9 @@ def eval_step(self, *args, **kwargs): run_options.only_execute_path_to_fetches = True run_options.training_mode = False - session_run_results = ort_training_session_run_helper(self.session, self.eval_io_binding, input, - input_desc, - output_desc, - self.device_, - run_options) + session_run_results = ort_training_session_run_helper( + self.session, self.eval_io_binding, input, input_desc, output_desc, self.device_, run_options + ) if len(session_run_results) == 1: return session_run_results[list(session_run_results.keys())[0]] @@ -1031,24 +1176,35 @@ def eval_step(self, *args, **kwargs): return [session_run_results[output_desc.name_] for output_desc in output_desc] def _verify_fully_optimized_model(self, model): - assert(len(model.graph.output) > 0) + assert len(model.graph.output) > 0 # model's first output must be the loss tensor - if model.graph.output[0].type.tensor_type.elem_type != onnx.TensorProto().FLOAT and\ - model.graph.output[0].type.tensor_type.elem_type != onnx.TensorProto().FLOAT16 and\ - model.graph.output[0].type.tensor_type.elem_type != onnx.TensorProto().DOUBLE and\ - model.graph.output[0].type.tensor_type.elem_type != onnx.TensorProto().COMPLEX64 and\ - model.graph.output[0].type.tensor_type.elem_type != onnx.TensorProto().COMPLEX128 and\ - model.graph.output[0].type.tensor_type.elem_type != onnx.TensorProto().BFLOAT16: - raise RuntimeError("the first output of a model to run with fully optimized ORT backend must be float types.") + if ( + model.graph.output[0].type.tensor_type.elem_type != onnx.TensorProto().FLOAT + and model.graph.output[0].type.tensor_type.elem_type != onnx.TensorProto().FLOAT16 + and model.graph.output[0].type.tensor_type.elem_type != onnx.TensorProto().DOUBLE + and model.graph.output[0].type.tensor_type.elem_type != onnx.TensorProto().COMPLEX64 + and model.graph.output[0].type.tensor_type.elem_type != onnx.TensorProto().COMPLEX128 + and model.graph.output[0].type.tensor_type.elem_type != onnx.TensorProto().BFLOAT16 + ): + raise RuntimeError( + "the first output of a model to run with fully optimized ORT backend must be float types." + ) if len(model.graph.output[0].type.tensor_type.shape.dim) != 0: raise RuntimeError( - "the first output of a model to run with fully optimized ORT backend assumed to be loss and must be a scalar.") - -class LossScaler(): - def __init__(self, loss_scale_input_name, is_dynamic_scale, - loss_scale=float(1 << 16), - up_scale_window=2000, - min_loss_scale=1.0, max_loss_scale=float(1 << 24)): + "the first output of a model to run with fully optimized ORT backend assumed to be loss and must be a scalar." + ) + + +class LossScaler: + def __init__( + self, + loss_scale_input_name, + is_dynamic_scale, + loss_scale=float(1 << 16), + up_scale_window=2000, + min_loss_scale=1.0, + max_loss_scale=float(1 << 24), + ): super(LossScaler, self).__init__() self.loss_scale_input_name_ = loss_scale_input_name self.is_dynamic_scale_ = is_dynamic_scale diff --git a/orttraining/orttraining/python/pt_patch.py b/orttraining/orttraining/python/pt_patch.py index e77b86b5ccb68..b524a286c9de7 100644 --- a/orttraining/orttraining/python/pt_patch.py +++ b/orttraining/orttraining/python/pt_patch.py @@ -5,41 +5,45 @@ from torch.onnx.symbolic_helper import parse_args import torch.onnx.symbolic_helper as sym_help -@parse_args('v', 'v', 'v', 'v', 'i', 'none') -def nll_loss_10(g, self, target, weight=None, reduction='mean', ignore_index=-100): + +@parse_args("v", "v", "v", "v", "i", "none") +def nll_loss_10(g, self, target, weight=None, reduction="mean", ignore_index=-100): if not weight and not ignore_index: return g.op("nll_loss", self, target) elif ignore_index: ignore_index_ = g.op("Constant", value_t=torch.tensor(ignore_index, dtype=torch.int64)) eq_ = g.op("Equal", target, ignore_index_) not_eq_ = g.op("Not", eq_) - weight_ = g.op("Cast", not_eq_, to_i=1) # FLOAT = 1; // float - not_eq_int64_ = g.op("Cast", not_eq_, to_i=7) #INT64 = 7; // int64_t + weight_ = g.op("Cast", not_eq_, to_i=1) # FLOAT = 1; // float + not_eq_int64_ = g.op("Cast", not_eq_, to_i=7) # INT64 = 7; // int64_t target_ = g.op("Mul", target, not_eq_int64_) # if weight: # weight_ = g.op("Mul", weight_, weight) return g.op("nll_loss", self, target_, weight_) + symbolic_opset10.nll_loss = nll_loss_10 + def nll_loss_12(g, self, target, weight, reduction, ignore_index): # none reduction : onnx::Constant[value={0}] # mean reduction : onnx::Constant[value={1}] # sum reduction : onnx::Constant[value={2}] - reduction = sym_help._maybe_get_const(reduction, 'i') - reduction_vals = ['none', 'mean', 'sum'] + reduction = sym_help._maybe_get_const(reduction, "i") + reduction_vals = ["none", "mean", "sum"] reduction = reduction_vals[reduction] # in onnx NegativeLogLikelihoodLoss specification, ignore_index is optional without default value. # therefore we need to set ignore_index attribute even if it is not specified (e.g. ignore_index=-100). - ignore_index = sym_help._maybe_get_const(ignore_index, 'i') + ignore_index = sym_help._maybe_get_const(ignore_index, "i") if weight.node().mustBeNone(): nllloss = g.op("NegativeLogLikelihoodLoss", self, target, reduction_s=reduction, ignore_index_i=ignore_index) else: nllloss = g.op( - "NegativeLogLikelihoodLoss", self, - target, weight, reduction_s=reduction, ignore_index_i=ignore_index) + "NegativeLogLikelihoodLoss", self, target, weight, reduction_s=reduction, ignore_index_i=ignore_index + ) return nllloss + symbolic_opset12.nll_loss = nll_loss_12 diff --git a/orttraining/orttraining/python/training/_checkpoint_storage.py b/orttraining/orttraining/python/training/_checkpoint_storage.py index 75b8309bc7767..461daa57134c0 100644 --- a/orttraining/orttraining/python/training/_checkpoint_storage.py +++ b/orttraining/orttraining/python/training/_checkpoint_storage.py @@ -7,6 +7,7 @@ from collections.abc import Mapping import pickle + def _dfs_save(group, save_obj): """Recursively go over each level in the save_obj dictionary and save values to a hdf5 group""" @@ -17,6 +18,7 @@ def _dfs_save(group, save_obj): else: group[key] = value + def save(save_obj: dict, path): """Persists the input dictionary to a file specified by path. @@ -35,9 +37,10 @@ def save(save_obj: dict, path): if not isinstance(save_obj, Mapping): raise ValueError("Object to be saved must be a dictionary") - with h5py.File(path, 'w-') as f: + with h5py.File(path, "w-") as f: _dfs_save(f, save_obj) + def _dfs_load(group, load_obj): """Recursively go over each level in the hdf5 group and load the values into the given dictionary""" @@ -48,6 +51,7 @@ def _dfs_load(group, load_obj): else: load_obj[key] = group[key][()] + def load(path, key=None): """Loads the data stored in the binary file specified at the given path into a dictionary and returns it. @@ -72,7 +76,7 @@ def load(path, key=None): raise ValueError(f"{path} is not an hdf5 file or a python file-like object.") load_obj = {} - with h5py.File(path, 'r') as f: + with h5py.File(path, "r") as f: if key: f = f[key] if isinstance(f, h5py.Dataset): @@ -82,11 +86,13 @@ def load(path, key=None): return load_obj + def to_serialized_hex(user_dict): """Serialize the user_dict and convert the serialized bytes to a hex string and return""" return pickle.dumps(user_dict).hex() + def from_serialized_hex(serialized_hex): """Convert serialized_hex to bytes and deserialize it and return""" diff --git a/orttraining/orttraining/python/training/_utils.py b/orttraining/orttraining/python/training/_utils.py index 0f882279a034a..d42e1d9ef52f0 100644 --- a/orttraining/orttraining/python/training/_utils.py +++ b/orttraining/orttraining/python/training/_utils.py @@ -23,7 +23,7 @@ def get_device_index(device): def get_device_index_from_input(input): - '''Returns device index from a input PyTorch Tensor''' + """Returns device index from a input PyTorch Tensor""" if isinstance(input, (list, tuple)): device_index = get_device_index(input[0].device) @@ -35,40 +35,40 @@ def get_device_index_from_input(input): def get_device_str(device): if isinstance(device, str): # could be 'cuda:0', 'cuda:1', or 'cpu'. with cpu, set index=0 - if device.find(':') == -1: - device += ':' + str(torch.cuda.current_device()) + if device.find(":") == -1: + device += ":" + str(torch.cuda.current_device()) elif isinstance(device, int): - device = 'cuda:' + str(device) + device = "cuda:" + str(device) elif isinstance(device, torch.device): if device.index is None: - device = device.type + ':' + str(torch.cuda.current_device()) + device = device.type + ":" + str(torch.cuda.current_device()) else: - device = device.type + ':' + str(device.index) + device = device.type + ":" + str(device.index) else: - raise RuntimeError('Unsupported device type') + raise RuntimeError("Unsupported device type") return device def get_all_gradients_finite_name_from_session(session): - '''Find all_gradients_finite node on Session graph and return its name''' + """Find all_gradients_finite node on Session graph and return its name""" - nodes = [x for x in session._outputs_meta if 'all_gradients_finite' in x.name] + nodes = [x for x in session._outputs_meta if "all_gradients_finite" in x.name] if len(nodes) != 1: raise RuntimeError("'all_gradients_finite' node not found within training session") return nodes[0].name def get_gradient_accumulation_name_from_session(session): - '''Find Group_Accumulated_Gradients node on Session graph and return its name''' + """Find Group_Accumulated_Gradients node on Session graph and return its name""" - nodes = [x for x in session._outputs_meta if 'Group_Accumulated_Gradients' in x.name] + nodes = [x for x in session._outputs_meta if "Group_Accumulated_Gradients" in x.name] if len(nodes) != 1: raise RuntimeError("'Group_Accumulated_Gradients' node not found within training session") return nodes[0].name def dtype_torch_to_numpy(torch_dtype): - '''Converts PyTorch types to Numpy types + """Converts PyTorch types to Numpy types Also must map to types accepted by: MLDataType NumpyTypeToOnnxRuntimeType(int numpy_type) @@ -76,7 +76,7 @@ def dtype_torch_to_numpy(torch_dtype): References: https://docs.scipy.org/doc/numpy-1.13.0/user/basics.types.html https://pytorch.org/docs/stable/tensors.html - ''' + """ if torch_dtype == torch.float64 or torch_dtype == torch.double: return np.float64 elif torch_dtype == torch.float32 or torch_dtype == torch.float: @@ -102,18 +102,34 @@ def dtype_torch_to_numpy(torch_dtype): elif torch_dtype == torch.bool: return np.bool_ else: - raise ValueError( - f'torch_dtype ({str(torch_dtype)}) type is not supported by Numpy') + raise ValueError(f"torch_dtype ({str(torch_dtype)}) type is not supported by Numpy") def dtype_onnx_to_torch(onnx_type): - '''Converts ONNX types to PyTorch types + """Converts ONNX types to PyTorch types Reference: https://github.com/onnx/onnx/blob/master/onnx/onnx.in.proto (enum DataType) https://pytorch.org/docs/stable/tensors.html - ''' - onnx_types = ['UNDEFINED', 'FLOAT', 'UINT8', 'INT8', 'UINT16', 'INT16', 'INT32', 'INT64', 'STRING', - 'BOOL', 'FLOAT16', 'DOUBLE', 'UINT32', 'UINT64', 'COMPLEX64', 'COMPLEX128', 'BFLOAT16'] + """ + onnx_types = [ + "UNDEFINED", + "FLOAT", + "UINT8", + "INT8", + "UINT16", + "INT16", + "INT32", + "INT64", + "STRING", + "BOOL", + "FLOAT16", + "DOUBLE", + "UINT32", + "UINT64", + "COMPLEX64", + "COMPLEX128", + "BFLOAT16", + ] if isinstance(onnx_type, int): assert onnx_type < len(onnx_types), "Invalid onnx_type integer" @@ -122,8 +138,7 @@ def dtype_onnx_to_torch(onnx_type): assert onnx_type in onnx_types, "Invalid onnx_type string" onnx_type = onnx_types.index(onnx_type) else: - raise ValueError( - "'onnx_type' must be an ONNX type represented by either a string or integer") + raise ValueError("'onnx_type' must be an ONNX type represented by either a string or integer") if onnx_type == 0: return None @@ -158,48 +173,50 @@ def dtype_onnx_to_torch(onnx_type): def static_vars(**kwargs): - r'''Decorator to add :py:attr:`kwargs` as static vars to 'func' - - Example: - - .. code-block:: python - - >>> @static_vars(counter=0) - ... def myfync(): - ... myfync.counter += 1 - ... return myfync.counter - ... - >>> print(myfunc()) - 1 - >>> print(myfunc()) - 2 - >>> print(myfunc()) - 3 - >>> myfunc.counter = 100 - >>> print(myfunc()) - 101 - ''' + r"""Decorator to add :py:attr:`kwargs` as static vars to 'func' + + Example: + + .. code-block:: python + + >>> @static_vars(counter=0) + ... def myfync(): + ... myfync.counter += 1 + ... return myfync.counter + ... + >>> print(myfunc()) + 1 + >>> print(myfunc()) + 2 + >>> print(myfunc()) + 3 + >>> myfunc.counter = 100 + >>> print(myfunc()) + 101 + """ + def decorate(func): for k in kwargs: setattr(func, k, kwargs[k]) return func + return decorate def import_module_from_file(file_path, module_name=None): - '''Import a Python module from a file into interpreter''' + """Import a Python module from a file into interpreter""" if not isinstance(file_path, str) or not os.path.exists(file_path): raise AssertionError( - "'file_path' must be a full path string with the python file to load. " - "file_path=%r." % (file_path, )) + "'file_path' must be a full path string with the python file to load. " "file_path=%r." % (file_path,) + ) if module_name is not None and (not isinstance(module_name, str) or not module_name): raise AssertionError( - "'module_name' must be a string with the python module name to load. " - "module_name=%r." % (module_name, )) + "'module_name' must be a string with the python module name to load. " "module_name=%r." % (module_name,) + ) if not module_name: - module_name = os.path.basename(file_path).split('.')[0] + module_name = os.path.basename(file_path).split(".")[0] spec = importlib.util.spec_from_file_location(module_name, file_path) module = importlib.util.module_from_spec(spec) @@ -211,106 +228,106 @@ def import_module_from_file(file_path, module_name=None): def state_dict_model_key(): """Returns the model key name in the state dictionary""" - return 'model' + return "model" def state_dict_optimizer_key(): """Returns the optimizer key name in the state dictionary""" - return 'optimizer' + return "optimizer" def state_dict_partition_info_key(): """Returns the partition info key name in the state dictionary""" - return 'partition_info' + return "partition_info" def state_dict_trainer_options_key(): """Returns the trainer options key name in the state dictionary""" - return 'trainer_options' + return "trainer_options" def state_dict_full_precision_key(): """Returns the full precision key name in the state dictionary""" - return 'full_precision' + return "full_precision" def state_dict_original_dimension_key(): """Returns the original dimension key name in the state dictionary""" - return 'original_dim' + return "original_dim" def state_dict_sharded_optimizer_keys(): """Returns the optimizer key names that can be sharded in the state dictionary""" - return { - 'Moment_1', - 'Moment_2' - } + return {"Moment_1", "Moment_2"} def state_dict_user_dict_key(): """Returns the user dict key name in the state dictionary""" - return 'user_dict' + return "user_dict" def state_dict_trainer_options_mixed_precision_key(): """Returns the trainer options mixed precision key name in the state dictionary""" - return 'mixed_precision' + return "mixed_precision" def state_dict_trainer_options_zero_stage_key(): """Returns the trainer options zero_stage key name in the state dictionary""" - return 'zero_stage' + return "zero_stage" def state_dict_trainer_options_world_rank_key(): """Returns the trainer options world_rank key name in the state dictionary""" - return 'world_rank' + return "world_rank" def state_dict_trainer_options_world_size_key(): """Returns the trainer options world_size key name in the state dictionary""" - return 'world_size' + return "world_size" def state_dict_trainer_options_data_parallel_size_key(): """Returns the trainer options data_parallel_size key name in the state dictionary""" - return 'data_parallel_size' + return "data_parallel_size" def state_dict_trainer_options_horizontal_parallel_size_key(): """Returns the trainer options horizontal_parallel_size key name in the state dictionary""" - return 'horizontal_parallel_size' + return "horizontal_parallel_size" def state_dict_trainer_options_optimizer_name_key(): """Returns the trainer options optimizer_name key name in the state dictionary""" - return 'optimizer_name' + return "optimizer_name" + def state_dict_train_step_info_key(): """Returns the train step info key name in the state dictionary""" - return 'train_step_info' + return "train_step_info" + def state_dict_train_step_info_optimization_step_key(): """Returns the train step info optimization step key name in the state dictionary""" - return 'optimization_step' + return "optimization_step" + def state_dict_train_step_info_step_key(): """Returns the train step info step key name in the state dictionary""" - return 'step' + return "step" diff --git a/orttraining/orttraining/python/training/amp/loss_scaler.py b/orttraining/orttraining/python/training/amp/loss_scaler.py index f64db9a3fa8ce..42d3d670a59ea 100644 --- a/orttraining/orttraining/python/training/amp/loss_scaler.py +++ b/orttraining/orttraining/python/training/amp/loss_scaler.py @@ -85,11 +85,14 @@ class DynamicLossScaler(LossScaler): print(f'Custom loss scale is {scaler2.loss_scale}') """ - def __init__(self, automatic_update=True, - loss_scale=float(1 << 16), - up_scale_window=2000, - min_loss_scale=1.0, - max_loss_scale=float(1 << 24)): + def __init__( + self, + automatic_update=True, + loss_scale=float(1 << 16), + up_scale_window=2000, + min_loss_scale=1.0, + max_loss_scale=float(1 << 24), + ): super().__init__(loss_scale) self.automatic_update = automatic_update self.up_scale_window = up_scale_window diff --git a/orttraining/orttraining/python/training/checkpoint.py b/orttraining/orttraining/python/training/checkpoint.py index 364eac64b0a8f..e4a2f1230b7a4 100644 --- a/orttraining/orttraining/python/training/checkpoint.py +++ b/orttraining/orttraining/python/training/checkpoint.py @@ -14,12 +14,16 @@ def experimental_state_dict(ort_trainer, include_optimizer_state=True): - warnings.warn("experimental_state_dict() will be deprecated soon. " - "Please use ORTTrainer.state_dict() instead.", DeprecationWarning) + warnings.warn( + "experimental_state_dict() will be deprecated soon. " "Please use ORTTrainer.state_dict() instead.", + DeprecationWarning, + ) if not ort_trainer._training_session: - warnings.warn("ONNX Runtime training session is not initialized yet. " - "Please run train_step or eval_step at least once before calling state_dict().") + warnings.warn( + "ONNX Runtime training session is not initialized yet. " + "Please run train_step or eval_step at least once before calling state_dict()." + ) return ort_trainer._state_dict # extract trained weights @@ -40,8 +44,10 @@ def experimental_state_dict(ort_trainer, include_optimizer_state=True): def experimental_load_state_dict(ort_trainer, state_dict, strict=False): - warnings.warn("experimental_load_state_dict() will be deprecated soon. " - "Please use ORTTrainer.load_state_dict() instead.", DeprecationWarning) + warnings.warn( + "experimental_load_state_dict() will be deprecated soon. " "Please use ORTTrainer.load_state_dict() instead.", + DeprecationWarning, + ) # Note: It may happen ONNX model has not yet been initialized # In this case we cache a reference to desired state and delay the restore until after initialization @@ -68,25 +74,35 @@ def experimental_load_state_dict(ort_trainer, state_dict, strict=False): ort_trainer._init_session() # load training state - session_state = {name:state_dict[name].numpy() for name in state_dict} + session_state = {name: state_dict[name].numpy() for name in state_dict} ort_trainer._training_session.load_state(session_state, strict) -def experimental_save_checkpoint(ort_trainer, checkpoint_dir, checkpoint_prefix="ORT_checkpoint", checkpoint_state_dict=None, include_optimizer_state=True): - warnings.warn("experimental_save_checkpoint() will be deprecated soon. " - "Please use ORTTrainer.save_checkpoint() instead.", DeprecationWarning) +def experimental_save_checkpoint( + ort_trainer, + checkpoint_dir, + checkpoint_prefix="ORT_checkpoint", + checkpoint_state_dict=None, + include_optimizer_state=True, +): + warnings.warn( + "experimental_save_checkpoint() will be deprecated soon. " "Please use ORTTrainer.save_checkpoint() instead.", + DeprecationWarning, + ) if checkpoint_state_dict is None: - checkpoint_state_dict = {'model': experimental_state_dict(ort_trainer, include_optimizer_state)} + checkpoint_state_dict = {"model": experimental_state_dict(ort_trainer, include_optimizer_state)} else: - checkpoint_state_dict.update({'model': experimental_state_dict(ort_trainer, include_optimizer_state)}) + checkpoint_state_dict.update({"model": experimental_state_dict(ort_trainer, include_optimizer_state)}) assert os.path.exists(checkpoint_dir), f"checkpoint_dir ({checkpoint_dir}) directory doesn't exist" - checkpoint_name = _get_checkpoint_name(checkpoint_prefix, - ort_trainer.options.distributed.deepspeed_zero_optimization.stage, - ort_trainer.options.distributed.world_rank, - ort_trainer.options.distributed.world_size) + checkpoint_name = _get_checkpoint_name( + checkpoint_prefix, + ort_trainer.options.distributed.deepspeed_zero_optimization.stage, + ort_trainer.options.distributed.world_rank, + ort_trainer.options.distributed.world_size, + ) checkpoint_file = os.path.join(checkpoint_dir, checkpoint_name) if os.path.exists(checkpoint_file): msg = f"{checkpoint_file} already exists, overwriting." @@ -95,15 +111,18 @@ def experimental_save_checkpoint(ort_trainer, checkpoint_dir, checkpoint_prefix= def experimental_load_checkpoint(ort_trainer, checkpoint_dir, checkpoint_prefix="ORT_checkpoint", strict=False): - warnings.warn("experimental_load_checkpoint() will be deprecated soon. " - "Please use ORTTrainer.load_checkpoint() instead.", DeprecationWarning) + warnings.warn( + "experimental_load_checkpoint() will be deprecated soon. " "Please use ORTTrainer.load_checkpoint() instead.", + DeprecationWarning, + ) - checkpoint_files = _list_checkpoint_files( - checkpoint_dir, checkpoint_prefix) + checkpoint_files = _list_checkpoint_files(checkpoint_dir, checkpoint_prefix) is_partitioned = False if len(checkpoint_files) > 1: - msg = (f"Found more than one file with prefix {checkpoint_prefix} in directory {checkpoint_dir}." - " Attempting to load ZeRO checkpoint.") + msg = ( + f"Found more than one file with prefix {checkpoint_prefix} in directory {checkpoint_dir}." + " Attempting to load ZeRO checkpoint." + ) warnings.warn(msg) is_partitioned = True if (not ort_trainer.options.distributed.deepspeed_zero_optimization.stage) and is_partitioned: @@ -113,32 +132,40 @@ def experimental_load_checkpoint(ort_trainer, checkpoint_dir, checkpoint_prefix= class _AGGREGATION_MODE(Enum): - Zero = 0 + Zero = 0 Megatron = 1 + def _order_paths(paths, D_groups, H_groups): """Reorders the given paths in order of aggregation of ranks for D and H parallellism respectively - and returns the ordered dict""" + and returns the ordered dict""" trainer_options_path_tuples = [] world_rank = _utils.state_dict_trainer_options_world_rank_key() for path in paths: - trainer_options_path_tuples.append((_checkpoint_storage.load(path, - key=_utils.state_dict_trainer_options_key()), path)) + trainer_options_path_tuples.append( + (_checkpoint_storage.load(path, key=_utils.state_dict_trainer_options_key()), path) + ) # sort paths according to rank - sorted_paths = [path for _, path in sorted(trainer_options_path_tuples, - key=lambda trainer_options_path_pair: trainer_options_path_pair[0][world_rank])] - + sorted_paths = [ + path + for _, path in sorted( + trainer_options_path_tuples, key=lambda trainer_options_path_pair: trainer_options_path_pair[0][world_rank] + ) + ] + ordered_paths = dict() - ordered_paths['D'] = [[sorted_paths[i] for i in D_groups[group_id]] for group_id in range(len(D_groups))] - ordered_paths['H'] = [[sorted_paths[i] for i in H_groups[group_id]] for group_id in range(len(H_groups))] + ordered_paths["D"] = [[sorted_paths[i] for i in D_groups[group_id]] for group_id in range(len(D_groups))] + ordered_paths["H"] = [[sorted_paths[i] for i in H_groups[group_id]] for group_id in range(len(H_groups))] return ordered_paths -def _add_or_update_sharded_key(state_key, state_value, state_sub_dict, - model_state_key, state_partition_info, sharded_states_original_dims, mode): + +def _add_or_update_sharded_key( + state_key, state_value, state_sub_dict, model_state_key, state_partition_info, sharded_states_original_dims, mode +): """Add or update the record for the sharded state_key in the state_sub_dict""" # record the original dimension for this state @@ -153,12 +180,12 @@ def _add_or_update_sharded_key(state_key, state_value, state_sub_dict, # state_dict already contains a record for this state # since this state is sharded, concatenate the state value to # the record in the state_dict - state_sub_dict[state_key] = \ - np.concatenate((state_sub_dict[state_key], state_value), axis) + state_sub_dict[state_key] = np.concatenate((state_sub_dict[state_key], state_value), axis) else: # create a new entry for this state in the state_dict state_sub_dict[state_key] = state_value + def _add_or_validate_unsharded_key(state_key, state_value, state_sub_dict, mismatch_error_string): """Add or validate the record for the unsharded state_key in the state_sub_dict""" @@ -170,7 +197,10 @@ def _add_or_validate_unsharded_key(state_key, state_value, state_sub_dict, misma # create a new entry for this state in the state_sub_dict state_sub_dict[state_key] = state_value -def _aggregate_model_states(rank_state_dict, sharded_states_original_dims, state_dict, mixed_precision_enabled, mode = _AGGREGATION_MODE.Zero): + +def _aggregate_model_states( + rank_state_dict, sharded_states_original_dims, state_dict, mixed_precision_enabled, mode=_AGGREGATION_MODE.Zero +): """Aggregates all model states from the rank_state_dict into state_dict""" model = _utils.state_dict_model_key() @@ -192,17 +222,30 @@ def _aggregate_model_states(rank_state_dict, sharded_states_original_dims, state # ZERO: full precision model states are sharded only when they exist in the partition_info subdict and mixed # precision training was enabled. for full precision training, full precision model states are not sharded # MEGATRON : full precision model states are sharded when they exist in the partition_info subdict - if (model_state_key in rank_state_dict[partition_info]) and (mode == _AGGREGATION_MODE.Megatron or mixed_precision_enabled): + if (model_state_key in rank_state_dict[partition_info]) and ( + mode == _AGGREGATION_MODE.Megatron or mixed_precision_enabled + ): # this model state is sharded - _add_or_update_sharded_key(model_state_key, model_state_value, - state_dict[model][full_precision], model_state_key, - rank_state_dict[partition_info][model_state_key], sharded_states_original_dims, mode) + _add_or_update_sharded_key( + model_state_key, + model_state_value, + state_dict[model][full_precision], + model_state_key, + rank_state_dict[partition_info][model_state_key], + sharded_states_original_dims, + mode, + ) else: # this model state is not sharded since a record for it does not exist in the partition_info subdict - _add_or_validate_unsharded_key(model_state_key, model_state_value, - state_dict[model][full_precision], "Value mismatch for model state {}".format(model_state_key)) + _add_or_validate_unsharded_key( + model_state_key, + model_state_value, + state_dict[model][full_precision], + "Value mismatch for model state {}".format(model_state_key), + ) + -def _aggregate_optimizer_states(rank_state_dict, sharded_states_original_dims, state_dict, mode = _AGGREGATION_MODE.Zero): +def _aggregate_optimizer_states(rank_state_dict, sharded_states_original_dims, state_dict, mode=_AGGREGATION_MODE.Zero): """Aggregates all optimizer states from the rank_state_dict into state_dict""" optimizer = _utils.state_dict_optimizer_key() @@ -224,15 +267,25 @@ def _aggregate_optimizer_states(rank_state_dict, sharded_states_original_dims, s if optimizer_key in sharded_optimizer_keys and model_state_key in rank_state_dict[partition_info]: # this optimizer state is sharded since a record exists in the partition_info subdict - _add_or_update_sharded_key(optimizer_key, optimizer_value, - state_dict[optimizer][model_state_key], model_state_key, - rank_state_dict[partition_info][model_state_key], sharded_states_original_dims, mode) + _add_or_update_sharded_key( + optimizer_key, + optimizer_value, + state_dict[optimizer][model_state_key], + model_state_key, + rank_state_dict[partition_info][model_state_key], + sharded_states_original_dims, + mode, + ) else: # this optimizer state is not sharded since a record for it does not exist in the partition_info subdict # or this optimizer key is not one of the sharded optimizer keys - _add_or_validate_unsharded_key(optimizer_key, optimizer_value, + _add_or_validate_unsharded_key( + optimizer_key, + optimizer_value, state_dict[optimizer][model_state_key], - "Value mismatch for model state {} and optimizer state {}".format(model_state_key, optimizer_key)) + "Value mismatch for model state {} and optimizer state {}".format(model_state_key, optimizer_key), + ) + def _reshape_states(sharded_states_original_dims, state_dict, mixed_precision_enabled): """Reshape model and optimizer states in the state_dict according to dimensions in sharded_states_original_dims""" @@ -245,8 +298,9 @@ def _reshape_states(sharded_states_original_dims, state_dict, mixed_precision_en for sharded_state_key, original_dim in sharded_states_original_dims.items(): # reshape model states to original_dim only when mixed precision is enabled if mixed_precision_enabled and (model in state_dict): - state_dict[model][full_precision][sharded_state_key] = \ - state_dict[model][full_precision][sharded_state_key].reshape(original_dim) + state_dict[model][full_precision][sharded_state_key] = state_dict[model][full_precision][ + sharded_state_key + ].reshape(original_dim) # reshape optimizer states to original_dim if optimizer in state_dict: @@ -254,6 +308,7 @@ def _reshape_states(sharded_states_original_dims, state_dict, mixed_precision_en if optimizer_key in sharded_optimizer_keys: state_dict[optimizer][sharded_state_key][optimizer_key] = optimizer_value.reshape(original_dim) + def _aggregate_trainer_options(rank_state_dict, state_dict, partial_aggregation): """Extracts trainer options from rank_state_dict and loads them accordingly on state_dict""" trainer_options = _utils.state_dict_trainer_options_key() @@ -275,35 +330,39 @@ def _aggregate_trainer_options(rank_state_dict, state_dict, partial_aggregation) state_dict[trainer_options][D_size] = 1 state_dict[trainer_options][H_size] = 1 + def _aggregate_megatron_partition_info(rank_state_dict, state_dict): """Extracts partition_info from rank_state_dict and loads on state_dict for megatron-partitioned weights""" partition_info = _utils.state_dict_partition_info_key() if partition_info not in state_dict: state_dict[partition_info] = {} - + rank_partition_info = rank_state_dict[partition_info] for model_state_key, partition_info_dict in rank_partition_info.items(): if model_state_key not in state_dict[partition_info]: # add partition info only if weight is megatron partitioned - if (partition_info_dict["megatron_row_partition"] >= 0): + if partition_info_dict["megatron_row_partition"] >= 0: state_dict[partition_info][model_state_key] = partition_info_dict + def _to_pytorch_format(state_dict): """Convert ORT state dictionary schema (hierarchical structure) to PyTorch state dictionary schema (flat structure)""" pytorch_state_dict = {} - for model_state_key, model_state_value in \ - state_dict[_utils.state_dict_model_key()][_utils.state_dict_full_precision_key()].items(): + for model_state_key, model_state_value in state_dict[_utils.state_dict_model_key()][ + _utils.state_dict_full_precision_key() + ].items(): # convert numpy array to a torch tensor pytorch_state_dict[model_state_key] = torch.tensor(model_state_value) return pytorch_state_dict + def _get_parallellism_groups(data_parallel_size, horizontal_parallel_size, world_size): """Returns the D and H groups for the given sizes""" num_data_groups = world_size // data_parallel_size data_groups = [] for data_group_id in range(num_data_groups): - data_group_ranks=[] + data_group_ranks = [] for r in range(data_parallel_size): data_group_ranks.append(data_group_id + horizontal_parallel_size * r) data_groups.append(data_group_ranks) @@ -311,23 +370,31 @@ def _get_parallellism_groups(data_parallel_size, horizontal_parallel_size, world num_horizontal_groups = world_size // horizontal_parallel_size horizontal_groups = [] for hori_group_id in range(num_horizontal_groups): - hori_group_ranks=[] + hori_group_ranks = [] for r in range(horizontal_parallel_size): hori_group_ranks.append(hori_group_id * horizontal_parallel_size + r) horizontal_groups.append(hori_group_ranks) return data_groups, horizontal_groups -def _aggregate_over_ranks(ordered_paths, ranks, sharded_states_original_dims = None, mode = _AGGREGATION_MODE.Zero, partial_aggregation = False, pytorch_format=True): + +def _aggregate_over_ranks( + ordered_paths, + ranks, + sharded_states_original_dims=None, + mode=_AGGREGATION_MODE.Zero, + partial_aggregation=False, + pytorch_format=True, +): """Aggregate checkpoint files over set of ranks and return a single state dictionary Args: ordered_paths: list of paths in the order in which they must be aggregated ranks: list of ranks that are to be aggregated - sharded_states_original_dims: dict containing the original dims for sharded states that are persisted over + sharded_states_original_dims: dict containing the original dims for sharded states that are persisted over multiple calls to _aggregate_over_ranks() mode: mode of aggregation: Zero or Megatron - partial_aggregation: boolean flag to indicate whether to produce a partially + partial_aggregation: boolean flag to indicate whether to produce a partially aggregated state which can be further aggregated over pytorch_format: boolean flag to select either ONNX Runtime or PyTorch state schema of the returned state_dict Returns: @@ -352,29 +419,35 @@ def _aggregate_over_ranks(ordered_paths, ranks, sharded_states_original_dims = N assert _utils.state_dict_partition_info_key() in rank_state_dict, "Missing information: partition_info" assert _utils.state_dict_trainer_options_key() in rank_state_dict, "Missing information: trainer_options" - assert ranks[i] == rank_state_dict[_utils.state_dict_trainer_options_key()][world_rank], \ - "Unexpected rank in file at path {}. Expected {}, got {}".\ - format(path, rank, rank_state_dict[_utils.state_dict_trainer_options_key()][world_rank]) + assert ( + ranks[i] == rank_state_dict[_utils.state_dict_trainer_options_key()][world_rank] + ), "Unexpected rank in file at path {}. Expected {}, got {}".format( + path, rank, rank_state_dict[_utils.state_dict_trainer_options_key()][world_rank] + ) if loaded_mixed_precision is None: loaded_mixed_precision = rank_state_dict[_utils.state_dict_trainer_options_key()][mixed_precision] else: - assert loaded_mixed_precision == rank_state_dict[_utils.state_dict_trainer_options_key()][mixed_precision], \ - "Mixed precision state mismatch among checkpoint files. File: {}".format(path) + assert ( + loaded_mixed_precision == rank_state_dict[_utils.state_dict_trainer_options_key()][mixed_precision] + ), "Mixed precision state mismatch among checkpoint files. File: {}".format(path) if loaded_world_size is None: loaded_world_size = rank_state_dict[_utils.state_dict_trainer_options_key()][world_size] else: - assert loaded_world_size == rank_state_dict[_utils.state_dict_trainer_options_key()][world_size], \ - "World size state mismatch among checkpoint files. File: {}".format(path) + assert ( + loaded_world_size == rank_state_dict[_utils.state_dict_trainer_options_key()][world_size] + ), "World size state mismatch among checkpoint files. File: {}".format(path) if loaded_zero_stage is None: loaded_zero_stage = rank_state_dict[_utils.state_dict_trainer_options_key()][zero_stage] else: - assert loaded_zero_stage == rank_state_dict[_utils.state_dict_trainer_options_key()][zero_stage], \ - "Zero stage mismatch among checkpoint files. File: {}".format(path) + assert ( + loaded_zero_stage == rank_state_dict[_utils.state_dict_trainer_options_key()][zero_stage] + ), "Zero stage mismatch among checkpoint files. File: {}".format(path) if loaded_optimizer_name is None: loaded_optimizer_name = rank_state_dict[_utils.state_dict_trainer_options_key()][optimizer_name] else: - assert loaded_optimizer_name == rank_state_dict[_utils.state_dict_trainer_options_key()][optimizer_name], \ - "Optimizer name mismatch among checkpoint files. File: {}".format(path) + assert ( + loaded_optimizer_name == rank_state_dict[_utils.state_dict_trainer_options_key()][optimizer_name] + ), "Optimizer name mismatch among checkpoint files. File: {}".format(path) # aggregate all model states _aggregate_model_states(rank_state_dict, sharded_states_original_dims, state_dict, loaded_mixed_precision, mode) @@ -383,9 +456,9 @@ def _aggregate_over_ranks(ordered_paths, ranks, sharded_states_original_dims = N # aggregate all optimizer states if pytorch_format is False _aggregate_optimizer_states(rank_state_dict, sharded_states_original_dims, state_dict, mode) - # for D+H aggregation scenario, the first pass of aggregation(partial aggregation) is over D groups + # for D+H aggregation scenario, the first pass of aggregation(partial aggregation) is over D groups # to aggregate over Zero, and another pass to aggregate Megatron partitioned - # states. Preserve the relevant partition info only for weights that are megatron partitioned for + # states. Preserve the relevant partition info only for weights that are megatron partitioned for # a partial aggregation call if partial_aggregation: _aggregate_megatron_partition_info(rank_state_dict, state_dict) @@ -395,8 +468,10 @@ def _aggregate_over_ranks(ordered_paths, ranks, sharded_states_original_dims = N _aggregate_trainer_options(rank_state_dict, state_dict, partial_aggregation) # entry for user_dict in the state_dict if not already present - if _utils.state_dict_user_dict_key() not in state_dict and \ - _utils.state_dict_user_dict_key() in rank_state_dict: + if ( + _utils.state_dict_user_dict_key() not in state_dict + and _utils.state_dict_user_dict_key() in rank_state_dict + ): state_dict[_utils.state_dict_user_dict_key()] = rank_state_dict[_utils.state_dict_user_dict_key()] # for a partial aggregation scenario, we might not have the entire tensor aggregated yet, thus skip reshape @@ -408,12 +483,13 @@ def _aggregate_over_ranks(ordered_paths, ranks, sharded_states_original_dims = N # else return the hierarchical structure for ORTTrainer return _to_pytorch_format(state_dict) if pytorch_format else state_dict + def _aggregate_over_D_H(ordered_paths, D_groups, H_groups, pytorch_format): """Aggregate checkpoint files and return a single state dictionary for the D+H - (Zero+Megatron) partitioning strategy. - For D+H aggregation scenario, the first pass of aggregation(partial aggregation) is over D groups - to aggregate over Zero, and another pass over the previously aggregated states - to aggregate Megatron partitioned states. + (Zero+Megatron) partitioning strategy. + For D+H aggregation scenario, the first pass of aggregation(partial aggregation) is over D groups + to aggregate over Zero, and another pass over the previously aggregated states + to aggregate Megatron partitioned states. """ sharded_states_original_dims = {} aggregate_data_checkpoint_files = [] @@ -421,9 +497,15 @@ def _aggregate_over_D_H(ordered_paths, D_groups, H_groups, pytorch_format): # combine for Zero over data groups and save to temp file with tempfile.TemporaryDirectory() as save_dir: for group_id, d_group in enumerate(D_groups): - aggregate_state_dict = _aggregate_over_ranks(ordered_paths['D'][group_id], d_group, sharded_states_original_dims, partial_aggregation = True, pytorch_format=False) - - filename = 'ort.data_group.' + str(group_id) + '.ort.pt' + aggregate_state_dict = _aggregate_over_ranks( + ordered_paths["D"][group_id], + d_group, + sharded_states_original_dims, + partial_aggregation=True, + pytorch_format=False, + ) + + filename = "ort.data_group." + str(group_id) + ".ort.pt" filepath = os.path.join(save_dir, filename) _checkpoint_storage.save(aggregate_state_dict, filepath) aggregate_data_checkpoint_files.append(filepath) @@ -431,10 +513,17 @@ def _aggregate_over_D_H(ordered_paths, D_groups, H_groups, pytorch_format): assert len(aggregate_data_checkpoint_files) > 0 # combine for megatron: - aggregate_state = _aggregate_over_ranks(aggregate_data_checkpoint_files, H_groups[0], sharded_states_original_dims, mode = _AGGREGATION_MODE.Megatron, pytorch_format = pytorch_format) + aggregate_state = _aggregate_over_ranks( + aggregate_data_checkpoint_files, + H_groups[0], + sharded_states_original_dims, + mode=_AGGREGATION_MODE.Megatron, + pytorch_format=pytorch_format, + ) return aggregate_state + def aggregate_checkpoints(paths, pytorch_format=True): """Aggregate checkpoint files and return a single state dictionary @@ -463,7 +552,7 @@ def aggregate_checkpoints(paths, pytorch_format=True): combine_zero = loaded_trainer_options[_utils.state_dict_trainer_options_zero_stage_key()] > 0 combine_megatron = len(H_groups[0]) > 1 - # order the paths in the order of groups in which they must be aggregated according to + # order the paths in the order of groups in which they must be aggregated according to # data-parallel groups and H-parallel groups obtained # eg: {'D': [[path_0, path_2],[path_1, path_3]], 'H': [[path_0, path_1],[path_2, path_3]]} ordered_paths = _order_paths(paths, D_groups, H_groups) @@ -472,12 +561,17 @@ def aggregate_checkpoints(paths, pytorch_format=True): if combine_zero and combine_megatron: aggregate_state = _aggregate_over_D_H(ordered_paths, D_groups, H_groups, pytorch_format) elif combine_zero: - aggregate_state = _aggregate_over_ranks(ordered_paths['D'][0], D_groups[0], mode = _AGGREGATION_MODE.Zero, pytorch_format = pytorch_format) + aggregate_state = _aggregate_over_ranks( + ordered_paths["D"][0], D_groups[0], mode=_AGGREGATION_MODE.Zero, pytorch_format=pytorch_format + ) elif combine_megatron: - aggregate_state = _aggregate_over_ranks(ordered_paths['H'][0], H_groups[0], mode = _AGGREGATION_MODE.Megatron, pytorch_format = pytorch_format) + aggregate_state = _aggregate_over_ranks( + ordered_paths["H"][0], H_groups[0], mode=_AGGREGATION_MODE.Megatron, pytorch_format=pytorch_format + ) return aggregate_state + ################################################################################ # Helper functions ################################################################################ @@ -485,20 +579,26 @@ def aggregate_checkpoints(paths, pytorch_format=True): def _load_single_checkpoint(ort_trainer, checkpoint_dir, checkpoint_prefix, is_partitioned, strict): checkpoint_name = _get_checkpoint_name( - checkpoint_prefix, is_partitioned, ort_trainer.options.distributed.world_rank, ort_trainer.options.distributed.world_size) + checkpoint_prefix, + is_partitioned, + ort_trainer.options.distributed.world_rank, + ort_trainer.options.distributed.world_size, + ) checkpoint_file = os.path.join(checkpoint_dir, checkpoint_name) if is_partitioned: - assert_msg = (f"Couldn't find checkpoint file {checkpoint_file}." - " Optimizer partitioning is enabled using ZeRO. Please make sure the checkpoint file exists " - f"for rank {ort_trainer.options.distributed.world_rank} of {ort_trainer.options.distributed.world_size}") + assert_msg = ( + f"Couldn't find checkpoint file {checkpoint_file}." + " Optimizer partitioning is enabled using ZeRO. Please make sure the checkpoint file exists " + f"for rank {ort_trainer.options.distributed.world_rank} of {ort_trainer.options.distributed.world_size}" + ) else: assert_msg = f"Couldn't find checkpoint file {checkpoint_file}." assert os.path.exists(checkpoint_file), assert_msg - checkpoint_state = torch.load(checkpoint_file, map_location='cpu') - experimental_load_state_dict(ort_trainer, checkpoint_state['model'], strict=strict) - del(checkpoint_state['model']) + checkpoint_state = torch.load(checkpoint_file, map_location="cpu") + experimental_load_state_dict(ort_trainer, checkpoint_state["model"], strict=strict) + del checkpoint_state["model"] return checkpoint_state @@ -514,13 +614,13 @@ def _load_multi_checkpoint(ort_trainer, checkpoint_dir, checkpoint_prefix, stric # Values will be overwritten for matching keys among workers all_checkpoint_states = dict() for checkpoint_file in checkpoint_files: - checkpoint_state = torch.load(checkpoint_file, map_location='cpu') - del(checkpoint_state['model']) + checkpoint_state = torch.load(checkpoint_file, map_location="cpu") + del checkpoint_state["model"] all_checkpoint_states.update(checkpoint_state) return all_checkpoint_states -def _list_checkpoint_files(checkpoint_dir, checkpoint_prefix, extension='.ort.pt'): +def _list_checkpoint_files(checkpoint_dir, checkpoint_prefix, extension=".ort.pt"): ckpt_file_names = [f for f in os.listdir(checkpoint_dir) if f.startswith(checkpoint_prefix)] ckpt_file_names = [f for f in ckpt_file_names if f.endswith(extension)] ckpt_file_names = [os.path.join(checkpoint_dir, f) for f in ckpt_file_names] @@ -530,27 +630,29 @@ def _list_checkpoint_files(checkpoint_dir, checkpoint_prefix, extension='.ort.pt def _get_checkpoint_name(prefix, is_partitioned, world_rank=None, world_size=None): - SINGLE_CHECKPOINT_FILENAME = '{prefix}.ort.pt' - MULTIPLE_CHECKPOINT_FILENAME = '{prefix}.ZeRO.{world_rank}.{world_size}.ort.pt' + SINGLE_CHECKPOINT_FILENAME = "{prefix}.ort.pt" + MULTIPLE_CHECKPOINT_FILENAME = "{prefix}.ZeRO.{world_rank}.{world_size}.ort.pt" if is_partitioned: - filename = MULTIPLE_CHECKPOINT_FILENAME.format(prefix=prefix, world_rank=world_rank, world_size=(world_size-1)) + filename = MULTIPLE_CHECKPOINT_FILENAME.format( + prefix=prefix, world_rank=world_rank, world_size=(world_size - 1) + ) else: filename = SINGLE_CHECKPOINT_FILENAME.format(prefix=prefix) return filename def _split_state_dict(state_dict): - optimizer_keys = ['Moment_1_', 'Moment_2_', 'Update_Count_', 'Step'] - split_sd = {'optimizer': {}, 'fp32_param': {}, 'fp16_param': {}} + optimizer_keys = ["Moment_1_", "Moment_2_", "Update_Count_", "Step"] + split_sd = {"optimizer": {}, "fp32_param": {}, "fp16_param": {}} for k, v in state_dict.items(): - mode = 'fp32_param' + mode = "fp32_param" for optim_key in optimizer_keys: if k.startswith(optim_key): - mode = 'optimizer' + mode = "optimizer" break - if k.endswith('_fp16'): - mode = 'fp16_param' + if k.endswith("_fp16"): + mode = "fp16_param" split_sd[mode][k] = v return split_sd @@ -561,30 +663,30 @@ def __init__(self, checkpoint_files, clean_state_dict=None): assert len(checkpoint_files) > 0, "No checkpoint files passed" self.checkpoint_files = checkpoint_files self.clean_state_dict = clean_state_dict - self.world_size = int(self.checkpoint_files[0].split('ZeRO')[1].split('.')[2]) + 1 + self.world_size = int(self.checkpoint_files[0].split("ZeRO")[1].split(".")[2]) + 1 assert len(self.checkpoint_files) == self.world_size, f"Could not find {self.world_size} files" self.weight_shape_map = dict() self.sharded_params = set() def _split_name(self, name): - name_split = name.split('_view_') + name_split = name.split("_view_") view_num = None - if(len(name_split) > 1): + if len(name_split) > 1: view_num = int(name_split[1]) - optimizer_key = '' - mp_suffix = '' - if name_split[0].startswith('Moment_1'): - optimizer_key = 'Moment_1_' - elif name_split[0].startswith('Moment_2'): - optimizer_key = 'Moment_2_' - elif name_split[0].startswith('Update_Count'): - optimizer_key = 'Update_Count_' - elif name_split[0].endswith('_fp16'): - mp_suffix = '_fp16' + optimizer_key = "" + mp_suffix = "" + if name_split[0].startswith("Moment_1"): + optimizer_key = "Moment_1_" + elif name_split[0].startswith("Moment_2"): + optimizer_key = "Moment_2_" + elif name_split[0].startswith("Update_Count"): + optimizer_key = "Update_Count_" + elif name_split[0].endswith("_fp16"): + mp_suffix = "_fp16" param_name = name_split[0] - if optimizer_key != '': + if optimizer_key != "": param_name = param_name.split(optimizer_key)[1] - param_name = param_name.split('_fp16')[0] + param_name = param_name.split("_fp16")[0] return param_name, optimizer_key, view_num, mp_suffix def _update_weight_statistics(self, name, value): @@ -604,7 +706,7 @@ def _aggregate(self, param_dict): # parameter is sharded param_name = optimizer_key + weight_name + mp_suffix - if param_name in self.aggregate_state_dict and optimizer_key not in ['Update_Count_']: + if param_name in self.aggregate_state_dict and optimizer_key not in ["Update_Count_"]: self.sharded_params.add(param_name) # Found a previous shard of the param, concatenate shards ordered by ranks self.aggregate_state_dict[param_name] = torch.cat((self.aggregate_state_dict[param_name], v)) @@ -618,25 +720,28 @@ def _aggregate(self, param_dict): self._update_weight_statistics(weight_name, v) def aggregate_checkpoints(self): - warnings.warn("_CombineZeroCheckpoint.aggregate_checkpoints() will be deprecated soon. " - "Please use aggregate_checkpoints() instead.", DeprecationWarning) + warnings.warn( + "_CombineZeroCheckpoint.aggregate_checkpoints() will be deprecated soon. " + "Please use aggregate_checkpoints() instead.", + DeprecationWarning, + ) - checkpoint_prefix = self.checkpoint_files[0].split('.ZeRO')[0] + checkpoint_prefix = self.checkpoint_files[0].split(".ZeRO")[0] self.aggregate_state_dict = dict() for i in range(self.world_size): checkpoint_name = _get_checkpoint_name(checkpoint_prefix, True, i, self.world_size) rank_state_dict = torch.load(checkpoint_name, map_location=torch.device("cpu")) - if 'model' in rank_state_dict: - rank_state_dict = rank_state_dict['model'] + if "model" in rank_state_dict: + rank_state_dict = rank_state_dict["model"] if self.clean_state_dict: rank_state_dict = self.clean_state_dict(rank_state_dict) rank_state_dict = _split_state_dict(rank_state_dict) - self._aggregate(rank_state_dict['fp16_param']) - self._aggregate(rank_state_dict['fp32_param']) - self._aggregate(rank_state_dict['optimizer']) + self._aggregate(rank_state_dict["fp16_param"]) + self._aggregate(rank_state_dict["fp32_param"]) + self._aggregate(rank_state_dict["optimizer"]) for k in self.sharded_params: self._reshape_tensor(k) diff --git a/orttraining/orttraining/python/training/experimental/gradient_graph/_gradient_graph_tools.py b/orttraining/orttraining/python/training/experimental/gradient_graph/_gradient_graph_tools.py index b55d272ef67f7..91c656619f621 100644 --- a/orttraining/orttraining/python/training/experimental/gradient_graph/_gradient_graph_tools.py +++ b/orttraining/orttraining/python/training/experimental/gradient_graph/_gradient_graph_tools.py @@ -10,12 +10,13 @@ def export_gradient_graph( - model: torch.nn.Module, - loss_fn: Callable[[Any, Any], Any], - example_input: torch.Tensor, - example_labels: torch.Tensor, - gradient_graph_path: Union[Path, str], - opset_version=12) -> None: + model: torch.nn.Module, + loss_fn: Callable[[Any, Any], Any], + example_input: torch.Tensor, + example_labels: torch.Tensor, + gradient_graph_path: Union[Path, str], + opset_version=12, +) -> None: r""" Build a gradient graph for `model` so that you can output gradients in an inference session when given specific input and corresponding labels. @@ -52,33 +53,37 @@ def forward(self, model_input, expected_labels, *model_params): wrapped_model = WrapperModule() dynamic_axes = { - 'input': {0: 'batch_size', }, - 'labels': {0: 'batch_size', }, - 'output': {0: 'batch_size', }, + "input": { + 0: "batch_size", + }, + "labels": { + 0: "batch_size", + }, + "output": { + 0: "batch_size", + }, } args = (example_input, example_labels, *tuple(model.parameters())) model_param_names = tuple(name for name, _ in model.named_parameters()) - input_names = ['input', 'labels', *model_param_names] - nodes_needing_gradients = set( - name for name, param in model.named_parameters() - if param.requires_grad) + input_names = ["input", "labels", *model_param_names] + nodes_needing_gradients = set(name for name, param in model.named_parameters() if param.requires_grad) f = io.BytesIO() torch.onnx.export( - wrapped_model, args, + wrapped_model, + args, f, export_params=True, - opset_version=opset_version, do_constant_folding=False, + opset_version=opset_version, + do_constant_folding=False, training=TrainingMode.TRAINING, input_names=input_names, - output_names=['output', 'loss'], - dynamic_axes=dynamic_axes) + output_names=["output", "loss"], + dynamic_axes=dynamic_axes, + ) exported_model = f.getvalue() - builder = GradientGraphBuilder(exported_model, - {'loss'}, - nodes_needing_gradients, - 'loss') + builder = GradientGraphBuilder(exported_model, {"loss"}, nodes_needing_gradients, "loss") builder.build() builder.save(gradient_graph_path) diff --git a/orttraining/orttraining/python/training/model_desc_validation.py b/orttraining/orttraining/python/training/model_desc_validation.py index d107b2a1a545f..e9181f732cb32 100644 --- a/orttraining/orttraining/python/training/model_desc_validation.py +++ b/orttraining/orttraining/python/training/model_desc_validation.py @@ -11,7 +11,6 @@ class _ORTTrainerModelDesc(object): - def __init__(self, model_desc): # Keep a copy of original input for debug self._original = dict(model_desc) @@ -29,23 +28,24 @@ def __init__(self, model_desc): validator = cerberus.Validator(MODEL_DESC_SCHEMA) self._validated = validator.validated(self._validated) if self._validated is None: - raise ValueError(f'Invalid model_desc: {validator.errors}') + raise ValueError(f"Invalid model_desc: {validator.errors}") # Normalize inputs to a list of namedtuple(name, shape) - self._InputDescription = namedtuple('InputDescription', ['name', 'shape']) - self._InputDescriptionTyped = namedtuple('InputDescriptionTyped', ['name', 'shape', 'dtype']) - for idx, input in enumerate(self._validated['inputs']): - self._validated['inputs'][idx] = self._InputDescription(*input) + self._InputDescription = namedtuple("InputDescription", ["name", "shape"]) + self._InputDescriptionTyped = namedtuple("InputDescriptionTyped", ["name", "shape", "dtype"]) + for idx, input in enumerate(self._validated["inputs"]): + self._validated["inputs"][idx] = self._InputDescription(*input) # Normalize outputs to a list of namedtuple(name, shape, is_loss) - self._OutputDescription = namedtuple('OutputDescription', ['name', 'shape', 'is_loss']) - self._OutputDescriptionTyped = namedtuple('OutputDescriptionTyped', - ['name', 'shape', 'is_loss', 'dtype', 'dtype_amp']) - for idx, output in enumerate(self._validated['outputs']): + self._OutputDescription = namedtuple("OutputDescription", ["name", "shape", "is_loss"]) + self._OutputDescriptionTyped = namedtuple( + "OutputDescriptionTyped", ["name", "shape", "is_loss", "dtype", "dtype_amp"] + ) + for idx, output in enumerate(self._validated["outputs"]): if len(output) == 2: - self._validated['outputs'][idx] = self._OutputDescription(*output, False) + self._validated["outputs"][idx] = self._OutputDescription(*output, False) else: - self._validated['outputs'][idx] = self._OutputDescription(*output) + self._validated["outputs"][idx] = self._OutputDescription(*output) # Hard-code learning rate, all_finite descriptors self.learning_rate = self._InputDescriptionTyped(LEARNING_RATE_IO_DESCRIPTION_NAME, [1], torch.float32) @@ -55,91 +55,96 @@ def __init__(self, model_desc): setattr(self, k, self._wrap(v)) def __repr__(self): - '''Pretty representation for a model description class''' + """Pretty representation for a model description class""" - pretty_msg = 'Model description:\n' + pretty_msg = "Model description:\n" # Inputs inputs = [] for i_desc in self.inputs: if isinstance(i_desc, self._InputDescription): - inputs.append(f'(name={i_desc.name}, shape={i_desc.shape})') + inputs.append(f"(name={i_desc.name}, shape={i_desc.shape})") elif isinstance(i_desc, self._InputDescriptionTyped): - inputs.append(f'(name={i_desc.name}, shape={i_desc.shape}, dtype={i_desc.dtype})') + inputs.append(f"(name={i_desc.name}, shape={i_desc.shape}, dtype={i_desc.dtype})") else: - raise ValueError(f'Unexpected type {type(i_desc)} for input description') + raise ValueError(f"Unexpected type {type(i_desc)} for input description") - pretty_msg += '\nInputs:' + pretty_msg += "\nInputs:" for idx, item in enumerate(inputs): - pretty_msg += f'\n\t{idx}: {item}' + pretty_msg += f"\n\t{idx}: {item}" # Outputs outputs = [] for o_desc in self.outputs: if isinstance(o_desc, self._OutputDescription): - outputs.append(f'(name={o_desc.name}, shape={o_desc.shape})') + outputs.append(f"(name={o_desc.name}, shape={o_desc.shape})") elif isinstance(o_desc, self._OutputDescriptionTyped): - outputs.append(f'(name={o_desc.name}, shape={o_desc.shape}, dtype={o_desc.dtype}, dtype_amp={o_desc.dtype_amp})') + outputs.append( + f"(name={o_desc.name}, shape={o_desc.shape}, dtype={o_desc.dtype}, dtype_amp={o_desc.dtype_amp})" + ) else: - raise ValueError(f'Unexpected type {type(o_desc)} for output description') - pretty_msg += '\nOutputs:' + raise ValueError(f"Unexpected type {type(o_desc)} for output description") + pretty_msg += "\nOutputs:" for idx, item in enumerate(outputs): - pretty_msg += f'\n\t{idx}: {item}' + pretty_msg += f"\n\t{idx}: {item}" # Learning rate if self.learning_rate: - pretty_msg += '\nLearning rate: ' - pretty_msg += f'(name={self.learning_rate.name}, shape={self.learning_rate.shape}, dtype={self.learning_rate.dtype})' + pretty_msg += "\nLearning rate: " + pretty_msg += ( + f"(name={self.learning_rate.name}, shape={self.learning_rate.shape}, dtype={self.learning_rate.dtype})" + ) # Mixed precision - if getattr(self, ALL_FINITE_IO_DESCRIPTION_NAME, None) or getattr(self, LOSS_SCALE_INPUT_IO_DESCRIPTION_NAME, None): - pretty_msg += '\nMixed Precision:' + if getattr(self, ALL_FINITE_IO_DESCRIPTION_NAME, None) or getattr( + self, LOSS_SCALE_INPUT_IO_DESCRIPTION_NAME, None + ): + pretty_msg += "\nMixed Precision:" if getattr(self, ALL_FINITE_IO_DESCRIPTION_NAME, None): - pretty_msg += '\n\tis gradients finite: ' - pretty_msg += f'(name={self.all_finite.name}, shape={self.all_finite.shape}, dtype={self.all_finite.dtype})' + pretty_msg += "\n\tis gradients finite: " + pretty_msg += ( + f"(name={self.all_finite.name}, shape={self.all_finite.shape}, dtype={self.all_finite.dtype})" + ) if getattr(self, LOSS_SCALE_INPUT_IO_DESCRIPTION_NAME, None): - pretty_msg += '\n\tloss scale input name: ' - pretty_msg += f'(name={self.loss_scale_input.name}, shape={self.loss_scale_input.shape}, dtype={self.loss_scale_input.dtype})' + pretty_msg += "\n\tloss scale input name: " + pretty_msg += f"(name={self.loss_scale_input.name}, shape={self.loss_scale_input.shape}, dtype={self.loss_scale_input.dtype})" # Gradient Accumulation steps if self.gradient_accumulation: - pretty_msg += '\nGradient Accumulation: ' - pretty_msg += f'(name={self.gradient_accumulation.name}, shape={self.gradient_accumulation.shape}, dtype={self.gradient_accumulation.dtype})' + pretty_msg += "\nGradient Accumulation: " + pretty_msg += f"(name={self.gradient_accumulation.name}, shape={self.gradient_accumulation.shape}, dtype={self.gradient_accumulation.dtype})" return pretty_msg def add_type_to_input_description(self, index, dtype): - '''Updates an existing input description at position 'index' with 'dtype' type information + """Updates an existing input description at position 'index' with 'dtype' type information Args: index (int): position within 'inputs' description dtype (torch.dtype): input data type - ''' + """ - assert isinstance(index, int) and index >= 0,\ - "input 'index' must be a positive int" - assert isinstance(dtype, torch.dtype),\ - "input 'dtype' must be a torch.dtype type" + assert isinstance(index, int) and index >= 0, "input 'index' must be a positive int" + assert isinstance(dtype, torch.dtype), "input 'dtype' must be a torch.dtype type" existing_values = (*self.inputs[index],) if isinstance(self.inputs[index], self._InputDescriptionTyped): existing_values = (*existing_values[:-1],) self.inputs[index] = self._InputDescriptionTyped(*existing_values, dtype) def add_type_to_output_description(self, index, dtype, dtype_amp=None): - '''Updates an existing output description at position 'index' with 'dtype' type information + """Updates an existing output description at position 'index' with 'dtype' type information Args: index (int): position within 'inputs' description dtype (torch.dtype): input data type dtype_amp (torch.dtype, default is None): input data type for evaluation with mixed precision - ''' - - assert isinstance(index, int) and index >= 0,\ - "output 'index' must be a positive int" - assert isinstance(dtype, torch.dtype),\ - "output 'dtype' must be a torch.dtype type" - assert dtype_amp is None or isinstance(dtype_amp, torch.dtype),\ - "output 'dtype_amp' must be either None or torch.dtype type" + """ + + assert isinstance(index, int) and index >= 0, "output 'index' must be a positive int" + assert isinstance(dtype, torch.dtype), "output 'dtype' must be a torch.dtype type" + assert dtype_amp is None or isinstance( + dtype_amp, torch.dtype + ), "output 'dtype_amp' must be either None or torch.dtype type" existing_values = (*self.outputs[index],) if isinstance(self.outputs[index], self._OutputDescriptionTyped): existing_values = (*existing_values[:-2],) @@ -151,7 +156,9 @@ def gradient_accumulation(self): @gradient_accumulation.setter def gradient_accumulation(self, name): - self._add_output_description(self, name, [1], False, torch.bool, None, GRADIENT_ACCUMULATION_IO_DESCRIPTION_NAME, ignore_duplicate=True) + self._add_output_description( + self, name, [1], False, torch.bool, None, GRADIENT_ACCUMULATION_IO_DESCRIPTION_NAME, ignore_duplicate=True + ) @property def all_finite(self): @@ -159,7 +166,9 @@ def all_finite(self): @all_finite.setter def all_finite(self, name): - self._add_output_description(self, name, [1], False, torch.bool, None, ALL_FINITE_IO_DESCRIPTION_NAME, ignore_duplicate=True) + self._add_output_description( + self, name, [1], False, torch.bool, None, ALL_FINITE_IO_DESCRIPTION_NAME, ignore_duplicate=True + ) @property def loss_scale_input(self): @@ -167,10 +176,12 @@ def loss_scale_input(self): @loss_scale_input.setter def loss_scale_input(self, name): - self._add_input_description(self, name, [], torch.float32, LOSS_SCALE_INPUT_IO_DESCRIPTION_NAME, ignore_duplicate=True) + self._add_input_description( + self, name, [], torch.float32, LOSS_SCALE_INPUT_IO_DESCRIPTION_NAME, ignore_duplicate=True + ) def _add_input_description(self, node, name, shape, dtype=None, attr_name=None, ignore_duplicate=False): - '''Add a new input description into the node object + """Add a new input description into the node object If 'dtype' is specified, a typed input description namedtuple(name, shape, dtype) is created. Otherwise an untyped input description namedtuple(name, shape) is created instead. @@ -184,7 +195,7 @@ def _add_input_description(self, node, name, shape, dtype=None, attr_name=None, dtype (torch.dtype): input data type attr_name (str, default is None): friendly name to allow direct access to the output description ignore_duplicate (bool, default is False): silently skips addition of duplicate inputs - ''' + """ assert isinstance(name, str) and len(name) > 0, "'name' is an invalid input name" not_found = True @@ -197,8 +208,9 @@ def _add_input_description(self, node, name, shape, dtype=None, attr_name=None, assert not_found, f"'attr_name' {attr_name} already exists in the 'node'" elif not not_found: return - assert isinstance(shape, list) and all([(isinstance(dim, int) or (isinstance(dim, str) and len(dim) > 0))\ - for dim in shape]), "'shape' must be a list of int or str with length at least 1" + assert isinstance(shape, list) and all( + [(isinstance(dim, int) or (isinstance(dim, str) and len(dim) > 0)) for dim in shape] + ), "'shape' must be a list of int or str with length at least 1" assert dtype is None or isinstance(dtype, torch.dtype), "'dtype' must be either None or a torch.dtype type" if dtype: new_input_desc = self._InputDescriptionTyped(name, shape, dtype) @@ -211,8 +223,10 @@ def _add_input_description(self, node, name, shape, dtype=None, attr_name=None, assert isinstance(attr_name, str) and len(attr_name) > 0, "Invalid 'attr_name'" setattr(node, attr_name, new_input_desc) - def _add_output_description(self, node, name, shape, is_loss, dtype=None, dtype_amp=None, attr_name=None, ignore_duplicate=False): - '''Add a new output description into the node object as a tuple + def _add_output_description( + self, node, name, shape, is_loss, dtype=None, dtype_amp=None, attr_name=None, ignore_duplicate=False + ): + """Add a new output description into the node object as a tuple When (name, shape, is_loss, dtype) is specified, a typed output description is created Otherwise an untyped output description (name, shape, is_loss) is created instead @@ -228,11 +242,12 @@ def _add_output_description(self, node, name, shape, is_loss, dtype=None, dtype_ dtype_amp (torch.dtype, default is None): input data type for evaluation with mixed precision. attr_name (str, default is None): friendly name to allow direct access to the output description ignore_duplicate (bool, default is False): silently skips addition of duplicate outputs - ''' + """ assert isinstance(name, str) and len(name) > 0, "'name' is an invalid output name" - assert isinstance(shape, list) and all([(isinstance(dim, int) or (isinstance(dim, str) and len(dim) > 0))\ - for dim in shape]), "'shape' must be a list of int or str with length at least 1" + assert isinstance(shape, list) and all( + [(isinstance(dim, int) or (isinstance(dim, str) and len(dim) > 0)) for dim in shape] + ), "'shape' must be a list of int or str with length at least 1" assert isinstance(is_loss, bool), "'is_loss' must be a bool" not_found = True @@ -240,8 +255,9 @@ def _add_output_description(self, node, name, shape, is_loss, dtype=None, dtype_ if id(node) == id(self.outputs): not_found = all([name not in o_desc.name for o_desc in node]) assert not_found, f"'name' {name} already exists in the outputs description" - assert all([not o_desc.is_loss for o_desc in node]) if is_loss else True,\ - "Only one 'is_loss' is supported at outputs description" + assert ( + all([not o_desc.is_loss for o_desc in node]) if is_loss else True + ), "Only one 'is_loss' is supported at outputs description" else: not_found = attr_name not in dir(self) assert not_found, f"'attr_name' {attr_name} already exists in the 'node'" @@ -261,19 +277,28 @@ def _add_output_description(self, node, name, shape, is_loss, dtype=None, dtype_ setattr(node, attr_name, new_output_desc) def _wrap(self, v): - '''Add 'v' as self's attribute to allow direct access as self.v''' + """Add 'v' as self's attribute to allow direct access as self.v""" if isinstance(v, (list)): return type(v)([self._wrap(v) for v in v]) - elif isinstance(v, (self._InputDescription, self._InputDescriptionTyped, - self._OutputDescription, self._OutputDescriptionTyped)): + elif isinstance( + v, + ( + self._InputDescription, + self._InputDescriptionTyped, + self._OutputDescription, + self._OutputDescriptionTyped, + ), + ): return v elif isinstance(v, (tuple)): return type(v)([self._wrap(v) for v in v]) elif isinstance(v, (dict, int, float, bool, str)): return _ORTTrainerModelDescInternal(self._main_class_name, v) if isinstance(v, dict) else v else: - raise ValueError(f"Unsupported type for model_desc ({v})." - "Only int, float, bool, str, list, tuple and dict are supported") + raise ValueError( + f"Unsupported type for model_desc ({v})." + "Only int, float, bool, str, list, tuple and dict are supported" + ) class _ORTTrainerModelDescInternal(_ORTTrainerModelDesc): @@ -292,7 +317,7 @@ def __init__(self, main_class_name, model_desc): def _model_desc_inputs_validation(field, value, error): - r'''Cerberus custom check method for 'model_desc.inputs' + r"""Cerberus custom check method for 'model_desc.inputs' 'model_desc.inputs' is a list of tuples. The list has variable length, but each tuple has size 2 @@ -310,7 +335,7 @@ def _model_desc_inputs_validation(field, value, error): model_desc['inputs'] = [('input1', ['batch', 1024]), ('input2', []) ('input3', [512])] - ''' + """ if not isinstance(value, tuple) or len(value) != 2: error(field, "must be a tuple with size 2") @@ -326,7 +351,7 @@ def _model_desc_inputs_validation(field, value, error): @static_vars(loss_counter=0) def _model_desc_outputs_validation(field, value, error): - r'''Cerberus custom check method for 'model_desc.outputs' + r"""Cerberus custom check method for 'model_desc.outputs' 'model_desc.outputs' is a list of tuples with variable length. The first element of the tuple is a string which represents the output name @@ -344,7 +369,7 @@ def _model_desc_outputs_validation(field, value, error): model_desc['outputs'] = [('output1', ['batch', 1024], is_loss=True), ('output2', [], is_loss=False) ('output3', [512])] - ''' + """ if not isinstance(value, tuple) or len(value) < 2 or len(value) > 3: error(field, "must be a tuple with size 2 or 3") @@ -367,21 +392,16 @@ def _model_desc_outputs_validation(field, value, error): # Validation schema for model description dictionary MODEL_DESC_SCHEMA = { - 'inputs': { - 'type': 'list', - 'required': True, - 'minlength': 1, - 'schema': { - 'check_with': _model_desc_inputs_validation - }, + "inputs": { + "type": "list", + "required": True, + "minlength": 1, + "schema": {"check_with": _model_desc_inputs_validation}, + }, + "outputs": { + "type": "list", + "required": True, + "minlength": 1, + "schema": {"check_with": _model_desc_outputs_validation}, }, - 'outputs': { - 'type': 'list', - 'required': True, - 'minlength': 1, - 'schema': { - - 'check_with': _model_desc_outputs_validation - }, - } } diff --git a/orttraining/orttraining/python/training/optim/__init__.py b/orttraining/orttraining/python/training/optim/__init__.py index 291268307d9f0..f74fe08202397 100644 --- a/orttraining/orttraining/python/training/optim/__init__.py +++ b/orttraining/orttraining/python/training/optim/__init__.py @@ -1,6 +1,11 @@ from .config import _OptimizerConfig, AdamConfig, LambConfig, SGDConfig -from .lr_scheduler import _LRScheduler, ConstantWarmupLRScheduler, CosineWarmupLRScheduler,\ - LinearWarmupLRScheduler, PolyWarmupLRScheduler +from .lr_scheduler import ( + _LRScheduler, + ConstantWarmupLRScheduler, + CosineWarmupLRScheduler, + LinearWarmupLRScheduler, + PolyWarmupLRScheduler, +) from .fused_adam import FusedAdam, AdamWMode from .fp16_optimizer import FP16_Optimizer diff --git a/orttraining/orttraining/python/training/optim/_apex_amp_modifier.py b/orttraining/orttraining/python/training/optim/_apex_amp_modifier.py index 3415a4ba2eced..1b91ec2bf3594 100644 --- a/orttraining/orttraining/python/training/optim/_apex_amp_modifier.py +++ b/orttraining/orttraining/python/training/optim/_apex_amp_modifier.py @@ -10,19 +10,22 @@ import warnings from ._modifier import FP16OptimizerModifier + class ApexAMPModifier(FP16OptimizerModifier): def __init__(self, optimizer, **kwargs) -> None: super().__init__(optimizer) pass def can_be_modified(self): - return self.check_requirements(["_post_amp_backward", "zero_grad"], - require_apex=True, require_torch_non_finite_check=False) + return self.check_requirements( + ["_post_amp_backward", "zero_grad"], require_apex=True, require_torch_non_finite_check=False + ) def override_function(m_self): from apex import amp as apex_amp from onnxruntime.training.ortmodule.torch_cpp_extensions import fused_ops - warnings.warn('Apex AMP fp16_optimizer functions are overrided with faster implementation.', UserWarning) + + warnings.warn("Apex AMP fp16_optimizer functions are overrided with faster implementation.", UserWarning) # Implementation adapted from https://github.com/NVIDIA/apex/blob/082f999a6e18a3d02306e27482cc7486dab71a50/apex/amp/_process_optimizer.py#L161 def post_backward_with_master_weights(self, scaler): @@ -36,8 +39,8 @@ def post_backward_with_master_weights(self, scaler): # new_fp32_grads = [] # fp16_grads_needing_unscale_with_stash = [] # preexisting_fp32_grads = [ - #i = 0 - #for fp16_param, fp32_param in zip(stash.all_fp16_params, + # i = 0 + # for fp16_param, fp32_param in zip(stash.all_fp16_params, # stash.all_fp32_from_fp16_params): # if fp16_param.grad is None and fp32_param.grad is not None: # continue @@ -68,30 +71,39 @@ def post_backward_with_master_weights(self, scaler): #### END OF THE ORIGINAL IMPLEMENTATION #### #### THIS IS THE FASTER IMPLEMENTATION #### - tensor_vector_exist = hasattr(stash, "all_fp16_params_tensor_vector") and \ - hasattr(stash, "all_fp32_from_fp16_params_tensor_vector") - tensor_vector_valid = tensor_vector_exist and \ - len(stash.all_fp16_params_tensor_vector) == len(stash.all_fp16_params) and \ - len(stash.all_fp32_from_fp16_params_tensor_vector) == len(stash.all_fp32_from_fp16_params) + tensor_vector_exist = hasattr(stash, "all_fp16_params_tensor_vector") and hasattr( + stash, "all_fp32_from_fp16_params_tensor_vector" + ) + tensor_vector_valid = ( + tensor_vector_exist + and len(stash.all_fp16_params_tensor_vector) == len(stash.all_fp16_params) + and len(stash.all_fp32_from_fp16_params_tensor_vector) == len(stash.all_fp32_from_fp16_params) + ) if not tensor_vector_valid: stash.all_fp16_params_tensor_vector = fused_ops.TorchTensorVector(stash.all_fp16_params) - stash.all_fp32_from_fp16_params_tensor_vector = fused_ops.TorchTensorVector(stash.all_fp32_from_fp16_params) - - fused_ops.unscale_fp16_grads_into_fp32_grads(stash.all_fp16_params_tensor_vector, - stash.all_fp32_from_fp16_params_tensor_vector, - scaler._overflow_buf, - scaler._loss_scale) + stash.all_fp32_from_fp16_params_tensor_vector = fused_ops.TorchTensorVector( + stash.all_fp32_from_fp16_params + ) + + fused_ops.unscale_fp16_grads_into_fp32_grads( + stash.all_fp16_params_tensor_vector, + stash.all_fp32_from_fp16_params_tensor_vector, + scaler._overflow_buf, + scaler._loss_scale, + ) #### END OF THE FASTER IMPLEMENTATION #### # fp32 params can be treated as they would be in the "no_master_weights" case. apex_amp._process_optimizer.post_backward_models_are_masters( - scaler, - stash.all_fp32_from_fp32_params, - stash.all_fp32_from_fp32_grad_stash) + scaler, stash.all_fp32_from_fp32_params, stash.all_fp32_from_fp32_grad_stash + ) from apex.optimizers import FusedSGD as FusedSGD + if not isinstance(m_self._optimizer, FusedSGD): - m_self._optimizer._post_amp_backward = types.MethodType(post_backward_with_master_weights, m_self._optimizer) + m_self._optimizer._post_amp_backward = types.MethodType( + post_backward_with_master_weights, m_self._optimizer + ) # Implementation adapted from https://github.com/NVIDIA/apex/blob/082f999a6e18a3d02306e27482cc7486dab71a50/apex/amp/_process_optimizer.py#L367 def _zero_grad(self, set_to_none=True): diff --git a/orttraining/orttraining/python/training/optim/_ds_modifier.py b/orttraining/orttraining/python/training/optim/_ds_modifier.py index e59cf52808fb4..d9515041f5dfa 100644 --- a/orttraining/orttraining/python/training/optim/_ds_modifier.py +++ b/orttraining/orttraining/python/training/optim/_ds_modifier.py @@ -18,8 +18,10 @@ from ._modifier import FP16OptimizerModifier, check_overflow, check_overflow_for_grads from ._multi_tensor_apply import MultiTensorApply + multi_tensor_applier = MultiTensorApply(2048 * 32) + class DeepSpeedZeROModifier(FP16OptimizerModifier): def __init__(self, optimizer, **kwargs) -> None: super().__init__(optimizer) @@ -27,34 +29,39 @@ def __init__(self, optimizer, **kwargs) -> None: def can_be_modified(self): try: import deepspeed + v = LooseVersion(deepspeed.__version__) if v > LooseVersion("0.5.4") or v < LooseVersion("0.4.0"): - warnings.warn('Unsupported DeepSpeed version to override, skipped.', UserWarning) + warnings.warn("Unsupported DeepSpeed version to override, skipped.", UserWarning) return False except Exception as _: return False - return self.check_requirements(["has_overflow_serial", "get_grad_norm_direct", "has_overflow_partitioned_grads_serial"], - require_apex=True, require_torch_non_finite_check=True) + return self.check_requirements( + ["has_overflow_serial", "get_grad_norm_direct", "has_overflow_partitioned_grads_serial"], + require_apex=True, + require_torch_non_finite_check=True, + ) def override_function(self): - warnings.warn('DeepSpeed fp16_optimizer functions are overrided with faster implementation.', UserWarning) + warnings.warn("DeepSpeed fp16_optimizer functions are overrided with faster implementation.", UserWarning) + def get_grad_norm_direct(target, gradients, params, norm_type=2): import amp_C + def is_model_parallel_parameter(p): - return hasattr(p, 'model_parallel') and p.model_parallel + return hasattr(p, "model_parallel") and p.model_parallel norm_type = float(norm_type) if norm_type == inf: total_norm = max(g.data.abs().max() for g in gradients) total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)]) - torch.distributed.all_reduce(total_norm_cuda, - op=torch.distributed.ReduceOp.MAX, - group=target.dp_process_group) + torch.distributed.all_reduce( + total_norm_cuda, op=torch.distributed.ReduceOp.MAX, group=target.dp_process_group + ) # Take max across all GPUs. - target._model_parallel_all_reduce(tensor=total_norm_cuda, - op=torch.distributed.ReduceOp.MAX) + target._model_parallel_all_reduce(tensor=total_norm_cuda, op=torch.distributed.ReduceOp.MAX) total_norm = total_norm_cuda[0].item() else: total_norm = 0.0 @@ -88,30 +95,25 @@ def is_model_parallel_parameter(p): # Multi-tensor applier takes a function and a list of list # and performs the operation on that list all in one kernel. grad_norm, _ = multi_tensor_applier( - amp_C.multi_tensor_l2norm, - dummy_overflow_buf, - [grads_for_norm], - False # no per-parameter norm + amp_C.multi_tensor_l2norm, dummy_overflow_buf, [grads_for_norm], False # no per-parameter norm ) # Since we will be summing across data parallel groups, # we need the pow(norm-type). - total_norm_cuda = grad_norm ** norm_type + total_norm_cuda = grad_norm**norm_type else: total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)]) #### END OF THE FASTER IMPLEMENTATION #### # Sum across all model parallel GPUs. - torch.distributed.all_reduce(total_norm_cuda, - op=torch.distributed.ReduceOp.SUM, - group=target.dp_process_group) + torch.distributed.all_reduce( + total_norm_cuda, op=torch.distributed.ReduceOp.SUM, group=target.dp_process_group + ) - target._model_parallel_all_reduce(tensor=total_norm_cuda, - op=torch.distributed.ReduceOp.SUM) + target._model_parallel_all_reduce(tensor=total_norm_cuda, op=torch.distributed.ReduceOp.SUM) - total_norm = total_norm_cuda[0].item()**(1. / norm_type) + total_norm = total_norm_cuda[0].item() ** (1.0 / norm_type) - if total_norm == float( - 'inf') or total_norm == -float('inf') or total_norm != total_norm: + if total_norm == float("inf") or total_norm == -float("inf") or total_norm != total_norm: total_norm = -1 return total_norm @@ -149,4 +151,6 @@ def has_overflow_partitioned_grads_serial(target): self._optimizer.has_overflow_serial = types.MethodType(has_overflow_serial, self._optimizer) self._optimizer.get_grad_norm_direct = types.MethodType(get_grad_norm_direct, self._optimizer) # zero1 should not call into following function, is this a deepspeed bug? - self._optimizer.has_overflow_partitioned_grads_serial = types.MethodType(has_overflow_partitioned_grads_serial, self._optimizer) + self._optimizer.has_overflow_partitioned_grads_serial = types.MethodType( + has_overflow_partitioned_grads_serial, self._optimizer + ) diff --git a/orttraining/orttraining/python/training/optim/_megatron_modifier.py b/orttraining/orttraining/python/training/optim/_megatron_modifier.py index ec7115c990ff4..545c66961542f 100644 --- a/orttraining/orttraining/python/training/optim/_megatron_modifier.py +++ b/orttraining/orttraining/python/training/optim/_megatron_modifier.py @@ -14,6 +14,7 @@ from numpy import inf from ._modifier import FP16OptimizerModifier, check_overflow, clip_grad_norm_fp32 + class LegacyMegatronLMModifier(FP16OptimizerModifier): def __init__(self, optimizer, **kwargs) -> None: super().__init__(optimizer) @@ -21,11 +22,13 @@ def __init__(self, optimizer, **kwargs) -> None: self.get_horizontal_model_parallel_group = kwargs.get("get_horizontal_model_parallel_group", None) def can_be_modified(self): - return self.check_requirements(["_check_overflow", "clip_master_grads"], - require_apex=True, require_torch_non_finite_check=True) + return self.check_requirements( + ["_check_overflow", "clip_master_grads"], require_apex=True, require_torch_non_finite_check=True + ) def override_function(self): - warnings.warn('Megatron-LM fp16_optimizer functions are overrided with faster implementation.', UserWarning) + warnings.warn("Megatron-LM fp16_optimizer functions are overrided with faster implementation.", UserWarning) + def clip_master_grads(target, max_norm, norm_type=2): """ Clips fp32 master gradients via ``torch.nn.utils.clip_grad_norm``. @@ -44,16 +47,20 @@ def clip_master_grads(target, max_norm, norm_type=2): if not target.overflow: fp32_params = [] for param_group in target.optimizer.param_groups: - for param in param_group['params']: + for param in param_group["params"]: fp32_params.append(param) #### THIS IS THE ORIGINAL IMPLEMENTATION #### - #return self.clip_grad_norm(fp32_params, max_norm, norm_type) + # return self.clip_grad_norm(fp32_params, max_norm, norm_type) #### END OF THE ORIGINAL IMPLEMENTATION #### #### THIS IS THE FASTER IMPLEMENTATION #### - return clip_grad_norm_fp32(fp32_params, max_norm, norm_type, - get_horizontal_model_parallel_rank=self.get_horizontal_model_parallel_rank, - get_horizontal_model_parallel_group=self.get_horizontal_model_parallel_group) + return clip_grad_norm_fp32( + fp32_params, + max_norm, + norm_type, + get_horizontal_model_parallel_rank=self.get_horizontal_model_parallel_rank, + get_horizontal_model_parallel_group=self.get_horizontal_model_parallel_group, + ) #### END OF THE FASTER IMPLEMENTATION #### else: return -1 diff --git a/orttraining/orttraining/python/training/optim/_modifier.py b/orttraining/orttraining/python/training/optim/_modifier.py index 26c5c72a7c8e2..9897ed41210e6 100644 --- a/orttraining/orttraining/python/training/optim/_modifier.py +++ b/orttraining/orttraining/python/training/optim/_modifier.py @@ -11,8 +11,10 @@ import torch from numpy import inf from ._multi_tensor_apply import MultiTensorApply + multi_tensor_applier = MultiTensorApply(2048 * 32) + class FP16OptimizerModifier(object): def __init__(self, optimizer) -> None: super().__init__() @@ -39,9 +41,11 @@ def check_requirements(self, required_funcs, require_apex=False, require_torch_n return False return True + def check_overflow(params): grad_data = [p.grad.data for p in params if p.grad is not None] - return check_overflow_for_grads(grad_data) + return check_overflow_for_grads(grad_data) + def check_overflow_for_grads(grad_data): found_inf = torch.cuda.FloatTensor([0.0]) @@ -50,11 +54,13 @@ def check_overflow_for_grads(grad_data): torch._amp_foreach_non_finite_check_and_unscale_(grad_data, found_inf, scaler) # Check for nan. - overflow = (found_inf.item() > 0) - return overflow + overflow = found_inf.item() > 0 + return overflow -def clip_grad_norm_fp32(parameters, max_norm, norm_type, - get_horizontal_model_parallel_rank=None, get_horizontal_model_parallel_group=None): + +def clip_grad_norm_fp32( + parameters, max_norm, norm_type, get_horizontal_model_parallel_rank=None, get_horizontal_model_parallel_group=None +): import amp_C horizontal_model_parallel_grad_norm_aggregation = False @@ -62,7 +68,7 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type, horizontal_model_parallel_grad_norm_aggregation = True def param_is_not_tensor_parallel_duplicate(param): - is_mp_tensor = hasattr(param, 'model_parallel') and param.model_parallel + is_mp_tensor = hasattr(param, "model_parallel") and param.model_parallel return is_mp_tensor or (get_horizontal_model_parallel_rank() == 0) if isinstance(parameters, torch.Tensor): @@ -77,7 +83,7 @@ def param_is_not_tensor_parallel_duplicate(param): grad = param.grad.detach() if grad_not_none: # Make sure the grads are in fp32 - assert param.grad.type() == 'torch.cuda.FloatTensor' + assert param.grad.type() == "torch.cuda.FloatTensor" if horizontal_model_parallel_grad_norm_aggregation: is_not_tp_duplicate = param_is_not_tensor_parallel_duplicate(param) if grad_not_none and is_not_tp_duplicate: @@ -98,9 +104,9 @@ def param_is_not_tensor_parallel_duplicate(param): total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)]) # Take max across all model-parallel GPUs. - torch.distributed.all_reduce(total_norm_cuda, - op=torch.distributed.ReduceOp.MAX, - group=get_horizontal_model_parallel_group()) + torch.distributed.all_reduce( + total_norm_cuda, op=torch.distributed.ReduceOp.MAX, group=get_horizontal_model_parallel_group() + ) total_norm = total_norm_cuda[0].item() else: @@ -109,10 +115,7 @@ def param_is_not_tensor_parallel_duplicate(param): # Multi-tensor applier takes a function and a list of list # and performs the operation on that list all in one kernel. grad_norm, _ = multi_tensor_applier( - amp_C.multi_tensor_l2norm, - dummy_overflow_buf, - [grads_for_norm], - False # no per-parameter norm + amp_C.multi_tensor_l2norm, dummy_overflow_buf, [grads_for_norm], False # no per-parameter norm ) if not horizontal_model_parallel_grad_norm_aggregation: @@ -120,28 +123,23 @@ def param_is_not_tensor_parallel_duplicate(param): # Since we will be summing across data parallel groups, # we need the pow(norm-type). - total_norm = grad_norm ** norm_type + total_norm = grad_norm**norm_type else: for grad in grads_for_norm: grad_norm = torch.norm(grad, norm_type) - total_norm += grad_norm ** norm_type + total_norm += grad_norm**norm_type if horizontal_model_parallel_grad_norm_aggregation: # Sum across all model-parallel GPUs. - torch.distributed.all_reduce(total_norm, - op=torch.distributed.ReduceOp.SUM, - group=get_horizontal_model_parallel_group()) + torch.distributed.all_reduce( + total_norm, op=torch.distributed.ReduceOp.SUM, group=get_horizontal_model_parallel_group() + ) total_norm = total_norm.item() ** (1.0 / norm_type) clip_coef = max_norm / (total_norm + 1e-6) # Filter parameters with gradients. grads = [p.grad for p in parameters if p.grad is not None] if clip_coef < 1.0: - multi_tensor_applier( - amp_C.multi_tensor_scale, - dummy_overflow_buf, - [grads, grads], - clip_coef) + multi_tensor_applier(amp_C.multi_tensor_scale, dummy_overflow_buf, [grads, grads], clip_coef) return total_norm - diff --git a/orttraining/orttraining/python/training/optim/_modifier_registry.py b/orttraining/orttraining/python/training/optim/_modifier_registry.py index f087661778f5c..142999f3f72c7 100644 --- a/orttraining/orttraining/python/training/optim/_modifier_registry.py +++ b/orttraining/orttraining/python/training/optim/_modifier_registry.py @@ -13,6 +13,6 @@ OptimizerModifierTypeRegistry = { LEAGCY_MEGATRON_LM_OPTIMIZER_NAME: LegacyMegatronLMModifier, - DEEPSPEED_ZERO1_AND_ZERO2_OPTIMIZER_NAME : DeepSpeedZeROModifier, - APEX_AMP_OPTIMIZER_NAME : ApexAMPModifier, + DEEPSPEED_ZERO1_AND_ZERO2_OPTIMIZER_NAME: DeepSpeedZeROModifier, + APEX_AMP_OPTIMIZER_NAME: ApexAMPModifier, } diff --git a/orttraining/orttraining/python/training/optim/_multi_tensor_apply.py b/orttraining/orttraining/python/training/optim/_multi_tensor_apply.py index 9a5451427f154..d1e837eb1a350 100644 --- a/orttraining/orttraining/python/training/optim/_multi_tensor_apply.py +++ b/orttraining/orttraining/python/training/optim/_multi_tensor_apply.py @@ -3,12 +3,12 @@ # multi_tensor_apply.py # This file has been adapted from microsoft/DeepSpeed -''' +""" Copyright 2020 The Microsoft DeepSpeed Team Copyright NVIDIA/apex This file is adapted from NVIDIA/apex, commit a109f85 -''' +""" class MultiTensorApply(object): diff --git a/orttraining/orttraining/python/training/optim/config.py b/orttraining/orttraining/python/training/optim/config.py index 8e97cf8cc37d5..91ff7bed112e2 100644 --- a/orttraining/orttraining/python/training/optim/config.py +++ b/orttraining/orttraining/python/training/optim/config.py @@ -39,24 +39,30 @@ class _OptimizerConfig(object): def __init__(self, name, params, defaults): assert isinstance(name, str), "'name' must be a string" - assert name in ['AdamOptimizer', 'LambOptimizer', 'SGDOptimizer'], \ - "'name' must be one of 'AdamOptimizer', 'LambOptimizer' or 'SGDOptimizer'" - assert isinstance(defaults, - dict), "'defaults' must be a dict" - assert 'lr' in defaults, "'defaults' must contain a {'lr' : positive number} entry" - assert (isinstance(defaults['lr'], float) or - isinstance(defaults['lr'], int)) and defaults['lr'] >= 0, "lr must be a positive number" + assert name in [ + "AdamOptimizer", + "LambOptimizer", + "SGDOptimizer", + ], "'name' must be one of 'AdamOptimizer', 'LambOptimizer' or 'SGDOptimizer'" + assert isinstance(defaults, dict), "'defaults' must be a dict" + assert "lr" in defaults, "'defaults' must contain a {'lr' : positive number} entry" + assert (isinstance(defaults["lr"], float) or isinstance(defaults["lr"], int)) and defaults[ + "lr" + ] >= 0, "lr must be a positive number" assert isinstance(params, list), "'params' must be a list" for group in params: - assert isinstance(group, dict) and len(group) > 1 and 'params' in group, \ - ("Each dict inside 'params' must contain a {'params' : [model parameter names]} entry" - " and additional entries for custom hyper parameter values") + assert isinstance(group, dict) and len(group) > 1 and "params" in group, ( + "Each dict inside 'params' must contain a {'params' : [model parameter names]} entry" + " and additional entries for custom hyper parameter values" + ) for k, _ in group.items(): - if k != 'params': - assert k in defaults or k.replace("_coef", "") in defaults, f"'params' has {k} hyper parameter not present at 'defaults'" + if k != "params": + assert ( + k in defaults or k.replace("_coef", "") in defaults + ), f"'params' has {k} hyper parameter not present at 'defaults'" self.name = name - self.lr = float(defaults['lr']) + self.lr = float(defaults["lr"]) self.defaults = defaults self.params = [] @@ -106,9 +112,7 @@ class SGDConfig(_OptimizerConfig): """ def __init__(self, params=[], lr=0.001): - super().__init__(name='SGDOptimizer', - params=params, - defaults={'lr': lr}) + super().__init__(name="SGDOptimizer", params=params, defaults={"lr": lr}) assert isinstance(params, list) and len(params) == 0, "'params' must be an empty list for SGD optimizer" @@ -145,11 +149,21 @@ class AdamConfig(_OptimizerConfig): @unique class DecayMode(IntEnum): - BEFORE_WEIGHT_UPDATE = 0, + BEFORE_WEIGHT_UPDATE = (0,) AFTER_WEIGHT_UPDATE = 1 - def __init__(self, params=[], lr=0.001, alpha=0.9, beta=0.999, lambda_coef=0.0, epsilon=1e-8, max_norm_clip=1.0, - do_bias_correction=True, weight_decay_mode=DecayMode.BEFORE_WEIGHT_UPDATE): + def __init__( + self, + params=[], + lr=0.001, + alpha=0.9, + beta=0.999, + lambda_coef=0.0, + epsilon=1e-8, + max_norm_clip=1.0, + do_bias_correction=True, + weight_decay_mode=DecayMode.BEFORE_WEIGHT_UPDATE, + ): assert lr >= 0, "'lr' must be a positive number" assert alpha >= 0, "'alpha' must be a positive number" assert beta >= 0, "'beta' must be a positive number" @@ -159,19 +173,19 @@ def __init__(self, params=[], lr=0.001, alpha=0.9, beta=0.999, lambda_coef=0.0, assert isinstance(do_bias_correction, bool), "'do_bias_correction' must be a boolean" assert isinstance(weight_decay_mode, AdamConfig.DecayMode), "'weight_decay_mode' must be a AdamConfig.DecayMode" for param in params: - assert 'lr' not in param, "'lr' is not supported inside params" - - defaults = {'lr': lr, - 'alpha': alpha, - 'beta': beta, - 'lambda': lambda_coef, - 'epsilon': epsilon, - 'max_norm_clip': max_norm_clip, - 'do_bias_correction': do_bias_correction, - 'weight_decay_mode': weight_decay_mode} - super().__init__(name='AdamOptimizer', - params=params, - defaults=defaults) + assert "lr" not in param, "'lr' is not supported inside params" + + defaults = { + "lr": lr, + "alpha": alpha, + "beta": beta, + "lambda": lambda_coef, + "epsilon": epsilon, + "max_norm_clip": max_norm_clip, + "do_bias_correction": do_bias_correction, + "weight_decay_mode": weight_decay_mode, + } + super().__init__(name="AdamOptimizer", params=params, defaults=defaults) self.alpha = alpha self.beta = beta self.lambda_coef = lambda_coef @@ -213,8 +227,19 @@ class LambConfig(_OptimizerConfig): lamb_optim2 = LambConfig([{'params':['fc1.weight','fc2.weight'], 'lr':0.005}], lr=0.01) """ - def __init__(self, params=[], lr=0.001, alpha=0.9, beta=0.999, lambda_coef=0.0, - ratio_min=float('-inf'), ratio_max=float('inf'), epsilon=1e-6, max_norm_clip=1.0, do_bias_correction=False): + def __init__( + self, + params=[], + lr=0.001, + alpha=0.9, + beta=0.999, + lambda_coef=0.0, + ratio_min=float("-inf"), + ratio_max=float("inf"), + epsilon=1e-6, + max_norm_clip=1.0, + do_bias_correction=False, + ): assert lr >= 0, "'lr' must be a positive number" assert alpha >= 0, "'alpha' must be a positive number" assert beta >= 0, "'beta' must be a positive number" @@ -225,20 +250,20 @@ def __init__(self, params=[], lr=0.001, alpha=0.9, beta=0.999, lambda_coef=0.0, assert max_norm_clip != 0, "'max_norm_clip' must not be 0" assert isinstance(do_bias_correction, bool), "'do_bias_correction' must be a boolean" for param in params: - assert 'lr' not in param, "'lr' is not supported inside params" - - defaults = {'lr': lr, - 'alpha': alpha, - 'beta': beta, - 'lambda': lambda_coef, - 'ratio_min': ratio_min, - 'ratio_max': ratio_max, - 'epsilon': epsilon, - 'max_norm_clip': max_norm_clip, - 'do_bias_correction': do_bias_correction} - super().__init__(name='LambOptimizer', - params=params, - defaults=defaults) + assert "lr" not in param, "'lr' is not supported inside params" + + defaults = { + "lr": lr, + "alpha": alpha, + "beta": beta, + "lambda": lambda_coef, + "ratio_min": ratio_min, + "ratio_max": ratio_max, + "epsilon": epsilon, + "max_norm_clip": max_norm_clip, + "do_bias_correction": do_bias_correction, + } + super().__init__(name="LambOptimizer", params=params, defaults=defaults) self.alpha = alpha self.beta = beta self.lambda_coef = lambda_coef diff --git a/orttraining/orttraining/python/training/optim/fp16_optimizer.py b/orttraining/orttraining/python/training/optim/fp16_optimizer.py index 49d18b5992df6..c4c353249f1ee 100644 --- a/orttraining/orttraining/python/training/optim/fp16_optimizer.py +++ b/orttraining/orttraining/python/training/optim/fp16_optimizer.py @@ -5,6 +5,7 @@ from ._modifier_registry import OptimizerModifierTypeRegistry + def FP16_Optimizer(optimizer, **kwargs): """ Simple wrapper to replace inefficient FP16_Optimizer function calls implemented by libraries for example @@ -76,15 +77,16 @@ def FP16_Optimizer(optimizer, **kwargs): The modified FP16_Optimizer instance """ + def get_full_qualified_type_name(o): if hasattr(optimizer, "_amp_stash"): return "apex.amp.optimizer.unique_name_as_id" klass = o.__class__ module = klass.__module__ - if module == 'builtins': + if module == "builtins": return klass.__qualname__ - return module + '.' + klass.__qualname__ + return module + "." + klass.__qualname__ optimizer_full_qualified_name = get_full_qualified_type_name(optimizer) if optimizer_full_qualified_name not in OptimizerModifierTypeRegistry: diff --git a/orttraining/orttraining/python/training/optim/fused_adam.py b/orttraining/orttraining/python/training/optim/fused_adam.py index e655468e41618..30ebcf30e4844 100644 --- a/orttraining/orttraining/python/training/optim/fused_adam.py +++ b/orttraining/orttraining/python/training/optim/fused_adam.py @@ -3,12 +3,12 @@ # fused_adam.py # This file has been adapted from microsoft/DeepSpeed -''' +""" Copyright 2020 The Microsoft DeepSpeed Team Copyright NVIDIA/apex This file is adapted from fused adam in NVIDIA/apex, commit a109f85 -''' +""" import torch from ._multi_tensor_apply import MultiTensorApply @@ -16,9 +16,9 @@ class AdamWMode(IntEnum): - ADAM_L2_REGULARIZATION = 0 # Adam with L2 regularization - ADAMW_TRANSFORMERS = 1 # Adam with weight decay implemented to be equivalent to Transformers/AdamW - ADAMW_TORCH = 2 # Adam with weight decay implemented to be equivalent to torch/AdamW + ADAM_L2_REGULARIZATION = 0 # Adam with L2 regularization + ADAMW_TRANSFORMERS = 1 # Adam with weight decay implemented to be equivalent to Transformers/AdamW + ADAMW_TORCH = 2 # Adam with weight decay implemented to be equivalent to torch/AdamW class FusedAdam(torch.optim.Optimizer): @@ -60,25 +60,23 @@ class FusedAdam(torch.optim.Optimizer): .. _On the Convergence of Adam and Beyond: https://openreview.net/forum?id=ryQu7f-RZ """ - def __init__(self, - params, - lr=1e-3, - bias_correction=True, - betas=(0.9, - 0.999), - eps=1e-6, - adam_w_mode=AdamWMode.ADAMW_TRANSFORMERS, - weight_decay=0., - set_grad_none=True): + + def __init__( + self, + params, + lr=1e-3, + bias_correction=True, + betas=(0.9, 0.999), + eps=1e-6, + adam_w_mode=AdamWMode.ADAMW_TRANSFORMERS, + weight_decay=0.0, + set_grad_none=True, + ): # The FusedAdam implementation is mathematically equivalent to # transformers AdamW. The input arguments also have the same defaults. - defaults = dict(lr=lr, - bias_correction=bias_correction, - betas=betas, - eps=eps, - weight_decay=weight_decay) + defaults = dict(lr=lr, bias_correction=bias_correction, betas=betas, eps=eps, weight_decay=weight_decay) super(FusedAdam, self).__init__(params, defaults) self._adam_w_mode = adam_w_mode self._set_grad_none = set_grad_none @@ -87,6 +85,7 @@ def __init__(self, self._dummy_overflow_buf = torch.cuda.IntTensor([0]) from onnxruntime.training.ortmodule.torch_cpp_extensions import fused_ops + self._multi_tensor_adam = fused_ops.multi_tensor_adam self._multi_tensor_applier = MultiTensorApply(2048 * 32) self._TorchTensorVector = fused_ops.TorchTensorVector @@ -94,7 +93,7 @@ def __init__(self, def zero_grad(self): if self._set_grad_none: for group in self.param_groups: - for p in group['params']: + for p in group["params"]: p.grad = None else: super(FusedAdam, self).zero_grad() @@ -113,78 +112,86 @@ def step(self, closure=None): loss = closure() for group in self.param_groups: - bias_correction = 1 if group['bias_correction'] else 0 - beta1, beta2 = group['betas'] + bias_correction = 1 if group["bias_correction"] else 0 + beta1, beta2 = group["betas"] # assume same step across group now to simplify things # per parameter step can be easily support by making it tensor, or pass list into kernel - if 'step' in group: - group['step'] += 1 + if "step" in group: + group["step"] += 1 else: - group['step'] = 1 + group["step"] = 1 # create lists for multi-tensor apply g_16, p_16, m_16, v_16 = [], [], [], [] g_32, p_32, m_32, v_32 = [], [], [], [] - for p in group['params']: + for p in group["params"]: if p.grad is None: continue if p.grad.data.is_sparse: raise RuntimeError( - 'FusedAdam does not support sparse gradients, please consider SparseAdam instead' + "FusedAdam does not support sparse gradients, please consider SparseAdam instead" ) state = self.state[p] # State initialization if len(state) == 0: # Exponential moving average of gradient values - state['exp_avg'] = torch.zeros_like(p.data) + state["exp_avg"] = torch.zeros_like(p.data) # Exponential moving average of squared gradient values - state['exp_avg_sq'] = torch.zeros_like(p.data) + state["exp_avg_sq"] = torch.zeros_like(p.data) if p.dtype == torch.float16: g_16.append(p.grad.data) p_16.append(p.data) - m_16.append(state['exp_avg']) - v_16.append(state['exp_avg_sq']) + m_16.append(state["exp_avg"]) + v_16.append(state["exp_avg_sq"]) elif p.dtype == torch.float32: g_32.append(p.grad.data) p_32.append(p.data) - m_32.append(state['exp_avg']) - v_32.append(state['exp_avg_sq']) + m_32.append(state["exp_avg"]) + v_32.append(state["exp_avg_sq"]) else: - raise RuntimeError('FusedAdam only support fp16 and fp32.') - - if (len(g_16) > 0): - self._multi_tensor_applier(self._multi_tensor_adam, - self._dummy_overflow_buf, - [self._TorchTensorVector(g_16), - self._TorchTensorVector(p_16), - self._TorchTensorVector(m_16), - self._TorchTensorVector(v_16)], - group['lr'], - beta1, - beta2, - group['eps'], - group['step'], - self._adam_w_mode, - bias_correction, - group['weight_decay']) - if (len(g_32) > 0): - self._multi_tensor_applier(self._multi_tensor_adam, - self._dummy_overflow_buf, - [self._TorchTensorVector(g_32), - self._TorchTensorVector(p_32), - self._TorchTensorVector(m_32), - self._TorchTensorVector(v_32)], - group['lr'], - beta1, - beta2, - group['eps'], - group['step'], - self._adam_w_mode, - bias_correction, - group['weight_decay']) + raise RuntimeError("FusedAdam only support fp16 and fp32.") + + if len(g_16) > 0: + self._multi_tensor_applier( + self._multi_tensor_adam, + self._dummy_overflow_buf, + [ + self._TorchTensorVector(g_16), + self._TorchTensorVector(p_16), + self._TorchTensorVector(m_16), + self._TorchTensorVector(v_16), + ], + group["lr"], + beta1, + beta2, + group["eps"], + group["step"], + self._adam_w_mode, + bias_correction, + group["weight_decay"], + ) + if len(g_32) > 0: + self._multi_tensor_applier( + self._multi_tensor_adam, + self._dummy_overflow_buf, + [ + self._TorchTensorVector(g_32), + self._TorchTensorVector(p_32), + self._TorchTensorVector(m_32), + self._TorchTensorVector(v_32), + ], + group["lr"], + beta1, + beta2, + group["eps"], + group["step"], + self._adam_w_mode, + bias_correction, + group["weight_decay"], + ) return loss diff --git a/orttraining/orttraining/python/training/optim/lr_scheduler.py b/orttraining/orttraining/python/training/optim/lr_scheduler.py index 15d37efbc5cc3..cbe013d32f310 100644 --- a/orttraining/orttraining/python/training/optim/lr_scheduler.py +++ b/orttraining/orttraining/python/training/optim/lr_scheduler.py @@ -43,7 +43,7 @@ def get_lr(self, train_step_info): raise NotImplementedError def get_last_lr(self): - r""" Return last computed learning rate by LR Scheduler""" + r"""Return last computed learning rate by LR Scheduler""" return self._last_lr @@ -82,12 +82,9 @@ class ConstantWarmupLRScheduler(_LRScheduler): def __init__(self, total_steps, warmup=0.002): super().__init__() - assert isinstance(total_steps, int) and total_steps > 0,\ - "total_steps must be a strict positive number" - assert isinstance(warmup, float) and warmup >= 0 and warmup < 1,\ - "warmup must be a float between (0, 1]" - assert total_steps > warmup,\ - "total_steps must be greater than warmup" + assert isinstance(total_steps, int) and total_steps > 0, "total_steps must be a strict positive number" + assert isinstance(warmup, float) and warmup >= 0 and warmup < 1, "warmup must be a float between (0, 1]" + assert total_steps > warmup, "total_steps must be greater than warmup" self.total_steps = total_steps self.warmup = warmup @@ -141,14 +138,10 @@ class CosineWarmupLRScheduler(_LRScheduler): def __init__(self, total_steps, cycles=0.5, warmup=0.002): super().__init__() - assert isinstance(total_steps, int) and total_steps > 0,\ - "total_steps must be a strict positive number" - assert isinstance(cycles, float) and cycles > 0,\ - "cycles must be a positive float" - assert isinstance(warmup, float) and warmup >= 0 and warmup < 1,\ - "warmup must be a float between (0, 1]" - assert total_steps > warmup,\ - "total_steps must be greater than warmup" + assert isinstance(total_steps, int) and total_steps > 0, "total_steps must be a strict positive number" + assert isinstance(cycles, float) and cycles > 0, "cycles must be a positive float" + assert isinstance(warmup, float) and warmup >= 0 and warmup < 1, "warmup must be a float between (0, 1]" + assert total_steps > warmup, "total_steps must be greater than warmup" self.total_steps = total_steps self.cycles = cycles @@ -158,7 +151,9 @@ def __init__(self, total_steps, cycles=0.5, warmup=0.002): def _warmup_cosine(self, train_step_info): if train_step_info.optimization_step < self._num_warmup_steps: return float(train_step_info.optimization_step) / float(max(1, self._num_warmup_steps)) - progress = float(train_step_info.optimization_step - self._num_warmup_steps) / float(max(1, self.total_steps - self._num_warmup_steps)) + progress = float(train_step_info.optimization_step - self._num_warmup_steps) / float( + max(1, self.total_steps - self._num_warmup_steps) + ) return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(self.cycles) * 2.0 * progress))) def get_lr(self, train_step_info): @@ -201,12 +196,9 @@ class LinearWarmupLRScheduler(_LRScheduler): def __init__(self, total_steps, warmup=0.002): super().__init__() - assert isinstance(total_steps, int) and total_steps > 0,\ - "total_steps must be a strict positive number" - assert isinstance(warmup, float) and warmup >= 0 and warmup < 1,\ - "warmup must be a float between (0, 1]" - assert total_steps > warmup,\ - "total_steps must be greater than warmup" + assert isinstance(total_steps, int) and total_steps > 0, "total_steps must be a strict positive number" + assert isinstance(warmup, float) and warmup >= 0 and warmup < 1, "warmup must be a float between (0, 1]" + assert total_steps > warmup, "total_steps must be greater than warmup" self.total_steps = total_steps self.warmup = warmup @@ -215,7 +207,11 @@ def __init__(self, total_steps, warmup=0.002): def _warmup_linear(self, train_step_info): if train_step_info.optimization_step < self._num_warmup_steps: return float(train_step_info.optimization_step) / float(max(1, self._num_warmup_steps)) - return max(0.0, float(self.total_steps - train_step_info.optimization_step) / float(max(1, self.total_steps - self._num_warmup_steps))) + return max( + 0.0, + float(self.total_steps - train_step_info.optimization_step) + / float(max(1, self.total_steps - self._num_warmup_steps)), + ) def get_lr(self, train_step_info): return [train_step_info.optimizer_config.lr * self._warmup_linear(train_step_info)] @@ -264,16 +260,11 @@ class PolyWarmupLRScheduler(_LRScheduler): def __init__(self, total_steps, lr_end=1e-7, power=1.0, warmup=0.002): super().__init__() - assert isinstance(total_steps, int) and total_steps > 0,\ - "total_steps must be a strict positive number" - assert isinstance(lr_end, float) and lr_end >= 0,\ - "lr_end must be a positive float" - assert isinstance(warmup, float) and warmup >= 0 and warmup < 1,\ - "warmup must be a float between (0, 1]" - assert isinstance(power, float) and power >= 0,\ - "power must be a positive float" - assert total_steps > warmup,\ - "total_steps must be greater than warmup" + assert isinstance(total_steps, int) and total_steps > 0, "total_steps must be a strict positive number" + assert isinstance(lr_end, float) and lr_end >= 0, "lr_end must be a positive float" + assert isinstance(warmup, float) and warmup >= 0 and warmup < 1, "warmup must be a float between (0, 1]" + assert isinstance(power, float) and power >= 0, "power must be a positive float" + assert total_steps > warmup, "total_steps must be greater than warmup" self.total_steps = total_steps self.lr_end = lr_end @@ -283,8 +274,9 @@ def __init__(self, total_steps, lr_end=1e-7, power=1.0, warmup=0.002): def _warmup_poly(self, train_step_info): - assert train_step_info.optimizer_config.lr > self.lr_end,\ - f"lr_end ({lr_end}) must be be smaller than initial lr ({train_step_info.optimizer_config.lr})" + assert ( + train_step_info.optimizer_config.lr > self.lr_end + ), f"lr_end ({lr_end}) must be be smaller than initial lr ({train_step_info.optimizer_config.lr})" if train_step_info.optimization_step < self._num_warmup_steps: return float(train_step_info.optimization_step) / float(max(1, self._num_warmup_steps)) @@ -294,9 +286,8 @@ def _warmup_poly(self, train_step_info): lr_range = train_step_info.optimizer_config.lr - self.lr_end decay_steps = self.total_steps - self._num_warmup_steps pct_remaining = 1 - (train_step_info.optimization_step - self._num_warmup_steps) / decay_steps - decay = lr_range * pct_remaining ** self.power + self.lr_end + decay = lr_range * pct_remaining**self.power + self.lr_end return decay / train_step_info.optimizer_config.lr - def get_lr(self, train_step_info): return [train_step_info.optimizer_config.lr * self._warmup_poly(train_step_info)] diff --git a/orttraining/orttraining/python/training/ortmodule/__init__.py b/orttraining/orttraining/python/training/ortmodule/__init__.py index 6e88cc63e9702..4e6c3a21f4bb4 100644 --- a/orttraining/orttraining/python/training/ortmodule/__init__.py +++ b/orttraining/orttraining/python/training/ortmodule/__init__.py @@ -11,10 +11,7 @@ from onnxruntime import set_seed from onnxruntime.capi import build_and_package_info as ort_info -from ._fallback import (_FallbackPolicy, - ORTModuleFallbackException, - ORTModuleInitException, - wrap_exception) +from ._fallback import _FallbackPolicy, ORTModuleFallbackException, ORTModuleInitException, wrap_exception from .torch_cpp_extensions import is_installed as is_torch_cpp_extensions_installed @@ -26,8 +23,7 @@ def _defined_from_envvar(name, default_value, warn=True): new_value = type(default_value)(new_value) except (TypeError, ValueError) as e: if warn: - warnings.warn( - "Unable to overwrite constant %r due to %r." % (name, e)) + warnings.warn("Unable to overwrite constant %r due to %r." % (name, e)) return default_value return new_value @@ -38,45 +34,55 @@ def _defined_from_envvar(name, default_value, warn=True): # assign them new values. Importing them directly do not propagate changes. ################################################################################ ONNX_OPSET_VERSION = 14 -MINIMUM_RUNTIME_PYTORCH_VERSION_STR = '1.8.1' -ORTMODULE_TORCH_CPP_DIR = os.path.join(os.path.dirname(__file__), 'torch_cpp_extensions') +MINIMUM_RUNTIME_PYTORCH_VERSION_STR = "1.8.1" +ORTMODULE_TORCH_CPP_DIR = os.path.join(os.path.dirname(__file__), "torch_cpp_extensions") _FALLBACK_INIT_EXCEPTION = None -ORTMODULE_FALLBACK_POLICY = _FallbackPolicy.FALLBACK_UNSUPPORTED_DEVICE |\ - _FallbackPolicy.FALLBACK_UNSUPPORTED_DATA |\ - _FallbackPolicy.FALLBACK_UNSUPPORTED_TORCH_MODEL |\ - _FallbackPolicy.FALLBACK_UNSUPPORTED_ONNX_MODEL +ORTMODULE_FALLBACK_POLICY = ( + _FallbackPolicy.FALLBACK_UNSUPPORTED_DEVICE + | _FallbackPolicy.FALLBACK_UNSUPPORTED_DATA + | _FallbackPolicy.FALLBACK_UNSUPPORTED_TORCH_MODEL + | _FallbackPolicy.FALLBACK_UNSUPPORTED_ONNX_MODEL +) ORTMODULE_FALLBACK_RETRY = False ORTMODULE_IS_DETERMINISTIC = torch.are_deterministic_algorithms_enabled() -ONNXRUNTIME_CUDA_VERSION = ort_info.cuda_version if hasattr(ort_info, 'cuda_version') else None -ONNXRUNTIME_ROCM_VERSION = ort_info.rocm_version if hasattr(ort_info, 'rocm_version') else None +ONNXRUNTIME_CUDA_VERSION = ort_info.cuda_version if hasattr(ort_info, "cuda_version") else None +ONNXRUNTIME_ROCM_VERSION = ort_info.rocm_version if hasattr(ort_info, "rocm_version") else None # Verify minimum PyTorch version is installed before proceding to ONNX Runtime initialization try: import torch - runtime_pytorch_version = version.parse(torch.__version__.split('+')[0]) + + runtime_pytorch_version = version.parse(torch.__version__.split("+")[0]) minimum_runtime_pytorch_version = version.parse(MINIMUM_RUNTIME_PYTORCH_VERSION_STR) if runtime_pytorch_version < minimum_runtime_pytorch_version: - raise wrap_exception(ORTModuleInitException, - RuntimeError( - 'ONNX Runtime ORTModule frontend requires PyTorch version greater' - f' or equal to {MINIMUM_RUNTIME_PYTORCH_VERSION_STR},' - f' but version {torch.__version__} was found instead.')) + raise wrap_exception( + ORTModuleInitException, + RuntimeError( + "ONNX Runtime ORTModule frontend requires PyTorch version greater" + f" or equal to {MINIMUM_RUNTIME_PYTORCH_VERSION_STR}," + f" but version {torch.__version__} was found instead." + ), + ) except ORTModuleFallbackException as e: # Initialization fallback is handled at ORTModule.__init__ _FALLBACK_INIT_EXCEPTION = e except ImportError as e: - raise RuntimeError(f'PyTorch {MINIMUM_RUNTIME_PYTORCH_VERSION_STR} must be ' - 'installed in order to run ONNX Runtime ORTModule frontend!') from e + raise RuntimeError( + f"PyTorch {MINIMUM_RUNTIME_PYTORCH_VERSION_STR} must be " + "installed in order to run ONNX Runtime ORTModule frontend!" + ) from e # Verify whether PyTorch C++ extensions are already compiled # TODO: detect when installed extensions are outdated and need reinstallation. Hash? Version file? -if not is_torch_cpp_extensions_installed(ORTMODULE_TORCH_CPP_DIR) and '-m' not in sys.argv: +if not is_torch_cpp_extensions_installed(ORTMODULE_TORCH_CPP_DIR) and "-m" not in sys.argv: _FALLBACK_INIT_EXCEPTION = wrap_exception( ORTModuleInitException, RuntimeError( f"ORTModule's extensions were not detected at '{ORTMODULE_TORCH_CPP_DIR}' folder. " - "Run `python -m torch_ort.configure` before using `ORTModule` frontend.")) + "Run `python -m torch_ort.configure` before using `ORTModule` frontend." + ), + ) # Initalized ORT's random seed with pytorch's initial seed # in case user has set pytorch seed before importing ORTModule diff --git a/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function.py b/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function.py index fee541a1e48a9..6f3c7c7d29b66 100644 --- a/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function.py +++ b/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function.py @@ -3,6 +3,7 @@ # Licensed under the MIT License. # -------------------------------------------------------------------------- + class Enabler(object): def __init__(self): self._state = False @@ -18,10 +19,15 @@ def state(self, val): custom_autograd_function_enabler = Enabler() + def enable_custom_autograd_support(): # Initialize static objects needed to run custom autograd.Function's. - from onnxruntime.capi._pybind_state import register_forward_runner, register_backward_runner, unregister_python_functions + from onnxruntime.capi._pybind_state import ( + register_forward_runner, + register_backward_runner, + unregister_python_functions, + ) from torch.onnx import register_custom_op_symbolic from ._custom_autograd_function_exporter import _export from ._custom_autograd_function_runner import call_python_forward_function, call_python_backward_function @@ -40,9 +46,9 @@ def enable_custom_autograd_support(): try: # This is for the latest Pytorch nightly after this commit: # https://github.com/pytorch/pytorch/commit/11bc435622e6b7207bbf37ed1aafe999e1f296ec - register_custom_op_symbolic('prim::PythonOp', _export, 1) + register_custom_op_symbolic("prim::PythonOp", _export, 1) except: # This applies to Pytorch 1.9 and 1.9.1. - register_custom_op_symbolic('::prim_PythonOp', _export, 1) + register_custom_op_symbolic("::prim_PythonOp", _export, 1) custom_autograd_function_enabler.state = True diff --git a/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_exporter.py b/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_exporter.py index 5a3057e8ec91f..63af43ce48eb7 100644 --- a/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_exporter.py +++ b/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_exporter.py @@ -19,23 +19,24 @@ # for big models such as GPT-2. Exporting CheckpointFunction as PythonOp means # every transformer would be computed by Pytorch and ORT doesn't contribute # at all. -BANNED_AUTOGRAD_FUNCTION_NAMES = set( - [torch.utils.checkpoint.CheckpointFunction.__name__]) +BANNED_AUTOGRAD_FUNCTION_NAMES = set([torch.utils.checkpoint.CheckpointFunction.__name__]) def _export_pt_1_10(g, n, *args, **kwargs): - ''' + """ This function exports PythonOp (input: "n") into a graph node in "g". "args" and "kwargs" are inputs to that PythonOp. A PythonOp represents a call to autograd.Function. - ''' + """ try: - name = kwargs['name'] + name = kwargs["name"] if name in BANNED_AUTOGRAD_FUNCTION_NAMES: - raise Exception(f'The autograd.Function {name} should not be exported to ONNX. ' - 'Please replace ORTModule with HierarchalORTModule to only' - 'wrap exportable sub-nn.Module\'s as ORTModule.') - inplace = kwargs['inplace'] + raise Exception( + f"The autograd.Function {name} should not be exported to ONNX. " + "Please replace ORTModule with HierarchalORTModule to only" + "wrap exportable sub-nn.Module's as ORTModule." + ) + inplace = kwargs["inplace"] training_mode = symbolic_helper._training_mode cconv = n.cconv() input_tensor_types = [] @@ -62,18 +63,17 @@ def _export_pt_1_10(g, n, *args, **kwargs): tensor_args = [] # Encode inputs to autograd.Function. for i, arg, call_type in zip(range(len(args)), args, cconv): - if call_type == 'd': + if call_type == "d": # Got a tensor variable. tensor_args.append(arg) requires_grad = 1 if arg.requires_grad() else 0 input_requires_grads.append(requires_grad) - scalar_type = int(symbolic_helper.cast_pytorch_to_onnx[arg.type( - ).scalarType()]) + scalar_type = int(symbolic_helper.cast_pytorch_to_onnx[arg.type().scalarType()]) input_tensor_types.append(scalar_type) input_tensor_ranks.append(arg.type().dim()) - elif call_type == 'c': + elif call_type == "c": # Got a non-tensor variable. # Non-tensor can't have gradient. input_requires_grads.append(0) @@ -96,27 +96,28 @@ def _export_pt_1_10(g, n, *args, **kwargs): elif all(isinstance(ele, float) for ele in arg): # A tuple of floats. input_float_tuple_positions.append(i) - input_float_tuple_begins.append( - len(input_float_tuples)) + input_float_tuple_begins.append(len(input_float_tuples)) input_float_tuples.extend(list(arg)) else: - raise wrap_exception(ORTModuleONNXModelException, - Exception(f'Unknown argument type found: {type(arg)}.')) + raise wrap_exception( + ORTModuleONNXModelException, Exception(f"Unknown argument type found: {type(arg)}.") + ) else: # All other inputs are accessed via "pointers". input_pointer_scalar_positions.append(i) input_pointer_scalars.append(id(arg)) else: - raise wrap_exception(ORTModuleONNXModelException, - Exception(f'Unknown calling convention found: {i}. Only \'d\' and \'c\' are supported')) + raise wrap_exception( + ORTModuleONNXModelException, + Exception(f"Unknown calling convention found: {i}. Only 'd' and 'c' are supported"), + ) output_tensor_types = [] output_tensor_ranks = [] output_tensor_requires_grads = [] for arg in n.outputs(): # Type of tensor's elements. - scalar_type = int(symbolic_helper.cast_pytorch_to_onnx[arg.type( - ).scalarType()]) + scalar_type = int(symbolic_helper.cast_pytorch_to_onnx[arg.type().scalarType()]) output_tensor_types.append(scalar_type) output_tensor_ranks.append(arg.type().dim()) # If output has gradient. @@ -125,36 +126,36 @@ def _export_pt_1_10(g, n, *args, **kwargs): # TODO: add fully-qualified name. attrs = { - 'name_s': name, - 'inplace_i': inplace, - 'input_convention_s': cconv, - 'outputs': n.outputsSize(), - 'input_tensor_types_i': input_tensor_types, - 'input_tensor_ranks_i': input_tensor_ranks, - 'input_requires_grads_i': input_requires_grads, - 'output_tensor_types_i': output_tensor_types, - 'output_tensor_ranks_i': output_tensor_ranks, - 'output_tensor_requires_grads_i': output_tensor_requires_grads, - 'training_mode_i': 1 if training_mode else 0 + "name_s": name, + "inplace_i": inplace, + "input_convention_s": cconv, + "outputs": n.outputsSize(), + "input_tensor_types_i": input_tensor_types, + "input_tensor_ranks_i": input_tensor_ranks, + "input_requires_grads_i": input_requires_grads, + "output_tensor_types_i": output_tensor_types, + "output_tensor_ranks_i": output_tensor_ranks, + "output_tensor_requires_grads_i": output_tensor_requires_grads, + "training_mode_i": 1 if training_mode else 0, } if len(input_int_scalars) > 0: - attrs['input_int_scalars_i'] = input_int_scalars - attrs['input_int_scalar_positions_i'] = input_int_scalar_positions + attrs["input_int_scalars_i"] = input_int_scalars + attrs["input_int_scalar_positions_i"] = input_int_scalar_positions if len(input_float_scalars) > 0: - attrs['input_float_scalars_f'] = input_float_scalars - attrs['input_float_scalar_positions_i'] = input_float_scalar_positions + attrs["input_float_scalars_f"] = input_float_scalars + attrs["input_float_scalar_positions_i"] = input_float_scalar_positions if len(input_int_tuples) > 0: - attrs['input_int_tuples_i'] = input_int_tuples - attrs['input_int_tuple_positions_i'] = input_int_tuple_positions - attrs['input_int_tuple_begins_i'] = input_int_tuple_begins + attrs["input_int_tuples_i"] = input_int_tuples + attrs["input_int_tuple_positions_i"] = input_int_tuple_positions + attrs["input_int_tuple_begins_i"] = input_int_tuple_begins if len(input_float_tuples) > 0: - attrs['input_float_tuples_f'] = input_float_tuples - attrs['input_float_tuple_positions_i'] = input_float_tuple_positions - attrs['input_float_tuple_begins_i'] = input_float_tuple_begins + attrs["input_float_tuples_f"] = input_float_tuples + attrs["input_float_tuple_positions_i"] = input_float_tuple_positions + attrs["input_float_tuple_begins_i"] = input_float_tuple_begins if len(input_pointer_scalars) > 0: - attrs['input_pointer_scalars_i'] = input_pointer_scalars - attrs['input_pointer_scalar_positions_i'] = input_pointer_scalar_positions + attrs["input_pointer_scalars_i"] = input_pointer_scalars + attrs["input_pointer_scalar_positions_i"] = input_pointer_scalar_positions returned_args = g.op("com.microsoft::PythonOp", *tensor_args, **attrs) @@ -164,33 +165,39 @@ def _export_pt_1_10(g, n, *args, **kwargs): sys.stderr.flush() raise wrap_exception(ORTModuleONNXModelException, e) + # Starting from PyTorch 1.11, there has been a change to symbolic function signature # in terms of how additional context is accessed. More info at # https://github.com/pytorch/pytorch/blob/6b02648479d3615fa3260961e24f38dd0f22da94/torch/onnx/symbolic_helper.py#L48 # This code can be cleaned up once support for PyTorch version < 1.11 is dropped. try: from torch.onnx import SymbolicContext + def _export(ctx: SymbolicContext, g, *args, **kwargs): n = ctx.cur_node return _export_pt_1_10(g, n, *args, **kwargs) + except ImportError: _export = _export_pt_1_10 + def _post_process_after_export(exported_model, enable_custom_autograd_function, log_level): if enable_custom_autograd_function: return _post_process_enabling_autograd_fallback(exported_model) is_pythonop_needed = False for node in exported_model.graph.node: - if node.domain == 'com.microsoft' and node.op_type in ["PythonOp"]: + if node.domain == "com.microsoft" and node.op_type in ["PythonOp"]: is_pythonop_needed = True break if is_pythonop_needed and log_level <= _logger.LogLevel.WARNING: - warnings.warn('Detected autograd functions usage in current model, the run will fail \ - without enabling \'_enable_custom_autograd_function\'. Please enable it with: \ - \'module._execution_manager(is_training_mode)._enable_custom_autograd_function = True\'', - UserWarning) + warnings.warn( + "Detected autograd functions usage in current model, the run will fail \ + without enabling '_enable_custom_autograd_function'. Please enable it with: \ + 'module._execution_manager(is_training_mode)._enable_custom_autograd_function = True'", + UserWarning, + ) return exported_model @@ -201,7 +208,7 @@ def _post_process_enabling_autograd_fallback(exported_model): # Collect mapping of class names to full qualified class names. if kclass.__name__ not in registered_name_mappings: registered_name_mappings[kclass.__name__] = [] - full_qualified_name = kclass.__module__ + '.' + kclass.__qualname__ + full_qualified_name = kclass.__module__ + "." + kclass.__qualname__ registered_name_mappings[kclass.__name__].append(full_qualified_name) # Register function with class names. @@ -209,21 +216,25 @@ def _post_process_enabling_autograd_fallback(exported_model): index = 0 for node in exported_model.graph.node: - if node.domain == 'com.microsoft' and node.op_type in ["PythonOp"]: + if node.domain == "com.microsoft" and node.op_type in ["PythonOp"]: output_names = list(node.output) del node.output[:] - node.output.append(output_names[0] + '_ctx') + node.output.append(output_names[0] + "_ctx") node.output.extend(output_names) for attr in node.attribute: - if attr.name == 'name': - kclass_name = attr.s.decode('utf-8') if isinstance(attr.s, bytes) else attr.s + if attr.name == "name": + kclass_name = attr.s.decode("utf-8") if isinstance(attr.s, bytes) else attr.s # If the duplicated function is used in ONNX graph, we will fail in case of a wrong function call. # Todo: remove this trick once exporter can support fully qualified name for PythonOp. if kclass_name in registered_name_mappings and len(registered_name_mappings[kclass_name]) > 1: - error_msg = 'More than one torch.autograd.Function named {}, but probabbly in different namespace. ' \ - 'The conflicting autograd.Functions are: {}. Currently torch exporter cannot ' \ - 'differentiate them with full qualified name, so there is a risk exported PythonOp calls a ' \ - 'wrong autograd.Function.'.format(kclass_name, ','.join(registered_name_mappings[kclass_name])) + error_msg = ( + "More than one torch.autograd.Function named {}, but probabbly in different namespace. " + "The conflicting autograd.Functions are: {}. Currently torch exporter cannot " + "differentiate them with full qualified name, so there is a risk exported PythonOp calls a " + "wrong autograd.Function.".format( + kclass_name, ",".join(registered_name_mappings[kclass_name]) + ) + ) raise wrap_exception(ORTModuleONNXModelException, RuntimeError(error_msg)) break diff --git a/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_runner.py b/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_runner.py index 940f29edbbaf2..acbfdb8b0db4d 100644 --- a/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_runner.py +++ b/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_runner.py @@ -12,8 +12,9 @@ from onnxruntime.training.ortmodule.torch_cpp_extensions import torch_interop_utils + def wrap_as_dlpack_or_not(grad_flag, tensor_flag, inplace_flag, training_mode_flag, arg): - ''' + """ If the input is a DLPack tensor, we wrap it as a torch.Tensor and set up its attributes according to other input flags. Otherwise, we return the input as is. @@ -25,7 +26,7 @@ def wrap_as_dlpack_or_not(grad_flag, tensor_flag, inplace_flag, training_mode_fl training_mode_flag: indicate if the top-level model is running under training (or inference) mode. arg: a DLPack tensor or a normal Python object (e.g, a tuple of ints). - ''' + """ if tensor_flag: # Got a tensor. Assume it's a DLPack tensor # and convert it to Pytorch tensor. @@ -41,14 +42,11 @@ def wrap_as_dlpack_or_not(grad_flag, tensor_flag, inplace_flag, training_mode_fl # Use non-tensor as is. It's a PyObject*. return arg + def call_python_forward_function( - forward_function, - requires_grad_flags, - tensor_type_flags, - is_training_mode, - inplace, - *args): - ''' + forward_function, requires_grad_flags, tensor_type_flags, is_training_mode, inplace, *args +): + """ This function bridges the gap between ORT variables and autograd.Function.apply. It conducts basic casting from ORT to Pytorch (before calling "forward_function") and from Pytorch to ORT (after calling "forward_function"). It also enable autograd in Pytorch. It formats returned outputs, @@ -64,7 +62,7 @@ def call_python_forward_function( is_training_mode: indicates if this model is running under training mode. inplace: indicates if args can be modified inside the custom function. args: inputs to "backward_function". - ''' + """ def generate_non_leaf_or_not(grad_flag, tensor_flag, arg, is_training_mode, is_inplace): if is_training_mode and tensor_flag and grad_flag and is_inplace: @@ -86,7 +84,7 @@ def register_context(result): # (https://github.com/pytorch/pytorch/blob/15532595209d2daf34d35e10f8d3d3b64966aea2/torch/csrc/autograd/custom_function.cpp#L267) first_tensor_output = None for arg in result: - if not isinstance(arg, torch.Tensor) or not hasattr(arg, 'grad_fn'): + if not isinstance(arg, torch.Tensor) or not hasattr(arg, "grad_fn"): continue # Use the first context we see because all of arg's # share the same one. @@ -143,17 +141,23 @@ def register_context(result): # are DLPack tensors. return wrapped else: - raise wrap_exception(ORTModuleIOError, - TypeError(f'ORTModule does not support the following model output type {type(result)}.')) + raise wrap_exception( + ORTModuleIOError, + TypeError(f"ORTModule does not support the following model output type {type(result)}."), + ) try: - wrapped_args = list(wrap_as_dlpack_or_not(grad_flag, tensor_flag, inplace, is_training_mode, arg) - for grad_flag, tensor_flag, arg in zip(requires_grad_flags, tensor_type_flags, args)) + wrapped_args = list( + wrap_as_dlpack_or_not(grad_flag, tensor_flag, inplace, is_training_mode, arg) + for grad_flag, tensor_flag, arg in zip(requires_grad_flags, tensor_type_flags, args) + ) with torch.set_grad_enabled(is_training_mode): # Another level of wrap to avoid requires_grad=True for leaf variables. - new_wrapped_args = list(generate_non_leaf_or_not(grad_flag, tensor_flag, arg, is_training_mode, inplace) - for grad_flag, tensor_flag, arg in zip(requires_grad_flags, tensor_type_flags, wrapped_args)) + new_wrapped_args = list( + generate_non_leaf_or_not(grad_flag, tensor_flag, arg, is_training_mode, inplace) + for grad_flag, tensor_flag, arg in zip(requires_grad_flags, tensor_type_flags, wrapped_args) + ) # Run autograd.Function.apply(...). result = forward_function(*new_wrapped_args) @@ -164,20 +168,16 @@ def register_context(result): return tuple(unwrapped_values) except Exception as e: # Flush buffers. Otherwise, calling this from C++ may lose them. - print('Exception happens when running ', forward_function) + print("Exception happens when running ", forward_function) sys.stdout.flush() sys.stderr.flush() raise wrap_exception(ORTModuleFallbackException, e) def call_python_backward_function( - backward_function, - requires_grad_flags, - tensor_type_flags, - is_training_mode, - inplace, - *args): - ''' + backward_function, requires_grad_flags, tensor_type_flags, is_training_mode, inplace, *args +): + """ This function bridges the gap between ORT variables and autograd.Function.backward. It conducts basic casting from ORT to Pytorch (before calling "backward_function") and from Pytorch to ORT (after calling "backward_function"). It formats returned @@ -190,24 +190,29 @@ def call_python_backward_function( is_training_mode: indicates if this model is running under training mode. inplace: indicates if args can be modified inside the custom function. args: inputs to "backward_function". - ''' + """ with torch.no_grad(): + def wrap_all_outputs(result): if isinstance(result, torch.Tensor): return [to_dlpack(result)] elif isinstance(result, tuple) or isinstance(result, list): return [to_dlpack(value) if value is not None else None for value in result] else: - raise wrap_exception(ORTModuleIOError, - TypeError(f'ORTModule does not support the following model output type {type(result)}.')) + raise wrap_exception( + ORTModuleIOError, + TypeError(f"ORTModule does not support the following model output type {type(result)}."), + ) try: # Backward inputs should not require gradients. assert all(grad_flag == 0 for grad_flag in requires_grad_flags) # Prepare inputs for calling Python function. - wrapped_args = list(wrap_as_dlpack_or_not(grad_flag, tensor_flag, inplace, is_training_mode, arg) - for grad_flag, tensor_flag, arg in zip(requires_grad_flags, tensor_type_flags, args)) + wrapped_args = list( + wrap_as_dlpack_or_not(grad_flag, tensor_flag, inplace, is_training_mode, arg) + for grad_flag, tensor_flag, arg in zip(requires_grad_flags, tensor_type_flags, args) + ) # Call Python function. result = backward_function(*wrapped_args) @@ -221,7 +226,7 @@ def wrap_all_outputs(result): return tuple(wrapped_returned_args) except Exception as e: # Flush buffers. Otherwise, calling this from C++ may lose them. - print('Exception happens when running ', backward_function) + print("Exception happens when running ", backward_function) sys.stdout.flush() sys.stderr.flush() raise wrap_exception(ORTModuleFallbackException, e) diff --git a/orttraining/orttraining/python/training/ortmodule/_custom_gradient_registry.py b/orttraining/orttraining/python/training/ortmodule/_custom_gradient_registry.py index 1e9492ef420f8..b0c441cfd4b80 100644 --- a/orttraining/orttraining/python/training/ortmodule/_custom_gradient_registry.py +++ b/orttraining/orttraining/python/training/ortmodule/_custom_gradient_registry.py @@ -34,7 +34,7 @@ def _to_gradient_definition(gradient): node_def = C.GradientNodeDefinition() if isinstance(node[0], str): node_def.op_type = node[0] - node_def.domain = '' + node_def.domain = "" else: node_def.op_type = node[0][0] node_def.domain = node[0][1] @@ -45,9 +45,9 @@ def _to_gradient_definition(gradient): for key, value in node[3].items(): attr_def = C.GradientNodeAttributeDefinition() attr_def.name = key - attr_def.value_json = json.dumps(value['value']) - attr_def.dtype = value['dtype'] - attr_def.is_tensor = value['is_tensor'] if 'is_tensor' in value else False + attr_def.value_json = json.dumps(value["value"]) + attr_def.dtype = value["dtype"] + attr_def.is_tensor = value["is_tensor"] if "is_tensor" in value else False attributes.append(attr_def) node_def.attributes = attributes node_defs.append(node_def) @@ -60,12 +60,12 @@ class CustomGradientRegistry: @classmethod def register(cls, domain, name, attributes, fn): - key = '::'.join([domain, name] + list(attributes)) + key = "::".join([domain, name] + list(attributes)) cls._GRADIENTS[key] = _to_gradient_definition(fn()) @classmethod def register_custom_stop_gradient_edges(cls, edges, domain, name, *attributes): - key = '::'.join([domain, name] + list(attributes)) + key = "::".join([domain, name] + list(attributes)) cls._STOP_GRADIENT_EDGES[key] = set(edges) @classmethod @@ -75,77 +75,116 @@ def register_all(cls): for key, value in cls._STOP_GRADIENT_EDGES.items(): C.register_custom_stop_gradient_edges(key, value) + def register_gradient(domain, name, *attributes): def gradient_wrapper(fn): CustomGradientRegistry.register(domain, name, attributes, fn) return fn + return gradient_wrapper # For ATen op, we need to provide op_name and overload name. -@register_gradient('org.pytorch.aten', 'ATen', 'aten::embedding', '') +@register_gradient("org.pytorch.aten", "ATen", "aten::embedding", "") def embedding_gradient(): return [ - ('Constant', [], ['Const_0'], {'value': {'value': 0, 'dtype': 'int', 'is_tensor': True}}), - ('Shape', ['I(0)'], ['Shape_X']), - ('Gather', ['Shape_X', 'Const_0'], ['Gather_X_0'], {'axis': {'value': 0, 'dtype': 'int'}}), - (('ATen', 'org.pytorch.aten'), ['GO(0)', 'I(1)', 'Gather_X_0', 'I(2)', 'I(3)', 'I(4)'], [ - 'GI(0)'], {'operator': {'value': 'aten::embedding_backward', 'dtype': 'string'}}), + ("Constant", [], ["Const_0"], {"value": {"value": 0, "dtype": "int", "is_tensor": True}}), + ("Shape", ["I(0)"], ["Shape_X"]), + ("Gather", ["Shape_X", "Const_0"], ["Gather_X_0"], {"axis": {"value": 0, "dtype": "int"}}), + ( + ("ATen", "org.pytorch.aten"), + ["GO(0)", "I(1)", "Gather_X_0", "I(2)", "I(3)", "I(4)"], + ["GI(0)"], + {"operator": {"value": "aten::embedding_backward", "dtype": "string"}}, + ), ] -@register_gradient('org.pytorch.aten', 'ATen', 'aten::diagonal', '') + +@register_gradient("org.pytorch.aten", "ATen", "aten::diagonal", "") def diagonal_gradient(): return [ - ('Shape', ['I(0)'], ['Shape_X']), - (('ATen', 'org.pytorch.aten'), ['GO(0)', 'Shape_X', 'I(1)', 'I(2)', 'I(3)'], [ - 'GI(0)'], {'operator': {'value': 'aten::diagonal_backward', 'dtype': 'string'}}), + ("Shape", ["I(0)"], ["Shape_X"]), + ( + ("ATen", "org.pytorch.aten"), + ["GO(0)", "Shape_X", "I(1)", "I(2)", "I(3)"], + ["GI(0)"], + {"operator": {"value": "aten::diagonal_backward", "dtype": "string"}}, + ), ] -@register_gradient('org.pytorch.aten', 'ATen', 'aten::max_pool2d_with_indices', '') + +@register_gradient("org.pytorch.aten", "ATen", "aten::max_pool2d_with_indices", "") def max_pool2d_gradient(): return [ - (('ATen', 'org.pytorch.aten'), ['GO(0)', 'I(0)', 'I(1)', 'I(2)', 'I(3)', 'I(4)', 'I(5)', 'O(1)'], [ - 'GI(0)'], {'operator': {'value': 'aten::max_pool2d_with_indices_backward', 'dtype': 'string'}}), + ( + ("ATen", "org.pytorch.aten"), + ["GO(0)", "I(0)", "I(1)", "I(2)", "I(3)", "I(4)", "I(5)", "O(1)"], + ["GI(0)"], + {"operator": {"value": "aten::max_pool2d_with_indices_backward", "dtype": "string"}}, + ), ] -@register_gradient('org.pytorch.aten', 'ATen', 'aten::unfold', '') +@register_gradient("org.pytorch.aten", "ATen", "aten::unfold", "") def unfold_gradient(): return [ - ('Shape', ['I(0)'], ['Shape_X']), - (('ATen', 'org.pytorch.aten'), ['GO(0)', 'Shape_X', 'I(1)', 'I(2)', 'I(3)'], [ - 'GI(0)'], {'operator': {'value': 'aten::unfold_backward', 'dtype': 'string'}}), + ("Shape", ["I(0)"], ["Shape_X"]), + ( + ("ATen", "org.pytorch.aten"), + ["GO(0)", "Shape_X", "I(1)", "I(2)", "I(3)"], + ["GI(0)"], + {"operator": {"value": "aten::unfold_backward", "dtype": "string"}}, + ), ] -@register_gradient('org.pytorch.aten', 'ATen', 'aten::avg_pool2d', '') +@register_gradient("org.pytorch.aten", "ATen", "aten::avg_pool2d", "") def avg_pool2d_gradient(): return [ - (('ATen', 'org.pytorch.aten'), ['GO(0)', 'I(0)', 'I(1)', 'I(2)', 'I(3)', 'I(4)', 'I(5)', 'I(6)'], [ - 'GI(0)'], {'operator': {'value': 'aten::avg_pool2d_backward', 'dtype': 'string'}}), + ( + ("ATen", "org.pytorch.aten"), + ["GO(0)", "I(0)", "I(1)", "I(2)", "I(3)", "I(4)", "I(5)", "I(6)"], + ["GI(0)"], + {"operator": {"value": "aten::avg_pool2d_backward", "dtype": "string"}}, + ), ] -@register_gradient('org.pytorch.aten', 'ATen', 'aten::_adaptive_avg_pool2d', '') +@register_gradient("org.pytorch.aten", "ATen", "aten::_adaptive_avg_pool2d", "") def adaptive_avg_pool2d_gradient(): return [ - (('ATen', 'org.pytorch.aten'), ['GO(0)', 'I(0)'], [ - 'GI(0)'], {'operator': {'value': 'aten::_adaptive_avg_pool2d_backward', 'dtype': 'string'}}), + ( + ("ATen", "org.pytorch.aten"), + ["GO(0)", "I(0)"], + ["GI(0)"], + {"operator": {"value": "aten::_adaptive_avg_pool2d_backward", "dtype": "string"}}, + ), ] -CustomGradientRegistry.register_custom_stop_gradient_edges([0], 'org.pytorch.aten', 'ATen', 'aten::argmax', '') -CustomGradientRegistry.register_custom_stop_gradient_edges([0], 'org.pytorch.aten', 'ATen', 'aten::multinomial', '') -@register_gradient('org.pytorch.aten', 'ATen', 'aten::binary_cross_entropy_with_logits', '') +CustomGradientRegistry.register_custom_stop_gradient_edges([0], "org.pytorch.aten", "ATen", "aten::argmax", "") +CustomGradientRegistry.register_custom_stop_gradient_edges([0], "org.pytorch.aten", "ATen", "aten::multinomial", "") + + +@register_gradient("org.pytorch.aten", "ATen", "aten::binary_cross_entropy_with_logits", "") def binary_cross_entropy_with_logits_gradient(): return [ - (('ATen', 'org.pytorch.aten'), ['GO(0)', 'I(0)', 'I(1)', 'I(2)', 'I(3)', 'I(4)'], [ - 'GI(0)'], {'operator': {'value': 'aten::binary_cross_entropy_with_logits_backward', 'dtype': 'string'}}), + ( + ("ATen", "org.pytorch.aten"), + ["GO(0)", "I(0)", "I(1)", "I(2)", "I(3)", "I(4)"], + ["GI(0)"], + {"operator": {"value": "aten::binary_cross_entropy_with_logits_backward", "dtype": "string"}}, + ), ] -@register_gradient('org.pytorch.aten', 'ATen', 'aten::numpy_T', '') + +@register_gradient("org.pytorch.aten", "ATen", "aten::numpy_T", "") def numpy_T_gradient(): return [ - (('ATen', 'org.pytorch.aten'), ['GO(0)'], [ - 'GI(0)'], {'operator': {'value': 'aten::numpy_T', 'dtype': 'string'}}), + ( + ("ATen", "org.pytorch.aten"), + ["GO(0)"], + ["GI(0)"], + {"operator": {"value": "aten::numpy_T", "dtype": "string"}}, + ), ] diff --git a/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py b/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py index 725425e38e949..2cefdf24f6fc2 100644 --- a/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py +++ b/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py @@ -14,7 +14,7 @@ class CustomOpSymbolicRegistry: @classmethod def register(cls, name, domain, fn): - cls._SYMBOLICS[domain + '::' + name] = fn + cls._SYMBOLICS[domain + "::" + name] = fn @classmethod def register_all(cls): @@ -23,120 +23,154 @@ def register_all(cls): register_custom_op_symbolic(name, fn, 1) -def register_symbolic(name, domain=''): +def register_symbolic(name, domain=""): def symbolic_wrapper(fn): CustomOpSymbolicRegistry.register(name, domain, fn) return fn + return symbolic_wrapper -@register_symbolic('cross_entropy_loss') -@parse_args('v', 'v', 'v', 'i', 'v', 'v') +@register_symbolic("cross_entropy_loss") +@parse_args("v", "v", "v", "i", "v", "v") def cross_entropy_loss(g, self, target, weight, reduction, ignore_index, label_smoothing=0.0): label_smoothing = sym_help._maybe_get_const(label_smoothing, "f") if label_smoothing > 0.0: raise RuntimeError("Unsupported: ONNX does not support label_smoothing") # reduction: 0->none, 1->mean, 2->sum - reduction = sym_help._maybe_get_const(reduction, 'i') - reduction_vals = ['none', 'mean', 'sum'] + reduction = sym_help._maybe_get_const(reduction, "i") + reduction_vals = ["none", "mean", "sum"] reduction = reduction_vals[reduction] - output, log_prob = g.op("com.microsoft::SoftmaxCrossEntropyLossInternal", - self, target, weight, ignore_index, - reduction_s=reduction, outputs=2) + output, log_prob = g.op( + "com.microsoft::SoftmaxCrossEntropyLossInternal", + self, + target, + weight, + ignore_index, + reduction_s=reduction, + outputs=2, + ) output.setType(self.type()) log_prob.setType(self.type()) return output -@register_symbolic('nll_loss') -@parse_args('v', 'v', 'v', 'i', 'v') +@register_symbolic("nll_loss") +@parse_args("v", "v", "v", "i", "v") def nll_loss(g, self, target, weight, reduction, ignore_index): # reduction: 0->none, 1->mean, 2->sum - reduction = sym_help._maybe_get_const(reduction, 'i') - reduction_vals = ['none', 'mean', 'sum'] + reduction = sym_help._maybe_get_const(reduction, "i") + reduction_vals = ["none", "mean", "sum"] reduction = reduction_vals[reduction] - output = g.op("com.microsoft::NegativeLogLikelihoodLossInternal", - self, target, weight, ignore_index, reduction_s=reduction) + output = g.op( + "com.microsoft::NegativeLogLikelihoodLossInternal", self, target, weight, ignore_index, reduction_s=reduction + ) output.setType(self.type()) return output -@register_symbolic('embedding') +@register_symbolic("embedding") def embedding(g, weight, indices, padding_idx, scale_grad_by_freq, sparse): - output = g.op("org.pytorch.aten::ATen", weight, indices, padding_idx, scale_grad_by_freq, sparse, - operator_s='aten::embedding') + output = g.op( + "org.pytorch.aten::ATen", weight, indices, padding_idx, scale_grad_by_freq, sparse, operator_s="aten::embedding" + ) indices_shape = _get_tensor_sizes(indices) - if indices_shape is not None and hasattr(weight.type(), 'with_sizes'): - output_type = weight.type().with_sizes( - indices_shape + [_get_tensor_dim_size(weight, 1)]) + if indices_shape is not None and hasattr(weight.type(), "with_sizes"): + output_type = weight.type().with_sizes(indices_shape + [_get_tensor_dim_size(weight, 1)]) output.setType(output_type) return output -@register_symbolic('bitwise_or') + +@register_symbolic("bitwise_or") def bitwise_or(g, self, other): - return g.op("org.pytorch.aten::ATen", self, other, - operator_s='aten::bitwise_or', overload_name_s='Tensor') + return g.op("org.pytorch.aten::ATen", self, other, operator_s="aten::bitwise_or", overload_name_s="Tensor") + -@register_symbolic('diagonal') +@register_symbolic("diagonal") def diagonal(g, self, offset, dim1, dim2): - return g.op("org.pytorch.aten::ATen", self, offset, dim1, dim2, - operator_s='aten::diagonal') + return g.op("org.pytorch.aten::ATen", self, offset, dim1, dim2, operator_s="aten::diagonal") -@register_symbolic('multinomial') +@register_symbolic("multinomial") def multinomial(g, self, num_samples, replacement=False, generator=None): if generator is not None and not sym_help._is_none(generator): raise RuntimeError("Unsupported: ONNX does not support generator for multinomial") - return g.op("org.pytorch.aten::ATen", self, num_samples, replacement, generator, - operator_s='aten::multinomial') + return g.op("org.pytorch.aten::ATen", self, num_samples, replacement, generator, operator_s="aten::multinomial") -@register_symbolic('max_pool2d') +@register_symbolic("max_pool2d") def max_pool2d(g, self, kernel_size, stride, padding, dilation, ceil_mode): - stride_val = sym_help._maybe_get_const(stride, 'is') + stride_val = sym_help._maybe_get_const(stride, "is") if not stride_val: stride = kernel_size - return g.op("org.pytorch.aten::ATen", self, kernel_size, stride, padding, dilation, ceil_mode, - operator_s='aten::max_pool2d_with_indices', outputs=2)[0] - - -@register_symbolic('unfold') + return g.op( + "org.pytorch.aten::ATen", + self, + kernel_size, + stride, + padding, + dilation, + ceil_mode, + operator_s="aten::max_pool2d_with_indices", + outputs=2, + )[0] + + +@register_symbolic("unfold") def unfold(g, input, dimension, size, step): - return g.op("org.pytorch.aten::ATen", input, dimension, size, step, operator_s='aten::unfold') + return g.op("org.pytorch.aten::ATen", input, dimension, size, step, operator_s="aten::unfold") -@register_symbolic('argmax') +@register_symbolic("argmax") def argmax(g, input, dim, keepdim): - return g.op("org.pytorch.aten::ATen", input, dim, keepdim, operator_s='aten::argmax') + return g.op("org.pytorch.aten::ATen", input, dim, keepdim, operator_s="aten::argmax") -@register_symbolic('avg_pool2d') +@register_symbolic("avg_pool2d") def avg_pool2d(g, self, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override): - stride_val = sym_help._maybe_get_const(stride, 'is') + stride_val = sym_help._maybe_get_const(stride, "is") if not stride_val: stride = kernel_size - return g.op("org.pytorch.aten::ATen", self, kernel_size, stride, padding, ceil_mode, - count_include_pad, divisor_override, operator_s='aten::avg_pool2d') - - -@register_symbolic('adaptive_avg_pool2d') + return g.op( + "org.pytorch.aten::ATen", + self, + kernel_size, + stride, + padding, + ceil_mode, + count_include_pad, + divisor_override, + operator_s="aten::avg_pool2d", + ) + + +@register_symbolic("adaptive_avg_pool2d") def adaptive_avg_pool2d(g, self, output_size): - return g.op("org.pytorch.aten::ATen", self, output_size, operator_s='aten::_adaptive_avg_pool2d') + return g.op("org.pytorch.aten::ATen", self, output_size, operator_s="aten::_adaptive_avg_pool2d") -@register_symbolic('binary_cross_entropy_with_logits') +@register_symbolic("binary_cross_entropy_with_logits") def binary_cross_entropy_with_logits(g, self, target, weight, pos_weight, reduction): # If weight is not None, we need to check if it requires grad and add gradient graph accordingly. # But current custom_gradient_registry doesn't support such None checking, # So doesn't support non-None weight for now. if weight is None or sym_help._is_none(weight): - return g.op("org.pytorch.aten::ATen", self, target, weight, pos_weight, reduction, - operator_s='aten::binary_cross_entropy_with_logits') + return g.op( + "org.pytorch.aten::ATen", + self, + target, + weight, + pos_weight, + reduction, + operator_s="aten::binary_cross_entropy_with_logits", + ) from torch.onnx.symbolic_opset12 import binary_cross_entropy_with_logits as bce + return bce(g, self, target, weight, pos_weight, reduction) -@register_symbolic('numpy_T') + +@register_symbolic("numpy_T") def numpy_T(g, self): # Numpy-style `a.T`: returns the tensor # with dims reversed @@ -147,28 +181,29 @@ def numpy_T(g, self): else: # if we don't have dim information we cannot # output a permute so use ATen instead - return g.op("com.microsoft::ATenOp", self, name_s='aten::numpy_T') + return g.op("com.microsoft::ATenOp", self, name_s="aten::numpy_T") -@register_symbolic('squeeze') +@register_symbolic("squeeze") def squeeze(g, self, dim=None): # Current _infer_If does not correctly infer shapes from its then- and else- branches, and will # cause error in shape inference of following nodes, here we choose to export it as `Squeeze.` from torch.onnx.symbolic_opset11 import squeeze as squeeze_with_if + if dim is None: return squeeze_with_if(g, self, dim) - squeeze_dim = sym_help._get_const(dim, 'i', 'dim') + squeeze_dim = sym_help._get_const(dim, "i", "dim") return sym_help._squeeze_helper(g, self, axes_i=[squeeze_dim]) # For torch.einsum. def parse_equation(equation): - pos_comma = equation.find(',') - pos_arrow = equation.find('->') + pos_comma = equation.find(",") + pos_arrow = equation.find("->") assert pos_comma != -1 and pos_arrow > pos_comma - lhs_labels = [label for label in equation[:pos_comma] if label != ' '] - rhs_labels = [label for label in equation[pos_comma + 1:pos_arrow] if label != ' '] - result_labels = [label for label in equation[pos_arrow + 2:] if label != ' '] + lhs_labels = [label for label in equation[:pos_comma] if label != " "] + rhs_labels = [label for label in equation[pos_comma + 1 : pos_arrow] if label != " "] + result_labels = [label for label in equation[pos_arrow + 2 :] if label != " "] # Two operands and result are not empty, and are all alpha characters. assert lhs_labels and rhs_labels and result_labels assert all(label.isalpha() for label in lhs_labels + rhs_labels + result_labels) @@ -177,9 +212,11 @@ def parse_equation(equation): assert all(label in lhs_labels or label in rhs_labels for label in result_labels) return lhs_labels, rhs_labels, result_labels + def need_permute(perm): return any(idx != axis for idx, axis in enumerate(perm)) + def map_labels_to_output(input_labels, label_perm_map): output_len = len(label_perm_map) perm = [-1] * output_len @@ -199,6 +236,7 @@ def map_labels_to_output(input_labels, label_perm_map): return perm, unsqueeze_axes + def unsqueeze_and_permute_for_mul(g, tensor, unsqueeze_axes, perm): # If perm is sorted after removing unsqueeze axes, then permute is not needed. # For example, a.unsqueeze(2).permute([0, 2, 1]) is same as a.unsqueeze(1). @@ -214,6 +252,7 @@ def unsqueeze_and_permute_for_mul(g, tensor, unsqueeze_axes, perm): tensor = g.op("Transpose", tensor, perm_i=perm) return tensor + def combine_unsqueeze_and_permute_for_matmul(unsqueeze_axes, perm1, perm2): # When going here, the unsqueeze axes must be some axes at the end. # We can combine two permutes and remove unsqueeze axes, because we will reshape it after this. @@ -223,25 +262,42 @@ def combine_unsqueeze_and_permute_for_matmul(unsqueeze_axes, perm1, perm2): new_perm = [axis for axis in new_perm if axis not in unsqueeze_axes] return new_perm + def is_axes_contiguous(axes): return len(axes) < 2 or all(axes[axis] + 1 == axes[axis + 1] for axis in range(len(axes) - 1)) + def get_shape_tensor_by_axes(g, input, input_shape, axes, need_numel_shape): if input_shape is None: input_shape = g.op("Shape", input) - shape_tensor = g.op("Gather", input_shape, g.op("Constant", value_t=torch.tensor(axes, dtype=torch.int64)), axis_i=0) + shape_tensor = g.op( + "Gather", input_shape, g.op("Constant", value_t=torch.tensor(axes, dtype=torch.int64)), axis_i=0 + ) numel_shape_tensor = None if need_numel_shape: assert len(axes) > 1 numel_shape_tensor = g.op("ReduceProd", shape_tensor) return shape_tensor, numel_shape_tensor, input_shape + def reshape_tensor(g, input, shape_tensors): shape_tensor = g.op("Concat", *shape_tensors, axis_i=0) if len(shape_tensors) > 1 else shape_tensors[0] return g.op("Reshape", input, shape_tensor) -def permute_and_reshape_tensor(g, tensor, is_lhs, rank, perm, matmul_output_axes, contraction_axes, - batch_length, matmul_output_numel_tensor, contraction_numel_tensor, shape_tensor): + +def permute_and_reshape_tensor( + g, + tensor, + is_lhs, + rank, + perm, + matmul_output_axes, + contraction_axes, + batch_length, + matmul_output_numel_tensor, + contraction_numel_tensor, + shape_tensor, +): # If matmul_output_axes and contraction_axes are contiguous in input tensor, # we can move Reshape to before Transpose, so it's possible that the Transpoase is fused to MatMul. # Otherwise, we have to Transpose first to move those axes together and then Reshape. @@ -263,7 +319,7 @@ def permute_and_reshape_tensor(g, tensor, is_lhs, rank, perm, matmul_output_axes else: new_tensor = sym_help._unsqueeze_helper(g, tensor, [batch_length if is_lhs else -1]) else: - axes_to_remove = contraction_axes[1:] # contraction_axes can't be empty. + axes_to_remove = contraction_axes[1:] # contraction_axes can't be empty. if len(matmul_output_axes) > 1: axes_to_remove = axes_to_remove + matmul_output_axes[1:] remaining_axes = [axis for axis in range(rank) if axis not in axes_to_remove] @@ -287,7 +343,8 @@ def permute_and_reshape_tensor(g, tensor, is_lhs, rank, perm, matmul_output_axes has_neg_one_dim = True else: single_axis_shape_tensor, _, shape_tensor = get_shape_tensor_by_axes( - g, tensor, shape_tensor, [axis], False) + g, tensor, shape_tensor, [axis], False + ) shape_tensors.append(single_axis_shape_tensor) if not has_neg_one_dim and last_zero_dim >= 0: shape_tensors[last_zero_dim] = g.op("Constant", value_t=torch.tensor([-1], dtype=torch.int64)) @@ -312,19 +369,24 @@ def permute_and_reshape_tensor(g, tensor, is_lhs, rank, perm, matmul_output_axes shape_tensors = [g.op("Constant", value_t=torch.tensor([0], dtype=torch.int64))] * batch_length if is_lhs: if matmul_output_numel_tensor is None: - matmul_output_numel_tensor = g.op("Constant", value_t=torch.tensor([1 - len(matmul_output_axes)], dtype=torch.int64)) + matmul_output_numel_tensor = g.op( + "Constant", value_t=torch.tensor([1 - len(matmul_output_axes)], dtype=torch.int64) + ) shape_tensors.append(matmul_output_numel_tensor) shape_tensors.append(g.op("Constant", value_t=torch.tensor([-1], dtype=torch.int64))) else: - if contraction_numel_tensor is None: # contraction_axes can't be empty, None here means only one contraction axis. + if ( + contraction_numel_tensor is None + ): # contraction_axes can't be empty, None here means only one contraction axis. contraction_numel_tensor = g.op("Constant", value_t=torch.tensor([0], dtype=torch.int64)) shape_tensors.append(contraction_numel_tensor) shape_tensors.append(g.op("Constant", value_t=torch.tensor([-1], dtype=torch.int64))) new_tensor = reshape_tensor(g, new_tensor, shape_tensors) return new_tensor, shape_tensor -@register_symbolic('einsum') -@parse_args('s', 'v') + +@register_symbolic("einsum") +@parse_args("s", "v") def einsum(g, equation, tensor_list): tensors = sym_help._unpack_list(tensor_list) num_ops = len(tensors) @@ -332,7 +394,7 @@ def einsum(g, equation, tensor_list): # Doesn't support implicit output is ellipsis or more than 2 oprands for now. # Doesn't support ellipsis ('...') for now as not easy to get sizes of oprands. - if num_ops != 2 or equation.find('->') == -1 or '.' in equation: + if num_ops != 2 or equation.find("->") == -1 or "." in equation: return g.op("Einsum", *tensors, equation_s=equation) # Take "ks,ksm->sm" as example. After prcoess inputs, @@ -409,8 +471,12 @@ def einsum(g, equation, tensor_list): # i.e., a.unsqueeze([2]).permute([1,2,0]).permute([0,1,2]) = [s,1,k] for the example. # For rhs input, the new permute is batched_axes + contraction_axes + matmul_output_axes: [0, 2, 1]. # i.e., b.unsqueeze([]).permute([1,2,0]).permute([0,2,1]) = [s,k,m] for the example. - lhs_perm = combine_unsqueeze_and_permute_for_matmul(lhs_unsqueeze_axes, lhs_perm, batched_axes + matmul_output_axes + contraction_axes) - rhs_perm = combine_unsqueeze_and_permute_for_matmul(rhs_unsqueeze_axes, rhs_perm, batched_axes + contraction_axes + matmul_output_axes) + lhs_perm = combine_unsqueeze_and_permute_for_matmul( + lhs_unsqueeze_axes, lhs_perm, batched_axes + matmul_output_axes + contraction_axes + ) + rhs_perm = combine_unsqueeze_and_permute_for_matmul( + rhs_unsqueeze_axes, rhs_perm, batched_axes + contraction_axes + matmul_output_axes + ) # Need to Reshape two input tensors before the BatchedMatMul and Reshape result to output shape. # Reshape lhs input to [[batched_shapes], Mul(lhs_matmul_output_shapes), Mul(contraction_shapes)]. @@ -419,8 +485,12 @@ def einsum(g, equation, tensor_list): # lhs_contraction_axes = [0], rhs_contraction_axes = [0], lhs_matmul_output_axes = [], rhs_matmul_output_axes = [2] for the example. lhs_contraction_axes = [lhs_labels.index(label) for label in contraction_labels] rhs_contraction_axes = [rhs_labels.index(label) for label in contraction_labels] - lhs_matmul_output_axes = [lhs_labels.index(result_labels[axis]) for axis in matmul_output_axes if result_labels[axis] in lhs_labels] - rhs_matmul_output_axes = [rhs_labels.index(result_labels[axis]) for axis in matmul_output_axes if result_labels[axis] in rhs_labels] + lhs_matmul_output_axes = [ + lhs_labels.index(result_labels[axis]) for axis in matmul_output_axes if result_labels[axis] in lhs_labels + ] + rhs_matmul_output_axes = [ + rhs_labels.index(result_labels[axis]) for axis in matmul_output_axes if result_labels[axis] in rhs_labels + ] # Caches of input shape tensors to avoid generating duplicated graph. lhs_shape_tensor = None @@ -430,7 +500,8 @@ def einsum(g, equation, tensor_list): contraction_numel_tensor = None if len(lhs_contraction_axes) > 1: _, contraction_numel_tensor, lhs_shape_tensor = get_shape_tensor_by_axes( - g, lhs_tensor, lhs_shape_tensor, lhs_contraction_axes, True) + g, lhs_tensor, lhs_shape_tensor, lhs_contraction_axes, True + ) # Prepare some shape tensors for Reshape if needed. # Both lhs_matmul_output_shape_tensor and lhs_matmul_output_numel_tensor is None for the example. @@ -438,22 +509,34 @@ def einsum(g, equation, tensor_list): lhs_matmul_output_numel_tensor = None if len(lhs_matmul_output_axes) > 1: lhs_matmul_output_shape_tensor, lhs_matmul_output_numel_tensor, lhs_shape_tensor = get_shape_tensor_by_axes( - g, lhs_tensor, lhs_shape_tensor, lhs_matmul_output_axes, True) + g, lhs_tensor, lhs_shape_tensor, lhs_matmul_output_axes, True + ) # Both rhs_matmul_output_shape_tensor and rhs_matmul_output_numel_tensor is None for the example. rhs_matmul_output_shape_tensor = None rhs_matmul_output_numel_tensor = None if len(rhs_matmul_output_axes) > 1: rhs_matmul_output_shape_tensor, rhs_matmul_output_numel_tensor, rhs_shape_tensor = get_shape_tensor_by_axes( - g, rhs_tensor, rhs_shape_tensor, rhs_matmul_output_axes, True) + g, rhs_tensor, rhs_shape_tensor, rhs_matmul_output_axes, True + ) new_lhs_tensor = lhs_tensor # Need to Reshape lhs_tensor if lhs_matmul_output_axes or lhs_contraction_axes is not 1, otherwise permute it directly. # Need to Reshape the lhs_tensor for the example, the new shape is [size(s), 1, size(k)]. if len(lhs_matmul_output_axes) != 1 or len(lhs_contraction_axes) != 1: new_lhs_tensor, lhs_shape_tensor = permute_and_reshape_tensor( - g, lhs_tensor, True, len(lhs_labels), lhs_perm, lhs_matmul_output_axes, lhs_contraction_axes, - len(batched_axes), lhs_matmul_output_numel_tensor, contraction_numel_tensor, lhs_shape_tensor) + g, + lhs_tensor, + True, + len(lhs_labels), + lhs_perm, + lhs_matmul_output_axes, + lhs_contraction_axes, + len(batched_axes), + lhs_matmul_output_numel_tensor, + contraction_numel_tensor, + lhs_shape_tensor, + ) else: if need_permute(lhs_perm): new_lhs_tensor = g.op("Transpose", lhs_tensor, perm_i=lhs_perm) @@ -463,8 +546,18 @@ def einsum(g, equation, tensor_list): new_rhs_tensor = rhs_tensor if len(rhs_matmul_output_axes) != 1 or len(rhs_contraction_axes) != 1: new_rhs_tensor, rhs_shape_tensor = permute_and_reshape_tensor( - g, rhs_tensor, False, len(rhs_labels), rhs_perm, rhs_matmul_output_axes, rhs_contraction_axes, - len(batched_axes), rhs_matmul_output_numel_tensor, contraction_numel_tensor, rhs_shape_tensor) + g, + rhs_tensor, + False, + len(rhs_labels), + rhs_perm, + rhs_matmul_output_axes, + rhs_contraction_axes, + len(batched_axes), + rhs_matmul_output_numel_tensor, + contraction_numel_tensor, + rhs_shape_tensor, + ) else: if need_permute(rhs_perm): new_rhs_tensor = g.op("Transpose", rhs_tensor, perm_i=rhs_perm) @@ -496,8 +589,11 @@ def einsum(g, equation, tensor_list): # Now output axes is ordered by [batched_axes, lhs_matmul_output_axes, rhs_matmut_output_axes], # if this is not same as output, need one permute. - labels = [result_labels[axis] for axis in batched_axes] + [ - lhs_labels[axis] for axis in lhs_matmul_output_axes] + [rhs_labels[axis] for axis in rhs_matmul_output_axes] + labels = ( + [result_labels[axis] for axis in batched_axes] + + [lhs_labels[axis] for axis in lhs_matmul_output_axes] + + [rhs_labels[axis] for axis in rhs_matmul_output_axes] + ) assert len(labels) == out_size output_perm = [labels.index(label) for label in result_labels] assert all(axis in output_perm for axis in range(out_size)) @@ -505,4 +601,6 @@ def einsum(g, equation, tensor_list): result = g.op("Transpose", result, perm_i=output_perm) return result + + # End of torch.einsum. diff --git a/orttraining/orttraining/python/training/ortmodule/_execution_agent.py b/orttraining/orttraining/python/training/ortmodule/_execution_agent.py index a5eef94e0585c..bce27ea94408f 100644 --- a/orttraining/orttraining/python/training/ortmodule/_execution_agent.py +++ b/orttraining/orttraining/python/training/ortmodule/_execution_agent.py @@ -52,8 +52,9 @@ def __init__(self, path_or_bytes, session_options=None, providers=None, provider self.create_inference_agent(path_or_bytes, session_options, providers, provider_options) def create_inference_agent(self, path_or_bytes, session_options, providers, provider_options): - self._inference_session = onnxruntime.InferenceSession(path_or_bytes, session_options, - providers, provider_options) + self._inference_session = onnxruntime.InferenceSession( + path_or_bytes, session_options, providers, provider_options + ) def io_binding(self): """Return an onnxruntime.IOBinding object`.""" @@ -62,9 +63,9 @@ def io_binding(self): def run_forward(self, iobinding, run_options): """ - Compute the forward graph. - :param iobinding: the iobinding object that has graph inputs/outputs bind. - :param run_options: See :class:`onnxruntime.RunOptions`. + Compute the forward graph. + :param iobinding: the iobinding object that has graph inputs/outputs bind. + :param run_options: See :class:`onnxruntime.RunOptions`. """ self._inference_session.run_with_iobinding(iobinding, run_options) @@ -77,9 +78,17 @@ class TrainingAgent(object): This is the main class used to run an ORTModule model training. """ - def __init__(self, path_or_bytes, fw_feed_names, fw_outputs_device_info, - bw_fetches_names, bw_outputs_device_info, session_options=None, - providers=None, provider_options=None): + def __init__( + self, + path_or_bytes, + fw_feed_names, + fw_outputs_device_info, + bw_fetches_names, + bw_outputs_device_info, + session_options=None, + providers=None, + provider_options=None, + ): """ :param path_or_bytes: filename or serialized ONNX or ORT format model in a byte string :param fw_feed_names: Feed names for foward pass. @@ -110,27 +119,33 @@ def __init__(self, path_or_bytes, fw_feed_names, fw_outputs_device_info, means execute a node using CUDAExecutionProvider if capable, otherwise execute using CPUExecutionProvider. """ - self._inference_session = onnxruntime.InferenceSession(path_or_bytes, session_options, - providers, provider_options) + self._inference_session = onnxruntime.InferenceSession( + path_or_bytes, session_options, providers, provider_options + ) - self._training_agent = C_TrainingAgent(self._inference_session._sess, fw_feed_names, fw_outputs_device_info, - bw_fetches_names, bw_outputs_device_info) + self._training_agent = C_TrainingAgent( + self._inference_session._sess, + fw_feed_names, + fw_outputs_device_info, + bw_fetches_names, + bw_outputs_device_info, + ) def run_forward(self, feeds, fetches, state, cache=None): """ - Compute the forward subgraph for given feeds and fetches. - :param feeds: Inputs to the graph run. - :param fetches: Outputs of the graph run. - :param state: State of the graph that is used for executing partial graph runs. - :param cache: Cache to store stashed OrtValues for intermediate activations. + Compute the forward subgraph for given feeds and fetches. + :param feeds: Inputs to the graph run. + :param fetches: Outputs of the graph run. + :param state: State of the graph that is used for executing partial graph runs. + :param cache: Cache to store stashed OrtValues for intermediate activations. """ self._training_agent.run_forward(feeds, fetches, state, cache) def run_backward(self, feeds, fetches, state): """ - Compute the backward subgraph for given feeds and fetches. - :param feeds: Inputs to the graph run. - :param fetches: Outputs of the graph run. - :param state: State of the graph that is used for executing partial graph runs. + Compute the backward subgraph for given feeds and fetches. + :param feeds: Inputs to the graph run. + :param fetches: Outputs of the graph run. + :param state: State of the graph that is used for executing partial graph runs. """ self._training_agent.run_backward(feeds, fetches, state) diff --git a/orttraining/orttraining/python/training/ortmodule/_fallback.py b/orttraining/orttraining/python/training/ortmodule/_fallback.py index 832c40526d656..7129e522b8c49 100644 --- a/orttraining/orttraining/python/training/ortmodule/_fallback.py +++ b/orttraining/orttraining/python/training/ortmodule/_fallback.py @@ -11,21 +11,23 @@ from enum import IntFlag from typing import Optional -from ._fallback_exceptions import (ORTModuleFallbackException, - ORTModuleInitException, - ORTModuleDeviceException, - ORTModuleIOError, - ORTModuleTorchModelException, - ORTModuleONNXModelException, - wrap_exception) +from ._fallback_exceptions import ( + ORTModuleFallbackException, + ORTModuleInitException, + ORTModuleDeviceException, + ORTModuleIOError, + ORTModuleTorchModelException, + ORTModuleONNXModelException, + wrap_exception, +) from . import _utils class _FallbackPolicy(IntFlag): - '''Policy to trigger fallback from ONNX Runtime engine to PyTorch + """Policy to trigger fallback from ONNX Runtime engine to PyTorch Each policy can be combined with the others (using |) in order to aggregate them - ''' + """ FALLBACK_DISABLE = 1 FALLBACK_FORCE_TORCH_FORWARD = 2 @@ -37,21 +39,21 @@ class _FallbackPolicy(IntFlag): FALLBACK_BAD_INITIALIZATION = 128 def is_set(self, policy): - '''Check whether `policy` is set on the `_FallbackPolicy` instance + """Check whether `policy` is set on the `_FallbackPolicy` instance FALLBACK_DISABLE implies the check will always fail and return False - ''' + """ return not self.is_disabled() and policy in self def is_disabled(self): - '''Check whether `_FallbackPolicy.FALLBACK_DEVICE` is set on the `_FallbackPolicy` instance''' + """Check whether `_FallbackPolicy.FALLBACK_DEVICE` is set on the `_FallbackPolicy` instance""" return _FallbackPolicy.FALLBACK_DISABLE in self class _FallbackManager(object): - '''Manages fallbacks based on incoming exceptions and specified policies + """Manages fallbacks based on incoming exceptions and specified policies The basic algorithm is based on a dictionary whose keys are the supported fallback policies and and values are a set of Exception that must be detected. @@ -64,49 +66,49 @@ class _FallbackManager(object): On the other hand, when the exception doesn't match any enabled policy, the exception will be raised to the user, terminating execution - ''' + """ - def __init__(self, - pytorch_module: torch.nn.Module, - policy: _FallbackPolicy, - retry: bool): + def __init__(self, pytorch_module: torch.nn.Module, policy: _FallbackPolicy, retry: bool): self._original_module = pytorch_module # Read policy from environment variable for testing purposes - policy = os.getenv('ORTMODULE_FALLBACK_POLICY', policy) + policy = os.getenv("ORTMODULE_FALLBACK_POLICY", policy) if isinstance(policy, str): policy = _FallbackPolicy[policy] # Read retry from environment variable for testing purposes - retry = os.getenv('ORTMODULE_FALLBACK_RETRY', str(retry)).lower() in ['true', '1', 'yes'] - - self._policy_exception_map = {_FallbackPolicy.FALLBACK_FORCE_TORCH_FORWARD.value: {ORTModuleFallbackException, - ORTModuleDeviceException, - ORTModuleIOError, - ORTModuleTorchModelException, - ORTModuleONNXModelException}, - _FallbackPolicy.FALLBACK_FORCE_TORCH_BACKWARD.value: {ORTModuleFallbackException, - ORTModuleDeviceException, - ORTModuleIOError, - ORTModuleTorchModelException, - ORTModuleONNXModelException}, - _FallbackPolicy.FALLBACK_UNSUPPORTED_DEVICE.value: {ORTModuleDeviceException}, - _FallbackPolicy.FALLBACK_UNSUPPORTED_DATA.value: {ORTModuleIOError}, - _FallbackPolicy.FALLBACK_UNSUPPORTED_TORCH_MODEL.value: {ORTModuleTorchModelException}, - _FallbackPolicy.FALLBACK_UNSUPPORTED_ONNX_MODEL.value: {ORTModuleONNXModelException}, - _FallbackPolicy.FALLBACK_BAD_INITIALIZATION.value: {ORTModuleInitException, - ORTModuleTorchModelException}, - } + retry = os.getenv("ORTMODULE_FALLBACK_RETRY", str(retry)).lower() in ["true", "1", "yes"] + + self._policy_exception_map = { + _FallbackPolicy.FALLBACK_FORCE_TORCH_FORWARD.value: { + ORTModuleFallbackException, + ORTModuleDeviceException, + ORTModuleIOError, + ORTModuleTorchModelException, + ORTModuleONNXModelException, + }, + _FallbackPolicy.FALLBACK_FORCE_TORCH_BACKWARD.value: { + ORTModuleFallbackException, + ORTModuleDeviceException, + ORTModuleIOError, + ORTModuleTorchModelException, + ORTModuleONNXModelException, + }, + _FallbackPolicy.FALLBACK_UNSUPPORTED_DEVICE.value: {ORTModuleDeviceException}, + _FallbackPolicy.FALLBACK_UNSUPPORTED_DATA.value: {ORTModuleIOError}, + _FallbackPolicy.FALLBACK_UNSUPPORTED_TORCH_MODEL.value: {ORTModuleTorchModelException}, + _FallbackPolicy.FALLBACK_UNSUPPORTED_ONNX_MODEL.value: {ORTModuleONNXModelException}, + _FallbackPolicy.FALLBACK_BAD_INITIALIZATION.value: {ORTModuleInitException, ORTModuleTorchModelException}, + } self.policy = policy self.retry = retry self._exception = None - def handle_exception(self, - exception: Exception, - log_level: _logger.LogLevel, - override_policy: Optional[_FallbackPolicy] = None) -> None: - '''Process incoming `exception` based on the selected `policy` + def handle_exception( + self, exception: Exception, log_level: _logger.LogLevel, override_policy: Optional[_FallbackPolicy] = None + ) -> None: + """Process incoming `exception` based on the selected `policy` If the incoming `exception` is handled by the specified policy, `_FallbackManager` saves the exception so that ORTModule can track the pending fallback @@ -120,23 +122,32 @@ def handle_exception(self, Raises: `exception`: Original exception is raised when there is no matching policy for it - ''' + """ + def _set_exception(policy: _FallbackPolicy, exception: Exception, log_level: _logger.LogLevel): - if policy is not _FallbackPolicy.FALLBACK_DISABLE and \ - self.policy.is_set(policy) and \ - (policy.value in self._policy_exception_map and type(exception) in self._policy_exception_map[policy.value]): + if ( + policy is not _FallbackPolicy.FALLBACK_DISABLE + and self.policy.is_set(policy) + and ( + policy.value in self._policy_exception_map + and type(exception) in self._policy_exception_map[policy.value] + ) + ): if log_level <= _logger.LogLevel.INFO: - warnings.warn( - f'Fallback for policy {policy.name} is pending.', UserWarning) + warnings.warn(f"Fallback for policy {policy.name} is pending.", UserWarning) # ORTModuleInitException exceptions do not call `fallback()` through `GraphExecutionManager`, # Instead, it fallbacks to PyTorch implicitly through `ORTModule._torch_module = TorchModulePytorch(module)` if log_level <= _logger.LogLevel.WARNING and policy == _FallbackPolicy.FALLBACK_BAD_INITIALIZATION: warnings.warn( - (f'Fallback to PyTorch due to exception {type(exception)} was triggered. ' - 'Report this issue with a minimal repro at https://www.github.com/microsoft/onnxruntime. ' - f'See details below:\n\n{_utils.get_exception_as_string(exception)}'), UserWarning) + ( + f"Fallback to PyTorch due to exception {type(exception)} was triggered. " + "Report this issue with a minimal repro at https://www.github.com/microsoft/onnxruntime. " + f"See details below:\n\n{_utils.get_exception_as_string(exception)}" + ), + UserWarning, + ) self._exception = exception @@ -151,23 +162,27 @@ def _set_exception(policy: _FallbackPolicy, exception: Exception, log_level: _lo raise exception def is_pending(self) -> bool: - '''Returns True when a fallback is pending + """Returns True when a fallback is pending ORTModule must execute fallback to PyTorch engine when a pending fallback is detected - ''' + """ return self._exception is not None def fallback(self, log_level: _logger.LogLevel, *inputs, **kwargs): - '''Executes user PyTorch `model` using the provided inputs and return the result''' + """Executes user PyTorch `model` using the provided inputs and return the result""" - assert self.is_pending(), '`fallback` can only be called when there is a pending fallback' + assert self.is_pending(), "`fallback` can only be called when there is a pending fallback" if log_level <= _logger.LogLevel.WARNING: warnings.warn( - (f'Fallback to PyTorch due to exception {type(self._exception)} was triggered. ' - 'Report this issue with a minimal repro at https://www.github.com/microsoft/onnxruntime. ' - f'See details below:\n\n{_utils.get_exception_as_string(self._exception)}'), UserWarning) + ( + f"Fallback to PyTorch due to exception {type(self._exception)} was triggered. " + "Report this issue with a minimal repro at https://www.github.com/microsoft/onnxruntime. " + f"See details below:\n\n{_utils.get_exception_as_string(self._exception)}" + ), + UserWarning, + ) # Pending fallbacks are resetted to enforce retries if self.retry: diff --git a/orttraining/orttraining/python/training/ortmodule/_fallback_exceptions.py b/orttraining/orttraining/python/training/ortmodule/_fallback_exceptions.py index 476317f92891a..3bb88cdebee18 100644 --- a/orttraining/orttraining/python/training/ortmodule/_fallback_exceptions.py +++ b/orttraining/orttraining/python/training/ortmodule/_fallback_exceptions.py @@ -4,68 +4,70 @@ class ORTModuleFallbackException(Exception): - '''Base exception class for fallback + """Base exception class for fallback Although it must be specialized for specific scenarios, it can also be used for generic exception that require fallback - ''' + """ pass class ORTModuleInitException(ORTModuleFallbackException): - '''Trigger fallback for ORTModule initialization related exceptions + """Trigger fallback for ORTModule initialization related exceptions This exception is triggered when an incompatible or missing requirements for ORTModule are detected, including PyTorch version, missing ORTModule's PyTorch C++ extension binaries, etc. - ''' + """ pass class ORTModuleDeviceException(ORTModuleFallbackException): - '''Trigger fallback for device related exceptions + """Trigger fallback for device related exceptions NOTE: This exception is raised during device validation within ORTModule frontend. Some device related exceptions can only be detected during PyTorch ONNX exporter execution. This exception does not capture these scenarios. - ''' + """ pass class ORTModuleIOError(ORTModuleFallbackException): - '''Trigger fallback for I/O related exceptions + """Trigger fallback for I/O related exceptions NOTE: This exception is raised during I/O validation within ORTModule Frontend. Some I/O related exceptions can only be detected during PyTorch ONNX exporter execution. This exception does not capture these scenarios. - ''' + """ pass class ORTModuleTorchModelException(ORTModuleFallbackException): - '''Trigger fallback for PyTorch modules related exceptions + """Trigger fallback for PyTorch modules related exceptions This exception is raised during model validation within ORTModule frontend and is based on checking type(model) over a hardcoded list of incompatible models. - ''' + """ pass class ORTModuleONNXModelException(ORTModuleFallbackException): - '''Trigger fallback for ONNX model related exceptions + """Trigger fallback for ONNX model related exceptions This exception is raised during model conversion to ONNX and post-processing validation within ORTModule frontend. - ''' + """ pass -def wrap_exception(new_exception: ORTModuleFallbackException, raised_exception: Exception) -> ORTModuleFallbackException: - '''Wraps `raised_exception` exception as cause for the returned `new_exception` exception''' +def wrap_exception( + new_exception: ORTModuleFallbackException, raised_exception: Exception +) -> ORTModuleFallbackException: + """Wraps `raised_exception` exception as cause for the returned `new_exception` exception""" exception = None try: diff --git a/orttraining/orttraining/python/training/ortmodule/_gradient_accumulation_manager.py b/orttraining/orttraining/python/training/ortmodule/_gradient_accumulation_manager.py index adf1e5c9145e5..8d9f1ea2595f2 100644 --- a/orttraining/orttraining/python/training/ortmodule/_gradient_accumulation_manager.py +++ b/orttraining/orttraining/python/training/ortmodule/_gradient_accumulation_manager.py @@ -11,6 +11,7 @@ class GradientAccumulationManager(object): This feature must be enabled once before training and cannot be turned off within a training run. """ + # TODO: enable switching the feature on/off in the middle of the training def __init__(self): @@ -35,8 +36,7 @@ def initialize(self, enabled, module, graph_info) -> None: # Since named_parameters() is a generator function, need to avoid overhead and # populate the params in memory to avoid generating the param map every # step. This will not work if the user adds or removes params between steps - self._param_name_value_map = { - name: param for name, param in module.named_parameters()} + self._param_name_value_map = {name: param for name, param in module.named_parameters()} self._param_version_map = dict() self._frontier_node_arg_map = graph_info.frontier_node_arg_map self._cached_node_arg_names = graph_info.cached_node_arg_names @@ -44,8 +44,7 @@ def initialize(self, enabled, module, graph_info) -> None: @property def enabled(self): - """Indicates whether gradient accumulation optimization is enabled. - """ + """Indicates whether gradient accumulation optimization is enabled.""" return self._enabled def extract_outputs_and_maybe_update_cache(self, forward_outputs, device): @@ -58,15 +57,14 @@ def extract_outputs_and_maybe_update_cache(self, forward_outputs, device): return _utils._ortvalues_to_torch_tensor(forward_outputs, device) if self._update_cache: for i in range(self._cache_start, len(forward_outputs)): - self.cache.insert( - self._cached_node_arg_names[i-self._cache_start], forward_outputs[i]) + self.cache.insert(self._cached_node_arg_names[i - self._cache_start], forward_outputs[i]) self._update_cache = False return _utils._ortvalues_to_torch_tensor_list( - [forward_outputs[i] for i in range(self._cache_start)], device, c_class=True) + [forward_outputs[i] for i in range(self._cache_start)], device, c_class=True + ) def maybe_update_cache_before_run(self): - """Update cache when model parameters are modified and optimization is enabled. - """ + """Update cache when model parameters are modified and optimization is enabled.""" # The current implementation relies on param._version, which might not be # updated in all cases(eg. inplace update) # TODO: Make detection of parameter update robust diff --git a/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py b/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py index 65a99e4c50f62..9c62fbebf1234 100644 --- a/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py +++ b/orttraining/orttraining/python/training/ortmodule/_graph_execution_manager.py @@ -4,20 +4,18 @@ # -------------------------------------------------------------------------- from .debug_options import DebugOptions, LogLevel -from . import (_utils, - _io, - _logger, - _onnx_models, - _are_deterministic_algorithms_enabled) +from . import _utils, _io, _logger, _onnx_models, _are_deterministic_algorithms_enabled from .torch_cpp_extensions.cpu.aten_op_executor import load_aten_op_executor_cpp_extension from ._custom_autograd_function import custom_autograd_function_enabler from ._custom_autograd_function_exporter import _post_process_after_export from ._graph_execution_interface import GraphExecutionInterface -from ._fallback import (_FallbackManager, - ORTModuleDeviceException, - ORTModuleONNXModelException, - ORTModuleTorchModelException, - wrap_exception) +from ._fallback import ( + _FallbackManager, + ORTModuleDeviceException, + ORTModuleONNXModelException, + ORTModuleTorchModelException, + wrap_exception, +) from ._gradient_accumulation_manager import GradientAccumulationManager from onnxruntime.training import ortmodule @@ -96,18 +94,22 @@ def __init__(self, module, debug_options: DebugOptions, fallback_manager: _Fallb # Update constant ONNX_OPSET_VERSION with env var ORTMODULE_ONNX_OPSET_VERSION # if defined. ortmodule.ONNX_OPSET_VERSION = ortmodule._defined_from_envvar( - 'ORTMODULE_ONNX_OPSET_VERSION', ortmodule.ONNX_OPSET_VERSION, warn=True) + "ORTMODULE_ONNX_OPSET_VERSION", ortmodule.ONNX_OPSET_VERSION, warn=True + ) # TrainingAgent or InferenceAgent self._execution_agent = None # indicators of some logic have been executed previously thus could be skipped for faster training # default is enabled, if not define in os env - self._skip_check = _SkipCheck(_SkipCheck.SKIP_CHECK_DEVICE | _SkipCheck.SKIP_CHECK_BUILD_GRADIENT | _SkipCheck.SKIP_CHECK_EXECUTION_AGENT) - if os.getenv('ORTMODULE_SKIPCHECK_POLICY') is not None: - self._skip_check = reduce(lambda x, y: x | y, - [_SkipCheck[name] for name in - _utils.parse_os_env_skip_check_flags('ORTMODULE_SKIPCHECK_POLICY')]) + self._skip_check = _SkipCheck( + _SkipCheck.SKIP_CHECK_DEVICE | _SkipCheck.SKIP_CHECK_BUILD_GRADIENT | _SkipCheck.SKIP_CHECK_EXECUTION_AGENT + ) + if os.getenv("ORTMODULE_SKIPCHECK_POLICY") is not None: + self._skip_check = reduce( + lambda x, y: x | y, + [_SkipCheck[name] for name in _utils.parse_os_env_skip_check_flags("ORTMODULE_SKIPCHECK_POLICY")], + ) self._first_skip_check_warning = True # Graph transformer config @@ -151,18 +153,17 @@ def __init__(self, module, debug_options: DebugOptions, fallback_manager: _Fallb self._module_output_schema = None self._device = _utils.get_device_from_module(module) - self._module_parameters = list(inspect.signature( - self._original_module.forward).parameters.values()) + self._module_parameters = list(inspect.signature(self._original_module.forward).parameters.values()) # TODO: remove after PyTorch ONNX exporter supports VAR_KEYWORD parameters. for input_parameter in self._module_parameters: if input_parameter.kind == inspect.Parameter.VAR_KEYWORD: if self._debug_options.logging.log_level <= LogLevel.WARNING: - warnings.warn("The model's forward method has **kwargs parameter which has EXPERIMENTAL support!", - UserWarning) + warnings.warn( + "The model's forward method has **kwargs parameter which has EXPERIMENTAL support!", UserWarning + ) - self.is_rocm_pytorch = (True if ( - (torch.version.hip is not None) and (ROCM_HOME is not None)) else False) + self.is_rocm_pytorch = True if ((torch.version.hip is not None) and (ROCM_HOME is not None)) else False self._use_external_gpu_allocator = True # assign self._torch_alloc and self._torch_free if self._use_external_gpu_allocator is True @@ -186,6 +187,7 @@ def _get_torch_gpu_allocator_function_addresses(self): if self._use_external_gpu_allocator and torch.cuda.is_available(): # CPP extension to get torch GPU allocator's alloc and free function addresses from onnxruntime.training.ortmodule.torch_cpp_extensions import torch_gpu_allocator + self._torch_alloc = torch_gpu_allocator.gpu_caching_allocator_raw_alloc_address() self._torch_free = torch_gpu_allocator.gpu_caching_allocator_raw_delete_address() self._torch_empty_cache = torch_gpu_allocator.gpu_caching_allocator_empty_cache_address() @@ -194,14 +196,20 @@ def _validate_module_type(self, module): """Raises ORTModuleTorchModelException if the module is not a torch.nn.Module""" if not isinstance(module, torch.nn.Module): - raise wrap_exception(ORTModuleTorchModelException, - TypeError(f"ORTModule only supports torch.nn.Module as input. {type(module)} is not supported.")) + raise wrap_exception( + ORTModuleTorchModelException, + TypeError(f"ORTModule only supports torch.nn.Module as input. {type(module)} is not supported."), + ) # Hard-coded list of unsupported torch.nn.Module goes here for fallback if isinstance(module, torch.nn.DataParallel): - raise wrap_exception(ORTModuleTorchModelException, - TypeError("ORTModule is not compatible with torch.nn.DataParallel. " - "Please use torch.nn.parallel.DistributedDataParallel instead.")) + raise wrap_exception( + ORTModuleTorchModelException, + TypeError( + "ORTModule is not compatible with torch.nn.DataParallel. " + "Please use torch.nn.parallel.DistributedDataParallel instead." + ), + ) @staticmethod def execution_session_run_forward(execution_session, onnx_model, device, *inputs): @@ -238,11 +246,11 @@ def _build_graph(self): else: self._graph_builder.build() - self._onnx_models.optimized_model = onnx.load_model_from_string( - self._graph_builder.get_model()) + self._onnx_models.optimized_model = onnx.load_model_from_string(self._graph_builder.get_model()) self._onnx_models.optimized_pre_grad_model = onnx.load_model_from_string( - self._graph_builder.get_inference_optimized_model()) + self._graph_builder.get_inference_optimized_model() + ) self._graph_info = self._graph_builder.get_graph_info() @@ -271,15 +279,15 @@ def _get_session_config(self): if _are_deterministic_algorithms_enabled(): if self._debug_options.logging.log_level <= _logger.LogLevel.INFO: - warnings.warn("ORTModule's determinism will be enabled because PyTorch's determinism is enabled.", - UserWarning) + warnings.warn( + "ORTModule's determinism will be enabled because PyTorch's determinism is enabled.", UserWarning + ) providers = None provider_options = None - if self._device.type == 'cuda': + if self._device.type == "cuda": # Configure the InferenceSessions to use the specific GPU on which the model is placed. - providers = (["ROCMExecutionProvider"] if self.is_rocm_pytorch else [ - "CUDAExecutionProvider"]) + providers = ["ROCMExecutionProvider"] if self.is_rocm_pytorch else ["CUDAExecutionProvider"] providers.append("CPUExecutionProvider") provider_option_map = {"device_id": str(self._device.index)} if not self.is_rocm_pytorch: @@ -292,10 +300,10 @@ def _get_session_config(self): provider_option_map["gpu_external_free"] = str(self._torch_free) provider_option_map["gpu_external_empty_cache"] = str(self._torch_empty_cache) provider_options = [provider_option_map, {}] - elif self._device.type == 'cpu': + elif self._device.type == "cpu": providers = ["CPUExecutionProvider"] provider_options = [{}] - elif self._device.type == 'ort': + elif self._device.type == "ort": provider_info = C.get_ort_device_provider_info(self._device.index) assert len(provider_info.keys()) == 1 providers = list(provider_info.keys()) @@ -308,15 +316,15 @@ def _get_session_config(self): # default to PRIORITY_BASED execution order session_options.execution_order = onnxruntime.ExecutionOrder.PRIORITY_BASED # 0:Verbose, 1:Info, 2:Warning. 3:Error, 4:Fatal. Default is 2. - session_options.log_severity_level = int( - self._debug_options.logging.log_level) + session_options.log_severity_level = int(self._debug_options.logging.log_level) if self._debug_options.save_onnx_models.save: - session_options.optimized_model_filepath = \ - os.path.join(self._debug_options.save_onnx_models.path, - _onnx_models._get_onnx_file_name( - self._debug_options.save_onnx_models.name_prefix, - 'execution_model', self._export_mode)) + session_options.optimized_model_filepath = os.path.join( + self._debug_options.save_onnx_models.path, + _onnx_models._get_onnx_file_name( + self._debug_options.save_onnx_models.name_prefix, "execution_model", self._export_mode + ), + ) return session_options, providers, provider_options @@ -336,23 +344,28 @@ def _export_model(self, *inputs, **kwargs): # e.g., some sympy functions in symbolic_shape_infer will change Python's random state. random_states = _utils.get_random_states() - schema = _io._extract_schema( - {'args': copy.copy(inputs), 'kwargs': copy.copy(kwargs)}) - if self._onnx_models.exported_model and schema == self._input_info.schema and not self._original_model_has_changed: + schema = _io._extract_schema({"args": copy.copy(inputs), "kwargs": copy.copy(kwargs)}) + if ( + self._onnx_models.exported_model + and schema == self._input_info.schema + and not self._original_model_has_changed + ): # All required models have already been exported previously return False self._set_device_from_module(inputs, kwargs) - self._onnx_models.exported_model = self._get_exported_model( - schema, *inputs, **kwargs) + self._onnx_models.exported_model = self._get_exported_model(schema, *inputs, **kwargs) if self._debug_options.save_onnx_models.save: - self._onnx_models.save_exported_model(self._debug_options.save_onnx_models.path, - self._debug_options.save_onnx_models.name_prefix, - self._export_mode) + self._onnx_models.save_exported_model( + self._debug_options.save_onnx_models.path, + self._debug_options.save_onnx_models.name_prefix, + self._export_mode, + ) if self._run_symbolic_shape_infer: - self._onnx_models.exported_model = SymbolicShapeInference.infer_shapes(self._onnx_models.exported_model, - auto_merge=True, guess_output_rank=True) + self._onnx_models.exported_model = SymbolicShapeInference.infer_shapes( + self._onnx_models.exported_model, auto_merge=True, guess_output_rank=True + ) # Restore the recorded random states _utils.set_random_states(random_states) @@ -360,20 +373,18 @@ def _export_model(self, *inputs, **kwargs): return True def _get_exported_model(self, input_schema, *inputs, **kwargs): - '''Exports PyTorch `self._flattened_module` to ONNX for inferencing or training, using `*inputs` and `**kwargs` as input + """Exports PyTorch `self._flattened_module` to ONNX for inferencing or training, using `*inputs` and `**kwargs` as input TODO: How to support dynamic axes? Dimensions are determined by samples - ''' + """ # Setup dynamic axes for onnx model - self._input_info = _io.parse_inputs_for_onnx_export(self._module_parameters, - None, - input_schema, - inputs, - kwargs) - output_names, output_dynamic_axes, self._module_output_schema = \ - _io.parse_outputs_for_onnx_export_and_extract_schema( - self._original_module, inputs, kwargs) + self._input_info = _io.parse_inputs_for_onnx_export(self._module_parameters, None, input_schema, inputs, kwargs) + ( + output_names, + output_dynamic_axes, + self._module_output_schema, + ) = _io.parse_outputs_for_onnx_export_and_extract_schema(self._original_module, inputs, kwargs) self._input_info.dynamic_axes.update(output_dynamic_axes) # FlattenedModule needs _InputInfo to expand user input from *args to *args + **kwargs @@ -385,58 +396,66 @@ def _get_exported_model(self, input_schema, *inputs, **kwargs): # Deepcopy inputs, since input values may change after model run. # NOTE: Inputs may contain tensors that have attributes preventing their deepcopy (example grad_fn). # Therefore, deepcopy only the data component of the input tensors for export. - sample_inputs_copy, sample_kwargs_copy = _io.deepcopy_model_input( - *inputs, **kwargs) + sample_inputs_copy, sample_kwargs_copy = _io.deepcopy_model_input(*inputs, **kwargs) # NOTE: Flattening the input will change the 'input schema', resulting in a re-export - sample_inputs_as_tuple = tuple(self._input_info.flatten( - sample_inputs_copy, sample_kwargs_copy, self._device)) + sample_inputs_as_tuple = tuple(self._input_info.flatten(sample_inputs_copy, sample_kwargs_copy, self._device)) # Ops behaving differently under train/eval mode need to exported with the # correct training flag to reflect the expected behavior. # For example, the Dropout node in a model is dropped under eval mode. assert self._export_mode is not None, "Please use a concrete instance of ExecutionManager" try: - with torch.set_grad_enabled(self._enable_custom_autograd_function), \ - _logger.suppress_os_stream_output(log_level=self._debug_options.logging.log_level): - required_export_kwargs = {'input_names': self._input_info.names, - 'output_names': output_names, - 'opset_version': ortmodule.ONNX_OPSET_VERSION, - 'do_constant_folding': False, - 'training': self._export_mode, - 'dynamic_axes': self._input_info.dynamic_axes, - 'verbose': self._debug_options.logging.log_level < LogLevel.WARNING, - 'export_params': False, - 'keep_initializers_as_inputs': True} + with torch.set_grad_enabled(self._enable_custom_autograd_function), _logger.suppress_os_stream_output( + log_level=self._debug_options.logging.log_level + ): + required_export_kwargs = { + "input_names": self._input_info.names, + "output_names": output_names, + "opset_version": ortmodule.ONNX_OPSET_VERSION, + "do_constant_folding": False, + "training": self._export_mode, + "dynamic_axes": self._input_info.dynamic_axes, + "verbose": self._debug_options.logging.log_level < LogLevel.WARNING, + "export_params": False, + "keep_initializers_as_inputs": True, + } invalid_args = self._export_extra_kwargs.keys() & required_export_kwargs.keys() - assert len(invalid_args) == 0,\ - f"The following PyTorch exporter arguments cannot be specified: '{invalid_args}'." - torch.onnx.export(self._flattened_module, - sample_inputs_as_tuple, - f, - **required_export_kwargs, - **self._export_extra_kwargs) + assert ( + len(invalid_args) == 0 + ), f"The following PyTorch exporter arguments cannot be specified: '{invalid_args}'." + torch.onnx.export( + self._flattened_module, + sample_inputs_as_tuple, + f, + **required_export_kwargs, + **self._export_extra_kwargs, + ) except Exception as e: - raise wrap_exception(ORTModuleONNXModelException, - RuntimeError(f'There was an error while exporting the PyTorch model to ONNX: ' - f'\n\n{_utils.get_exception_as_string(e)}')) + raise wrap_exception( + ORTModuleONNXModelException, + RuntimeError( + f"There was an error while exporting the PyTorch model to ONNX: " + f"\n\n{_utils.get_exception_as_string(e)}" + ), + ) exported_model = onnx.load_model_from_string(f.getvalue()) - exported_model = _post_process_after_export(exported_model, - self._enable_custom_autograd_function, - self._debug_options.logging.log_level) + exported_model = _post_process_after_export( + exported_model, self._enable_custom_autograd_function, self._debug_options.logging.log_level + ) return exported_model def _set_device_from_module(self, inputs, kwargs): """Get the device from the module and save it to self._device""" - device = _utils.get_device_from_module(self._original_module) or \ - _utils.get_device_from_inputs(inputs, kwargs) + device = _utils.get_device_from_module(self._original_module) or _utils.get_device_from_inputs(inputs, kwargs) if not self._device or self._device != device: self._device = device if not self._device: - raise wrap_exception(ORTModuleDeviceException, - RuntimeError('A device must be specified in the model or inputs!')) + raise wrap_exception( + ORTModuleDeviceException, RuntimeError("A device must be specified in the model or inputs!") + ) def _get_graph_transformer_config(self): graph_transformer_config = C.TrainingGraphTransformerConfiguration() @@ -452,14 +471,17 @@ def _initialize_graph_builder(self, training): # All initializer names along with user inputs are a part of the onnx graph inputs # since the onnx model was exported with the flag keep_initializers_as_inputs=True - onnx_initializer_names = { - p.name for p in self._onnx_models.exported_model.graph.input} + onnx_initializer_names = {p.name for p in self._onnx_models.exported_model.graph.input} # TODO: PyTorch exporter bug: changes the initializer order in ONNX model - initializer_names = [name for name, _ in self._flattened_module.named_parameters() - if name in onnx_initializer_names] - initializer_names_to_train = [name for name, param in self._flattened_module.named_parameters() - if param.requires_grad and name in onnx_initializer_names] + initializer_names = [ + name for name, _ in self._flattened_module.named_parameters() if name in onnx_initializer_names + ] + initializer_names_to_train = [ + name + for name, param in self._flattened_module.named_parameters() + if param.requires_grad and name in onnx_initializer_names + ] # Build and optimize the full graph grad_builder_config = C.OrtModuleGraphBuilderConfiguration() @@ -469,25 +491,26 @@ def _initialize_graph_builder(self, training): grad_builder_config.build_gradient_graph = training grad_builder_config.graph_transformer_config = self._get_graph_transformer_config() grad_builder_config.enable_caching = self._enable_grad_acc_optimization - grad_builder_config.loglevel = _logger.ortmodule_loglevel_to_onnxruntime_c_loglevel(self._debug_options.logging.log_level) + grad_builder_config.loglevel = _logger.ortmodule_loglevel_to_onnxruntime_c_loglevel( + self._debug_options.logging.log_level + ) grad_builder_config.use_memory_efficient_gradient = self._use_memory_efficient_gradient self._graph_builder = C.OrtModuleGraphBuilder() # It is assumed here that the order and names of the inputs and outputs are not modified by the backend in any way # and are kept as they appear in the exported onnx model. - self._graph_builder.initialize( - self._onnx_models.exported_model.SerializeToString(), grad_builder_config) + self._graph_builder.initialize(self._onnx_models.exported_model.SerializeToString(), grad_builder_config) # TODO: Explore ways to make self._graph_info.initializer_names and self._graph_info.initializer_names_to_train # a set (unordered_set in the backend) that does not require a copy on each reference. self._graph_initializer_names = set(initializer_names) - self._graph_initializer_names_to_train = set( - initializer_names_to_train) + self._graph_initializer_names_to_train = set(initializer_names_to_train) # Initializers can be cached and used since they are expected not to be re-instantiated # between forward calls. - self._graph_initializers = [param for name, param in self._flattened_module.named_parameters() - if name in self._graph_initializer_names] + self._graph_initializers = [ + param for name, param in self._flattened_module.named_parameters() if name in self._graph_initializer_names + ] def signal_model_changed(self): """Signals the execution manager to re-export the model on the next forward call""" @@ -503,7 +526,7 @@ def __getstate__(self): "_execution_agent", "_torch_alloc", "_torch_free", - "_torch_empty_cache" + "_torch_empty_cache", ] for attribute_name in serialization_deny_list: del state[attribute_name] diff --git a/orttraining/orttraining/python/training/ortmodule/_inference_manager.py b/orttraining/orttraining/python/training/ortmodule/_inference_manager.py index e95e27e0097de..d8614f5bfe966 100644 --- a/orttraining/orttraining/python/training/ortmodule/_inference_manager.py +++ b/orttraining/orttraining/python/training/ortmodule/_inference_manager.py @@ -3,14 +3,8 @@ # Licensed under the MIT License. # -------------------------------------------------------------------------- -from . import (_utils, - _io, - _logger, - _are_deterministic_algorithms_enabled, - _use_deterministic_algorithms) -from ._graph_execution_manager import (GraphExecutionManager, - _RunStateInfo, - _SkipCheck) +from . import _utils, _io, _logger, _are_deterministic_algorithms_enabled, _use_deterministic_algorithms +from ._graph_execution_manager import GraphExecutionManager, _RunStateInfo, _SkipCheck from ._execution_agent import InferenceAgent from .debug_options import DebugOptions from ._fallback import ORTModuleFallbackException, _FallbackPolicy, _FallbackManager @@ -54,19 +48,18 @@ def execution_session_run_forward(execution_session, onnx_model, device, *inputs user_outputs = _utils._ortvalues_to_torch_tensor_list(forward_outputs, device) state = None - output_info = [(output.shape, output.device, output.dtype) - for output in user_outputs] + output_info = [(output.shape, output.device, output.dtype) for output in user_outputs] run_info = _RunStateInfo(state, output_info) # Return user outputs and forward run information return user_outputs, run_info def forward(self, *inputs, **kwargs): - '''Forward pass of the inference model + """Forward pass of the inference model ONNX model is exported the first time this method is executed. Next, we build an optimized inference graph with module_graph_builder. Finally, we instantiate the ONNX Runtime InferenceSession through the InferenceAgent. - ''' + """ # Fallback to PyTorch due to failures *external* to forward(), # typically from initialization @@ -75,19 +68,27 @@ def forward(self, *inputs, **kwargs): try: # Issue at most one warning message about fast path - if self._first_skip_check_warning is True and self._skip_check.is_disabled() is False \ - and self._debug_options.logging.log_level <= _logger.LogLevel.WARNING: + if ( + self._first_skip_check_warning is True + and self._skip_check.is_disabled() is False + and self._debug_options.logging.log_level <= _logger.LogLevel.WARNING + ): self._first_skip_check_warning = False - warnings.warn(f"Fast path enabled - skipping checks." - f"rebuild gradient graph: {self._skip_check.is_set(_SkipCheck.SKIP_CHECK_BUILD_GRADIENT)}," - f"execution agent recreation: {self._skip_check.is_set(_SkipCheck.SKIP_CHECK_EXECUTION_AGENT)}," - f"device check: {self._skip_check.is_set(_SkipCheck.SKIP_CHECK_DEVICE)}", UserWarning) + warnings.warn( + f"Fast path enabled - skipping checks." + f"rebuild gradient graph: {self._skip_check.is_set(_SkipCheck.SKIP_CHECK_BUILD_GRADIENT)}," + f"execution agent recreation: {self._skip_check.is_set(_SkipCheck.SKIP_CHECK_EXECUTION_AGENT)}," + f"device check: {self._skip_check.is_set(_SkipCheck.SKIP_CHECK_DEVICE)}", + UserWarning, + ) # If exporting module to ONNX for the first time, this skip check will not take effect. # It will only take effect on subsequent forward calls. build_graph = False - if self._skip_check.is_set(_SkipCheck.SKIP_CHECK_BUILD_GRADIENT) is False or \ - not self._onnx_models.exported_model: + if ( + self._skip_check.is_set(_SkipCheck.SKIP_CHECK_BUILD_GRADIENT) is False + or not self._onnx_models.exported_model + ): # Exporting module to ONNX for the first time build_graph = self._export_model(*inputs, **kwargs) if build_graph: @@ -101,14 +102,14 @@ def forward(self, *inputs, **kwargs): # If creating the execution agent for the first time, this skip check will not take effect. # It will only take effect on subsequent forward calls. create_execution_session = False - if self._skip_check.is_set(_SkipCheck.SKIP_CHECK_EXECUTION_AGENT) is False or \ - not self._execution_agent: - module_device = _utils.get_device_from_module( - self._original_module) - - create_execution_session = (build_graph or self._device != module_device or - torch.are_deterministic_algorithms_enabled() is not - _are_deterministic_algorithms_enabled()) + if self._skip_check.is_set(_SkipCheck.SKIP_CHECK_EXECUTION_AGENT) is False or not self._execution_agent: + module_device = _utils.get_device_from_module(self._original_module) + + create_execution_session = ( + build_graph + or self._device != module_device + or torch.are_deterministic_algorithms_enabled() is not _are_deterministic_algorithms_enabled() + ) _use_deterministic_algorithms(torch.are_deterministic_algorithms_enabled()) if self._device != module_device: @@ -122,29 +123,32 @@ def forward(self, *inputs, **kwargs): # Assert that the input and model device match _utils._check_same_device(self._device, "Input argument to forward", *inputs) - user_outputs, _ = InferenceManager.execution_session_run_forward(self._execution_agent, - self._onnx_models.optimized_model, - self._device, - *_io._combine_input_buffers_initializers( - self._graph_initializers, - self._graph_info.user_input_names, - self._input_info, - self._flattened_module.named_buffers(), - inputs, - kwargs, - self._device)) - - return _io.unflatten_user_output(self._module_output_schema, - user_outputs) + user_outputs, _ = InferenceManager.execution_session_run_forward( + self._execution_agent, + self._onnx_models.optimized_model, + self._device, + *_io._combine_input_buffers_initializers( + self._graph_initializers, + self._graph_info.user_input_names, + self._input_info, + self._flattened_module.named_buffers(), + inputs, + kwargs, + self._device, + ), + ) + + return _io.unflatten_user_output(self._module_output_schema, user_outputs) except ORTModuleFallbackException as e: # Exceptions subject to fallback are handled here - self._fallback_manager.handle_exception(exception=e, - log_level=self._debug_options.logging.log_level) + self._fallback_manager.handle_exception(exception=e, log_level=self._debug_options.logging.log_level) except Exception as e: # Catch-all FALLBACK_FORCE_TORCH_FORWARD fallback is handled here - self._fallback_manager.handle_exception(exception=e, - log_level=self._debug_options.logging.log_level, - override_policy=_FallbackPolicy.FALLBACK_FORCE_TORCH_FORWARD) + self._fallback_manager.handle_exception( + exception=e, + log_level=self._debug_options.logging.log_level, + override_policy=_FallbackPolicy.FALLBACK_FORCE_TORCH_FORWARD, + ) # Fallback to PyTorch due to failures *during* forward(), # (e.g. export, model/input post-processing, forward, output processing, etc) if self._fallback_manager.is_pending(): @@ -155,13 +159,16 @@ def _build_graph(self): super()._build_graph() if self._debug_options.save_onnx_models.save: - self._onnx_models.save_optimized_model(self._debug_options.save_onnx_models.path, - self._debug_options.save_onnx_models.name_prefix, - self._export_mode) + self._onnx_models.save_optimized_model( + self._debug_options.save_onnx_models.path, + self._debug_options.save_onnx_models.name_prefix, + self._export_mode, + ) def _create_execution_agent(self): """Creates an InferenceAgent that can run forward graph on an inference model""" session_options, providers, provider_options = self._get_session_config() - self._execution_agent = InferenceAgent(self._onnx_models.optimized_model.SerializeToString(), - session_options, providers, provider_options) + self._execution_agent = InferenceAgent( + self._onnx_models.optimized_model.SerializeToString(), session_options, providers, provider_options + ) diff --git a/orttraining/orttraining/python/training/ortmodule/_io.py b/orttraining/orttraining/python/training/ortmodule/_io.py index 1a62d5ab1495a..d99fd16b33e6b 100644 --- a/orttraining/orttraining/python/training/ortmodule/_io.py +++ b/orttraining/orttraining/python/training/ortmodule/_io.py @@ -13,8 +13,9 @@ from ._fallback import _FallbackManager, ORTModuleIOError, ORTModuleONNXModelException, wrap_exception from ._utils import warn_of_constant_inputs + class _OutputIdentityOp(torch.autograd.Function): - '''Internal class used to prepend Identity ops in model's outputs + """Internal class used to prepend Identity ops in model's outputs This class is required to support ONNX models which passthrough [some of] the models's inputs directly to the graph output. This is an issue because ONNX Runtime cannot build proper @@ -49,19 +50,24 @@ def forward(self, input1, passthrough_input): onnx subgraph for this example would be `passthrough_input -> Identity -> output2`. TODO: Remove once PyTorch 1.8.2 or newer is released - ''' + """ + @staticmethod def forward(ctx, input): return torch.nn.Identity()(input) + @staticmethod def backward(ctx, grad_output): return grad_output + @staticmethod def symbolic(g, self): return g.op("Identity", self) + class _PrimitiveType(object): _primitive_types = {int, bool, float} + @staticmethod def is_primitive_type(value): return type(value) in _PrimitiveType._primitive_types @@ -77,16 +83,19 @@ def get_primitive_dtype(value): # and the model will be re-exported. return f"{str(type(value))}_{value}" if isinstance(value, bool) else str(type(value)) + class _InputInfo(object): - def __init__(self, - names, - shape, - require_grad_names=None, - dynamic_axes=None, - schema=None, - num_positionals=0, - num_expanded_positionals_non_none=0, - keyword_names=None): + def __init__( + self, + names, + shape, + require_grad_names=None, + dynamic_axes=None, + schema=None, + num_positionals=0, + num_expanded_positionals_non_none=0, + keyword_names=None, + ): self.names = names self.shape = shape self.require_grad_names = require_grad_names if require_grad_names else [] @@ -97,7 +106,7 @@ def __init__(self, self.keyword_names = keyword_names def __repr__(self) -> str: - return f'''_InputInfo class: + return f"""_InputInfo class: \tNames: {self.names} \tShape: {self.shape} \tRequire gradient: {self.require_grad_names} @@ -105,14 +114,19 @@ def __repr__(self) -> str: \tSchema: {self.schema} \t#Positionals (total): {self.num_positionals} \t#Expanded Positionals (non-None): {self.num_expanded_positionals_non_none} - \tKeyword names: {self.keyword_names}''' + \tKeyword names: {self.keyword_names}""" def flatten(self, args, kwargs, device): - '''Flatten args and kwargs in a single tuple of tensors with strict ordering''' + """Flatten args and kwargs in a single tuple of tensors with strict ordering""" ret = [_PrimitiveType.get_tensor(arg, device) if _PrimitiveType.is_primitive_type(arg) else arg for arg in args] - ret += [_PrimitiveType.get_tensor(kwargs[name], device) if _PrimitiveType.is_primitive_type(kwargs[name]) - else kwargs[name] for name in self.names if name in kwargs] + ret += [ + _PrimitiveType.get_tensor(kwargs[name], device) + if _PrimitiveType.is_primitive_type(kwargs[name]) + else kwargs[name] + for name in self.names + if name in kwargs + ] # if kwargs is empty, append an empty dictionary at the end of the sample inputs to make exporter # happy. This is because the exporter is confused with kwargs and dictionary inputs otherwise. @@ -122,20 +136,26 @@ def flatten(self, args, kwargs, device): return ret def unflatten(self, flat_args): - '''Unflatten tuple of tensors into args and kwargs''' - - args = tuple(flat_args[:self.num_positionals]) - kwargs = {name: arg for name, arg in zip(self.names[self.num_expanded_positionals_non_none:], flat_args[self.num_positionals:]) \ - if name in self.keyword_names} + """Unflatten tuple of tensors into args and kwargs""" + + args = tuple(flat_args[: self.num_positionals]) + kwargs = { + name: arg + for name, arg in zip( + self.names[self.num_expanded_positionals_non_none :], flat_args[self.num_positionals :] + ) + if name in self.keyword_names + } return args, kwargs + def _combine_input_buffers_initializers(params, onnx_input_names, input_info, buffer_names, inputs, kwargs, device): - '''Creates forward `*inputs` list from user input and PyTorch initializers + """Creates forward `*inputs` list from user input and PyTorch initializers ONNX Runtime forward requires an ordered list of: * User input: computed from forward InferenceSession * Initializers: computed from original PyTorch model parameters - ''' + """ def _expand_inputs(current_input, non_none_inputs): # The exporter handles input lists by expanding them so that each @@ -158,7 +178,6 @@ def _expand_inputs(current_input, non_none_inputs): # else just collect all the non none inputs within non_none_inputs non_none_inputs.append(current_input) - # User inputs non_none_inputs = [] _expand_inputs(inputs, non_none_inputs) @@ -197,8 +216,9 @@ def _expand_inputs(current_input, non_none_inputs): inp = _PrimitiveType.get_tensor(inp, device) result.append(inp) else: - raise wrap_exception(ORTModuleONNXModelException, - RuntimeError(f'Input is present in ONNX graph but not provided: {name}.')) + raise wrap_exception( + ORTModuleONNXModelException, RuntimeError(f"Input is present in ONNX graph but not provided: {name}.") + ) # params is a list of all initializers known to the onnx graph result.extend(params) @@ -215,6 +235,7 @@ def extract_tensor(value): return value.data else: return value + sample_inputs_copy = [extract_tensor(value) for value in inputs] sample_inputs_copy = copy.deepcopy(tuple(sample_inputs_copy)) @@ -227,9 +248,9 @@ def extract_tensor(value): class _TensorStub(object): - '''Tensor stub class used to represent model's input or output''' + """Tensor stub class used to represent model's input or output""" - __slots__ = ['name', 'dtype', 'shape', 'shape_dims'] + __slots__ = ["name", "dtype", "shape", "shape_dims"] def __init__(self, name=None, dtype=None, shape=None, shape_dims=None): self.name = name @@ -238,29 +259,29 @@ def __init__(self, name=None, dtype=None, shape=None, shape_dims=None): self.shape_dims = shape_dims def __repr__(self) -> str: - result = '_TensorStub(' + result = "_TensorStub(" if self.name is not None: - result += f'name={self.name}' + result += f"name={self.name}" if self.dtype is not None: - if result[-1] != '(': - result += ', ' - result += f'dtype={self.dtype}' + if result[-1] != "(": + result += ", " + result += f"dtype={self.dtype}" if self.shape is not None: - if result[-1] != '(': - result += ', ' - result += f'shape={self.shape}' + if result[-1] != "(": + result += ", " + result += f"shape={self.shape}" if self.shape_dims is not None: - if result[-1] != '(': - result += ', ' - result += f'shape_dims={self.shape_dims}' - result += ')' + if result[-1] != "(": + result += ", " + result += f"shape_dims={self.shape_dims}" + result += ")" return result def __eq__(self, other): if not other: return False elif not isinstance(other, _TensorStub): - raise NotImplemented('_TensorStub must only be compared to another _TensorStub instance!') + raise NotImplemented("_TensorStub must only be compared to another _TensorStub instance!") elif self.name != other.name: return False elif self.dtype != other.dtype: @@ -288,23 +309,25 @@ def _replace_stub_with_tensor_value(user_output, outputs, output_idx): if isinstance(user_output, abc.Sequence): sequence_type = type(user_output) - if hasattr(sequence_type, '_make'): # namedtuple + if hasattr(sequence_type, "_make"): # namedtuple sequence_type = type(user_output) user_output = sequence_type._make( - _replace_stub_with_tensor_value(uo, outputs, output_idx) - for uo in user_output) + _replace_stub_with_tensor_value(uo, outputs, output_idx) for uo in user_output + ) else: user_output = sequence_type( - _replace_stub_with_tensor_value(uo, outputs, output_idx) - for uo in user_output) + _replace_stub_with_tensor_value(uo, outputs, output_idx) for uo in user_output + ) elif isinstance(user_output, abc.Mapping): new_user_output = copy.copy(user_output) for key in sorted(user_output): new_user_output[key] = _replace_stub_with_tensor_value(new_user_output[key], outputs, output_idx) user_output = new_user_output else: - raise wrap_exception(ORTModuleIOError, - TypeError(f'ORTModule does not support the following model output type {type(user_output)}.')) + raise wrap_exception( + ORTModuleIOError, + TypeError(f"ORTModule does not support the following model output type {type(user_output)}."), + ) return user_output @@ -348,8 +371,9 @@ def _extract_schema(data): stubbed_schema = {key: _extract_schema(data[key]) for key in data} stubbed_schema = dict_type(**stubbed_schema) else: - raise wrap_exception(ORTModuleIOError, - TypeError(f'ORTModule does not support the following model data type {type(data)}')) + raise wrap_exception( + ORTModuleIOError, TypeError(f"ORTModule does not support the following model data type {type(data)}") + ) return stubbed_schema @@ -364,12 +388,12 @@ def _populate_output_names_and_dynamic_axes(output, output_names, output_dynamic elif isinstance(output, torch.Tensor): # Naming the outputs with a hyphen ensures that there can be no input with the same # name, preventing collisions with other NodeArgs (for example an input to forward called output0) - output_name = f'output-{output_idx[0]}' + output_name = f"output-{output_idx[0]}" output_idx[0] += 1 output_names.append(output_name) output_dynamic_axes[output_name] = {} for dim_idx in range(len(output.shape)): - output_dynamic_axes[output_name].update({dim_idx: f'{output_name}_dim{dim_idx}'}) + output_dynamic_axes[output_name].update({dim_idx: f"{output_name}_dim{dim_idx}"}) return if isinstance(output, abc.Sequence): @@ -379,8 +403,10 @@ def _populate_output_names_and_dynamic_axes(output, output_names, output_dynamic for _, value in sorted(output.items()): _populate_output_names_and_dynamic_axes(value, output_names, output_dynamic_axes, output_idx) else: - raise wrap_exception(ORTModuleIOError, - TypeError(f'ORTModule does not support the following model output type {type(output)}')) + raise wrap_exception( + ORTModuleIOError, + TypeError(f"ORTModule does not support the following model output type {type(output)}"), + ) output_names = [] output_dynamic_axes = {} @@ -408,8 +434,9 @@ def _flatten_data(data, flat_data): for _, value in sorted(data.items()): _flatten_data(value, flat_data) else: - raise wrap_exception(ORTModuleIOError, - TypeError(f'ORTModule does not support the following data type {type(data)}.')) + raise wrap_exception( + ORTModuleIOError, TypeError(f"ORTModule does not support the following data type {type(data)}.") + ) flat_data = [] _flatten_data(data, flat_data) @@ -431,11 +458,10 @@ def forward(self, *args): def parse_inputs_for_onnx_export(all_input_parameters, onnx_graph, schema, inputs, kwargs): - def _add_dynamic_shape(name, input): dynamic_axes[name] = {} for dim_idx in range(len(input.shape)): - dynamic_axes[name].update({dim_idx: f'{name}_dim{dim_idx}'}) + dynamic_axes[name].update({dim_idx: f"{name}_dim{dim_idx}"}) return dynamic_axes def _add_input(name, input, onnx_graph, onnx_graph_input_names): @@ -452,8 +478,7 @@ def _add_input(name, input, onnx_graph, onnx_graph_input_names): for i, val in enumerate(input): # Name each input with the index appended to the original name of the # argument. - num_expanded_non_none_inputs += \ - _add_input(f"{name}_{i}", val, onnx_graph, onnx_graph_input_names) + num_expanded_non_none_inputs += _add_input(f"{name}_{i}", val, onnx_graph, onnx_graph_input_names) # Return here since the list by itself is not a valid input. # All the elements of the list have already been added as inputs individually. @@ -462,8 +487,7 @@ def _add_input(name, input, onnx_graph, onnx_graph_input_names): # If the input is a mapping (like a dict), expand the dict so that # each element of the dict is an input by itself. for key, val in input.items(): - num_expanded_non_none_inputs += \ - _add_input(f"{name}_{key}", val, onnx_graph, onnx_graph_input_names) + num_expanded_non_none_inputs += _add_input(f"{name}_{key}", val, onnx_graph, onnx_graph_input_names) # Return here since the dict by itself is not a valid input. # All the elements of the dict have already been added as inputs individually. @@ -500,14 +524,15 @@ def _add_input(name, input, onnx_graph, onnx_graph_input_names): # VAR_POSITIONAL parameter carries all *args parameters from original forward method for args_i in range(input_idx, len(inputs)): - name = f'{input_parameter.name}_{var_positional_idx}' + name = f"{input_parameter.name}_{var_positional_idx}" var_positional_idx += 1 inp = inputs[args_i] - num_expanded_non_none_positional_inputs += \ - _add_input(name, inp, onnx_graph, onnx_graph_input_names) - elif input_parameter.kind == inspect.Parameter.POSITIONAL_ONLY or\ - input_parameter.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD or\ - input_parameter.kind == inspect.Parameter.KEYWORD_ONLY: + num_expanded_non_none_positional_inputs += _add_input(name, inp, onnx_graph, onnx_graph_input_names) + elif ( + input_parameter.kind == inspect.Parameter.POSITIONAL_ONLY + or input_parameter.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD + or input_parameter.kind == inspect.Parameter.KEYWORD_ONLY + ): # All positional non-*args and non-**kwargs are processed here name = input_parameter.name inp = None @@ -518,27 +543,27 @@ def _add_input(name, input, onnx_graph, onnx_graph_input_names): elif name in kwargs and kwargs[name] is not None: inp = kwargs[name] is_positional = False - num_expanded_non_none_inputs_local = \ - _add_input(name, inp, onnx_graph, onnx_graph_input_names) + num_expanded_non_none_inputs_local = _add_input(name, inp, onnx_graph, onnx_graph_input_names) if is_positional: num_expanded_non_none_positional_inputs += num_expanded_non_none_inputs_local elif input_parameter.kind == inspect.Parameter.VAR_KEYWORD: # **kwargs is always the last argument of forward() - for name,inp in kwargs.items(): + for name, inp in kwargs.items(): if name not in input_names: _add_input(name, inp, onnx_graph, onnx_graph_input_names) - # input_names have been expanded so to get the correct number of non none # positional names, we need to collect the num_expanded_non_none_positional_inputs. - return _InputInfo(names=input_names, - shape=input_shape, - require_grad_names=input_names_require_grad, - dynamic_axes=dynamic_axes, - schema=schema, - num_positionals=len(inputs), - num_expanded_positionals_non_none=num_expanded_non_none_positional_inputs, - keyword_names=list(kwargs.keys())) + return _InputInfo( + names=input_names, + shape=input_shape, + require_grad_names=input_names_require_grad, + dynamic_axes=dynamic_axes, + schema=schema, + num_positionals=len(inputs), + num_expanded_positionals_non_none=num_expanded_non_none_positional_inputs, + keyword_names=list(kwargs.keys()), + ) def parse_outputs_for_onnx_export_and_extract_schema(module, inputs, kwargs): @@ -555,9 +580,11 @@ def parse_outputs_for_onnx_export_and_extract_schema(module, inputs, kwargs): is_deepcopy = True except Exception: model_copy = module - warnings.warn("This model cannot be deep copied (or pickled), " - "which is a required step for stateful models to be properly exported to ONNX." - " Compute will continue, but unexpected results may occur!") + warnings.warn( + "This model cannot be deep copied (or pickled), " + "which is a required step for stateful models to be properly exported to ONNX." + " Compute will continue, but unexpected results may occur!" + ) sample_outputs = model_copy(*sample_inputs_copy, **sample_kwargs_copy) diff --git a/orttraining/orttraining/python/training/ortmodule/_logger.py b/orttraining/orttraining/python/training/ortmodule/_logger.py index c503df765f7d9..66e1cb556538f 100644 --- a/orttraining/orttraining/python/training/ortmodule/_logger.py +++ b/orttraining/orttraining/python/training/ortmodule/_logger.py @@ -49,12 +49,18 @@ def suppress_os_stream_output(suppress_stdout=True, suppress_stderr=True, log_le if fo.tell() > 0 and suppress_logs: # If anything was captured in fo, raise a single user warning letting users know that there was # some warning or error that was raised - warnings.warn("There were one or more warnings or errors raised while exporting the PyTorch " - "model. Please enable INFO level logging to view all warnings and errors.", UserWarning) + warnings.warn( + "There were one or more warnings or errors raised while exporting the PyTorch " + "model. Please enable INFO level logging to view all warnings and errors.", + UserWarning, + ) + def ortmodule_loglevel_to_onnxruntime_c_loglevel(loglevel): - return {LogLevel.VERBOSE: Severity.VERBOSE, - LogLevel.INFO: Severity.INFO, - LogLevel.WARNING: Severity.WARNING, - LogLevel.ERROR: Severity.ERROR, - LogLevel.FATAL: Severity.FATAL}.get(loglevel, Severity.WARNING) + return { + LogLevel.VERBOSE: Severity.VERBOSE, + LogLevel.INFO: Severity.INFO, + LogLevel.WARNING: Severity.WARNING, + LogLevel.ERROR: Severity.ERROR, + LogLevel.FATAL: Severity.FATAL, + }.get(loglevel, Severity.WARNING) diff --git a/orttraining/orttraining/python/training/ortmodule/_onnx_models.py b/orttraining/orttraining/python/training/ortmodule/_onnx_models.py index f583f7efa361a..3f2dc04e26c14 100644 --- a/orttraining/orttraining/python/training/ortmodule/_onnx_models.py +++ b/orttraining/orttraining/python/training/ortmodule/_onnx_models.py @@ -7,13 +7,16 @@ import os import torch + def _get_onnx_file_name(name_prefix, name, export_mode): - suffix = 'training' if export_mode == torch.onnx.TrainingMode.TRAINING else 'inference' + suffix = "training" if export_mode == torch.onnx.TrainingMode.TRAINING else "inference" return f"{name_prefix}_{name}_{suffix}.onnx" + def _save_model(model: onnx.ModelProto, file_path: str): onnx.save(model, file_path) + @dataclass class ONNXModels: """Encapsulates all ORTModule onnx models. @@ -32,12 +35,16 @@ class ONNXModels: def save_exported_model(self, path, name_prefix, export_mode): # save the ortmodule exported model - _save_model(self.exported_model, - os.path.join(path, _get_onnx_file_name(name_prefix, 'torch_exported', export_mode))) + _save_model( + self.exported_model, os.path.join(path, _get_onnx_file_name(name_prefix, "torch_exported", export_mode)) + ) def save_optimized_model(self, path, name_prefix, export_mode): # save the ortmodule optimized model - _save_model(self.optimized_model, - os.path.join(path, _get_onnx_file_name(name_prefix, 'optimized', export_mode))) - _save_model(self.optimized_pre_grad_model, - os.path.join(path, _get_onnx_file_name(name_prefix, 'optimized_pre_grad', export_mode))) + _save_model( + self.optimized_model, os.path.join(path, _get_onnx_file_name(name_prefix, "optimized", export_mode)) + ) + _save_model( + self.optimized_pre_grad_model, + os.path.join(path, _get_onnx_file_name(name_prefix, "optimized_pre_grad", export_mode)), + ) diff --git a/orttraining/orttraining/python/training/ortmodule/_torch_module_interface.py b/orttraining/orttraining/python/training/ortmodule/_torch_module_interface.py index 769f44affb47e..6d7a9db2433a0 100644 --- a/orttraining/orttraining/python/training/ortmodule/_torch_module_interface.py +++ b/orttraining/orttraining/python/training/ortmodule/_torch_module_interface.py @@ -7,7 +7,7 @@ from typing import Iterator, Optional, Tuple, TypeVar, Callable -T = TypeVar('T', bound='torch.nn.Module') +T = TypeVar("T", bound="torch.nn.Module") class TorchModuleInterface: @@ -52,11 +52,10 @@ def is_training(self): def train(self: T, mode: bool = True) -> T: raise NotImplementedError(f"train is not implemented for {type(self)}.") - def state_dict(self, destination=None, prefix='', keep_vars=False): + def state_dict(self, destination=None, prefix="", keep_vars=False): raise NotImplementedError(f"state_dict is not implemented for {type(self)}.") - def load_state_dict(self, state_dict: 'OrderedDict[str, torch.Tensor]', - strict: bool = True): + def load_state_dict(self, state_dict: "OrderedDict[str, torch.Tensor]", strict: bool = True): raise NotImplementedError(f"load_state_dict is not implemented for {type(self)}.") def register_buffer(self, name: str, tensor: Optional[torch.Tensor], persistent: bool = True) -> None: @@ -74,17 +73,18 @@ def get_buffer(self, target: str) -> torch.Tensor: def parameters(self, recurse: bool = True) -> Iterator[torch.nn.Parameter]: raise NotImplementedError(f"parameters is not implemented for {type(self)}.") - def named_parameters(self, prefix: str = '', recurse: bool = True) -> Iterator[Tuple[str, torch.nn.Parameter]]: + def named_parameters(self, prefix: str = "", recurse: bool = True) -> Iterator[Tuple[str, torch.nn.Parameter]]: raise NotImplementedError(f"named_parameters is not implemented for {type(self)}.") def buffers(self, recurse: bool = True) -> Iterator[torch.Tensor]: raise NotImplementedError(f"buffers is not implemented for {type(self)}.") - def named_buffers(self, prefix: str = '', recurse: bool = True) -> Iterator[Tuple[str, torch.Tensor]]: + def named_buffers(self, prefix: str = "", recurse: bool = True) -> Iterator[Tuple[str, torch.Tensor]]: raise NotImplementedError(f"named_buffers is not implemented for {type(self)}.") - def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, - missing_keys, unexpected_keys, error_msgs): + def _load_from_state_dict( + self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs + ): raise NotImplementedError(f"_load_from_state_dict is not implemented for {type(self)}.") def named_children(self) -> Iterator[Tuple[str, T]]: @@ -99,5 +99,5 @@ def named_modules(self, *args, **kwargs): def _replicate_for_data_parallel(self): raise NotImplementedError(f"_replicate_for_data_parallel is not implemented for {type(self)}.") - def add_module(self, name: str, module: Optional['Module']) -> None: + def add_module(self, name: str, module: Optional["Module"]) -> None: raise NotImplementedError(f"add_module is not implemented for {type(self)}.") diff --git a/orttraining/orttraining/python/training/ortmodule/_torch_module_ort.py b/orttraining/orttraining/python/training/ortmodule/_torch_module_ort.py index b64821e3996ba..ea0676b12587c 100644 --- a/orttraining/orttraining/python/training/ortmodule/_torch_module_ort.py +++ b/orttraining/orttraining/python/training/ortmodule/_torch_module_ort.py @@ -12,7 +12,7 @@ from typing import Iterator, Optional, Tuple, TypeVar, Callable -T = TypeVar('T', bound='torch.nn.Module') +T = TypeVar("T", bound="torch.nn.Module") class TorchModuleORT(TorchModuleInterface): @@ -51,22 +51,19 @@ def train(self: T, mode: bool = True) -> T: self._flattened_module.train(mode) return self - def state_dict(self, destination=None, prefix='', keep_vars=False): + def state_dict(self, destination=None, prefix="", keep_vars=False): """Override original method to delegate execution to the original PyTorch user module""" # Override the state_dict() method so that the state dict key names # do not contain the flattened_module._original_module prefix - return self._original_module.state_dict( - destination=destination, prefix=prefix, keep_vars=keep_vars) + return self._original_module.state_dict(destination=destination, prefix=prefix, keep_vars=keep_vars) - def load_state_dict(self, state_dict: 'OrderedDict[str, torch.Tensor]', - strict: bool = True): + def load_state_dict(self, state_dict: "OrderedDict[str, torch.Tensor]", strict: bool = True): """Override original method to delegate execution to the original PyTorch user module""" # Override the load_state_dict() method so that the loaded state dict # key names does not need to contain the _module.flattened_module._original_module prefix - return self._original_module.load_state_dict( - state_dict, strict=strict) + return self._original_module.load_state_dict(state_dict, strict=strict) def register_buffer(self, name: str, tensor: Optional[torch.Tensor], persistent: bool = True) -> None: """Override original method to delegate execution to the original PyTorch user module""" @@ -93,7 +90,7 @@ def parameters(self, recurse: bool = True) -> Iterator[torch.nn.Parameter]: yield from self._original_module.parameters(recurse=recurse) - def named_parameters(self, prefix: str = '', recurse: bool = True) -> Iterator[Tuple[str, torch.nn.Parameter]]: + def named_parameters(self, prefix: str = "", recurse: bool = True) -> Iterator[Tuple[str, torch.nn.Parameter]]: """Override original method to delegate execution to the original PyTorch user module""" yield from self._original_module.named_parameters(prefix=prefix, recurse=recurse) @@ -103,13 +100,14 @@ def buffers(self, recurse: bool = True) -> Iterator[torch.Tensor]: yield from self._original_module.buffers(recurse=recurse) - def named_buffers(self, prefix: str = '', recurse: bool = True) -> Iterator[Tuple[str, torch.Tensor]]: + def named_buffers(self, prefix: str = "", recurse: bool = True) -> Iterator[Tuple[str, torch.Tensor]]: """Override original method to delegate execution to the original PyTorch user module""" yield from self._original_module.named_buffers(prefix=prefix, recurse=recurse) - def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, - missing_keys, unexpected_keys, error_msgs): + def _load_from_state_dict( + self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs + ): """Override original method to delegate execution to the original PyTorch user module""" # PyTorch load_state_dict implementation does not recursively call load_state_dict on its sub-modules. @@ -117,8 +115,9 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, # For the scenario where an ORTModule is a sub-module of another module, loading of the state # dictionary requires the _load_from_state_dict to be overridden to prevent an error. - self._original_module._load_from_state_dict(state_dict, prefix, local_metadata, strict, - missing_keys, unexpected_keys, error_msgs) + self._original_module._load_from_state_dict( + state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs + ) def named_children(self) -> Iterator[Tuple[str, T]]: """Override original method to delegate execution to the original PyTorch user module""" @@ -138,13 +137,18 @@ def named_modules(self, *args, **kwargs): yield from self._original_module.named_modules(*args, **kwargs) def _replicate_for_data_parallel(self): - raise wrap_exception(ORTModuleTorchModelException, - NotImplementedError("ORTModule is not compatible with torch.nn.DataParallel. " - "Please use torch.nn.parallel.DistributedDataParallel instead.")) - - def add_module(self, name: str, module: Optional['Module']) -> None: - raise wrap_exception(ORTModuleTorchModelException, - NotImplementedError("ORTModule does not support adding modules to it.")) + raise wrap_exception( + ORTModuleTorchModelException, + NotImplementedError( + "ORTModule is not compatible with torch.nn.DataParallel. " + "Please use torch.nn.parallel.DistributedDataParallel instead." + ), + ) + + def add_module(self, name: str, module: Optional["Module"]) -> None: + raise wrap_exception( + ORTModuleTorchModelException, NotImplementedError("ORTModule does not support adding modules to it.") + ) @TorchModuleInterface.module.getter def module(self): diff --git a/orttraining/orttraining/python/training/ortmodule/_torch_module_pytorch.py b/orttraining/orttraining/python/training/ortmodule/_torch_module_pytorch.py index 29066738f3dee..44a43b2429e1c 100644 --- a/orttraining/orttraining/python/training/ortmodule/_torch_module_pytorch.py +++ b/orttraining/orttraining/python/training/ortmodule/_torch_module_pytorch.py @@ -9,11 +9,10 @@ from typing import Iterator, Optional, Tuple, TypeVar, Callable -T = TypeVar('T', bound='torch.nn.Module') +T = TypeVar("T", bound="torch.nn.Module") class TorchModulePytorch(TorchModuleInterface): - def __init__(self, module: torch.nn.Module): super().__init__(module) self._original_module = module @@ -33,10 +32,10 @@ def train(self: T, mode: bool = True) -> T: self._original_module.train(mode) return self - def state_dict(self, destination=None, prefix='', keep_vars=False): + def state_dict(self, destination=None, prefix="", keep_vars=False): return self._original_module.state_dict(destination=destination, prefix=prefix, keep_vars=keep_vars) - def load_state_dict(self, state_dict: 'OrderedDict[str, torch.Tensor]', strict: bool = True): + def load_state_dict(self, state_dict: "OrderedDict[str, torch.Tensor]", strict: bool = True): return self._original_module.load_state_dict(state_dict, strict=strict) def register_buffer(self, name: str, tensor: Optional[torch.Tensor], persistent: bool = True) -> None: @@ -54,19 +53,21 @@ def get_buffer(self, target: str) -> torch.Tensor: def parameters(self, recurse: bool = True) -> Iterator[torch.nn.Parameter]: yield from self._original_module.parameters(recurse=recurse) - def named_parameters(self, prefix: str = '', recurse: bool = True) -> Iterator[Tuple[str, torch.nn.Parameter]]: + def named_parameters(self, prefix: str = "", recurse: bool = True) -> Iterator[Tuple[str, torch.nn.Parameter]]: yield from self._original_module.named_parameters(prefix=prefix, recurse=recurse) def buffers(self, recurse: bool = True) -> Iterator[torch.Tensor]: yield from self._original_module.buffers(recurse=recurse) - def named_buffers(self, prefix: str = '', recurse: bool = True) -> Iterator[Tuple[str, torch.Tensor]]: + def named_buffers(self, prefix: str = "", recurse: bool = True) -> Iterator[Tuple[str, torch.Tensor]]: yield from self._original_module.named_buffers(prefix=prefix, recurse=recurse) - def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, - missing_keys, unexpected_keys, error_msgs): - self._original_module._load_from_state_dict(state_dict, prefix, local_metadata, strict, - missing_keys, unexpected_keys, error_msgs) + def _load_from_state_dict( + self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs + ): + self._original_module._load_from_state_dict( + state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs + ) def named_children(self) -> Iterator[Tuple[str, T]]: yield from self._original_module.named_children() @@ -82,7 +83,7 @@ def named_modules(self, *args, **kwargs): def _replicate_for_data_parallel(self): return self._original_module._replicate_for_data_parallel() - def add_module(self, name: str, module: Optional['Module']) -> None: + def add_module(self, name: str, module: Optional["Module"]) -> None: self._original_module.add_module(name, module) @TorchModuleInterface.module.getter diff --git a/orttraining/orttraining/python/training/ortmodule/_training_manager.py b/orttraining/orttraining/python/training/ortmodule/_training_manager.py index 513bee9342a07..17fa40326bed0 100644 --- a/orttraining/orttraining/python/training/ortmodule/_training_manager.py +++ b/orttraining/orttraining/python/training/ortmodule/_training_manager.py @@ -3,19 +3,11 @@ # Licensed under the MIT License. # -------------------------------------------------------------------------- -from . import (_utils, - _io, - _logger, - _are_deterministic_algorithms_enabled, - _use_deterministic_algorithms) -from ._graph_execution_manager import (GraphExecutionManager, - _RunStateInfo, - _SkipCheck) +from . import _utils, _io, _logger, _are_deterministic_algorithms_enabled, _use_deterministic_algorithms +from ._graph_execution_manager import GraphExecutionManager, _RunStateInfo, _SkipCheck from ._execution_agent import TrainingAgent from .debug_options import DebugOptions -from ._fallback import (ORTModuleFallbackException, - _FallbackPolicy, - _FallbackManager) +from ._fallback import ORTModuleFallbackException, _FallbackPolicy, _FallbackManager from onnxruntime.capi import _pybind_state as C from onnxruntime.capi.onnxruntime_inference_collection import get_ort_device_type @@ -50,7 +42,7 @@ def execution_session_run_forward(execution_session, onnx_model, device, gradien # TODO: Non-contiguous tensor input in execution_session_run_forward, need tensor copy. if not input.is_contiguous(): input = input.contiguous() - if input.device.type == 'ort': + if input.device.type == "ort": forward_inputs.push_back(C.aten_ort_tensor_to_ort_value(input)) else: valid_ort_tensor = _utils._torch_tensor_to_dlpack(input) @@ -68,21 +60,20 @@ def execution_session_run_forward(execution_session, onnx_model, device, gradien return user_outputs, run_info def _create_autofunction_class(self): - class _ORTModuleFunction(torch.autograd.Function): - '''Use a custom torch.autograd.Function to associate self.backward_graph as the - gradient implementation for self.forward_graph.''' + """Use a custom torch.autograd.Function to associate self.backward_graph as the + gradient implementation for self.forward_graph.""" @staticmethod def forward(ctx, *inputs): - '''Performs forward pass based on user input and PyTorch initializer + """Performs forward pass based on user input and PyTorch initializer Autograd Function's apply() doesn't support keyword arguments, so `*inputs` has all the arguments - keyword arguments converted to positional/keywords during `TrainingManager.forward`. Module outputs are returned to the user - ''' + """ if self._skip_check.is_set(_SkipCheck.SKIP_CHECK_DEVICE) is False: # Assert that the input and model device match @@ -93,7 +84,8 @@ def forward(ctx, *inputs): self._onnx_models.optimized_model, self._device, self._gradient_accumulation_manager, - *inputs) + *inputs, + ) # Disable materializing grads then None object will not be # converted to a tensor filled with zeros prior to calling backward. @@ -120,9 +112,9 @@ def forward(ctx, *inputs): @staticmethod def backward(ctx, *grad_outputs): - '''Performs backward pass based on grad wrt module output''' + """Performs backward pass based on grad wrt module output""" - assert ctx.run_info is not None, 'forward() or __call__() methods must be called before backward()' + assert ctx.run_info is not None, "forward() or __call__() methods must be called before backward()" if self._skip_check.is_set(_SkipCheck.SKIP_CHECK_DEVICE) is False: _utils._check_same_device(self._device, "Input argument to backward", *grad_outputs) @@ -136,11 +128,12 @@ def backward(ctx, *grad_outputs): backward_inputs.reserve(len(grad_outputs)) for idx, grad_output in enumerate(grad_outputs): if idx in self._graph_info.output_grad_indices_non_differentiable: - assert grad_output is None, "ORT found the {}-th module output '{}' is " \ - "non-differentiable according to the onnx graph. " \ - "However, the gradient value is still provided by " \ - "PyTorch's autograd engine." \ - .format(idx, self._graph_info.user_output_names[idx]) + assert grad_output is None, ( + "ORT found the {}-th module output '{}' is " + "non-differentiable according to the onnx graph. " + "However, the gradient value is still provided by " + "PyTorch's autograd engine.".format(idx, self._graph_info.user_output_names[idx]) + ) continue if grad_output is None: @@ -148,14 +141,15 @@ def backward(ctx, *grad_outputs): if idx in self._graph_info.output_grad_indices_require_full_shape: grad_output = torch.zeros(shape, device=device, dtype=dtype) else: - grad_output = torch.tensor(0., device=device, dtype=dtype) + grad_output = torch.tensor(0.0, device=device, dtype=dtype) elif not grad_output.is_contiguous(): grad_output = grad_output.contiguous() - if grad_output.device.type == 'ort': + if grad_output.device.type == "ort": backward_inputs.push_back(C.aten_ort_tensor_to_ort_value(grad_output)) else: - backward_inputs.push_back(_utils._torch_tensor_to_dlpack(grad_output), - grad_output.dtype is torch.bool) + backward_inputs.push_back( + _utils._torch_tensor_to_dlpack(grad_output), grad_output.dtype is torch.bool + ) backward_inputs.shrink_to_fit() # Run and get results @@ -173,12 +167,12 @@ def backward(ctx, *grad_outputs): return _ORTModuleFunction def forward(self, *inputs, **kwargs): - '''Forward pass starts here and continues at `_ORTModuleFunction.forward` + """Forward pass starts here and continues at `_ORTModuleFunction.forward` ONNX model is exported the first time this method is executed. Next, we build a full training graph with module_graph_builder. Finally, we instantiate the ONNX Runtime InferenceSession. - ''' + """ # Fallback to PyTorch due to failures *external* to forward(), # typically from initialization @@ -186,20 +180,28 @@ def forward(self, *inputs, **kwargs): return self._fallback_manager.fallback(self._debug_options.logging.log_level, *inputs, **kwargs) try: - if self._first_skip_check_warning is True and self._skip_check.is_disabled() is False \ - and self._debug_options.logging.log_level <= _logger.LogLevel.WARNING: + if ( + self._first_skip_check_warning is True + and self._skip_check.is_disabled() is False + and self._debug_options.logging.log_level <= _logger.LogLevel.WARNING + ): # Only change this after the firs time a warning is issued. self._first_skip_check_warning = False - warnings.warn(f"Fast path enabled - skipping checks." - f" Rebuild graph: {self._skip_check.is_set(_SkipCheck.SKIP_CHECK_BUILD_GRADIENT)}," - f" Execution agent: {self._skip_check.is_set(_SkipCheck.SKIP_CHECK_EXECUTION_AGENT)}," - f" Device check: {self._skip_check.is_set(_SkipCheck.SKIP_CHECK_DEVICE)}", UserWarning) + warnings.warn( + f"Fast path enabled - skipping checks." + f" Rebuild graph: {self._skip_check.is_set(_SkipCheck.SKIP_CHECK_BUILD_GRADIENT)}," + f" Execution agent: {self._skip_check.is_set(_SkipCheck.SKIP_CHECK_EXECUTION_AGENT)}," + f" Device check: {self._skip_check.is_set(_SkipCheck.SKIP_CHECK_DEVICE)}", + UserWarning, + ) # If exporting module to ONNX for the first time, this skip check will not take effect. # It will only take effect on subsequent forward calls. build_gradient_graph = False - if self._skip_check.is_set(_SkipCheck.SKIP_CHECK_BUILD_GRADIENT) is False or \ - not self._onnx_models.exported_model: + if ( + self._skip_check.is_set(_SkipCheck.SKIP_CHECK_BUILD_GRADIENT) is False + or not self._onnx_models.exported_model + ): build_gradient_graph = self._export_model(*inputs, **kwargs) if build_gradient_graph: # If model was exported, then initialize the graph builder @@ -208,17 +210,14 @@ def forward(self, *inputs, **kwargs): # since the schema was just extracted while trying to export the model and it was either # saved to self._input_info.schema or checked for equality with the self._input_info.schema # it should not need to be updated again. Pass it inside parse_inputs_for_onnx_export. - input_info = _io.parse_inputs_for_onnx_export(self._module_parameters, - self._onnx_models.exported_model, - self._input_info.schema, - inputs, - kwargs) + input_info = _io.parse_inputs_for_onnx_export( + self._module_parameters, self._onnx_models.exported_model, self._input_info.schema, inputs, kwargs + ) # Reinitialize graph builder if the inputs or initializers requiring gradient have changed. # Order of or operation is important here because we always need to call # _reinitialize_graph_builder irrespective of the value of build_gradient_graph. - build_gradient_graph = self._reinitialize_graph_builder( - input_info) or build_gradient_graph + build_gradient_graph = self._reinitialize_graph_builder(input_info) or build_gradient_graph # Build the gradient graph if build_gradient_graph: @@ -227,13 +226,15 @@ def forward(self, *inputs, **kwargs): # If creating the execution agent for the first time, this skip check will not take effect. # It will only take effect on subsequent forward calls. create_execution_session = False - if self._skip_check.is_set(_SkipCheck.SKIP_CHECK_EXECUTION_AGENT) is False or \ - not self._execution_agent: - device = _utils.get_device_from_module(self._original_module) or \ - _utils.get_device_from_inputs(inputs, kwargs) - create_execution_session = (build_gradient_graph or self._device != device or - torch.are_deterministic_algorithms_enabled() is not - _are_deterministic_algorithms_enabled()) + if self._skip_check.is_set(_SkipCheck.SKIP_CHECK_EXECUTION_AGENT) is False or not self._execution_agent: + device = _utils.get_device_from_module(self._original_module) or _utils.get_device_from_inputs( + inputs, kwargs + ) + create_execution_session = ( + build_gradient_graph + or self._device != device + or torch.are_deterministic_algorithms_enabled() is not _are_deterministic_algorithms_enabled() + ) _use_deterministic_algorithms(torch.are_deterministic_algorithms_enabled()) if self._device != device: self._device = device @@ -242,31 +243,36 @@ def forward(self, *inputs, **kwargs): # Create execution session creates the training_session self._create_execution_agent() - self._gradient_accumulation_manager.initialize(self._enable_grad_acc_optimization, - self._flattened_module, - self._graph_info) + self._gradient_accumulation_manager.initialize( + self._enable_grad_acc_optimization, self._flattened_module, self._graph_info + ) self._gradient_accumulation_manager.maybe_update_cache_before_run() - return _io.unflatten_user_output(self._module_output_schema, - self._forward_class.apply( - *_io._combine_input_buffers_initializers( - self._graph_initializers, - self._graph_info.user_input_names, - self._input_info, - self._flattened_module.named_buffers(), - inputs, - kwargs, - self._device))) + return _io.unflatten_user_output( + self._module_output_schema, + self._forward_class.apply( + *_io._combine_input_buffers_initializers( + self._graph_initializers, + self._graph_info.user_input_names, + self._input_info, + self._flattened_module.named_buffers(), + inputs, + kwargs, + self._device, + ) + ), + ) except ORTModuleFallbackException as e: # Exceptions subject to fallback are handled here - self._fallback_manager.handle_exception(exception=e, - log_level=self._debug_options.logging.log_level) + self._fallback_manager.handle_exception(exception=e, log_level=self._debug_options.logging.log_level) except Exception as e: # Catch-all FALLBACK_FORCE_TORCH_FORWARD fallback is handled here - self._fallback_manager.handle_exception(exception=e, - log_level=self._debug_options.logging.log_level, - override_policy=_FallbackPolicy.FALLBACK_FORCE_TORCH_FORWARD) + self._fallback_manager.handle_exception( + exception=e, + log_level=self._debug_options.logging.log_level, + override_policy=_FallbackPolicy.FALLBACK_FORCE_TORCH_FORWARD, + ) # Fallback to PyTorch due to failures *during* forward(), # (e.g. export, model/input post-processing, forward, output processing, etc) @@ -279,9 +285,11 @@ def _build_graph(self): super()._build_graph() if self._debug_options.save_onnx_models.save: - self._onnx_models.save_optimized_model(self._debug_options.save_onnx_models.path, - self._debug_options.save_onnx_models.name_prefix, - self._export_mode) + self._onnx_models.save_optimized_model( + self._debug_options.save_onnx_models.path, + self._debug_options.save_onnx_models.name_prefix, + self._export_mode, + ) def _create_execution_agent(self): """Creates a TrainingAgent that can run the forward and backward graph on the training model""" @@ -289,36 +297,41 @@ def _create_execution_agent(self): session_options, providers, provider_options = self._get_session_config() fw_feed_names = [input.name for input in self._onnx_models.optimized_model.graph.input] device_type = self._device if type(self._device) is str else self._device.type.lower() - if device_type == 'ort': - fw_outputs_device_info = [C.get_ort_device(self._device.index)] * (len(self._graph_info.user_output_names) + - len(self._graph_info.frontier_node_arg_map)) + if device_type == "ort": + fw_outputs_device_info = [C.get_ort_device(self._device.index)] * ( + len(self._graph_info.user_output_names) + len(self._graph_info.frontier_node_arg_map) + ) else: fw_outputs_device_info = [ - C.OrtDevice(get_ort_device_type(self._device), - C.OrtDevice.default_memory(), - _utils.get_device_index(self._device) - )] * (len(self._graph_info.user_output_names) + - len(self._graph_info.frontier_node_arg_map)) + C.OrtDevice( + get_ort_device_type(self._device), + C.OrtDevice.default_memory(), + _utils.get_device_index(self._device), + ) + ] * (len(self._graph_info.user_output_names) + len(self._graph_info.frontier_node_arg_map)) bw_fetches_names = [output.name for output in self._onnx_models.optimized_model.graph.output] - if device_type == 'ort': - bw_outputs_device_info = [ - C.get_ort_device(self._device.index)] * len(bw_fetches_names) + if device_type == "ort": + bw_outputs_device_info = [C.get_ort_device(self._device.index)] * len(bw_fetches_names) else: bw_outputs_device_info = [ - C.OrtDevice(get_ort_device_type(self._device), - C.OrtDevice.default_memory(), - _utils.get_device_index(self._device) - )] * len(bw_fetches_names) - - self._execution_agent = TrainingAgent(self._onnx_models.optimized_model.SerializeToString(), - fw_feed_names, - fw_outputs_device_info, - bw_fetches_names, - bw_outputs_device_info, - session_options, - providers, - provider_options) + C.OrtDevice( + get_ort_device_type(self._device), + C.OrtDevice.default_memory(), + _utils.get_device_index(self._device), + ) + ] * len(bw_fetches_names) + + self._execution_agent = TrainingAgent( + self._onnx_models.optimized_model.SerializeToString(), + fw_feed_names, + fw_outputs_device_info, + bw_fetches_names, + bw_outputs_device_info, + session_options, + providers, + provider_options, + ) def _reinitialize_graph_builder(self, input_info): """Return true if the module graph builder was reinitialized""" @@ -326,14 +339,18 @@ def _reinitialize_graph_builder(self, input_info): # Model may have unused params dropped after export and not part of self._graph_initializer_names_to_train # To see if any trainable initializers changed, compare self._graph_initializer_names_to_train # with initializers in module named_parameters that are known to the onnx graph. - initializer_names_to_train_set_user_model = {name for name, param in - self._flattened_module.named_parameters() - if param.requires_grad and name in self._graph_initializer_names} + initializer_names_to_train_set_user_model = { + name + for name, param in self._flattened_module.named_parameters() + if param.requires_grad and name in self._graph_initializer_names + } # If inputs requiring gradient change from forward to the next, the module_gradient_graph_builder # needs to be reinitialized so it can compute the backward output for the new inputs that require_grad - if input_info.require_grad_names != self._input_info.require_grad_names or \ - initializer_names_to_train_set_user_model != self._graph_initializer_names_to_train: + if ( + input_info.require_grad_names != self._input_info.require_grad_names + or initializer_names_to_train_set_user_model != self._graph_initializer_names_to_train + ): self._input_info = input_info self._initialize_graph_builder(training=True) return True @@ -345,7 +362,7 @@ def __getstate__(self): # Only top level classes are pickleable. So, _ORTModuleFunction is # not pickleable. So, let's not pickle it, and redefine it when # loading the state. - del state['_forward_class'] + del state["_forward_class"] return state def __setstate__(self, state): diff --git a/orttraining/orttraining/python/training/ortmodule/_utils.py b/orttraining/orttraining/python/training/ortmodule/_utils.py index 58706ff3bba68..e6f43c064a0ff 100644 --- a/orttraining/orttraining/python/training/ortmodule/_utils.py +++ b/orttraining/orttraining/python/training/ortmodule/_utils.py @@ -6,8 +6,7 @@ from onnxruntime.capi.onnxruntime_inference_collection import OrtValue from onnxruntime.capi import _pybind_state as C from onnxruntime.tools import pytorch_export_contrib_ops -from ._fallback_exceptions import ( - ORTModuleDeviceException, wrap_exception, ORTModuleIOError) +from ._fallback_exceptions import ORTModuleDeviceException, wrap_exception, ORTModuleIOError from ._torch_module_pytorch import TorchModulePytorch from ._custom_op_symbolic_registry import CustomOpSymbolicRegistry from ._custom_gradient_registry import CustomGradientRegistry @@ -30,6 +29,7 @@ import random import numpy as np + def get_random_states(): r_state = random.getstate() np_state = np.random.get_state() @@ -37,6 +37,7 @@ def get_random_states(): torch_cuda_state = torch.cuda.get_rng_state() if torch.cuda.is_available() else None return r_state, np_state, torch_state, torch_cuda_state + def set_random_states(states): r_state, np_state, torch_state, torch_cuda_state = states random.setstate(r_state) @@ -53,7 +54,7 @@ def _ortvalue_from_torch_tensor(torch_tensor): # DLPack is discussing how to support bool type, we can remove this workaround once both DLPack # and PyTorch support bool type. is_bool_tensor = torch_tensor.dtype == torch.bool - if is_bool_tensor and LooseVersion(torch.__version__) >= LooseVersion('1.10.0'): + if is_bool_tensor and LooseVersion(torch.__version__) >= LooseVersion("1.10.0"): torch_tensor = torch_tensor.to(torch.uint8) return C.OrtValue.from_dlpack(to_dlpack(torch_tensor), is_bool_tensor) @@ -62,8 +63,8 @@ def _ortvalues_to_torch_tensor(ortvalues, device): if len(ortvalues) == 0: return tuple() - if 'ort' == device.type: - if not hasattr(C, 'to_aten_ort_device_tensor'): + if "ort" == device.type: + if not hasattr(C, "to_aten_ort_device_tensor"): raise AttributeError("onnxruntime is missing to_aten_ort_device_tensor needed to support device == 'ort'.") return tuple(C.to_aten_ort_device_tensor(ov) for ov in ortvalues) @@ -96,8 +97,8 @@ def _ortvalues_to_torch_tensor_list(ortvalues, device, c_class=False): if len(ortvalues) == 0: return tuple() - if 'ort' == device.type: - if not hasattr(C, 'to_aten_ort_device_tensor'): + if "ort" == device.type: + if not hasattr(C, "to_aten_ort_device_tensor"): raise AttributeError("onnxruntime is missing to_aten_ort_device_tensor needed to support device == 'ort'.") return tuple(C.to_aten_ort_device_tensor(ov) for ov in ortvalues) @@ -125,24 +126,26 @@ def _torch_tensor_to_dlpack(tensor): # DLPack is discussing how to support bool type, we can remove this workaround once both DLPack # and PyTorch support bool type. if not tensor.is_contiguous(): - raise ORTModuleIOError( - "Only contiguous tensors are supported.") - if tensor.dtype == torch.bool and LooseVersion(torch.__version__) >= LooseVersion('1.10.0'): + raise ORTModuleIOError("Only contiguous tensors are supported.") + if tensor.dtype == torch.bool and LooseVersion(torch.__version__) >= LooseVersion("1.10.0"): tensor = tensor.to(torch.uint8) return to_dlpack(tensor) def _check_same_device(device, argument_str, *args): - '''Check that all tensor arguments in *args reside on the same device as the input device''' + """Check that all tensor arguments in *args reside on the same device as the input device""" - assert isinstance(device, torch.device), '`device` must be a valid `torch.device` object' + assert isinstance(device, torch.device), "`device` must be a valid `torch.device` object" for arg in args: if arg is not None and isinstance(arg, torch.Tensor): arg_device = torch.device(arg.device) if arg_device != device: - raise wrap_exception(ORTModuleDeviceException, - RuntimeError( - f"{argument_str} found on device {arg_device}, but expected it to be on module device {device}.")) + raise wrap_exception( + ORTModuleDeviceException, + RuntimeError( + f"{argument_str} found on device {arg_device}, but expected it to be on module device {device}." + ), + ) def get_device_index(device): @@ -157,36 +160,37 @@ def get_device_index(device): def get_device_str(device): if isinstance(device, str): # could be 'cuda:0', 'cuda:1', or 'cpu'. with cpu, set index=0 - if device.find(':') == -1: - device += ':' + str(torch.cuda.current_device()) + if device.find(":") == -1: + device += ":" + str(torch.cuda.current_device()) elif isinstance(device, int): - device = 'cuda:' + str(device) + device = "cuda:" + str(device) elif isinstance(device, torch.device): if device.index is None: - device = device.type + ':' + str(torch.cuda.current_device()) + device = device.type + ":" + str(torch.cuda.current_device()) else: - device = device.type + ':' + str(device.index) + device = device.type + ":" + str(device.index) else: - raise wrap_exception(ORTModuleDeviceException, RuntimeError('Unsupported device type')) + raise wrap_exception(ORTModuleDeviceException, RuntimeError("Unsupported device type")) return device def get_device_from_module(module): - '''Returns the first device found in the `module`'s parameters or None + """Returns the first device found in the `module`'s parameters or None Args: module (torch.nn.Module): PyTorch model to extract device from Raises: ORTModuleFallbackException: When more than one device is found at `module` - ''' + """ device = None try: device = next(module.parameters()).device for param in module.parameters(): if param.device != device: - raise wrap_exception(ORTModuleDeviceException, - RuntimeError('ORTModule supports a single device per model')) + raise wrap_exception( + ORTModuleDeviceException, RuntimeError("ORTModule supports a single device per model") + ) except StopIteration: # Model doesn't have a device set to any of the model parameters pass @@ -194,12 +198,12 @@ def get_device_from_module(module): def get_device_from_inputs(args, kwargs): - '''Returns device from first PyTorch Tensor within args or kwargs + """Returns device from first PyTorch Tensor within args or kwargs Args: args: List with inputs kwargs: Dictionary with inputs - ''' + """ device = None if args: @@ -210,15 +214,15 @@ def get_device_from_inputs(args, kwargs): def _create_iobinding(io_binding, inputs, model, device): - '''Creates IO binding for a `model` inputs and output''' + """Creates IO binding for a `model` inputs and output""" for idx, value_info in enumerate(model.graph.input): io_binding.bind_ortvalue_input(value_info.name, OrtValue(_ortvalue_from_torch_tensor(inputs[idx]))) for value_info in model.graph.output: io_binding.bind_output(value_info.name, device.type, device_id=get_device_index(device)) -def check_for_name_collisions_and_bind_methods_to_ortmodule(ortmodule: torch.nn.Module, - user_module: torch.nn.Module): + +def check_for_name_collisions_and_bind_methods_to_ortmodule(ortmodule: torch.nn.Module, user_module: torch.nn.Module): """Warns if there are any common attributes between the user's model and ORTModule and binds user methods to ORTModule If there are methods defined on the user's model that ORTModule does not recognize (custom methods), @@ -240,25 +244,29 @@ def check_for_name_collisions_and_bind_methods_to_ortmodule(ortmodule: torch.nn. for attribute_name, attribute in user_module_attributes: if inspect.ismethod(attribute): # Skip the dunder methods - if attribute_name.startswith('__'): + if attribute_name.startswith("__"): continue # if the attribute is not a torch attribute, or if the torch attribute # corresponding to attribute_name is not a method or the user attribute # does not equal the torch attribute, then this is a user defined method. - if attribute_name not in torch_module_attributes or \ - not inspect.ismethod(torch_module_attributes[attribute_name]) or \ - attribute.__func__ != torch_module_attributes[attribute_name].__func__: + if ( + attribute_name not in torch_module_attributes + or not inspect.ismethod(torch_module_attributes[attribute_name]) + or attribute.__func__ != torch_module_attributes[attribute_name].__func__ + ): # forward is expected to be defined by the user. - if attribute_name == 'forward': + if attribute_name == "forward": continue # This is a user defined/overriden method. Check for collisions. if attribute_name in ortmodule_attributes: # This is a user defined method, issue a warning. - warnings.warn(f"User Module's attribute name {attribute_name} collides with ORTModule's attribute name. " - "User Module's method may not be called upon invocation through ORTModule.") + warnings.warn( + f"User Module's attribute name {attribute_name} collides with ORTModule's attribute name. " + "User Module's method may not be called upon invocation through ORTModule." + ) else: # This is a custom method, copy it and bind the copy to ORTModule. # This is needed for cases where the user's custom method invokes @@ -269,8 +277,11 @@ def check_for_name_collisions_and_bind_methods_to_ortmodule(ortmodule: torch.nn. if attribute_name not in torch_module_attributes and attribute_name in ortmodule_attributes: # This is a user defined attribute that collides with ORTModule if attribute_name in ortmodule_attributes: - warnings.warn(f"User Module's attribute name {attribute_name} collides with ORTModule's attribute name. " - "User Module's attribute may not be returned when trying to retrieve the attribute through ORTModule.") + warnings.warn( + f"User Module's attribute name {attribute_name} collides with ORTModule's attribute name. " + "User Module's attribute may not be returned when trying to retrieve the attribute through ORTModule." + ) + def get_state_after_deletion_of_non_ortmodule_methods(ortmodule, user_module): """Returns ORTModule state after deleting any user defined method from ORTModule state""" @@ -286,20 +297,22 @@ def get_state_after_deletion_of_non_ortmodule_methods(ortmodule, user_module): for attribute_name, attribute in user_module_attributes: if inspect.ismethod(attribute): # Skip the dunder methods - if attribute_name.startswith('__'): + if attribute_name.startswith("__"): continue # if the attribute is not a torch attribute, and if the attribute # corresponding to attribute_name is an ORTModule method and the user attribute # does equals the ORTModule attribute, then this is a user defined method and # must be dropped. - if attribute_name not in torch_module_attributes and \ - attribute_name in ortmodule_attributes and \ - inspect.ismethod(ortmodule_attributes[attribute_name]) and \ - attribute.__func__ == ortmodule_attributes[attribute_name].__func__: + if ( + attribute_name not in torch_module_attributes + and attribute_name in ortmodule_attributes + and inspect.ismethod(ortmodule_attributes[attribute_name]) + and attribute.__func__ == ortmodule_attributes[attribute_name].__func__ + ): # forward is expected to be defined by the user. - if attribute_name == 'forward': + if attribute_name == "forward": continue # This is a custom method, drop it from ORTModule state before serialization. @@ -307,19 +320,22 @@ def get_state_after_deletion_of_non_ortmodule_methods(ortmodule, user_module): return ortmodule_state + def parse_os_env_skip_check_flags(env_name): """Returns a list of SkipChecks as defined by os env variable env_name""" - return os.getenv(env_name).split('|') + return os.getenv(env_name).split("|") + def get_exception_as_string(exception): - assert isinstance(exception, Exception), 'exception must be a `Exception`' + assert isinstance(exception, Exception), "exception must be a `Exception`" try: raise exception except: return traceback.format_exc() + def switch_backend_to_pytorch(ortmodule, pytorch_module): ortmodule._torch_module = TorchModulePytorch(pytorch_module) @@ -337,47 +353,50 @@ def switch_backend_to_pytorch(ortmodule, pytorch_module): ortmodule._modules = pytorch_module._modules ortmodule.forward = pytorch_module.forward + def warn_of_constant_inputs(data): - warnings.warn(f"Received input of type {type(data)} which may be treated as a constant by ORT by default." - " Please consider moving constant arguments to the model constructor.") + warnings.warn( + f"Received input of type {type(data)} which may be treated as a constant by ORT by default." + " Please consider moving constant arguments to the model constructor." + ) + def patch_torch_module_ort_forward_method(torch_module_ort): def _forward(self, *inputs, **kwargs): - '''Forward pass starts here and continues at `_ORTModuleFunction.forward` + """Forward pass starts here and continues at `_ORTModuleFunction.forward` ONNX model is exported the first time this method is executed. Next, we build a full training graph with module_gradient_graph_builder. Finally, we instantiate the ONNX Runtime InferenceSession. - ''' + """ - return torch_module_ort._execution_manager( - torch_module_ort.is_training()).forward(*inputs, **kwargs) + return torch_module_ort._execution_manager(torch_module_ort.is_training()).forward(*inputs, **kwargs) # Bind the forward method. torch_module_ort.forward = _forward.__get__(torch_module_ort) # Copy the forward signature from the PyTorch module. - functools.update_wrapper( - torch_module_ort.forward.__func__, torch_module_ort._original_module.forward.__func__) + functools.update_wrapper(torch_module_ort.forward.__func__, torch_module_ort._original_module.forward.__func__) + def patch_ortmodule_forward_method(ortmodule): # Create forward dynamically, so each ORTModule instance will have its own copy. # This is needed to be able to copy the forward signatures from the original PyTorch models # and possibly have different signatures for different instances. def _forward(self, *inputs, **kwargs): - '''Forward pass starts here and continues at `_ORTModuleFunction.forward` + """Forward pass starts here and continues at `_ORTModuleFunction.forward` ONNX model is exported the first time this method is executed. Next, we build a full training graph with module_gradient_graph_builder. Finally, we instantiate the ONNX Runtime InferenceSession. - ''' + """ return ortmodule._torch_module.forward(*inputs, **kwargs) # Bind the forward method. ortmodule.forward = _forward.__get__(ortmodule) # Copy the forward signature from the _torch_module's forward signature. - functools.update_wrapper( - ortmodule.forward.__func__, ortmodule._torch_module.forward.__func__) + functools.update_wrapper(ortmodule.forward.__func__, ortmodule._torch_module.forward.__func__) + def reinitialize_ortmodule(ortmodule): # Re-register contrib OPs @@ -391,13 +410,15 @@ def reinitialize_ortmodule(ortmodule): # Re-bind users custom methods to ORTModule check_for_name_collisions_and_bind_methods_to_ortmodule(ortmodule, ortmodule.module) + def reinitialize_torch_module_ort(torch_module): # Re-initialize the forward method patch_torch_module_ort_forward_method(torch_module) + def reinitialize_graph_execution_manager(graph_execution_manager): # Instantiate the onnx models so they can populated on the first call to forward - if hasattr(graph_execution_manager, '_onnx_models'): + if hasattr(graph_execution_manager, "_onnx_models"): del graph_execution_manager._onnx_models graph_execution_manager._onnx_models = _onnx_models.ONNXModels() @@ -411,6 +432,7 @@ def reinitialize_graph_execution_manager(graph_execution_manager): # Load ATen op executor extension. load_aten_op_executor_cpp_extension() + def reinitialize_training_manager(training_manager): # Redefine training managers forward_class training_manager._forward_class = training_manager._create_autofunction_class() diff --git a/orttraining/orttraining/python/training/ortmodule/debug_options.py b/orttraining/orttraining/python/training/ortmodule/debug_options.py index 7c2f3844e536b..700789668a313 100644 --- a/orttraining/orttraining/python/training/ortmodule/debug_options.py +++ b/orttraining/orttraining/python/training/ortmodule/debug_options.py @@ -11,7 +11,7 @@ class _SaveOnnxOptions: """Configurable option to save ORTModule intermediate onnx models.""" # class variable - _path_environment_key = 'ORTMODULE_SAVE_ONNX_PATH' + _path_environment_key = "ORTMODULE_SAVE_ONNX_PATH" def __init__(self, save, name_prefix): self._save, self._name_prefix, self._path = self._extract_info(save, name_prefix) @@ -27,7 +27,9 @@ def _extract_info(self, save, name_prefix): def _validate(self, save, name_prefix, destination_path): # check if directory is writable if not os.access(destination_path, os.W_OK): - raise OSError(f"Directory {destination_path} is not writable. Please set the {_SaveOnnxOptions._path_environment_key} environment variable to a writable path.") + raise OSError( + f"Directory {destination_path} is not writable. Please set the {_SaveOnnxOptions._path_environment_key} environment variable to a writable path." + ) # check if input prefix is a string if not isinstance(name_prefix, str): @@ -54,7 +56,7 @@ class _LoggingOptions: """Configurable option to set the log level in ORTModule.""" # class variable - _log_level_environment_key = 'ORTMODULE_LOG_LEVEL' + _log_level_environment_key = "ORTMODULE_LOG_LEVEL" def __init__(self, log_level): self._log_level = self._extract_info(log_level) @@ -75,6 +77,7 @@ def _validate(self, log_level): def log_level(self): return self._log_level + class DebugOptions: """Configurable debugging options for ORTModule. @@ -98,10 +101,7 @@ class DebugOptions: """ - def __init__(self, - log_level=LogLevel.WARNING, - save_onnx=False, - onnx_prefix=''): + def __init__(self, log_level=LogLevel.WARNING, save_onnx=False, onnx_prefix=""): self._save_onnx_models = _SaveOnnxOptions(save_onnx, onnx_prefix) self._logging = _LoggingOptions(log_level) diff --git a/orttraining/orttraining/python/training/ortmodule/experimental/hierarchical_ortmodule/_hierarchical_ortmodule.py b/orttraining/orttraining/python/training/ortmodule/experimental/hierarchical_ortmodule/_hierarchical_ortmodule.py index ad2af5df6b346..650a51ecb1b55 100644 --- a/orttraining/orttraining/python/training/ortmodule/experimental/hierarchical_ortmodule/_hierarchical_ortmodule.py +++ b/orttraining/orttraining/python/training/ortmodule/experimental/hierarchical_ortmodule/_hierarchical_ortmodule.py @@ -10,10 +10,13 @@ # nn.Module's in this set are considered exportable to ONNX. # For other nn.Module's, torch.onnx.export is called to check if # they are exportable. -_force_exportable_set = set([torch.nn.Linear, torch.nn.Identity, torch.nn.modules.linear.NonDynamicallyQuantizableLinear]) +_force_exportable_set = set( + [torch.nn.Linear, torch.nn.Identity, torch.nn.modules.linear.NonDynamicallyQuantizableLinear] +) + class HierarchicalORTModule(torch.nn.Module): - ''' + """ Recursively wraps submodules of `module` as ORTModule whenever possible Similarly to ORTModule, the actual wrapping happens in its first `forward` call during Pytorch-to-ONNX export. Supported computation is delegated to ONNX Runtime and unsupported computation is still done by PyTorch. @@ -46,7 +49,7 @@ def custom_forward(x_): m = HierarchicalORTModule(Foo()) y = m(x) - ''' + """ def __init__(self, module, debug_options=None): self._initialized = False @@ -97,12 +100,17 @@ def check_exportable(module): # Check if this leaf module is exportable. for args in module_arg_pool[module]: try: - with tempfile.NamedTemporaryFile(prefix='sub-module') as temp: + with tempfile.NamedTemporaryFile(prefix="sub-module") as temp: torch.onnx.export( - module, args, temp, opset_version=ortmodule.ONNX_OPSET_VERSION, - do_constant_folding=False, export_params=False, + module, + args, + temp, + opset_version=ortmodule.ONNX_OPSET_VERSION, + do_constant_folding=False, + export_params=False, keep_initializers_as_inputs=True, - training=torch.onnx.TrainingMode.TRAINING) + training=torch.onnx.TrainingMode.TRAINING, + ) except Exception as e: exportable = False @@ -131,12 +139,17 @@ def check_exportable(module): module_exportable = True for args in module_arg_pool[module]: try: - with tempfile.NamedTemporaryFile(prefix='sub-module') as temp: + with tempfile.NamedTemporaryFile(prefix="sub-module") as temp: torch.onnx.export( - module, args, temp, opset_version=ortmodule.ONNX_OPSET_VERSION, - do_constant_folding=False, export_params=False, + module, + args, + temp, + opset_version=ortmodule.ONNX_OPSET_VERSION, + do_constant_folding=False, + export_params=False, keep_initializers_as_inputs=True, - training=torch.onnx.TrainingMode.TRAINING) + training=torch.onnx.TrainingMode.TRAINING, + ) except Exception as e: # If this module is not exportable for one arg # group, we say this module is not exportable. @@ -179,17 +192,13 @@ def recursive_wrap(module): # Let's wrap them one-by-one. for name1, sub1 in sub._modules.items(): if is_supported(sub1): - sub._modules[name1] = ORTModule( - sub1, - debug_options=self._debug_options) + sub._modules[name1] = ORTModule(sub1, debug_options=self._debug_options) else: recursive_wrap(sub1) else: if is_supported(sub): # Just wrap it as ORTModule when possible. - sub_dict[name] = ORTModule( - sub, - debug_options=self._debug_options) + sub_dict[name] = ORTModule(sub, debug_options=self._debug_options) else: # This sub-module is not exportable to ONNX # Let's check its sub-modules. diff --git a/orttraining/orttraining/python/training/ortmodule/experimental/json_config/_load_config_from_json.py b/orttraining/orttraining/python/training/ortmodule/experimental/json_config/_load_config_from_json.py index 371ffe3feea20..e896b9f6e6085 100644 --- a/orttraining/orttraining/python/training/ortmodule/experimental/json_config/_load_config_from_json.py +++ b/orttraining/orttraining/python/training/ortmodule/experimental/json_config/_load_config_from_json.py @@ -38,8 +38,9 @@ def _load_propagate_cast_ops(ortmodule_config_accessor, data): log.info(f"Found keyword {_load_propagate_cast_ops.loading_key} in json. Loading attributes from file.") def _update_strategy(): - ortmodule_config_accessor._propagate_cast_ops_strategy = \ - C.PropagateCastOpsStrategy.__members__[data.PropagateCastOps.Strategy] + ortmodule_config_accessor._propagate_cast_ops_strategy = C.PropagateCastOpsStrategy.__members__[ + data.PropagateCastOps.Strategy + ] def _update_level(): ortmodule_config_accessor._propagate_cast_ops_level = data.PropagateCastOps.Level @@ -47,11 +48,7 @@ def _update_level(): def _update_allow(): ortmodule_config_accessor._propagate_cast_ops_allow = data.PropagateCastOps.Allow - key_to_function_mapping = { - "Strategy": _update_strategy, - "Level": _update_level, - "Allow": _update_allow - } + key_to_function_mapping = {"Strategy": _update_strategy, "Level": _update_level, "Allow": _update_allow} for key, _ in data.PropagateCastOps.__dict__.items(): key_to_function_mapping[key]() @@ -63,7 +60,9 @@ def _load_use_external_gpu_allocator(ortmodule_config_accessor, data): assert hasattr(data, _load_use_external_gpu_allocator.loading_key) log.info(f"Found keyword {_load_use_external_gpu_allocator.loading_key} in json. Loading attributes from file.") - assert isinstance(data.UseExternalGPUAllocator, bool), f"{_load_use_external_gpu_allocator.loading_key} must be a boolean" + assert isinstance( + data.UseExternalGPUAllocator, bool + ), f"{_load_use_external_gpu_allocator.loading_key} must be a boolean" ortmodule_config_accessor._use_external_gpu_allocator = data.UseExternalGPUAllocator ortmodule_config_accessor._get_torch_gpu_allocator_function_addresses() @@ -72,9 +71,13 @@ def _load_enable_custom_autograd_function(ortmodule_config_accessor, data): """Loads EnableCustomAutogradFunction from json file onto ORTModule.""" assert hasattr(data, _load_enable_custom_autograd_function.loading_key) - log.info(f"Found keyword {_load_enable_custom_autograd_function.loading_key} in json. Loading attributes from file.") + log.info( + f"Found keyword {_load_enable_custom_autograd_function.loading_key} in json. Loading attributes from file." + ) - assert isinstance(data.EnableCustomAutogradFunction, bool), f"{_load_enable_custom_autograd_function.loading_key} must be a boolean" + assert isinstance( + data.EnableCustomAutogradFunction, bool + ), f"{_load_enable_custom_autograd_function.loading_key} must be a boolean" ortmodule_config_accessor._enable_custom_autograd_function = data.EnableCustomAutogradFunction @@ -84,7 +87,9 @@ def _load_allow_layer_norm_mod_precision(ortmodule_config_accessor, data): assert hasattr(data, _load_allow_layer_norm_mod_precision.loading_key) log.info(f"Found keyword {_load_allow_layer_norm_mod_precision.loading_key} in json. Loading attributes from file.") - assert isinstance(data.AllowLayerNormModPrecision, bool), f"{_load_allow_layer_norm_mod_precision.loading_key} must be a boolean" + assert isinstance( + data.AllowLayerNormModPrecision, bool + ), f"{_load_allow_layer_norm_mod_precision.loading_key} must be a boolean" ortmodule_config_accessor._allow_layer_norm_mod_precision = data.AllowLayerNormModPrecision @@ -94,7 +99,9 @@ def _load_enable_grad_acc_optimization(ortmodule_config_accessor, data): assert hasattr(data, _load_enable_grad_acc_optimization.loading_key) log.info(f"Found keyword {_load_enable_grad_acc_optimization.loading_key} in json. Loading attributes from file.") - assert isinstance(data.EnableGradAccOptimization, bool), f"{_load_enable_grad_acc_optimization.loading_key} must be a boolean" + assert isinstance( + data.EnableGradAccOptimization, bool + ), f"{_load_enable_grad_acc_optimization.loading_key} must be a boolean" ortmodule_config_accessor._enable_grad_acc_optimization = data.EnableGradAccOptimization @@ -104,7 +111,9 @@ def _load_run_symbolic_shape_infer(ortmodule_config_accessor, data): assert hasattr(data, _load_run_symbolic_shape_infer.loading_key) log.info(f"Found keyword {_load_run_symbolic_shape_infer.loading_key} in json. Loading attributes from file.") - assert isinstance(data.RunSymbolicShapeInference, bool), f"{_load_run_symbolic_shape_infer.loading_key} must be a boolean" + assert isinstance( + data.RunSymbolicShapeInference, bool + ), f"{_load_run_symbolic_shape_infer.loading_key} must be a boolean" ortmodule_config_accessor._run_symbolic_shape_infer = data.RunSymbolicShapeInference @@ -147,7 +156,7 @@ def _update_save_onnx(): nonlocal save_onnx save_onnx = data.DebugOptions.SaveONNX - onnx_prefix = '' + onnx_prefix = "" def _update_onnx_prefix(): nonlocal onnx_prefix @@ -160,7 +169,7 @@ def _update_onnx_path(): "LogLevel": _update_log_level, "SaveONNX": _update_save_onnx, "ONNXPrefix": _update_onnx_prefix, - "SaveONNXPath": _update_onnx_path + "SaveONNXPath": _update_onnx_path, } for key, _ in data.DebugOptions.__dict__.items(): @@ -176,7 +185,9 @@ def _load_use_memory_efficient_gradient(ortmodule_config_accessor, data): assert hasattr(data, _load_use_memory_efficient_gradient.loading_key) log.info(f"Found keyword {_load_use_memory_efficient_gradient.loading_key} in json. Loading attributes from file.") - assert isinstance(data.UseMemoryEfficientGradient, bool), f"{_load_use_memory_efficient_gradient.loading_key} must be a boolean" + assert isinstance( + data.UseMemoryEfficientGradient, bool + ), f"{_load_use_memory_efficient_gradient.loading_key} must be a boolean" ortmodule_config_accessor._use_memory_efficient_gradient = data.UseMemoryEfficientGradient @@ -274,7 +285,8 @@ def load_from_json(ortmodule, path=None): if path is None: raise ValueError( "Path to json is not provided." - f"Provide the path through function call or setting the environment variable {JSON_PATH_ENVIRONMENT_KEY}") + f"Provide the path through function call or setting the environment variable {JSON_PATH_ENVIRONMENT_KEY}" + ) # load the entire json file data = _load_data_from_json(path) @@ -295,7 +307,7 @@ def load_from_json(ortmodule, path=None): _load_debug_options.loading_key: _load_debug_options, _load_use_memory_efficient_gradient.loading_key: _load_use_memory_efficient_gradient, _load_fallback_policy.loading_key: _load_fallback_policy, - _load_onnx_opset_version.loading_key: _load_onnx_opset_version + _load_onnx_opset_version.loading_key: _load_onnx_opset_version, } for training_mode in [True, False]: diff --git a/orttraining/orttraining/python/training/ortmodule/ortmodule.py b/orttraining/orttraining/python/training/ortmodule/ortmodule.py index d2b79e1ab71f5..c91bdb74596fa 100644 --- a/orttraining/orttraining/python/training/ortmodule/ortmodule.py +++ b/orttraining/orttraining/python/training/ortmodule/ortmodule.py @@ -10,9 +10,7 @@ from ._custom_gradient_registry import CustomGradientRegistry from . import _utils from .debug_options import DebugOptions -from ._fallback import (_FallbackManager, - _FallbackPolicy, - ORTModuleFallbackException) +from ._fallback import _FallbackManager, _FallbackPolicy, ORTModuleFallbackException from onnxruntime.training import ortmodule from onnxruntime.tools import pytorch_export_contrib_ops @@ -22,7 +20,7 @@ # Needed to override PyTorch methods -T = TypeVar('T', bound='Module') +T = TypeVar("T", bound="Module") class ORTModule(torch.nn.Module): @@ -56,9 +54,9 @@ def __init__(self, module, debug_options=None): debug_options = DebugOptions() # Fallback settings - self._fallback_manager = _FallbackManager(pytorch_module=module, - policy=ortmodule.ORTMODULE_FALLBACK_POLICY, - retry=ortmodule.ORTMODULE_FALLBACK_RETRY) + self._fallback_manager = _FallbackManager( + pytorch_module=module, policy=ortmodule.ORTMODULE_FALLBACK_POLICY, retry=ortmodule.ORTMODULE_FALLBACK_RETRY + ) try: # Read ORTModule module initialization status @@ -87,17 +85,18 @@ def __init__(self, module, debug_options=None): _utils.switch_backend_to_pytorch(self, module) # Exceptions subject to fallback are handled here - self._fallback_manager.handle_exception(exception=e, - log_level=debug_options.logging.log_level) + self._fallback_manager.handle_exception(exception=e, log_level=debug_options.logging.log_level) except Exception as e: # Although backend is switched to PyTorch here, # it is up to _FallbackManager to actually terminate execution or fallback _utils.switch_backend_to_pytorch(self, module) # Catch-all FALLBACK_FORCE_TORCH_FORWARD fallback is handled here - self._fallback_manager.handle_exception(exception=e, - log_level=debug_options.logging.log_level, - override_policy=_FallbackPolicy.FALLBACK_FORCE_TORCH_FORWARD) + self._fallback_manager.handle_exception( + exception=e, + log_level=debug_options.logging.log_level, + override_policy=_FallbackPolicy.FALLBACK_FORCE_TORCH_FORWARD, + ) # Finally, ORTModule initialization is complete. # Assign self._is_initialized to True after all the ORTModule class attributes have been assigned @@ -108,7 +107,7 @@ def __init__(self, module, debug_options=None): # This declaration is for automatic document generation purposes only # The actual forward implementation is bound during ORTModule initialization def forward(self, *inputs, **kwargs): - '''Delegate the :meth:`~torch.nn.Module.forward` pass of PyTorch training to + """Delegate the :meth:`~torch.nn.Module.forward` pass of PyTorch training to ONNX Runtime. The first call to forward performs setup and checking steps. During this call, @@ -125,7 +124,7 @@ def forward(self, *inputs, **kwargs): The output as expected from the forward method defined by the user's PyTorch module. Output values supported include tensors, nested sequences of tensors and nested dictionaries of tensor values. - ''' + """ def _replicate_for_data_parallel(self): """Raises a NotImplementedError exception since ORTModule is not compatible with torch.nn.DataParallel @@ -145,7 +144,7 @@ def _replicate_for_data_parallel(self): return self._torch_module._replicate_for_data_parallel() - def add_module(self, name: str, module: Optional['Module']) -> None: + def add_module(self, name: str, module: Optional["Module"]) -> None: """Raises a ORTModuleTorchModelException exception since ORTModule does not support adding modules to it""" self._torch_module.add_module(name, module) @@ -176,7 +175,7 @@ def _apply(self, fn): self._torch_module._apply(fn) return self - def apply(self: T, fn: Callable[['Module'], None]) -> T: + def apply(self: T, fn: Callable[["Module"], None]) -> T: """Override :meth:`~torch.nn.Module.apply` to delegate execution to ONNX Runtime""" self._torch_module.apply(fn) @@ -198,14 +197,12 @@ def train(self: T, mode: bool = True) -> T: self._torch_module.train(mode) return self - def state_dict(self, destination=None, prefix='', keep_vars=False): + def state_dict(self, destination=None, prefix="", keep_vars=False): """Override :meth:`~torch.nn.Module.state_dict` to delegate execution to ONNX Runtime""" - return self._torch_module.state_dict( - destination=destination, prefix=prefix, keep_vars=keep_vars) + return self._torch_module.state_dict(destination=destination, prefix=prefix, keep_vars=keep_vars) - def load_state_dict(self, state_dict: 'OrderedDict[str, Tensor]', - strict: bool = True): + def load_state_dict(self, state_dict: "OrderedDict[str, Tensor]", strict: bool = True): """Override :meth:`~torch.nn.Module.load_state_dict` to delegate execution to ONNX Runtime""" return self._torch_module.load_state_dict(state_dict, strict=strict) @@ -235,7 +232,7 @@ def parameters(self, recurse: bool = True) -> Iterator[torch.nn.Parameter]: yield from self._torch_module.parameters(recurse=recurse) - def named_parameters(self, prefix: str = '', recurse: bool = True) -> Iterator[Tuple[str, torch.nn.Parameter]]: + def named_parameters(self, prefix: str = "", recurse: bool = True) -> Iterator[Tuple[str, torch.nn.Parameter]]: """Override :meth:`~torch.nn.Module.named_parameters`""" yield from self._torch_module.named_parameters(prefix=prefix, recurse=recurse) @@ -245,24 +242,26 @@ def buffers(self, recurse: bool = True) -> Iterator[torch.Tensor]: yield from self._torch_module.buffers(recurse=recurse) - def named_buffers(self, prefix: str = '', recurse: bool = True) -> Iterator[Tuple[str, torch.Tensor]]: + def named_buffers(self, prefix: str = "", recurse: bool = True) -> Iterator[Tuple[str, torch.Tensor]]: """Override :meth:`~torch.nn.Module.named_buffers`""" yield from self._torch_module.named_buffers(prefix=prefix, recurse=recurse) - def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, - missing_keys, unexpected_keys, error_msgs): + def _load_from_state_dict( + self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs + ): """Override original method to delegate execution to the original PyTorch user module""" - self._torch_module._load_from_state_dict(state_dict, prefix, local_metadata, strict, - missing_keys, unexpected_keys, error_msgs) + self._torch_module._load_from_state_dict( + state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs + ) - def named_children(self) -> Iterator[Tuple[str, 'Module']]: + def named_children(self) -> Iterator[Tuple[str, "Module"]]: """Override :meth:`~torch.nn.Module.named_children`""" yield from self._torch_module.named_children() - def modules(self) -> Iterator['Module']: + def modules(self) -> Iterator["Module"]: """Override :meth:`~torch.nn.Module.modules`""" yield from self._torch_module.modules() @@ -273,11 +272,11 @@ def named_modules(self, *args, **kwargs): yield from self._torch_module.named_modules(*args, **kwargs) def __getattr__(self, name: str): - if '_is_initialized' in self.__dict__ and self.__dict__['_is_initialized'] is True: + if "_is_initialized" in self.__dict__ and self.__dict__["_is_initialized"] is True: # If ORTModule is initialized and attribute is not found in ORTModule, # it must be present in the user's torch.nn.Module. Forward the call to # the user's model. - assert '_torch_module' in self.__dict__, "ORTModule does not have a reference to the user's model" + assert "_torch_module" in self.__dict__, "ORTModule does not have a reference to the user's model" return getattr(self.module, name) else: return super(ORTModule, self).__getattr__(name) @@ -288,9 +287,9 @@ def __setattr__(self, name: str, value) -> None: # If the name is an attribute of ORTModule, update only ORTModule self.__dict__[name] = value - elif '_is_initialized' in self.__dict__ and self.__dict__['_is_initialized'] is True: + elif "_is_initialized" in self.__dict__ and self.__dict__["_is_initialized"] is True: - assert '_torch_module' in self.__dict__, "ORTModule does not have a reference to the user's model" + assert "_torch_module" in self.__dict__, "ORTModule does not have a reference to the user's model" # If the name is an attribute of user model, or is a new attribute, update there. # Set the attribute on the user's original module diff --git a/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/__init__.py b/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/__init__.py index dbb5e4fcee26e..765f33dd9a50a 100644 --- a/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/__init__.py +++ b/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/__init__.py @@ -33,7 +33,7 @@ def is_installed(torch_cpp_extension_path): - torch_cpp_exts = glob(os.path.join(torch_cpp_extension_path, '*.so')) - torch_cpp_exts.extend(glob(os.path.join(torch_cpp_extension_path, '*.dll'))) - torch_cpp_exts.extend(glob(os.path.join(torch_cpp_extension_path, '*.dylib'))) + torch_cpp_exts = glob(os.path.join(torch_cpp_extension_path, "*.so")) + torch_cpp_exts.extend(glob(os.path.join(torch_cpp_extension_path, "*.dll"))) + torch_cpp_exts.extend(glob(os.path.join(torch_cpp_extension_path, "*.dylib"))) return len(torch_cpp_exts) > 0 diff --git a/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cpu/aten_op_executor/__init__.py b/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cpu/aten_op_executor/__init__.py index a17f4369da812..7a7695167443d 100644 --- a/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cpu/aten_op_executor/__init__.py +++ b/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cpu/aten_op_executor/__init__.py @@ -10,6 +10,7 @@ def run_once_aten_op_executor(f): :param f: function to be run only once during execution time despite the number of calls :return: The original function with the params passed to it if it hasn't already been run before """ + @wraps(f) def aten_op_executor_wrapper(*args, **kwargs): if not aten_op_executor_wrapper.has_run: @@ -26,5 +27,7 @@ def aten_op_executor_wrapper(*args, **kwargs): @run_once_aten_op_executor def load_aten_op_executor_cpp_extension(): from onnxruntime.training.ortmodule.torch_cpp_extensions import aten_op_executor - C.register_aten_op_executor(str(aten_op_executor.is_tensor_argument_address()), - str(aten_op_executor.execute_aten_operator_address())) + + C.register_aten_op_executor( + str(aten_op_executor.is_tensor_argument_address()), str(aten_op_executor.execute_aten_operator_address()) + ) diff --git a/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cpu/aten_op_executor/setup.py b/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cpu/aten_op_executor/setup.py index efe2841f90193..5485170be84ce 100644 --- a/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cpu/aten_op_executor/setup.py +++ b/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cpu/aten_op_executor/setup.py @@ -7,9 +7,9 @@ from setuptools import setup, Extension from torch.utils import cpp_extension -filename = os.path.join(os.path.dirname(__file__), - 'aten_op_executor.cc') -setup(name='aten_op_executor', - ext_modules=[cpp_extension.CppExtension(name='aten_op_executor', - sources=[filename])], - cmdclass={'build_ext': cpp_extension.BuildExtension}) +filename = os.path.join(os.path.dirname(__file__), "aten_op_executor.cc") +setup( + name="aten_op_executor", + ext_modules=[cpp_extension.CppExtension(name="aten_op_executor", sources=[filename])], + cmdclass={"build_ext": cpp_extension.BuildExtension}, +) diff --git a/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cpu/torch_interop_utils/__init__.py b/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cpu/torch_interop_utils/__init__.py index 7082e0d948c69..d6e78073f05d2 100644 --- a/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cpu/torch_interop_utils/__init__.py +++ b/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cpu/torch_interop_utils/__init__.py @@ -1,9 +1,11 @@ - def clear_all_grad_fns(): from onnxruntime.training.ortmodule.torch_cpp_extensions import torch_interop_utils + torch_interop_utils.clear_all_grad_fns() + import atexit + # Clear all gradient functions, to avoid a deadlock issue. # Check the called function for more detailed comments. atexit.register(clear_all_grad_fns) diff --git a/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cpu/torch_interop_utils/setup.py b/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cpu/torch_interop_utils/setup.py index cb9bff33c0a81..42fc1c747a6c3 100644 --- a/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cpu/torch_interop_utils/setup.py +++ b/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cpu/torch_interop_utils/setup.py @@ -7,9 +7,9 @@ from setuptools import setup, Extension from torch.utils import cpp_extension -filename = os.path.join(os.path.dirname(__file__), - 'torch_interop_utils.cc') -setup(name='torch_interop_utils', - ext_modules=[cpp_extension.CppExtension(name='torch_interop_utils', - sources=[filename])], - cmdclass={'build_ext': cpp_extension.BuildExtension}) +filename = os.path.join(os.path.dirname(__file__), "torch_interop_utils.cc") +setup( + name="torch_interop_utils", + ext_modules=[cpp_extension.CppExtension(name="torch_interop_utils", sources=[filename])], + cmdclass={"build_ext": cpp_extension.BuildExtension}, +) diff --git a/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cuda/fused_ops/setup.py b/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cuda/fused_ops/setup.py index 9d7297702643a..d0202f4637796 100644 --- a/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cuda/fused_ops/setup.py +++ b/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cuda/fused_ops/setup.py @@ -10,23 +10,22 @@ from setuptools import setup from torch.utils import cpp_extension -filenames = [os.path.join(os.path.dirname(__file__), 'fused_ops_frontend.cpp'), - os.path.join(os.path.dirname(__file__), 'multi_tensor_adam.cu'), - os.path.join(os.path.dirname(__file__), 'multi_tensor_scale_kernel.cu'), - os.path.join(os.path.dirname(__file__), 'multi_tensor_axpby_kernel.cu')] +filenames = [ + os.path.join(os.path.dirname(__file__), "fused_ops_frontend.cpp"), + os.path.join(os.path.dirname(__file__), "multi_tensor_adam.cu"), + os.path.join(os.path.dirname(__file__), "multi_tensor_scale_kernel.cu"), + os.path.join(os.path.dirname(__file__), "multi_tensor_axpby_kernel.cu"), +] -use_rocm = True if os.environ['ONNXRUNTIME_ROCM_VERSION'] else False -extra_compile_args = { - 'cxx': ['-O3'] -} +use_rocm = True if os.environ["ONNXRUNTIME_ROCM_VERSION"] else False +extra_compile_args = {"cxx": ["-O3"]} if not use_rocm: - extra_compile_args.update({ - 'nvcc': ['-lineinfo', '-O3', '--use_fast_math'] - }) + extra_compile_args.update({"nvcc": ["-lineinfo", "-O3", "--use_fast_math"]}) -setup(name='fused_ops', - ext_modules=[cpp_extension.CUDAExtension(name='fused_ops', - sources=filenames, - extra_compile_args=extra_compile_args - )], - cmdclass={'build_ext': cpp_extension.BuildExtension}) +setup( + name="fused_ops", + ext_modules=[ + cpp_extension.CUDAExtension(name="fused_ops", sources=filenames, extra_compile_args=extra_compile_args) + ], + cmdclass={"build_ext": cpp_extension.BuildExtension}, +) diff --git a/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cuda/torch_gpu_allocator/setup.py b/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cuda/torch_gpu_allocator/setup.py index 6a36e16699bab..7a71c95a3b465 100644 --- a/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cuda/torch_gpu_allocator/setup.py +++ b/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/cuda/torch_gpu_allocator/setup.py @@ -12,20 +12,20 @@ # TODO: Implement a cleaner way to auto-generate torch_gpu_allocator.cc -use_rocm = True if os.environ['ONNXRUNTIME_ROCM_VERSION'] else False +use_rocm = True if os.environ["ONNXRUNTIME_ROCM_VERSION"] else False gpu_identifier = "hip" if use_rocm else "cuda" gpu_allocator_header = "HIPCachingAllocator" if use_rocm else "CUDACachingAllocator" -filename = os.path.join(os.path.dirname(__file__), - 'torch_gpu_allocator.cc') +filename = os.path.join(os.path.dirname(__file__), "torch_gpu_allocator.cc") with fileinput.FileInput(filename, inplace=True) as file: for line in file: - if '___gpu_identifier___' in line: - line = line.replace('___gpu_identifier___', gpu_identifier) - if '___gpu_allocator_header___' in line: - line = line.replace('___gpu_allocator_header___', gpu_allocator_header) + if "___gpu_identifier___" in line: + line = line.replace("___gpu_identifier___", gpu_identifier) + if "___gpu_allocator_header___" in line: + line = line.replace("___gpu_allocator_header___", gpu_allocator_header) sys.stdout.write(line) -setup(name='torch_gpu_allocator', - ext_modules=[cpp_extension.CUDAExtension(name='torch_gpu_allocator', - sources=[filename])], - cmdclass={'build_ext': cpp_extension.BuildExtension}) +setup( + name="torch_gpu_allocator", + ext_modules=[cpp_extension.CUDAExtension(name="torch_gpu_allocator", sources=[filename])], + cmdclass={"build_ext": cpp_extension.BuildExtension}, +) diff --git a/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/install.py b/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/install.py index 7511266e7a500..fbd11797d229e 100644 --- a/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/install.py +++ b/orttraining/orttraining/python/training/ortmodule/torch_cpp_extensions/install.py @@ -16,75 +16,63 @@ def _list_extensions(path): extensions = [] for root, _, files in os.walk(path): for name in files: - if name.lower() == 'setup.py': + if name.lower() == "setup.py": extensions.append(os.path.join(root, name)) return extensions def _list_cpu_extensions(): - return _list_extensions(os.path.join(ortmodule.ORTMODULE_TORCH_CPP_DIR, 'cpu')) + return _list_extensions(os.path.join(ortmodule.ORTMODULE_TORCH_CPP_DIR, "cpu")) def _list_cuda_extensions(): - return _list_extensions(os.path.join(ortmodule.ORTMODULE_TORCH_CPP_DIR, 'cuda')) + return _list_extensions(os.path.join(ortmodule.ORTMODULE_TORCH_CPP_DIR, "cuda")) def _install_extension(ext_name, ext_path, cwd): - ret_code = subprocess.call(f"{sys.executable} {ext_path} build", - cwd=cwd, - shell=True) + ret_code = subprocess.call(f"{sys.executable} {ext_path} build", cwd=cwd, shell=True) if ret_code != 0: print(f'There was an error compiling "{ext_name}" PyTorch CPP extension') sys.exit(ret_code) def build_torch_cpp_extensions(): - '''Builds PyTorch CPP extensions and returns metadata''' + """Builds PyTorch CPP extensions and returns metadata""" # Run this from within onnxruntime package folder - is_gpu_available = ortmodule.ONNXRUNTIME_CUDA_VERSION is not None or\ - ortmodule.ONNXRUNTIME_ROCM_VERSION is not None + is_gpu_available = ortmodule.ONNXRUNTIME_CUDA_VERSION is not None or ortmodule.ONNXRUNTIME_ROCM_VERSION is not None os.chdir(ortmodule.ORTMODULE_TORCH_CPP_DIR) # Extensions might leverage CUDA/ROCM versions internally - os.environ["ONNXRUNTIME_CUDA_VERSION"] = ortmodule.ONNXRUNTIME_CUDA_VERSION \ - if not ortmodule.ONNXRUNTIME_CUDA_VERSION is None else '' - os.environ["ONNXRUNTIME_ROCM_VERSION"] = ortmodule.ONNXRUNTIME_ROCM_VERSION \ - if not ortmodule.ONNXRUNTIME_ROCM_VERSION is None else '' + os.environ["ONNXRUNTIME_CUDA_VERSION"] = ( + ortmodule.ONNXRUNTIME_CUDA_VERSION if not ortmodule.ONNXRUNTIME_CUDA_VERSION is None else "" + ) + os.environ["ONNXRUNTIME_ROCM_VERSION"] = ( + ortmodule.ONNXRUNTIME_ROCM_VERSION if not ortmodule.ONNXRUNTIME_ROCM_VERSION is None else "" + ) ############################################################################ # Pytorch CPP Extensions that DO require CUDA/ROCM ############################################################################ if is_gpu_available: for ext_setup in _list_cuda_extensions(): - _install_extension(ext_setup.split( - os.sep)[-2], ext_setup, ortmodule.ORTMODULE_TORCH_CPP_DIR) + _install_extension(ext_setup.split(os.sep)[-2], ext_setup, ortmodule.ORTMODULE_TORCH_CPP_DIR) ############################################################################ # Pytorch CPP Extensions that DO NOT require CUDA/ROCM ############################################################################ for ext_setup in _list_cpu_extensions(): - _install_extension(ext_setup.split( - os.sep)[-2], ext_setup, ortmodule.ORTMODULE_TORCH_CPP_DIR) + _install_extension(ext_setup.split(os.sep)[-2], ext_setup, ortmodule.ORTMODULE_TORCH_CPP_DIR) ############################################################################ # Install Pytorch CPP Extensions into local onnxruntime package folder ############################################################################ - torch_cpp_exts = glob(os.path.join(ortmodule.ORTMODULE_TORCH_CPP_DIR, - 'build', - 'lib.*', - '*.so')) - torch_cpp_exts.extend(glob(os.path.join(ortmodule.ORTMODULE_TORCH_CPP_DIR, - 'build', - 'lib.*', - '*.dll'))) - torch_cpp_exts.extend(glob(os.path.join(ortmodule.ORTMODULE_TORCH_CPP_DIR, - 'build', - 'lib.*', - '*.dylib'))) + torch_cpp_exts = glob(os.path.join(ortmodule.ORTMODULE_TORCH_CPP_DIR, "build", "lib.*", "*.so")) + torch_cpp_exts.extend(glob(os.path.join(ortmodule.ORTMODULE_TORCH_CPP_DIR, "build", "lib.*", "*.dll"))) + torch_cpp_exts.extend(glob(os.path.join(ortmodule.ORTMODULE_TORCH_CPP_DIR, "build", "lib.*", "*.dylib"))) for ext in torch_cpp_exts: dest_ext = os.path.join(ortmodule.ORTMODULE_TORCH_CPP_DIR, os.path.basename(ext)) - print(f'Installing {ext} -> {dest_ext}') + print(f"Installing {ext} -> {dest_ext}") copyfile(ext, dest_ext) # Tear down @@ -92,5 +80,5 @@ def build_torch_cpp_extensions(): os.environ.pop("ONNXRUNTIME_ROCM_VERSION") -if __name__ == '__main__': +if __name__ == "__main__": build_torch_cpp_extensions() diff --git a/orttraining/orttraining/python/training/orttrainer.py b/orttraining/orttraining/python/training/orttrainer.py index af95713b47fd2..ed84582546280 100644 --- a/orttraining/orttraining/python/training/orttrainer.py +++ b/orttraining/orttraining/python/training/orttrainer.py @@ -14,6 +14,7 @@ from onnxruntime.tools.symbolic_shape_infer import SymbolicShapeInference + class TrainStepInfo(object): r"""Private class used to store runtime information from current train step. @@ -43,16 +44,13 @@ class TrainStepInfo(object): """ def __init__(self, optimizer_config, all_finite=True, fetches=[], optimization_step=0, step=0): - assert isinstance(optimizer_config, optim._OptimizerConfig),\ - "optimizer_config must be a optim._OptimizerConfig" - assert isinstance(all_finite, bool),\ - "all_finite must be a bool" - assert isinstance(fetches, list) and all([isinstance(item, str) for item in fetches]),\ - "fetches must be a list of str" - assert isinstance(optimization_step, int) and optimization_step >= 0,\ - "optimization_step must be a positive int" - assert (isinstance(step, int) and step >= 0),\ - "step must be a positive int" + assert isinstance(optimizer_config, optim._OptimizerConfig), "optimizer_config must be a optim._OptimizerConfig" + assert isinstance(all_finite, bool), "all_finite must be a bool" + assert isinstance(fetches, list) and all( + [isinstance(item, str) for item in fetches] + ), "fetches must be a list of str" + assert isinstance(optimization_step, int) and optimization_step >= 0, "optimization_step must be a positive int" + assert isinstance(step, int) and step >= 0, "step must be a positive int" self.optimizer_config = optimizer_config self.all_finite = all_finite @@ -120,17 +118,18 @@ class ORTTrainer(object): ort_trainer = ORTTrainer(model, model_desc, optim_config, loss_fn) """ - def __init__(self, model, model_desc, optim_config, - loss_fn=None, - options=None): + def __init__(self, model, model_desc, optim_config, loss_fn=None, options=None): assert model is not None, "'model' is required and must be either a 'torch.nn.Module' or ONNX model" assert isinstance(model_desc, dict), "'model_desc' must be a 'dict'" - assert isinstance(optim_config, optim._OptimizerConfig),\ - "'optim_config' is required and must be any of 'AdamConfig', 'LambConfig' or 'SGDConfig'" - assert loss_fn is None or (callable(loss_fn) and len(signature(loss_fn).parameters) == 2),\ - "'loss_fn' must be either 'None' or a callable with two parameters" - assert options is None or isinstance(options, ORTTrainerOptions),\ - "'options' must be either 'None' or 'ORTTrainerOptions'" + assert isinstance( + optim_config, optim._OptimizerConfig + ), "'optim_config' is required and must be any of 'AdamConfig', 'LambConfig' or 'SGDConfig'" + assert loss_fn is None or ( + callable(loss_fn) and len(signature(loss_fn).parameters) == 2 + ), "'loss_fn' must be either 'None' or a callable with two parameters" + assert options is None or isinstance( + options, ORTTrainerOptions + ), "'options' must be either 'None' or 'ORTTrainerOptions'" # Model + Loss validation # Supported combinarios are @@ -144,8 +143,9 @@ def __init__(self, model, model_desc, optim_config, self._torch_model = None self._onnx_model = None if isinstance(model, torch.nn.Module): - assert loss_fn is None or isinstance(model, torch.nn.Module),\ - "'loss_fn' must be either 'None' or 'torch.nn.Module'" + assert loss_fn is None or isinstance( + model, torch.nn.Module + ), "'loss_fn' must be either 'None' or 'torch.nn.Module'" self._torch_model = model self.loss_fn = loss_fn # TODO: Remove when experimental checkpoint functions are removed. @@ -178,13 +178,13 @@ def __init__(self, model, model_desc, optim_config, # When input model is already ONNX (and not exported from Pytorch within ORTTrainer), # append 'dtype' from ONNX into model description's for idx_i, i_desc in enumerate(self.model_desc.inputs): - dtype = None - for onnx_input in self._onnx_model.graph.input: - if onnx_input.name == i_desc.name: - dtype = _utils.dtype_onnx_to_torch(onnx_input.type.tensor_type.elem_type) - self.model_desc.add_type_to_input_description(idx_i, dtype) - break - assert dtype is not None, f"ONNX model with unknown input type ({i_desc.name})" + dtype = None + for onnx_input in self._onnx_model.graph.input: + if onnx_input.name == i_desc.name: + dtype = _utils.dtype_onnx_to_torch(onnx_input.type.tensor_type.elem_type) + self.model_desc.add_type_to_input_description(idx_i, dtype) + break + assert dtype is not None, f"ONNX model with unknown input type ({i_desc.name})" for idx_o, o_desc in enumerate(self.model_desc.outputs): dtype = None for onnx_output in self._onnx_model.graph.output: @@ -196,7 +196,8 @@ def __init__(self, model, model_desc, optim_config, try: from torch.utils.cpp_extension import ROCM_HOME - self.is_rocm_pytorch = (True if ((torch.version.hip is not None) and (ROCM_HOME is not None)) else False) + + self.is_rocm_pytorch = True if ((torch.version.hip is not None) and (ROCM_HOME is not None)) else False except ImportError: self.is_rocm_pytorch = False @@ -206,8 +207,10 @@ def __init__(self, model, model_desc, optim_config, self._train_step_info = TrainStepInfo(self.optim_config) self._training_session = None self._load_state_dict = None - self._init_session(provider_options=self.options._validated_opts['provider_options'], - session_options=self.options.session_options) + self._init_session( + provider_options=self.options._validated_opts["provider_options"], + session_options=self.options.session_options, + ) def eval_step(self, *args, **kwargs): r"""Evaluation step method @@ -220,8 +223,7 @@ def eval_step(self, *args, **kwargs): ordered :py:obj:`list` with model outputs as described by :py:attr:`.ORTTrainer.model_desc` """ # Get data. CombineTorchModelLossFn takes label as last input and outputs loss first - sample_input = self._prepare_model_input(self.model_desc.inputs, - None, None, *args, **kwargs) + sample_input = self._prepare_model_input(self.model_desc.inputs, None, None, *args, **kwargs) # Export model to ONNX if self._onnx_model is None: @@ -248,16 +250,13 @@ def eval_step(self, *args, **kwargs): run_options.training_mode = False # Run a eval step and return - session_run_results = self._training_session_run_helper(False, - sample_input, - inputs_desc, - outputs_desc, - run_options) + session_run_results = self._training_session_run_helper( + False, sample_input, inputs_desc, outputs_desc, run_options + ) # Output must be returned in the same order as defined in the model description results = [session_run_results[o_desc.name] for o_desc in outputs_desc] - return results[0] if len (results) == 1 else results - + return results[0] if len(results) == 1 else results def save_as_onnx(self, path): r"""Persists ONNX model into :py:attr:`path` @@ -273,8 +272,10 @@ def save_as_onnx(self, path): ValueError: raised when `path` is not valid path """ if not self._training_session: - warnings.warn("Training session is not initialized yet. " - "'train_step' or 'eval_step' methods must be executed at least once before calling 'save_as_onnx()'.") + warnings.warn( + "Training session is not initialized yet. " + "'train_step' or 'eval_step' methods must be executed at least once before calling 'save_as_onnx()'." + ) return state_tensors = self._training_session.get_state() self._update_onnx_model_initializers(state_tensors) @@ -294,10 +295,11 @@ def _check_model_export(self, input): import numpy as np from numpy.testing import assert_allclose import _test_helpers + onnx_model_copy = copy.deepcopy(self._onnx_model) # Mute the dropout nodes - dropout_nodes = [n for n in onnx_model_copy.graph.node if n.op_type == 'Dropout'] + dropout_nodes = [n for n in onnx_model_copy.graph.node if n.op_type == "Dropout"] for node in dropout_nodes: ratio_node = [n for n in onnx_model_copy.graph.node if node.input[1] in n.output][0] training_mode_node = [n for n in onnx_model_copy.graph.node if node.input[2] in n.output][0] @@ -315,18 +317,25 @@ def _check_model_export(self, input): training_mode_node.attribute[0].name = "value" ratio_node.attribute[0].name = "value" - _inference_sess = ort.InferenceSession(onnx_model_copy.SerializeToString(), providers=ort.get_available_providers()) + _inference_sess = ort.InferenceSession( + onnx_model_copy.SerializeToString(), providers=ort.get_available_providers() + ) inf_inputs = {} for i, input_elem in enumerate(input): inf_inputs[_inference_sess.get_inputs()[i].name] = input_elem.cpu().numpy() _inference_outs = _inference_sess.run(None, inf_inputs) for torch_item, ort_item in zip(self.torch_sample_outputs, _inference_outs): - assert_allclose(torch_item, ort_item, rtol=1e-2, atol=1e-6, - err_msg="Mismatch between outputs of PyTorch model and exported ONNX model. " - "Note that different backends may exhibit small computational differences." - "If this is within acceptable margin, or if there is random generator " - "in the model causing inevitable mismatch, you can proceed training by " - "setting the flag debug.check_model_export to False.") + assert_allclose( + torch_item, + ort_item, + rtol=1e-2, + atol=1e-6, + err_msg="Mismatch between outputs of PyTorch model and exported ONNX model. " + "Note that different backends may exhibit small computational differences." + "If this is within acceptable margin, or if there is random generator " + "in the model causing inevitable mismatch, you can proceed training by " + "setting the flag debug.check_model_export to False.", + ) def train_step(self, *args, **kwargs): r"""Train step method @@ -351,7 +360,6 @@ def train_step(self, *args, **kwargs): if self.options.debug.check_model_export: self._check_model_export(sample_input) - # Prepare inputs+lr and output descriptions inputs_desc = self._model_desc_inputs_with_lr outputs_desc = self.model_desc.outputs @@ -398,8 +406,7 @@ def train_step(self, *args, **kwargs): args = (args,) # Run a train step and return - session_run_results = self._training_session_run_helper(True, input, inputs_desc, - outputs_desc, run_options) + session_run_results = self._training_session_run_helper(True, input, inputs_desc, outputs_desc, run_options) if mixed_precision_without_fetches: # After session run with all_fp32_gradients_finite, we need to clear the training I/O binding's output # Otherwise next run with only_execute_path_to_fetches will lead to gradient all reduce @@ -423,7 +430,7 @@ def train_step(self, *args, **kwargs): results = [session_run_results[o_desc] for o_desc in self._train_step_info.fetches] else: results = [session_run_results[o_desc.name] for o_desc in self.model_desc.outputs] - return results[0] if len (results) == 1 else results + return results[0] if len(results) == 1 else results def _convert_torch_model_loss_fn_to_onnx(self, inputs, device): # Dynamic axes @@ -448,9 +455,13 @@ def _convert_torch_model_loss_fn_to_onnx(self, inputs, device): if isinstance(inputs, dict): sample_inputs = [inputs[k.name_].to(device=device) for k in self.model_desc.inputs] elif isinstance(inputs, (list, tuple)): - sample_inputs = [input.to(device=device) for i, input in enumerate(inputs) if i < len(self.model_desc.inputs)] + sample_inputs = [ + input.to(device=device) for i, input in enumerate(inputs) if i < len(self.model_desc.inputs) + ] else: - raise RuntimeError("Unexpected input type. Only torch.Tensor, or dict/list/tuple of torch.Tensor is supported.") + raise RuntimeError( + "Unexpected input type. Only torch.Tensor, or dict/list/tuple of torch.Tensor is supported." + ) # PyTorch ONNX exporter does not match argument names # This is an issue because the ONNX graph depends on all inputs to be specified @@ -468,8 +479,7 @@ def _convert_torch_model_loss_fn_to_onnx(self, inputs, device): # Label from loss_fn goes after model input if self.loss_fn: - ordered_input_list = [*ordered_input_list, - list(sig_loss.parameters.keys())[1]] + ordered_input_list = [*ordered_input_list, list(sig_loss.parameters.keys())[1]] class CombineTorchModelLossFnWrapInput(torch.nn.Module): def __init__(self, model, loss_fn, input_names): @@ -506,8 +516,10 @@ def forward(self, *inputs): model_copy = copy.deepcopy(model) except Exception: model_copy = model - warnings.warn("This model cannot be deep copied (or pickled), which is a required step for stateful models to be properly exported to ONNX." - " Compute will continue, but unexpected results may occur!") + warnings.warn( + "This model cannot be deep copied (or pickled), which is a required step for stateful models to be properly exported to ONNX." + " Compute will continue, but unexpected results may occur!" + ) sample_outputs = model_copy(*sample_inputs_copy) self.torch_sample_outputs = sample_outputs model.train() @@ -518,12 +530,10 @@ def forward(self, *inputs): # Append 'dtype' for model description's inputs/outputs for idx_i, sample_input in enumerate(sample_inputs): if idx_i < len(self.model_desc.inputs): - self.model_desc.add_type_to_input_description( - idx_i, sample_input.dtype) + self.model_desc.add_type_to_input_description(idx_i, sample_input.dtype) for idx_o, sample_output in enumerate(sample_outputs): if idx_o < len(self.model_desc.outputs): - self.model_desc.add_type_to_output_description( - idx_o, sample_output.dtype) + self.model_desc.add_type_to_output_description(idx_o, sample_output.dtype) # Export the model to ONNX f = io.BytesIO() @@ -533,6 +543,7 @@ def forward(self, *inputs): # Handle contrib OPs support from onnxruntime.tools import pytorch_export_contrib_ops + if self.options._internal_use.enable_onnx_contrib_ops: pytorch_export_contrib_ops.register() else: @@ -540,21 +551,25 @@ def forward(self, *inputs): pytorch_export_contrib_ops.unregister() # Export torch.nn.Module to ONNX - torch.onnx.export(model, tuple(sample_inputs_copy), f, - input_names=[input.name for input in self.model_desc.inputs], - output_names=[output.name for output in self.model_desc.outputs], - opset_version=self.options._internal_use.onnx_opset_version, - dynamic_axes=dynamic_axes, - do_constant_folding=False, - training=torch.onnx.TrainingMode.TRAINING) + torch.onnx.export( + model, + tuple(sample_inputs_copy), + f, + input_names=[input.name for input in self.model_desc.inputs], + output_names=[output.name for output in self.model_desc.outputs], + opset_version=self.options._internal_use.onnx_opset_version, + dynamic_axes=dynamic_axes, + do_constant_folding=False, + training=torch.onnx.TrainingMode.TRAINING, + ) onnx_model = onnx.load_model_from_string(f.getvalue()) # Remove 'model.' prefix introduced by CombineTorchModelLossFn class if isinstance(model, CombineTorchModelLossFnWrapInput): replace_name_dict = {} for n in onnx_model.graph.initializer: - if n.name.startswith('model.'): - replace_name_dict[n.name] = n.name[len('model.'):] + if n.name.startswith("model."): + replace_name_dict[n.name] = n.name[len("model.") :] n.name = replace_name_dict[n.name] for n in onnx_model.graph.node: for i, name in enumerate(n.input): @@ -563,16 +578,17 @@ def forward(self, *inputs): return onnx_model - def _create_ort_training_session(self, - optimizer_state_dict={}, - session_options=None, - provider_options=None): + def _create_ort_training_session(self, optimizer_state_dict={}, session_options=None, provider_options=None): # Validating frozen_weights names - unused_frozen_weights = [n for n in self.options.utils.frozen_weights\ - if n not in [i.name for i in self._onnx_model.graph.initializer]] + unused_frozen_weights = [ + n + for n in self.options.utils.frozen_weights + if n not in [i.name for i in self._onnx_model.graph.initializer] + ] if unused_frozen_weights: - raise RuntimeError("{} params from 'frozen_weights' not found in the ONNX model.".format( - unused_frozen_weights)) + raise RuntimeError( + "{} params from 'frozen_weights' not found in the ONNX model.".format(unused_frozen_weights) + ) # Get loss name from model description loss_name = [item.name for item in self.model_desc.outputs if item.is_loss] @@ -591,12 +607,12 @@ def _create_ort_training_session(self, optimizer_int_attributes_map[initializer.name] = {} not_in_param_groups = True for param_group in self.optim_config.params: - if initializer.name not in param_group['params']: + if initializer.name not in param_group["params"]: continue # keep looking for a matching param_group not_in_param_groups = False for k, v in param_group.items(): # 'params' is not a hyper parameter, skip it. 'lr' per weight is not supported - if k == 'params' or k == 'lr': + if k == "params" or k == "lr": continue if isinstance(v, float): optimizer_attributes_map[initializer.name][k] = v @@ -608,7 +624,7 @@ def _create_ort_training_session(self, # set default values for params not found in groups if not_in_param_groups: for k, v in self.optim_config.defaults.items(): - if k == 'lr': + if k == "lr": continue if isinstance(v, float): optimizer_attributes_map[initializer.name][k] = v @@ -618,7 +634,9 @@ def _create_ort_training_session(self, raise ValueError("Optimizer attributes must be either float or int.") self.options.distributed.horizontal_parallel_size = max(self.options.distributed.horizontal_parallel_size, 1) - self.options.distributed.data_parallel_size = self.options.distributed.world_size // self.options.distributed.horizontal_parallel_size + self.options.distributed.data_parallel_size = ( + self.options.distributed.world_size // self.options.distributed.horizontal_parallel_size + ) # TrainingParameters ort_parameters = ort.TrainingParameters() @@ -649,7 +667,9 @@ def _create_ort_training_session(self, ort_parameters.data_parallel_size = self.options.distributed.data_parallel_size ort_parameters.horizontal_parallel_size = self.options.distributed.horizontal_parallel_size ort_parameters.pipeline_parallel_size = self.options.distributed.pipeline_parallel.pipeline_parallel_size - ort_parameters.num_pipeline_micro_batches = self.options.distributed.pipeline_parallel.num_pipeline_micro_batches + ort_parameters.num_pipeline_micro_batches = ( + self.options.distributed.pipeline_parallel.num_pipeline_micro_batches + ) ort_parameters.pipeline_cut_info_string = self.options.distributed.pipeline_parallel.pipeline_cut_info_string # We have special handling for dictionary-typed option. # sliced_schema._validated_opts is the original dictionary while sliced_schema is a _ORTTrainerOptionsInternal. @@ -659,19 +679,29 @@ def _create_ort_training_session(self, ort_parameters.sliced_axes = self.options.distributed.pipeline_parallel.sliced_axes._validated_opts ort_parameters.sliced_tensor_names = self.options.distributed.pipeline_parallel.sliced_tensor_names - ort_parameters.model_after_graph_transforms_path = self.options.debug.graph_save_paths.model_after_graph_transforms_path - ort_parameters.model_with_gradient_graph_path = self.options.debug.graph_save_paths.model_with_gradient_graph_path - ort_parameters.model_with_training_graph_path = self.options.debug.graph_save_paths.model_with_training_graph_path + ort_parameters.model_after_graph_transforms_path = ( + self.options.debug.graph_save_paths.model_after_graph_transforms_path + ) + ort_parameters.model_with_gradient_graph_path = ( + self.options.debug.graph_save_paths.model_with_gradient_graph_path + ) + ort_parameters.model_with_training_graph_path = ( + self.options.debug.graph_save_paths.model_with_training_graph_path + ) # SessionOptions session_options = ort.SessionOptions() if session_options is None else session_options session_options.use_deterministic_compute = self.options.debug.deterministic_compute - if (self.options.graph_transformer.attn_dropout_recompute or - self.options.graph_transformer.gelu_recompute or - self.options.graph_transformer.transformer_layer_recompute): + if ( + self.options.graph_transformer.attn_dropout_recompute + or self.options.graph_transformer.gelu_recompute + or self.options.graph_transformer.transformer_layer_recompute + ): session_options.execution_order = ort.ExecutionOrder.PRIORITY_BASED if len(self.options.debug.graph_save_paths.model_with_training_graph_after_optimization_path) > 0: - session_options.optimized_model_filepath = self.options.debug.graph_save_paths.model_with_training_graph_after_optimization_path + session_options.optimized_model_filepath = ( + self.options.debug.graph_save_paths.model_with_training_graph_after_optimization_path + ) # old ort session may already exists and occupies GPU memory when creating new session, this may cause OOM error. # for example, load_state_dict will be called before returing the function, and it calls _init_session again @@ -685,25 +715,28 @@ def get_providers(provider_options): providers[providers.index(provider_name)] = (provider_name, provider_options[provider_name]) else: providers.insert(0, (provider_name, provider_options[provider_name])) - #default: using cuda - elif 'cuda' in self.options.device.id.lower(): + # default: using cuda + elif "cuda" in self.options.device.id.lower(): gpu_ep_options = {"device_id": _utils.get_device_index(self.options.device.id)} - gpu_ep_name = ("ROCMExecutionProvider" if self.is_rocm_pytorch else "CUDAExecutionProvider") + gpu_ep_name = "ROCMExecutionProvider" if self.is_rocm_pytorch else "CUDAExecutionProvider" if self.options.device.mem_limit > 0: gpu_ep_options["gpu_mem_limit"] = self.options.device.mem_limit if gpu_ep_name not in providers: raise RuntimeError( "ORTTrainer options specify a CUDA device but the {} provider is unavailable.".format( - cuda_ep_name)) + cuda_ep_name + ) + ) providers[providers.index(gpu_ep_name)] = (gpu_ep_name, gpu_ep_options) return providers # TrainingSession - self._training_session = ort.TrainingSession(self._onnx_model.SerializeToString(), ort_parameters, - session_options, get_providers(provider_options)) + self._training_session = ort.TrainingSession( + self._onnx_model.SerializeToString(), ort_parameters, session_options, get_providers(provider_options) + ) # I/O bindings self._train_io_binding = self._training_session.io_binding() @@ -722,7 +755,7 @@ def _init_onnx_model(self, inputs): self.options.utils.frozen_weights.extend(torch_buffers) # Export to ONNX - self._onnx_model = self._convert_torch_model_loss_fn_to_onnx(inputs, 'cpu') + self._onnx_model = self._convert_torch_model_loss_fn_to_onnx(inputs, "cpu") # Post processing for ONNX models expported from PyTorch if self.options._internal_use.enable_internal_postprocess: @@ -734,31 +767,36 @@ def _init_onnx_model(self, inputs): if self._load_state_dict: optimizer_state_dict = self._load_state_dict() - self._init_session(optimizer_state_dict, - session_options=self.options.session_options, - provider_options=self.options._validated_opts['provider_options']) + self._init_session( + optimizer_state_dict, + session_options=self.options.session_options, + provider_options=self.options._validated_opts["provider_options"], + ) - def _init_session(self, optimizer_state_dict={}, - session_options=None, - provider_options=None): + def _init_session(self, optimizer_state_dict={}, session_options=None, provider_options=None): if self._onnx_model is None: return if self.options.utils.run_symbolic_shape_infer: - self._onnx_model = SymbolicShapeInference.infer_shapes(self._onnx_model, auto_merge=True, guess_output_rank=True) + self._onnx_model = SymbolicShapeInference.infer_shapes( + self._onnx_model, auto_merge=True, guess_output_rank=True + ) # Create training session used by train_step # pass all optimizer states to the backend - self._create_ort_training_session(optimizer_state_dict, - session_options=session_options, - provider_options=provider_options) + self._create_ort_training_session( + optimizer_state_dict, session_options=session_options, provider_options=provider_options + ) # Update model description to update dtype when mixed precision is enabled # C++ backend modifies model's output dtype from float32 to float16 for mixed precision # Note that for training we must use float32 and for evaluation we must use float16 for idx, o_desc in enumerate(self.model_desc.outputs): - if (self.options.mixed_precision.enabled and o_desc.dtype == torch.float32 and - not self._training_session.is_output_fp32_node(o_desc.name)): + if ( + self.options.mixed_precision.enabled + and o_desc.dtype == torch.float32 + and not self._training_session.is_output_fp32_node(o_desc.name) + ): self.model_desc.add_type_to_output_description(idx, o_desc.dtype, torch.float16) # Update model description @@ -768,7 +806,9 @@ def _init_session(self, optimizer_state_dict={}, if self.options.mixed_precision.enabled: self.model_desc.loss_scale_input = self._training_session.loss_scale_input_name self._model_desc_inputs_with_lr_and_loss_scale = [ - *self._model_desc_inputs_with_lr, self.model_desc.loss_scale_input] + *self._model_desc_inputs_with_lr, + self.model_desc.loss_scale_input, + ] self.model_desc.all_finite = _utils.get_all_gradients_finite_name_from_session(self._training_session) self._model_desc_outputs_with_all_finite = [*self.model_desc.outputs, self.model_desc.all_finite] elif self.options.mixed_precision.loss_scaler: @@ -782,9 +822,13 @@ def _init_session(self, optimizer_state_dict={}, # Update Gradient Accumulation, if applicable if self.options.batch.gradient_accumulation_steps > 1: - self.model_desc.gradient_accumulation = _utils.get_gradient_accumulation_name_from_session(self._training_session) + self.model_desc.gradient_accumulation = _utils.get_gradient_accumulation_name_from_session( + self._training_session + ) self._model_desc_outputs_with_gradient_accumulation = [ - *self.model_desc.outputs, self.model_desc.gradient_accumulation] + *self.model_desc.outputs, + self.model_desc.gradient_accumulation, + ] # TODO: Remove when experimental checkpoint functions are removed if self._state_dict: @@ -831,8 +875,7 @@ def _resolve_symbolic_dimensions(self, inputs, inputs_desc, outputs_desc): if i_axis not in resolved_dims: resolved_dims[i_axis] = input.size()[i_idx] else: - assert resolved_dims[i_axis] == input.size()[i_idx],\ - f"Mismatch in dynamic shape {i_axis}" + assert resolved_dims[i_axis] == input.size()[i_idx], f"Mismatch in dynamic shape {i_axis}" for o_desc in outputs: for idx_o, o_axis in enumerate(o_desc.shape): @@ -860,12 +903,14 @@ def _training_session_run_helper(self, is_train, inputs, inputs_desc, outputs_de for input, input_desc in zip(inputs, inputs_desc): if input_desc.name in input_node_names: device_index = _utils.get_device_index_from_input(input) - iobinding.bind_input(input_desc.name, - input.device.type, - device_index, - _utils.dtype_torch_to_numpy(input.dtype), - list(input.size()), - input.data_ptr()) + iobinding.bind_input( + input_desc.name, + input.device.type, + device_index, + _utils.dtype_torch_to_numpy(input.dtype), + list(input.size()), + input.data_ptr(), + ) # Bind output tensors outputs_desc_resolved = self._resolve_symbolic_dimensions(inputs, inputs_desc, outputs_desc) @@ -875,7 +920,7 @@ def _training_session_run_helper(self, is_train, inputs, inputs_desc, outputs_de if self.options.mixed_precision.enabled and output_desc.name == self.model_desc.all_finite.name: # Keep all finite flag on CPU to match backend implementation # This prevents CPU -> GPU -> CPU copies between frontend and backend - target_device = 'cpu' + target_device = "cpu" # the self.options.device may be a device that pytorch does not recognize. # in that case, we temporary prefer to leave the input/output on CPU and let ORT session # to move the data between device and host. @@ -883,15 +928,23 @@ def _training_session_run_helper(self, is_train, inputs, inputs_desc, outputs_de try: test_pt_device = torch.device(target_device) except: - #in this case, input/output must on CPU - assert(input.device.type == 'cpu') - target_device = 'cpu' - - torch_tensor = torch.zeros(output_desc.shape, device=target_device, - dtype=output_desc.dtype_amp if output_desc.dtype_amp else output_desc.dtype) - iobinding.bind_output(output_desc.name, torch_tensor.device.type, _utils.get_device_index(target_device), - _utils.dtype_torch_to_numpy(torch_tensor.dtype), - list(torch_tensor.size()), torch_tensor.data_ptr()) + # in this case, input/output must on CPU + assert input.device.type == "cpu" + target_device = "cpu" + + torch_tensor = torch.zeros( + output_desc.shape, + device=target_device, + dtype=output_desc.dtype_amp if output_desc.dtype_amp else output_desc.dtype, + ) + iobinding.bind_output( + output_desc.name, + torch_tensor.device.type, + _utils.get_device_index(target_device), + _utils.dtype_torch_to_numpy(torch_tensor.dtype), + list(torch_tensor.size()), + torch_tensor.data_ptr(), + ) result[output_desc.name] = torch_tensor # Run a train/eval step @@ -899,7 +952,7 @@ def _training_session_run_helper(self, is_train, inputs, inputs_desc, outputs_de return result def _update_onnx_model_initializers(self, state_tensors): - r""" Updates ONNX graph initializers with state_tensors's values + r"""Updates ONNX graph initializers with state_tensors's values Usually called to save or load an ONNX model. @@ -930,22 +983,28 @@ def _extract_model_states(self, state_dict, pytorch_format): state_dict[_utils.state_dict_model_key()][precision] = {} for model_state_key in model_states[precision]: if pytorch_format: - state_dict[_utils.state_dict_model_key()][precision][model_state_key] = \ - torch.from_numpy(model_states[precision][model_state_key]) - else: - state_dict[_utils.state_dict_model_key()][precision][model_state_key] = \ + state_dict[_utils.state_dict_model_key()][precision][model_state_key] = torch.from_numpy( model_states[precision][model_state_key] + ) + else: + state_dict[_utils.state_dict_model_key()][precision][model_state_key] = model_states[precision][ + model_state_key + ] # extract untrained (frozen) model weights for node in self._onnx_model.graph.initializer: - if node.name not in state_dict[_utils.state_dict_model_key()][_utils.state_dict_full_precision_key()] and \ - node.name in self.options.utils.frozen_weights: + if ( + node.name not in state_dict[_utils.state_dict_model_key()][_utils.state_dict_full_precision_key()] + and node.name in self.options.utils.frozen_weights + ): if pytorch_format: - state_dict[_utils.state_dict_model_key()][_utils.state_dict_full_precision_key()][node.name] = \ - torch.from_numpy(onnx.numpy_helper.to_array(node)) + state_dict[_utils.state_dict_model_key()][_utils.state_dict_full_precision_key()][ + node.name + ] = torch.from_numpy(onnx.numpy_helper.to_array(node)) else: - state_dict[_utils.state_dict_model_key()][_utils.state_dict_full_precision_key()][node.name] = \ - onnx.numpy_helper.to_array(node) + state_dict[_utils.state_dict_model_key()][_utils.state_dict_full_precision_key()][ + node.name + ] = onnx.numpy_helper.to_array(node) def _extract_trainer_options(self, state_dict): """Extract relevant trainer configuration and load it into the state_dict""" @@ -960,8 +1019,9 @@ def _extract_trainer_options(self, state_dict): state_dict[_utils.state_dict_trainer_options_key()] = {} state_dict[_utils.state_dict_trainer_options_key()][mixed_precision] = self.options.mixed_precision.enabled - state_dict[_utils.state_dict_trainer_options_key()][zero_stage] = \ - self.options.distributed.deepspeed_zero_optimization.stage + state_dict[_utils.state_dict_trainer_options_key()][ + zero_stage + ] = self.options.distributed.deepspeed_zero_optimization.stage state_dict[_utils.state_dict_trainer_options_key()][world_rank] = self.options.distributed.world_rank state_dict[_utils.state_dict_trainer_options_key()][world_size] = self.options.distributed.world_size state_dict[_utils.state_dict_trainer_options_key()][optimizer_name] = self.optim_config.name @@ -1138,9 +1198,11 @@ def state_dict(self, pytorch_format=False): A dictionary with `ORTTrainer` state """ if not self._training_session: - warnings.warn("ONNX Runtime training session is not initialized yet. " - "Please run train_step or eval_step at least once before calling ORTTrainer.state_dict().", - UserWarning) + warnings.warn( + "ONNX Runtime training session is not initialized yet. " + "Please run train_step or eval_step at least once before calling ORTTrainer.state_dict().", + UserWarning, + ) return self._load_state_dict.args[0] if self._load_state_dict else {} state_dict = {} @@ -1166,7 +1228,10 @@ def state_dict(self, pytorch_format=False): self._extract_train_step_info(state_dict) # add partition information in case of a distributed run - if self.options.distributed.deepspeed_zero_optimization.stage > 0 or self.options.distributed.horizontal_parallel_size > 1: + if ( + self.options.distributed.deepspeed_zero_optimization.stage > 0 + or self.options.distributed.horizontal_parallel_size > 1 + ): state_dict[_utils.state_dict_partition_info_key()] = self._training_session.get_partition_info_map() return state_dict @@ -1204,8 +1269,9 @@ def _check_optimizer_mismatch(state_dict): # the state_dict optimizer_name can be a byte string (if coming from checkpoint file) # or can be a regular string (coming from user) - optimizer_name = \ - state_dict[_utils.state_dict_trainer_options_key()][_utils.state_dict_trainer_options_optimizer_name_key()] + optimizer_name = state_dict[_utils.state_dict_trainer_options_key()][ + _utils.state_dict_trainer_options_optimizer_name_key() + ] # optimizer_name can be either a regular string or a byte string. # if it is a byte string, convert to regular string using decode() @@ -1214,8 +1280,9 @@ def _check_optimizer_mismatch(state_dict): optimizer_name = optimizer_name.decode() except AttributeError: pass - assert self.optim_config.name == optimizer_name, \ - "Optimizer mismatch: expected {}, got {}".format(self.optim_config.name, optimizer_name) + assert self.optim_config.name == optimizer_name, "Optimizer mismatch: expected {}, got {}".format( + self.optim_config.name, optimizer_name + ) if _utils.state_dict_optimizer_key() not in state_dict: return @@ -1232,9 +1299,9 @@ def _check_optimizer_mismatch(state_dict): if model_state_key not in current_state_dict[_utils.state_dict_optimizer_key()]: current_state_dict[_utils.state_dict_optimizer_key()][model_state_key] = {} for optimizer_state_key, optimizer_state_value in optimizer_dict.items(): - current_state_dict[_utils.state_dict_optimizer_key()][model_state_key][optimizer_state_key] = \ - optimizer_state_value - + current_state_dict[_utils.state_dict_optimizer_key()][model_state_key][ + optimizer_state_key + ] = optimizer_state_value def _load_state_dict_impl(self, state_dict, strict=True): """Load the state dictionary onto the onnx model and on the training session graph""" @@ -1263,28 +1330,41 @@ def _check_model_key_mismatch(current_state_dict, state_dict, allow_unexpected=F # check unxexpected and missing precision keys in the model state_dict compared to the training # session model state_dict - _mismatch_keys(current_state_dict[_utils.state_dict_model_key()], - state_dict[_utils.state_dict_model_key()], 'state_dict[model]', allow_unexpected) + _mismatch_keys( + current_state_dict[_utils.state_dict_model_key()], + state_dict[_utils.state_dict_model_key()], + "state_dict[model]", + allow_unexpected, + ) # check for model state key mismatch for precision_key in current_state_dict[_utils.state_dict_model_key()]: - _mismatch_keys(current_state_dict[_utils.state_dict_model_key()][precision_key], - state_dict[_utils.state_dict_model_key()][precision_key], - 'state_dict[model][{}]'.format(precision_key), allow_unexpected) + _mismatch_keys( + current_state_dict[_utils.state_dict_model_key()][precision_key], + state_dict[_utils.state_dict_model_key()][precision_key], + "state_dict[model][{}]".format(precision_key), + allow_unexpected, + ) def _check_optimizer_key_mismatch(current_state_dict, state_dict, allow_unexpected=False): """Check if there is any mismatch in the optimizer sub state dictionary between the two state_dicts""" # check for model state key mismatch for the optimizer state_dict - _mismatch_keys(current_state_dict[_utils.state_dict_optimizer_key()], - state_dict[_utils.state_dict_optimizer_key()], - 'state_dict[optimizer]', allow_unexpected) + _mismatch_keys( + current_state_dict[_utils.state_dict_optimizer_key()], + state_dict[_utils.state_dict_optimizer_key()], + "state_dict[optimizer]", + allow_unexpected, + ) # check for optimizer state keys mismatch for model_state_key in current_state_dict[_utils.state_dict_optimizer_key()]: - _mismatch_keys(current_state_dict[_utils.state_dict_optimizer_key()][model_state_key], - state_dict[_utils.state_dict_optimizer_key()][model_state_key], - 'state_dict[optimizer][{}]'.format(model_state_key), allow_unexpected) + _mismatch_keys( + current_state_dict[_utils.state_dict_optimizer_key()][model_state_key], + state_dict[_utils.state_dict_optimizer_key()][model_state_key], + "state_dict[optimizer][{}]".format(model_state_key), + allow_unexpected, + ) def _check_key_mismatch(current_state_dict, state_dict, allow_unexpected=False): """Check if there is a mismatch in the keys (model and optimizer) in the two state_dicts""" @@ -1320,8 +1400,11 @@ def _check_key_mismatch(current_state_dict, state_dict, allow_unexpected=False): # dictionary self._load_optimizer_states(current_state_dict, state_dict) - return current_state_dict[_utils.state_dict_optimizer_key()] if \ - _utils.state_dict_optimizer_key() in current_state_dict else {} + return ( + current_state_dict[_utils.state_dict_optimizer_key()] + if _utils.state_dict_optimizer_key() in current_state_dict + else {} + ) def _load_train_step_info(self, state_dict): """Load the train step info settings from state dict""" @@ -1365,9 +1448,11 @@ def load_state_dict(self, state_dict, strict=True): # create a new training session after loading initializer states onto the onnx graph # pass the populated states to the training session to populate the backend graph - self._init_session(optimizer_state_dict, - session_options=self.options.session_options, - provider_options=self.options._validated_opts['provider_options']) + self._init_session( + optimizer_state_dict, + session_options=self.options.session_options, + provider_options=self.options._validated_opts["provider_options"], + ) def save_checkpoint(self, path, user_dict={}, include_optimizer_states=True): """Persists ORTTrainer state dictionary on disk along with user_dict. @@ -1404,8 +1489,10 @@ def _aggregation_required(self, loaded_trainer_options): # To load states in the backend, aggregation is required for every ZeRO # or Megatron checkpoint - return loaded_trainer_options[_utils.state_dict_trainer_options_zero_stage_key()] > 0 or \ - loaded_trainer_options[_utils.state_dict_trainer_options_horizontal_parallel_size_key()] > 1 + return ( + loaded_trainer_options[_utils.state_dict_trainer_options_zero_stage_key()] > 0 + or loaded_trainer_options[_utils.state_dict_trainer_options_horizontal_parallel_size_key()] > 1 + ) def load_checkpoint(self, *paths, strict=True): """Loads the saved checkpoint state dictionary into the ORTTrainer diff --git a/orttraining/orttraining/python/training/orttrainer_options.py b/orttraining/orttraining/python/training/orttrainer_options.py index 44ad78c12fab6..9431e2eaa48ed 100644 --- a/orttraining/orttraining/python/training/orttrainer_options.py +++ b/orttraining/orttraining/python/training/orttrainer_options.py @@ -6,6 +6,7 @@ from . import PropagateCastOpsStrategy import onnxruntime as ort + class ORTTrainerOptions(object): r"""Settings used by ONNX Runtime training backend @@ -389,7 +390,7 @@ class ORTTrainerOptions(object): graph_transformer.propagate_cast_ops_config.allow(list of str, []) List of opcodes to be considered safe to move before/after cast operation if propagate_cast_ops_level is zero. graph_transformer.allow_layer_norm_mod_precision(bool, default False) - Enable LayerNormalization/SimplifiedLayerNormalization fusion + Enable LayerNormalization/SimplifiedLayerNormalization fusion even if it requires modified compute precision attn_dropout_recompute (bool, default is False): enable recomputing attention dropout to save memory @@ -443,8 +444,8 @@ class ORTTrainerOptions(object): This flag may be removed anytime in the future. session_options (onnxruntime.SessionOptions): The SessionOptions instance that TrainingSession will use. - provider_options (dict): - The provider_options for customized execution providers. it is dict map from EP name to + provider_options (dict): + The provider_options for customized execution providers. it is dict map from EP name to a key-value pairs, like {'EP1' : {'key1' : 'val1'}, ....} Example: @@ -465,7 +466,7 @@ class ORTTrainerOptions(object): } }) fp16_enabled = opts.mixed_precision.enabled - """ + """ def __init__(self, options={}): # Keep a copy of original input for debug @@ -476,20 +477,23 @@ def __init__(self, options={}): # Validates user input self._validated_opts = dict(self._original_opts) - validator = ORTTrainerOptionsValidator( - _ORTTRAINER_OPTIONS_SCHEMA) + validator = ORTTrainerOptionsValidator(_ORTTRAINER_OPTIONS_SCHEMA) self._validated_opts = validator.validated(self._validated_opts) if self._validated_opts is None: - raise ValueError(f'Invalid options: {validator.errors}') + raise ValueError(f"Invalid options: {validator.errors}") # Convert dict in object for k, v in self._validated_opts.items(): setattr(self, k, self._wrap(v)) def __repr__(self): - return '{%s}' % str(', '.join("'%s': %s" % (k, repr(v)) - for (k, v) in self.__dict__.items() - if k not in ['_original_opts', '_validated_opts', '_main_class_name'])) + return "{%s}" % str( + ", ".join( + "'%s': %s" % (k, repr(v)) + for (k, v) in self.__dict__.items() + if k not in ["_original_opts", "_validated_opts", "_main_class_name"] + ) + ) def _wrap(self, v): if isinstance(v, (tuple, list, set, frozenset)): @@ -517,22 +521,20 @@ def __init__(self, main_class_name, options): class ORTTrainerOptionsValidator(cerberus.Validator): - _LR_SCHEDULER = cerberus.TypeDefinition( - 'lr_scheduler', (lr_scheduler._LRScheduler,), ()) - _LOSS_SCALER = cerberus.TypeDefinition( - 'loss_scaler', (loss_scaler.LossScaler,), ()) + _LR_SCHEDULER = cerberus.TypeDefinition("lr_scheduler", (lr_scheduler._LRScheduler,), ()) + _LOSS_SCALER = cerberus.TypeDefinition("loss_scaler", (loss_scaler.LossScaler,), ()) - _SESSION_OPTIONS = cerberus.TypeDefinition( - 'session_options', (ort.SessionOptions,),()) + _SESSION_OPTIONS = cerberus.TypeDefinition("session_options", (ort.SessionOptions,), ()) _PROPAGATE_CAST_OPS_STRATEGY = cerberus.TypeDefinition( - "propagate_cast_ops_strategy", (PropagateCastOpsStrategy,),()) + "propagate_cast_ops_strategy", (PropagateCastOpsStrategy,), () + ) types_mapping = cerberus.Validator.types_mapping.copy() - types_mapping['lr_scheduler'] = _LR_SCHEDULER - types_mapping['loss_scaler'] = _LOSS_SCALER - types_mapping['session_options'] = _SESSION_OPTIONS - types_mapping['propagate_cast_ops_strategy'] = _PROPAGATE_CAST_OPS_STRATEGY + types_mapping["lr_scheduler"] = _LR_SCHEDULER + types_mapping["loss_scaler"] = _LOSS_SCALER + types_mapping["session_options"] = _SESSION_OPTIONS + types_mapping["propagate_cast_ops_strategy"] = _PROPAGATE_CAST_OPS_STRATEGY def _check_is_callable(field, value, error): @@ -542,303 +544,157 @@ def _check_is_callable(field, value, error): result = value is None or callable(value) except: # Python 3 but < 3.2 - if hasattr(value, '__call__'): + if hasattr(value, "__call__"): result = True if not result: error(field, "Must be callable or None") _ORTTRAINER_OPTIONS_SCHEMA = { - 'batch': { - 'type': 'dict', - 'default_setter': lambda _: {}, - 'required': False, - 'schema': { - 'gradient_accumulation_steps': { - 'type': 'integer', - 'min': 1, - 'default': 1 - } - }, + "batch": { + "type": "dict", + "default_setter": lambda _: {}, + "required": False, + "schema": {"gradient_accumulation_steps": {"type": "integer", "min": 1, "default": 1}}, }, - 'device': { - 'type': 'dict', - 'default_setter': lambda _: {}, - 'required': False, - 'schema': { - 'id': { - 'type': 'string', - 'default': 'cuda' - }, - 'mem_limit': { - 'type': 'integer', - 'min': 0, - 'default': 0 - } - } + "device": { + "type": "dict", + "default_setter": lambda _: {}, + "required": False, + "schema": { + "id": {"type": "string", "default": "cuda"}, + "mem_limit": {"type": "integer", "min": 0, "default": 0}, + }, }, - 'distributed': { - 'type': 'dict', - 'default_setter': lambda _: {}, - 'required': False, - 'schema': { - 'world_rank': { - 'type': 'integer', - 'min': 0, - 'default': 0 - }, - 'world_size': { - 'type': 'integer', - 'min': 1, - 'default': 1 - }, - 'local_rank': { - 'type': 'integer', - 'min': 0, - 'default': 0 - }, - 'data_parallel_size': { - 'type': 'integer', - 'min': 1, - 'default': 1 - }, - 'horizontal_parallel_size': { - 'type': 'integer', - 'min': 1, - 'default': 1 - }, - 'pipeline_parallel' : { - 'type': 'dict', - 'default_setter': lambda _: {}, - 'required': False, - 'schema': { - 'pipeline_parallel_size': { - 'type': 'integer', - 'min': 1, - 'default': 1 - }, - 'num_pipeline_micro_batches': { - 'type': 'integer', - 'min': 1, - 'default': 1 - }, - 'pipeline_cut_info_string': { - 'type': 'string', - 'default': '' + "distributed": { + "type": "dict", + "default_setter": lambda _: {}, + "required": False, + "schema": { + "world_rank": {"type": "integer", "min": 0, "default": 0}, + "world_size": {"type": "integer", "min": 1, "default": 1}, + "local_rank": {"type": "integer", "min": 0, "default": 0}, + "data_parallel_size": {"type": "integer", "min": 1, "default": 1}, + "horizontal_parallel_size": {"type": "integer", "min": 1, "default": 1}, + "pipeline_parallel": { + "type": "dict", + "default_setter": lambda _: {}, + "required": False, + "schema": { + "pipeline_parallel_size": {"type": "integer", "min": 1, "default": 1}, + "num_pipeline_micro_batches": {"type": "integer", "min": 1, "default": 1}, + "pipeline_cut_info_string": {"type": "string", "default": ""}, + "sliced_schema": { + "type": "dict", + "default_setter": lambda _: {}, + "keysrules": {"type": "string"}, + "valuesrules": {"type": "list", "schema": {"type": "integer"}}, }, - 'sliced_schema': { - 'type': 'dict', - 'default_setter': lambda _: {}, - 'keysrules': {'type': 'string'}, - 'valuesrules': { - 'type': 'list', - 'schema': {'type': 'integer'} - } - }, - 'sliced_axes': { - 'type': 'dict', - 'default_setter': lambda _: {}, - 'keysrules': {'type': 'string'}, - 'valuesrules': {'type': 'integer'} + "sliced_axes": { + "type": "dict", + "default_setter": lambda _: {}, + "keysrules": {"type": "string"}, + "valuesrules": {"type": "integer"}, }, - 'sliced_tensor_names': { - 'type': 'list', - 'schema': {'type': 'string'}, - 'default': [] - } - } - }, - 'allreduce_post_accumulation': { - 'type': 'boolean', - 'default': False + "sliced_tensor_names": {"type": "list", "schema": {"type": "string"}, "default": []}, + }, }, - 'deepspeed_zero_optimization': { - 'type': 'dict', - 'default_setter': lambda _: {}, - 'required': False, - 'schema': { - 'stage': { - 'type': 'integer', - 'min': 0, - 'max': 1, - 'default': 0 - }, - } + "allreduce_post_accumulation": {"type": "boolean", "default": False}, + "deepspeed_zero_optimization": { + "type": "dict", + "default_setter": lambda _: {}, + "required": False, + "schema": { + "stage": {"type": "integer", "min": 0, "max": 1, "default": 0}, + }, }, - 'enable_adasum': { - 'type': 'boolean', - 'default': False - } - } - }, - 'lr_scheduler': { - 'type': 'lr_scheduler', - 'nullable': True, - 'default': None + "enable_adasum": {"type": "boolean", "default": False}, + }, }, - 'mixed_precision': { - 'type': 'dict', - 'default_setter': lambda _: {}, - 'required': False, - 'schema': { - 'enabled': { - 'type': 'boolean', - 'default': False - }, - 'loss_scaler': { - 'type': 'loss_scaler', - 'nullable': True, - 'default': None - } - } + "lr_scheduler": {"type": "lr_scheduler", "nullable": True, "default": None}, + "mixed_precision": { + "type": "dict", + "default_setter": lambda _: {}, + "required": False, + "schema": { + "enabled": {"type": "boolean", "default": False}, + "loss_scaler": {"type": "loss_scaler", "nullable": True, "default": None}, + }, }, - 'graph_transformer': { - 'type': 'dict', - 'default_setter': lambda _: {}, - 'required': False, - 'schema': { - 'attn_dropout_recompute': { - 'type': 'boolean', - 'default': False - }, - 'gelu_recompute': { - 'type': 'boolean', - 'default': False - }, - 'transformer_layer_recompute': { - 'type': 'boolean', - 'default': False - }, - 'number_recompute_layers': { - 'type': 'integer', - 'min': 0, - 'default': 0 - }, - 'allow_layer_norm_mod_precision': { - 'type': 'boolean', - 'default': False - }, - 'propagate_cast_ops_config': { - 'type': 'dict', - 'default_setter': lambda _: {}, - 'required': False, - 'schema': { - 'strategy': { - 'type': 'propagate_cast_ops_strategy', - 'nullable': True, - 'default': PropagateCastOpsStrategy.FLOOD_FILL - }, - 'level': { - 'type': 'integer', - 'min': -1, - 'default': 1 + "graph_transformer": { + "type": "dict", + "default_setter": lambda _: {}, + "required": False, + "schema": { + "attn_dropout_recompute": {"type": "boolean", "default": False}, + "gelu_recompute": {"type": "boolean", "default": False}, + "transformer_layer_recompute": {"type": "boolean", "default": False}, + "number_recompute_layers": {"type": "integer", "min": 0, "default": 0}, + "allow_layer_norm_mod_precision": {"type": "boolean", "default": False}, + "propagate_cast_ops_config": { + "type": "dict", + "default_setter": lambda _: {}, + "required": False, + "schema": { + "strategy": { + "type": "propagate_cast_ops_strategy", + "nullable": True, + "default": PropagateCastOpsStrategy.FLOOD_FILL, }, - 'allow': { - 'type': 'list', - 'schema': {'type': 'string'}, - 'default': [] - } - } - } - } - }, - 'utils': { - 'type': 'dict', - 'default_setter': lambda _: {}, - 'required': False, - 'schema': { - 'frozen_weights': { - 'type': 'list', - 'default': [] - }, - 'grad_norm_clip': { - 'type': 'boolean', - 'default': True - }, - 'memory_efficient_gradient': { - 'type': 'boolean', - 'default': False + "level": {"type": "integer", "min": -1, "default": 1}, + "allow": {"type": "list", "schema": {"type": "string"}, "default": []}, + }, }, - 'run_symbolic_shape_infer': { - 'type': 'boolean', - 'default': False - } - } + }, }, - 'debug': { - 'type': 'dict', - 'default_setter': lambda _: {}, - 'required': False, - 'schema': { - 'deterministic_compute': { - 'type': 'boolean', - 'default': False - }, - 'check_model_export': { - 'type': 'boolean', - 'default': False - }, - 'graph_save_paths' : { - 'type' : 'dict', - 'default_setter': lambda _: {}, - 'required': False, - 'schema': { - 'model_after_graph_transforms_path': { - 'type': 'string', - 'default': '' - }, - 'model_with_gradient_graph_path':{ - 'type': 'string', - 'default': '' - }, - 'model_with_training_graph_path': { - 'type': 'string', - 'default': '' - }, - 'model_with_training_graph_after_optimization_path': { - 'type': 'string', - 'default': '' - }, - } - }, - } + "utils": { + "type": "dict", + "default_setter": lambda _: {}, + "required": False, + "schema": { + "frozen_weights": {"type": "list", "default": []}, + "grad_norm_clip": {"type": "boolean", "default": True}, + "memory_efficient_gradient": {"type": "boolean", "default": False}, + "run_symbolic_shape_infer": {"type": "boolean", "default": False}, + }, }, - '_internal_use': { - 'type': 'dict', - 'default_setter': lambda _: {}, - 'required': False, - 'schema': { - 'enable_internal_postprocess': { - 'type': 'boolean', - 'default': True - }, - 'extra_postprocess': { - 'check_with': _check_is_callable, - 'nullable': True, - 'default': None - }, - 'onnx_opset_version': { - 'type': 'integer', - 'min': 12, - 'max': 14, - 'default': 14 + "debug": { + "type": "dict", + "default_setter": lambda _: {}, + "required": False, + "schema": { + "deterministic_compute": {"type": "boolean", "default": False}, + "check_model_export": {"type": "boolean", "default": False}, + "graph_save_paths": { + "type": "dict", + "default_setter": lambda _: {}, + "required": False, + "schema": { + "model_after_graph_transforms_path": {"type": "string", "default": ""}, + "model_with_gradient_graph_path": {"type": "string", "default": ""}, + "model_with_training_graph_path": {"type": "string", "default": ""}, + "model_with_training_graph_after_optimization_path": {"type": "string", "default": ""}, + }, }, - 'enable_onnx_contrib_ops': { - 'type': 'boolean', - 'default': True - } - } + }, }, - 'provider_options':{ - 'type': 'dict', - 'default_setter': lambda _: {}, - 'required': False, - 'allow_unknown': True, - 'schema': {} + "_internal_use": { + "type": "dict", + "default_setter": lambda _: {}, + "required": False, + "schema": { + "enable_internal_postprocess": {"type": "boolean", "default": True}, + "extra_postprocess": {"check_with": _check_is_callable, "nullable": True, "default": None}, + "onnx_opset_version": {"type": "integer", "min": 12, "max": 14, "default": 14}, + "enable_onnx_contrib_ops": {"type": "boolean", "default": True}, + }, }, - 'session_options': { - 'type': 'session_options', - 'nullable': True, - 'default': None + "provider_options": { + "type": "dict", + "default_setter": lambda _: {}, + "required": False, + "allow_unknown": True, + "schema": {}, }, + "session_options": {"type": "session_options", "nullable": True, "default": None}, } diff --git a/orttraining/orttraining/python/training/postprocess.py b/orttraining/orttraining/python/training/postprocess.py index 1f406f471c0bd..ff77a05e41e31 100644 --- a/orttraining/orttraining/python/training/postprocess.py +++ b/orttraining/orttraining/python/training/postprocess.py @@ -21,13 +21,15 @@ def run_postprocess(model): model = fix_expand_shape_pt_1_5(model) return model + def find_input_node(model, arg): result = [] for node in model.graph.node: for output in node.output: if output == arg: result.append(node) - return result[0] if len(result)== 1 else None + return result[0] if len(result) == 1 else None + def find_output_node(model, arg): result = [] @@ -37,22 +39,25 @@ def find_output_node(model, arg): result.append(node) return result[0] if len(result) == 1 else result + def add_name(model): i = 0 for node in model.graph.node: - node.name = '%s_%d' %(node.op_type, i) - i += 1 + node.name = "%s_%d" % (node.op_type, i) + i += 1 return model + # Expand Shape PostProcess + def fix_expand_shape(model): - expand_nodes = [n for n in model.graph.node if n.op_type == 'Expand'] + expand_nodes = [n for n in model.graph.node if n.op_type == "Expand"] model_inputs_names = [i.name for i in model.graph.input] for expand_node in expand_nodes: shape = find_input_node(model, expand_node.input[1]) - if shape.op_type == 'Shape': + if shape.op_type == "Shape": # an expand subgraph # Input Input2 # | | @@ -73,6 +78,7 @@ def fix_expand_shape(model): expand_out.type.CopyFrom(model.graph.input[index].type) return model + def fix_expand_shape_pt_1_5(model): # expand subgraph # Constant @@ -120,47 +126,47 @@ def fix_expand_shape_pt_1_5(model): # output # # This pass will copy Input's shape to the output of Expand. - expand_nodes = [n for n in model.graph.node if n.op_type == 'Expand'] + expand_nodes = [n for n in model.graph.node if n.op_type == "Expand"] model_inputs_names = [i.name for i in model.graph.input] for expand_node in expand_nodes: n_where = find_input_node(model, expand_node.input[1]) - if n_where.op_type != 'Where': + if n_where.op_type != "Where": continue n_equal = find_input_node(model, n_where.input[0]) n_cos = find_input_node(model, n_where.input[1]) n_reshape = find_input_node(model, n_where.input[2]) - if n_equal.op_type != 'Equal' or n_cos.op_type != 'ConstantOfShape' or n_reshape.op_type != 'Reshape': + if n_equal.op_type != "Equal" or n_cos.op_type != "ConstantOfShape" or n_reshape.op_type != "Reshape": continue n_reshape_e = find_input_node(model, n_equal.input[0]) n_mul = find_input_node(model, n_equal.input[1]) - if n_reshape_e != n_reshape or n_mul.op_type != 'Mul': + if n_reshape_e != n_reshape or n_mul.op_type != "Mul": continue n_cos_m = find_input_node(model, n_mul.input[0]) n_constant = find_input_node(model, n_mul.input[1]) - if n_cos_m != n_cos or n_constant.op_type != 'Constant': + if n_cos_m != n_cos or n_constant.op_type != "Constant": continue n_concat = find_input_node(model, n_reshape.input[0]) n_constant_r = find_input_node(model, n_reshape.input[1]) - if n_concat.op_type != 'Concat' or n_constant_r.op_type != 'Constant': + if n_concat.op_type != "Concat" or n_constant_r.op_type != "Constant": continue n_input_candidates = [] for concat_in in n_concat.input: n_unsqueeze = find_input_node(model, concat_in) - if n_unsqueeze.op_type != 'Unsqueeze': + if n_unsqueeze.op_type != "Unsqueeze": break n_gather = find_input_node(model, n_unsqueeze.input[0]) - if n_gather.op_type != 'Gather': + if n_gather.op_type != "Gather": break n_shape = find_input_node(model, n_gather.input[0]) n_constant_g = find_input_node(model, n_gather.input[1]) - if n_shape.op_type != 'Shape' or n_constant_g.op_type != 'Constant': + if n_shape.op_type != "Shape" or n_constant_g.op_type != "Constant": break n_input = n_shape.input[0] if not n_input in model_inputs_names: @@ -176,8 +182,10 @@ def fix_expand_shape_pt_1_5(model): expand_out.type.CopyFrom(model.graph.input[index].type) return model + # LayerNorm PostProcess + def find_nodes(graph, op_type): nodes = [] for node in graph.node: @@ -185,18 +193,20 @@ def find_nodes(graph, op_type): nodes.append(node) return nodes + def is_type(node, op_type): if node is None or isinstance(node, list): return False return node.op_type == op_type -def add_const(model, name, output, t_value = None, f_value = None): + +def add_const(model, name, output, t_value=None, f_value=None): const_node = model.graph.node.add() - const_node.op_type = 'Constant' + const_node.op_type = "Constant" const_node.name = name const_node.output.extend([output]) attr = const_node.attribute.add() - attr.name = 'value' + attr.name = "value" if t_value is not None: attr.type = 4 attr.t.CopyFrom(t_value) @@ -205,6 +215,7 @@ def add_const(model, name, output, t_value = None, f_value = None): attr.f = f_value return const_node + def layer_norm_transform(model): # DEPRECATED: This pass is no longer needed as the transform is handled at the backend. # Converting below subgraph @@ -302,7 +313,7 @@ def layer_norm_transform(model): optional_mul = find_output_node(model, div.output[0]) if not is_type(optional_mul, "Mul"): optional_mul = None - continue # default bias and weight not supported + continue # default bias and weight not supported # check if mul output is Add if optional_mul is not None: @@ -311,8 +322,7 @@ def layer_norm_transform(model): optional_add = find_output_node(model, div.output[0]) if not is_type(optional_add, "Add"): optional_add = None - continue # default bias and weight not supported - + continue # default bias and weight not supported # add nodes to remove_nodes remove_nodes.extend([reduce_mean, sub, div, pow, reduce_mean2, add, sqrt]) @@ -340,20 +350,22 @@ def layer_norm_transform(model): else: layer_norm_output.append(div.output[0]) - layer_norm_output.append('saved_mean_' + str(id)) - layer_norm_output.append('saved_inv_std_var_' + str(id)) + layer_norm_output.append("saved_mean_" + str(id)) + layer_norm_output.append("saved_inv_std_var_" + str(id)) epsilon_node = find_input_node(model, add.input[1]) epsilon = epsilon_node.attribute[0].t.raw_data - epsilon = struct.unpack('f', epsilon)[0] - - layer_norm = helper.make_node("LayerNormalization", - layer_norm_input, - layer_norm_output, - "LayerNormalization_" + str(id), - None, - axis = reduce_mean.attribute[0].ints[0], - epsilon = epsilon) + epsilon = struct.unpack("f", epsilon)[0] + + layer_norm = helper.make_node( + "LayerNormalization", + layer_norm_input, + layer_norm_output, + "LayerNormalization_" + str(id), + None, + axis=reduce_mean.attribute[0].ints[0], + epsilon=epsilon, + ) layer_norm_nodes.append(layer_norm) id += 1 @@ -380,8 +392,10 @@ def layer_norm_transform(model): graph.node.extend(all_nodes) return model + # Fuse SoftmaxCrossEntropy + def fuse_softmaxNLL_to_softmaxCE(onnx_model): # Converting below subgraph # @@ -448,9 +462,18 @@ def fuse_softmaxNLL_to_softmaxCE(onnx_model): probability_output_name = softmax_node.output[0] node = onnx_model.graph.node.add() - inputs = [softmax_node.input[0], label_input_name, weight_input_name] if weight_input_name else [softmax_node.input[0], label_input_name] - node.CopyFrom(onnx.helper.make_node("SparseSoftmaxCrossEntropy", inputs, - [nll_loss_node.output[0], probability_output_name], - "nll_loss_node_" + str(nll_count))) + inputs = ( + [softmax_node.input[0], label_input_name, weight_input_name] + if weight_input_name + else [softmax_node.input[0], label_input_name] + ) + node.CopyFrom( + onnx.helper.make_node( + "SparseSoftmaxCrossEntropy", + inputs, + [nll_loss_node.output[0], probability_output_name], + "nll_loss_node_" + str(nll_count), + ) + ) return onnx_model diff --git a/orttraining/orttraining/python/training/utils/data/sampler.py b/orttraining/orttraining/python/training/utils/data/sampler.py index e0e42b2b3caab..932f9e76dc13c 100644 --- a/orttraining/orttraining/python/training/utils/data/sampler.py +++ b/orttraining/orttraining/python/training/utils/data/sampler.py @@ -26,9 +26,9 @@ def _shard_wrapped_indices_across_workers(dataset_index_list, num_shards, num_sa def shard_wrapped_indices_for_worker(dataset_index_list, shard_id, num_shards): """Shard wrapped around dataset_index_list across num_shards and return the indices for this shard_id""" num_samples_per_worker = (len(dataset_index_list) + num_shards - 1) // num_shards - sharded_indices = list(_shard_wrapped_indices_across_workers(dataset_index_list, - num_shards, - num_samples_per_worker)) + sharded_indices = list( + _shard_wrapped_indices_across_workers(dataset_index_list, num_shards, num_samples_per_worker) + ) return [sharded_indices[i][shard_id] for i in range(len(sharded_indices))] @@ -122,10 +122,7 @@ def __init__( raise RuntimeError("Requires distributed package to be available") rank = dist.get_rank() if rank >= world_size or rank < 0: - raise ValueError( - "Invalid rank {}, rank should be in the interval" - " [0, {}]".format(rank, world_size - 1) - ) + raise ValueError("Invalid rank {}, rank should be in the interval" " [0, {}]".format(rank, world_size - 1)) self.dataset = dataset self.world_size = world_size self.rank = rank @@ -152,11 +149,7 @@ def __init__( self.ordered_sample_complexities = None if random_level < 0.0 or random_level > 1.0: - raise ValueError( - "Invalid random level {}, shoule be in the range [0.0, 1.0]".format( - random_level - ) - ) + raise ValueError("Invalid random level {}, shoule be in the range [0.0, 1.0]".format(random_level)) self.random_level = random_level self.random_number = None @@ -178,8 +171,9 @@ def sort_in_groups(sample_complexities, group_size): # Sort the dataset samples inside each group of the dataset based on sample complexity. for group_begin_index in range(0, len(sample_complexities), group_size): group_end_index = min(group_begin_index + group_size, len(sample_complexities)) - sorted_indices = \ - group_begin_index + np.argsort(sample_complexities[group_begin_index:group_end_index, 1]) + sorted_indices = group_begin_index + np.argsort( + sample_complexities[group_begin_index:group_end_index, 1] + ) sample_complexities[group_begin_index:group_end_index, :] = sample_complexities[sorted_indices] return sample_complexities @@ -224,12 +218,13 @@ def sort_in_groups(sample_complexities, group_size): end = 0 sample_complexities_copy = ordered_sample_complexities.copy() for group_index in group_order: - original_list_begin_index = self.group_size*group_index - original_list_end_index = min(original_list_begin_index+self.group_size, len(sample_complexities)) + original_list_begin_index = self.group_size * group_index + original_list_end_index = min(original_list_begin_index + self.group_size, len(sample_complexities)) begin = end end = begin + (original_list_end_index - original_list_begin_index) - sample_complexities_copy[begin:end, :] = \ - sample_complexities[original_list_begin_index:original_list_end_index, :] + sample_complexities_copy[begin:end, :] = sample_complexities[ + original_list_begin_index:original_list_end_index, : + ] ordered_sample_complexities = sample_complexities_copy # Shard the data across the different workers. @@ -237,7 +232,7 @@ def sort_in_groups(sample_complexities, group_size): _shard_wrapped_indices_across_workers( [index_complexity_tuple[0] for index_complexity_tuple in ordered_sample_complexities], self.world_size, - self.num_samples + self.num_samples, ) ) @@ -252,9 +247,7 @@ def sort_in_groups(sample_complexities, group_size): if padding_size <= len(chunk_indices): chunk_indices += chunk_indices[:padding_size] else: - chunk_indices += ( - chunk_indices * math.ceil(padding_size / len(chunk_indices)) - )[:padding_size] + chunk_indices += (chunk_indices * math.ceil(padding_size / len(chunk_indices)))[:padding_size] else: # Remove tail of data to make it evenly divisible. chunk_indices = chunk_indices[: self.num_samples] @@ -314,9 +307,7 @@ def __init__( drop_last: bool = False, ) -> None: if not isinstance(sampler, LoadBalancingDistributedSampler): - raise ValueError( - "sampler should be of LoadBalancingDistributedSampler type." - ) + raise ValueError("sampler should be of LoadBalancingDistributedSampler type.") if sampler.drop_last: raise ValueError("drop_last of sampler should be False") @@ -338,20 +329,14 @@ def generate_batches(self): sub_indices = [index_chunks[i][rank] for i in chunk_indices] batches.append(self.batch_fn(sub_indices)) - self.total_batch = ( - max([len(b) for b in batches]) - if not self.drop_last - else min([len(b) for b in batches]) - ) + self.total_batch = max([len(b) for b in batches]) if not self.drop_last else min([len(b) for b in batches]) # here {len(batches[self.rank]) - self.total_batch} batches dropped for # rank {self.rank} if self.total_batch < len(batches[self.rank]): pass - self.padded_batches = [ - batch + batch[: self.total_batch - len(batch)] for batch in batches - ] + self.padded_batches = [batch + batch[: self.total_batch - len(batch)] for batch in batches] def __iter__(self): return iter(self.padded_batches[self.rank]) diff --git a/orttraining/orttraining/test/external_custom_ops/setup.py b/orttraining/orttraining/test/external_custom_ops/setup.py index 65d5ecc29e2ad..57ba10b91ad2d 100644 --- a/orttraining/orttraining/test/external_custom_ops/setup.py +++ b/orttraining/orttraining/test/external_custom_ops/setup.py @@ -24,28 +24,32 @@ def build_extension(self, ext): if not os.path.exists(self.build_temp): os.makedirs(self.build_temp) subprocess.check_call( - ["cmake", - "-DPYBIND11_PYTHON_VERSION={}.{}.{}".format(sys.version_info.major, sys.version_info.minor, - sys.version_info.micro), - "-Dpybind11_DIR={}".format(pybind11.get_cmake_dir()), - "-DONNX_INCLUDE={}".format( - os.path.dirname(os.path.dirname(onnx.__file__))), - "-DONNXRUNTIME_EXTERNAL_INCLUDE={}".format(os.path.join(os.path.join(os.path.dirname(onnxruntime.__file__), - "external"), "include")), - "-DCMAKE_LIBRARY_OUTPUT_DIRECTORY={}".format(extdir), - ext.sourcedir], cwd=self.build_temp) - subprocess.check_call( - ["cmake", "--build", "."], cwd=self.build_temp + [ + "cmake", + "-DPYBIND11_PYTHON_VERSION={}.{}.{}".format( + sys.version_info.major, sys.version_info.minor, sys.version_info.micro + ), + "-Dpybind11_DIR={}".format(pybind11.get_cmake_dir()), + "-DONNX_INCLUDE={}".format(os.path.dirname(os.path.dirname(onnx.__file__))), + "-DONNXRUNTIME_EXTERNAL_INCLUDE={}".format( + os.path.join(os.path.join(os.path.dirname(onnxruntime.__file__), "external"), "include") + ), + "-DCMAKE_LIBRARY_OUTPUT_DIRECTORY={}".format(extdir), + ext.sourcedir, + ], + cwd=self.build_temp, ) + subprocess.check_call(["cmake", "--build", "."], cwd=self.build_temp) -setup(name='orttraining_external_custom_ops', - version='0.1', - author='', - author_email='', - description='External custom ops example', - long_description='', - ext_modules=[CMakeExtension('orttrainng_external_custom_ops')], - cmdclass=dict(build_ext=CMakeBuild), - zip_safe=False - ) +setup( + name="orttraining_external_custom_ops", + version="0.1", + author="", + author_email="", + description="External custom ops example", + long_description="", + ext_modules=[CMakeExtension("orttrainng_external_custom_ops")], + cmdclass=dict(build_ext=CMakeBuild), + zip_safe=False, +) diff --git a/orttraining/orttraining/test/external_custom_ops/test.py b/orttraining/orttraining/test/external_custom_ops/test.py index e1474425218b7..7d3e4edf48bd8 100644 --- a/orttraining/orttraining/test/external_custom_ops/test.py +++ b/orttraining/orttraining/test/external_custom_ops/test.py @@ -5,14 +5,16 @@ import sys import numpy as np -# Expose available (onnx::* and protobuf::*) symbols from onnxruntime to resolve references in +# Expose available (onnx::* and protobuf::*) symbols from onnxruntime to resolve references in # the custom ops shared library. Deepbind flag is required to avoid conflicts with other # instances of onnx/protobuf libraries. import onnxruntime + # Restore dlopen flags. import orttraining_external_custom_ops + so = onnxruntime.SessionOptions() sess = onnxruntime.InferenceSession("testdata/model.onnx", so) input = np.random.rand(2, 2).astype(np.float32) -output = sess.run(None, {"input1" : input})[0] +output = sess.run(None, {"input1": input})[0] np.testing.assert_equal(input, output) diff --git a/orttraining/orttraining/test/external_custom_ops/testdata/gen_model.py b/orttraining/orttraining/test/external_custom_ops/testdata/gen_model.py index 84de72bd7afc2..ce1419e0cd42a 100644 --- a/orttraining/orttraining/test/external_custom_ops/testdata/gen_model.py +++ b/orttraining/orttraining/test/external_custom_ops/testdata/gen_model.py @@ -3,12 +3,13 @@ import onnx from onnx import helper + graph = helper.make_graph( [helper.make_node("Foo", ["input1"], ["output1"], "", "", "com.examples")], "external_custom_op_example_model", - [helper.make_tensor_value_info("input1", helper.TensorProto.FLOAT, [2,2])], - [helper.make_tensor_value_info("output1", helper.TensorProto.FLOAT, [2,2])], - [] + [helper.make_tensor_value_info("input1", helper.TensorProto.FLOAT, [2, 2])], + [helper.make_tensor_value_info("output1", helper.TensorProto.FLOAT, [2, 2])], + [], ) model = helper.make_model(graph) opset = model.opset_import.add() diff --git a/orttraining/orttraining/test/external_transformer/test/external_transformers_test.py b/orttraining/orttraining/test/external_transformer/test/external_transformers_test.py index 2d87a9e82432e..a1377d2448bfd 100644 --- a/orttraining/orttraining/test/external_transformer/test/external_transformers_test.py +++ b/orttraining/orttraining/test/external_transformer/test/external_transformers_test.py @@ -2,10 +2,12 @@ import threading import time + class OutputGrabber(object): """ Class used to grab standard output or another stream. """ + escape_char = "\b" def __init__(self, stream=None, threaded=False): @@ -40,7 +42,7 @@ def start(self): self.workerThread.start() # Make sure that the thread is running and os.read() has executed: time.sleep(0.01) - + def stop(self): """ Stop capturing the stream data and save the text in `capturedtext`. @@ -70,11 +72,12 @@ def readOutput(self): and save the text in `capturedtext`. """ while True: - char = os.read(self.pipe_out,1).decode(self.origstream.encoding) + char = os.read(self.pipe_out, 1).decode(self.origstream.encoding) if not char or self.escape_char in char: break self.capturedtext += char + import torch from onnxruntime.capi import _pybind_state as torch_ort_eager import torch.nn as nn @@ -84,9 +87,11 @@ def readOutput(self): from onnxruntime.training import optim, orttrainer, orttrainer_options import unittest + def my_loss(x, target): return F.nll_loss(F.log_softmax(x, dim=1), target) + class NeuralNet(nn.Module): def __init__(self, input_size, hidden_size, num_classes): super(NeuralNet, self).__init__() @@ -100,6 +105,7 @@ def forward(self, x, target): out = self.fc2(out) return my_loss(out, target) + class OrtEPTests(unittest.TestCase): def test_external_graph_transformer_triggering(self): input_size = 784 @@ -107,20 +113,30 @@ def test_external_graph_transformer_triggering(self): num_classes = 10 batch_size = 128 model = NeuralNet(input_size, hidden_size, num_classes) - - model_desc = {'inputs': [('x', [batch_size, input_size]), - ('target', [batch_size,])], - 'outputs': [('loss', [], True)]} + + model_desc = { + "inputs": [ + ("x", [batch_size, input_size]), + ( + "target", + [ + batch_size, + ], + ), + ], + "outputs": [("loss", [], True)], + } optim_config = optim.SGDConfig() - opts = orttrainer.ORTTrainerOptions({'device':{'id':'cpu'}}) + opts = orttrainer.ORTTrainerOptions({"device": {"id": "cpu"}}) model = orttrainer.ORTTrainer(model, model_desc, optim_config, options=opts) # because orttrainer is lazy initialized, feed in a random data to trigger the graph transformer data = torch.rand(batch_size, input_size) target = torch.randint(0, 10, (batch_size,)) - + with OutputGrabber() as out: loss = model.train_step(data, target) - assert '******************Trigger Customized Graph Transformer: MyGraphTransformer!' in out.capturedtext + assert "******************Trigger Customized Graph Transformer: MyGraphTransformer!" in out.capturedtext + -if __name__ == '__main__': - unittest.main() \ No newline at end of file +if __name__ == "__main__": + unittest.main() diff --git a/orttraining/orttraining/test/python/_orttraining_ortmodule_models.py b/orttraining/orttraining/test/python/_orttraining_ortmodule_models.py index 0e3607ab0e1a8..f1e5c22bfa486 100644 --- a/orttraining/orttraining/test/python/_orttraining_ortmodule_models.py +++ b/orttraining/orttraining/test/python/_orttraining_ortmodule_models.py @@ -8,7 +8,7 @@ class MyCustomClassInputNet(torch.nn.Module): def forward(self, x, custom_class_obj): if custom_class_obj.x == 1: - return x+1 + return x + 1 return x @@ -68,10 +68,11 @@ def forward(ctx, input): @staticmethod def backward(ctx, grad_output): - input, = ctx.saved_tensors + (input,) = ctx.saved_tensors grad_input = grad_output.clone() grad_input[input < 0] = 0 return grad_input + self.relu = MyReLU.apply def forward(self, input): diff --git a/orttraining/orttraining/test/python/_test_commons.py b/orttraining/orttraining/test/python/_test_commons.py index d9dd7d34af589..1b418df81de63 100644 --- a/orttraining/orttraining/test/python/_test_commons.py +++ b/orttraining/orttraining/test/python/_test_commons.py @@ -11,31 +11,33 @@ import onnxruntime from onnxruntime.training import optim, _utils -def _single_run(execution_file, scenario, checkopint_dir = None): + +def _single_run(execution_file, scenario, checkopint_dir=None): cmd = [sys.executable, execution_file] if scenario: - cmd += ['--scenario', scenario] + cmd += ["--scenario", scenario] if checkopint_dir: - cmd += ['--checkpoint_dir', checkopint_dir] + cmd += ["--checkpoint_dir", checkopint_dir] assert subprocess.call(cmd) == 0 -def _distributed_run(execution_file, scenario, checkopint_dir = None): + +def _distributed_run(execution_file, scenario, checkopint_dir=None): ngpus = torch.cuda.device_count() - cmd = ['mpirun', '-n', str(ngpus), '--tag-output', sys.executable, execution_file] + cmd = ["mpirun", "-n", str(ngpus), "--tag-output", sys.executable, execution_file] if scenario: - cmd += ['--scenario', scenario] + cmd += ["--scenario", scenario] if checkopint_dir: - cmd += ['--checkpoint_dir', checkopint_dir] + cmd += ["--checkpoint_dir", checkopint_dir] assert subprocess.call(cmd) == 0 + def is_windows(): return sys.platform.startswith("win") -def run_subprocess(args, cwd=None, capture=False, dll_path=None, - shell=False, env={}, log=None): + +def run_subprocess(args, cwd=None, capture=False, dll_path=None, shell=False, env={}, log=None): if log: - log.info("Running subprocess in '{0}'\n{1}".format( - cwd or os.getcwd(), args)) + log.info("Running subprocess in '{0}'\n{1}".format(cwd or os.getcwd(), args)) my_env = os.environ.copy() if dll_path: if is_windows(): @@ -46,16 +48,12 @@ def run_subprocess(args, cwd=None, capture=False, dll_path=None, else: my_env["LD_LIBRARY_PATH"] = dll_path - stdout, stderr = (subprocess.PIPE, subprocess.STDOUT) if capture else ( - None, None) + stdout, stderr = (subprocess.PIPE, subprocess.STDOUT) if capture else (None, None) my_env.update(env) - completed_process = subprocess.run( - args, cwd=cwd, check=True, stdout=stdout, stderr=stderr, - env=my_env, shell=shell) - + completed_process = subprocess.run(args, cwd=cwd, check=True, stdout=stdout, stderr=stderr, env=my_env, shell=shell) + if log: - log.debug("Subprocess completed. Return code=" + - str(completed_process.returncode)) + log.debug("Subprocess completed. Return code=" + str(completed_process.returncode)) return completed_process @@ -78,7 +76,6 @@ def legacy_cosine_lr_scheduler(global_step, initial_lr, total_steps, warmup, cyc return new_lr - def legacy_linear_lr_scheduler(global_step, initial_lr, total_steps, warmup): num_warmup_steps = warmup * total_steps if global_step < num_warmup_steps: @@ -98,7 +95,7 @@ def legacy_poly_lr_scheduler(global_step, initial_lr, total_steps, warmup, power lr_range = initial_lr - lr_end decay_steps = total_steps - num_warmup_steps pct_remaining = 1 - (global_step - num_warmup_steps) / decay_steps - decay = lr_range * pct_remaining ** power + lr_end + decay = lr_range * pct_remaining**power + lr_end new_lr = decay return new_lr @@ -132,30 +129,24 @@ def generate_dummy_optim_state(model, optimizer): if isinstance(optimizer, optim.LambConfig): step_val = np.full([1], 5, dtype=np.int64) optim_state[shared_state_key] = {step_key: step_val} - return { - 'optimizer': optim_state, - 'trainer_options': { - 'optimizer_name': optimizer.name - } - } + return {"optimizer": optim_state, "trainer_options": {"optimizer_name": optimizer.name}} + def _load_pytorch_transformer_model(device, dynamic_axes=False, legacy_api=False, data_dir=None): # Loads external Pytorch TransformerModel into utils - root = 'samples' + root = "samples" if not os.path.exists(root): root = os.path.normpath( - os.path.join( - os.path.dirname( - os.path.abspath(__file__)), "..", "..", "..", "..", "samples")) + os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", "..", "..", "..", "samples") + ) if not os.path.exists(root): - raise FileNotFoundError( - "Unable to find folder 'samples', tried %r." % root) - pytorch_transformer_path = os.path.join(root, 'python', 'training', 'orttrainer', 'pytorch_transformer') - pt_model_path = os.path.join(pytorch_transformer_path, 'pt_model.py') + raise FileNotFoundError("Unable to find folder 'samples', tried %r." % root) + pytorch_transformer_path = os.path.join(root, "python", "training", "orttrainer", "pytorch_transformer") + pt_model_path = os.path.join(pytorch_transformer_path, "pt_model.py") pt_model = _utils.import_module_from_file(pt_model_path) - ort_utils_path = os.path.join(pytorch_transformer_path, 'ort_utils.py') + ort_utils_path = os.path.join(pytorch_transformer_path, "ort_utils.py") ort_utils = _utils.import_module_from_file(ort_utils_path) - utils_path = os.path.join(pytorch_transformer_path, 'utils.py') + utils_path = os.path.join(pytorch_transformer_path, "utils.py") utils = _utils.import_module_from_file(utils_path) # Modeling @@ -172,20 +163,20 @@ def _load_pytorch_transformer_model(device, dynamic_axes=False, legacy_api=False else: model_desc = ort_utils.transformer_model_description() - # Preparing data train_data, val_data, test_data = utils.prepare_data(device, 20, 20, data_dir) return model, model_desc, my_loss, utils.get_batch, train_data, val_data, test_data -def generate_random_input_from_bart_model_desc(desc, seed=1, device = "cuda:0"): - '''Generates a sample input for the BART model using the model desc''' + +def generate_random_input_from_bart_model_desc(desc, seed=1, device="cuda:0"): + """Generates a sample input for the BART model using the model desc""" torch.manual_seed(seed) onnxruntime.set_seed(seed) dtype = torch.int64 vocab_size = 30528 sample_input = [] - for index, input in enumerate(desc['inputs']): + for index, input in enumerate(desc["inputs"]): size = [] for s in input[1]: if isinstance(s, (int)): @@ -195,47 +186,70 @@ def generate_random_input_from_bart_model_desc(desc, seed=1, device = "cuda:0"): sample_input.append(torch.randint(0, vocab_size, tuple(size), dtype=dtype).to(device)) return sample_input + def _load_bart_model(): - bart_onnx_model_path = os.path.join('testdata', "bart_tiny.onnx") + bart_onnx_model_path = os.path.join("testdata", "bart_tiny.onnx") model = onnx.load(bart_onnx_model_path) batch = 2 seq_len = 1024 model_desc = { - 'inputs': [ - ('src_tokens', [batch, seq_len],), - ('prev_output_tokens', [batch, seq_len],), - ('target', [batch*seq_len],)], - 'outputs': [ - ('loss', [], True)]} + "inputs": [ + ( + "src_tokens", + [batch, seq_len], + ), + ( + "prev_output_tokens", + [batch, seq_len], + ), + ( + "target", + [batch * seq_len], + ), + ], + "outputs": [("loss", [], True)], + } return model, model_desc + def assert_all_states_close_ort(state_dict_pre_checkpoint, state_dict_post_checkpoint, reshape_states=False): """Assert that the two ORTTrainer (hierarchical) state dictionaries are very close for all states""" - assert ('model' in state_dict_pre_checkpoint) == ('model' in state_dict_post_checkpoint) - assert ('optimizer' in state_dict_pre_checkpoint) == ('optimizer' in state_dict_post_checkpoint) + assert ("model" in state_dict_pre_checkpoint) == ("model" in state_dict_post_checkpoint) + assert ("optimizer" in state_dict_pre_checkpoint) == ("optimizer" in state_dict_post_checkpoint) - if 'model' in state_dict_pre_checkpoint: - for model_state_key in state_dict_pre_checkpoint['model']['full_precision']: + if "model" in state_dict_pre_checkpoint: + for model_state_key in state_dict_pre_checkpoint["model"]["full_precision"]: if reshape_states: - assert_allclose(state_dict_pre_checkpoint['model']['full_precision'][model_state_key], - state_dict_post_checkpoint['model']['full_precision'][model_state_key]\ - .reshape(state_dict_pre_checkpoint['model']['full_precision'][model_state_key].shape)) + assert_allclose( + state_dict_pre_checkpoint["model"]["full_precision"][model_state_key], + state_dict_post_checkpoint["model"]["full_precision"][model_state_key].reshape( + state_dict_pre_checkpoint["model"]["full_precision"][model_state_key].shape + ), + ) else: - assert_allclose(state_dict_pre_checkpoint['model']['full_precision'][model_state_key], - state_dict_post_checkpoint['model']['full_precision'][model_state_key]) - - if 'optimizer' in state_dict_pre_checkpoint: - for model_state_key in state_dict_pre_checkpoint['optimizer']: - for optimizer_state_key in state_dict_pre_checkpoint['optimizer'][model_state_key]: + assert_allclose( + state_dict_pre_checkpoint["model"]["full_precision"][model_state_key], + state_dict_post_checkpoint["model"]["full_precision"][model_state_key], + ) + + if "optimizer" in state_dict_pre_checkpoint: + for model_state_key in state_dict_pre_checkpoint["optimizer"]: + for optimizer_state_key in state_dict_pre_checkpoint["optimizer"][model_state_key]: if reshape_states: - assert_allclose(state_dict_pre_checkpoint['optimizer'][model_state_key][optimizer_state_key], - state_dict_post_checkpoint['optimizer'][model_state_key][optimizer_state_key]\ - .reshape(state_dict_pre_checkpoint['optimizer'][model_state_key][optimizer_state_key].shape)) + assert_allclose( + state_dict_pre_checkpoint["optimizer"][model_state_key][optimizer_state_key], + state_dict_post_checkpoint["optimizer"][model_state_key][optimizer_state_key].reshape( + state_dict_pre_checkpoint["optimizer"][model_state_key][optimizer_state_key].shape + ), + ) else: - assert_allclose(state_dict_pre_checkpoint['optimizer'][model_state_key][optimizer_state_key], - state_dict_post_checkpoint['optimizer'][model_state_key][optimizer_state_key]) + assert_allclose( + state_dict_pre_checkpoint["optimizer"][model_state_key][optimizer_state_key], + state_dict_post_checkpoint["optimizer"][model_state_key][optimizer_state_key], + ) + def assert_all_states_close_pytorch(state_dict_pre_checkpoint, pytorch_model): """Assert that the state_dict_pre_checkpoint state dictionary is very close to the one extracted from the pytorch model after loading""" diff --git a/orttraining/orttraining/test/python/_test_helpers.py b/orttraining/orttraining/test/python/_test_helpers.py index aea0ed2fef134..e888aade9f9ab 100644 --- a/orttraining/orttraining/test/python/_test_helpers.py +++ b/orttraining/orttraining/test/python/_test_helpers.py @@ -17,17 +17,19 @@ pass except Exception as e: from onnxruntime.training.ortmodule._fallback import ORTModuleInitException + if isinstance(e, ORTModuleInitException): # ORTModule is present but not ready to run # That is OK because this file is also used by ORTTrainer tests pass raise + def is_all_or_nothing_fallback_enabled(model, policy=None): from onnxruntime.training.ortmodule import ORTMODULE_FALLBACK_POLICY from onnxruntime.training.ortmodule._fallback import _FallbackPolicy - if os.getenv('ORTMODULE_FALLBACK_POLICY') == _FallbackPolicy.FALLBACK_DISABLE.name: + if os.getenv("ORTMODULE_FALLBACK_POLICY") == _FallbackPolicy.FALLBACK_DISABLE.name: return False if not policy: @@ -36,10 +38,13 @@ def is_all_or_nothing_fallback_enabled(model, policy=None): fallback_on_env = policy in ORTMODULE_FALLBACK_POLICY fallback_on_model = False if model: - fallback_on_model = policy in model._torch_module._execution_manager(is_training=True)._fallback_manager.policy or\ - policy in model._torch_module._execution_manager(is_training=False)._fallback_manager.policy + fallback_on_model = ( + policy in model._torch_module._execution_manager(is_training=True)._fallback_manager.policy + or policy in model._torch_module._execution_manager(is_training=False)._fallback_manager.policy + ) return fallback_on_env or fallback_on_model + def assert_model_outputs(output_a, output_b, verbose=False, rtol=1e-7, atol=0): r"""Asserts whether output_a and output_b difference is within specified tolerance @@ -49,15 +54,16 @@ def assert_model_outputs(output_a, output_b, verbose=False, rtol=1e-7, atol=0): rtol (float, default is 1e-7): Max relative difference atol (float, default is 1e-4): Max absolute difference """ - assert isinstance(output_a, list) and isinstance(output_b, list),\ - "output_a and output_b must be list of numbers" + assert isinstance(output_a, list) and isinstance(output_b, list), "output_a and output_b must be list of numbers" if len(output_a) != len(output_b): raise AssertionError( - "output_a and output_b must have the same length (%r != %r)." % (len(output_a), len(output_b))) + "output_a and output_b must have the same length (%r != %r)." % (len(output_a), len(output_b)) + ) # for idx in range(len(output_a)): assert_allclose(output_a, output_b, rtol=rtol, atol=atol, err_msg=f"Model output value mismatch") + def assert_onnx_weights(model_a, model_b, verbose=False, rtol=1e-7, atol=0): r"""Asserts whether weight difference between models a and b differences are within specified tolerance @@ -114,9 +120,10 @@ def _assert_state_dict_weights(state_dict_a, state_dict_b, verbose, rtol, atol): np_b_vals = np.array(b_val).flatten() assert np_a_vals.shape == np_b_vals.shape if verbose: - print(f'Weight name: {a_name}: absolute difference: {np.abs(np_a_vals-np_b_vals).max()}') + print(f"Weight name: {a_name}: absolute difference: {np.abs(np_a_vals-np_b_vals).max()}") assert_allclose(a_val, b_val, rtol=rtol, atol=atol, err_msg=f"Weight mismatch for {a_name}") + def assert_optim_state(expected_state, actual_state, rtol=1e-7, atol=0): r"""Asserts whether optimizer state differences are within specified tolerance @@ -144,9 +151,15 @@ def assert_optim_state(expected_state, actual_state, rtol=1e-7, atol=0): """ assert expected_state.keys() == actual_state.keys() for param_name, a_state in actual_state.items(): - for k,v in a_state.items(): - assert_allclose(v, expected_state[param_name][k], rtol=rtol, atol=atol, - err_msg=f"Optimizer state mismatch for param {param_name}, key {k}") + for k, v in a_state.items(): + assert_allclose( + v, + expected_state[param_name][k], + rtol=rtol, + atol=atol, + err_msg=f"Optimizer state mismatch for param {param_name}, key {k}", + ) + def is_dynamic_axes(model): # Check inputs @@ -166,6 +179,7 @@ def is_dynamic_axes(model): return False return True + # TODO: thiagofc: Checkpoint related for redesign def _get_name(name): if os.path.exists(name): @@ -180,9 +194,12 @@ def _get_name(name): return res raise FileNotFoundError("Unable to find '{0}' or '{1}' or '{2}'".format(name, rel, res)) + # Depending on calling backward() from which outputs, it's possible that grad of some weights are not calculated. # none_pt_params is to tell what these weights are, so we will not compare the tensors. -def assert_gradients_match_and_reset_gradient(ort_model, pt_model, none_pt_params=[], reset_gradient=True, rtol=1e-05, atol=1e-06): +def assert_gradients_match_and_reset_gradient( + ort_model, pt_model, none_pt_params=[], reset_gradient=True, rtol=1e-05, atol=1e-06 +): ort_named_params = list(ort_model.named_parameters()) pt_named_params = list(pt_model.named_parameters()) assert len(ort_named_params) == len(pt_named_params) @@ -202,6 +219,7 @@ def assert_gradients_match_and_reset_gradient(ort_model, pt_model, none_pt_param ort_param.grad = None pt_param.grad = None + def assert_values_are_close(input, other, rtol=1e-05, atol=1e-06): are_close = torch.allclose(input, other, rtol=rtol, atol=atol) if not are_close: @@ -212,9 +230,11 @@ def assert_values_are_close(input, other, rtol=1e-05, atol=1e-06): err_msg = "The maximum atol is {}, maximum rtol is {}".format(max_atol, max_rtol) assert False, err_msg + def enable_custom_autograd_function(module): enable_custom_autograd_support() + def _run_model_on_device(device, model, input_list, label_input, is_eval_mode=False, run_forward_twice=False): if is_eval_mode: model.eval() @@ -255,12 +275,14 @@ def generate_inputs(input_list_, label_input_): grad_outputs.append(param.grad) return forward_outputs, grad_outputs + def run_with_pytorch_on_device(device, model, input_list, label_input, is_eval_mode=False, run_forward_twice=False): with torch.no_grad(): model = copy.deepcopy(model).to(device) return _run_model_on_device(device, model, input_list, label_input, is_eval_mode, run_forward_twice) + def run_with_ort_on_device(device, model, input_list, label_input, is_eval_mode=False, run_forward_twice=False): with torch.no_grad(): model = copy.deepcopy(model) @@ -270,29 +292,66 @@ def run_with_ort_on_device(device, model, input_list, label_input, is_eval_mode= return _run_model_on_device(device, model, input_list, label_input, is_eval_mode, run_forward_twice) + def compare_tensor_list(val_list_a, val_list_b): for val_a, val_b in zip(val_list_a, val_list_b): - assert_values_are_close(val_a, val_b, atol=1e-7, rtol=1e-6) - -def run_training_test_and_compare(pt_model_builder_func, pt_model_inputs_generator, pt_model_label_input, - run_forward_twice=False, ignore_grad_compare=False, expected_outputs=[], expected_grads=[]): + assert_values_are_close(val_a, val_b, atol=1e-7, rtol=1e-6) + + +def run_training_test_and_compare( + pt_model_builder_func, + pt_model_inputs_generator, + pt_model_label_input, + run_forward_twice=False, + ignore_grad_compare=False, + expected_outputs=[], + expected_grads=[], +): cpu = torch.device("cpu") def cpu_barrier_func(): pass + run_training_test_on_device_and_compare( - cpu, pt_model_builder_func, pt_model_inputs_generator, pt_model_label_input, cpu_barrier_func, - run_forward_twice, ignore_grad_compare, expected_outputs, expected_grads) + cpu, + pt_model_builder_func, + pt_model_inputs_generator, + pt_model_label_input, + cpu_barrier_func, + run_forward_twice, + ignore_grad_compare, + expected_outputs, + expected_grads, + ) def cuda_barrier_func(): torch.cuda.synchronize() - cuda = torch.device('cuda:0') - run_training_test_on_device_and_compare( - cuda, pt_model_builder_func, pt_model_inputs_generator, pt_model_label_input, cuda_barrier_func, - run_forward_twice, ignore_grad_compare, expected_outputs, expected_grads) -def run_training_test_on_device_and_compare(device, pt_model_builder_func, pt_model_inputs_generator, pt_model_label_input, barrier_func, - run_forward_twice=False, ignore_grad_compare=False, expected_outputs=[], expected_grads=[]): + cuda = torch.device("cuda:0") + run_training_test_on_device_and_compare( + cuda, + pt_model_builder_func, + pt_model_inputs_generator, + pt_model_label_input, + cuda_barrier_func, + run_forward_twice, + ignore_grad_compare, + expected_outputs, + expected_grads, + ) + + +def run_training_test_on_device_and_compare( + device, + pt_model_builder_func, + pt_model_inputs_generator, + pt_model_label_input, + barrier_func, + run_forward_twice=False, + ignore_grad_compare=False, + expected_outputs=[], + expected_grads=[], +): repeats = 16 for i in range(repeats): m = pt_model_builder_func() @@ -303,11 +362,13 @@ def run_training_test_on_device_and_compare(device, pt_model_builder_func, pt_mo x_ort = copy.deepcopy(x) outputs, grads = run_with_pytorch_on_device( - device, m, [x], pt_model_label_input, run_forward_twice=run_forward_twice) + device, m, [x], pt_model_label_input, run_forward_twice=run_forward_twice + ) barrier_func() outputs_ort, grads_ort = run_with_ort_on_device( - device, m_ort, [x_ort], pt_model_label_input, run_forward_twice=run_forward_twice) + device, m_ort, [x_ort], pt_model_label_input, run_forward_twice=run_forward_twice + ) barrier_func() val_list_a = [o.detach().cpu() for o in outputs if o is not None] @@ -326,28 +387,47 @@ def run_training_test_on_device_and_compare(device, pt_model_builder_func, pt_mo if len(expected_grads) > 0: compare_tensor_list(val_list_a, expected_grads) -def run_evaluate_test_and_compare(pt_model_builder_func, pt_model_inputs_generator, pt_model_label_input, - run_forward_twice=False): + +def run_evaluate_test_and_compare( + pt_model_builder_func, pt_model_inputs_generator, pt_model_label_input, run_forward_twice=False +): cpu = torch.device("cpu") def cpu_barrier_func(): pass run_evaluate_test_on_device_and_compare( - cpu, pt_model_builder_func, pt_model_inputs_generator, pt_model_label_input, - cpu_barrier_func, run_forward_twice=run_forward_twice) + cpu, + pt_model_builder_func, + pt_model_inputs_generator, + pt_model_label_input, + cpu_barrier_func, + run_forward_twice=run_forward_twice, + ) def cuda_barrier_func(): torch.cuda.synchronize() pass - cuda = torch.device('cuda:0') + cuda = torch.device("cuda:0") run_evaluate_test_on_device_and_compare( - cuda, pt_model_builder_func, pt_model_inputs_generator, pt_model_label_input, - cuda_barrier_func, run_forward_twice=run_forward_twice) - -def run_evaluate_test_on_device_and_compare(device, pt_model_builder_func, pt_model_inputs_generator, - pt_model_label_input, barrier_func, run_forward_twice=False): + cuda, + pt_model_builder_func, + pt_model_inputs_generator, + pt_model_label_input, + cuda_barrier_func, + run_forward_twice=run_forward_twice, + ) + + +def run_evaluate_test_on_device_and_compare( + device, + pt_model_builder_func, + pt_model_inputs_generator, + pt_model_label_input, + barrier_func, + run_forward_twice=False, +): repeats = 16 for i in range(repeats): m = pt_model_builder_func() @@ -357,11 +437,13 @@ def run_evaluate_test_on_device_and_compare(device, pt_model_builder_func, pt_mo x_ort = copy.deepcopy(x) outputs, grads = run_with_pytorch_on_device( - device, m, [x], pt_model_label_input, is_eval_mode=True, run_forward_twice=run_forward_twice) + device, m, [x], pt_model_label_input, is_eval_mode=True, run_forward_twice=run_forward_twice + ) barrier_func() outputs_ort, grads_ort = run_with_ort_on_device( - device, m_ort, [x_ort], pt_model_label_input, is_eval_mode=True, run_forward_twice=run_forward_twice) + device, m_ort, [x_ort], pt_model_label_input, is_eval_mode=True, run_forward_twice=run_forward_twice + ) barrier_func() val_list_a = [o.detach().cpu() for o in outputs if o is not None] diff --git a/orttraining/orttraining/test/python/checkpoint/_test_helpers.py b/orttraining/orttraining/test/python/checkpoint/_test_helpers.py index 734ab53e48349..6bf0792bbf9e2 100644 --- a/orttraining/orttraining/test/python/checkpoint/_test_helpers.py +++ b/orttraining/orttraining/test/python/checkpoint/_test_helpers.py @@ -12,62 +12,75 @@ from onnxruntime.training import amp, checkpoint, _checkpoint_storage, optim, orttrainer from onnxruntime.capi._pybind_state import set_cuda_device_id, get_mpi_context_world_rank, get_mpi_context_world_size -from _test_commons import generate_random_input_from_bart_model_desc, generate_dummy_optim_state, \ - _load_pytorch_transformer_model, _load_bart_model, \ - assert_all_states_close_ort, assert_all_states_close_pytorch +from _test_commons import ( + generate_random_input_from_bart_model_desc, + generate_dummy_optim_state, + _load_pytorch_transformer_model, + _load_bart_model, + assert_all_states_close_ort, + assert_all_states_close_pytorch, +) from numpy.testing import assert_allclose, assert_array_equal global_fp16_fp32_atol = 1e-3 -def _train(trainer, train_data, batcher_fn, total_batch_steps = 5, seed = 1): + +def _train(trainer, train_data, batcher_fn, total_batch_steps=5, seed=1): """Runs train_step total_batch_steps number of times on the given trainer""" for i in range(total_batch_steps): torch.manual_seed(seed) set_seed(seed) - data, targets = batcher_fn(train_data, i*35) + data, targets = batcher_fn(train_data, i * 35) trainer.train_step(data, targets) + def makedir(checkpoint_dir): """Creates a directory if checkpoint_dir does not exist""" if not os.path.exists(checkpoint_dir): - os.makedirs(checkpoint_dir, exist_ok = True) + os.makedirs(checkpoint_dir, exist_ok=True) + def _save(trainer, checkpoint_dir, state_dict_key_name, world_rank=None): """Saves the ORTTrainer checkpoint and the complete state dictionary to the given checkpoint_dir directory""" # save current model parameters as a checkpoint makedir(checkpoint_dir) - checkpoint_file_name = 'checkpoint{}.ortcp'.format('' if world_rank is None else str(world_rank)) + checkpoint_file_name = "checkpoint{}.ortcp".format("" if world_rank is None else str(world_rank)) trainer.save_checkpoint(os.path.join(checkpoint_dir, checkpoint_file_name)) state_dict = trainer.state_dict() - with open(os.path.join(checkpoint_dir, state_dict_key_name+'.pkl'), "wb") as f: - pickle.dump({state_dict_key_name : state_dict}, f) + with open(os.path.join(checkpoint_dir, state_dict_key_name + ".pkl"), "wb") as f: + pickle.dump({state_dict_key_name: state_dict}, f) + def save_ort_ckpt(state_dict, filepath): _checkpoint_storage.save(state_dict, filepath) + def _chunkify(sequence, num_chunks): """Breaks down a given sequence into num_chunks chunks""" quo, rem = divmod(len(sequence), num_chunks) - return (sequence[i * quo + min(i, rem):(i + 1) * quo + min(i + 1, rem)] for i in range(num_chunks)) + return (sequence[i * quo + min(i, rem) : (i + 1) * quo + min(i + 1, rem)] for i in range(num_chunks)) + def _setup_test_infra(world_rank, world_size): """distributed setup just for testing purposes""" - os.environ['RANK'] = str(world_rank) - os.environ['WORLD_SIZE'] = str(world_size) - os.environ['MASTER_ADDR'] = '127.0.0.1' - os.environ['MASTER_PORT'] = '29500' + os.environ["RANK"] = str(world_rank) + os.environ["WORLD_SIZE"] = str(world_size) + os.environ["MASTER_ADDR"] = "127.0.0.1" + os.environ["MASTER_PORT"] = "29500" set_cuda_device_id(world_rank) - dist.init_process_group(backend='nccl', world_size=world_size, rank=world_rank) + dist.init_process_group(backend="nccl", world_size=world_size, rank=world_rank) + def _is_model_parallel_run(trainer_options): zero = trainer_options.distributed.deepspeed_zero_optimization.stage > 0 megatron = trainer_options.distributed.horizontal_parallel_size > 1 return zero or megatron + def distributed_setup(func): """Decorator function for distributed tests. @@ -78,10 +91,11 @@ def distributed_setup(func): Also sets up the infrastructure required for the distributed tests such as setting up the torch distributed initialization """ + def setup(checkpoint_dir): world_rank = get_mpi_context_world_rank() world_size = get_mpi_context_world_size() - device = 'cuda:' + str(world_rank) + device = "cuda:" + str(world_rank) _setup_test_infra(world_rank, world_size) @@ -89,7 +103,10 @@ def setup(checkpoint_dir): return setup -def create_orttrainer_and_load_checkpoint(device, trainer_opts, checkpoint_dir, use_lamb=True, seed=1, learning_rate=0.1): + +def create_orttrainer_and_load_checkpoint( + device, trainer_opts, checkpoint_dir, use_lamb=True, seed=1, learning_rate=0.1 +): """Instantiate and load checkpoint into trainer - Instantiates the ORTTrainer with given input trainer_opts configuration for a simple transformer model @@ -100,13 +117,15 @@ def create_orttrainer_and_load_checkpoint(device, trainer_opts, checkpoint_dir, torch.manual_seed(seed) set_seed(seed) - # PyTorch transformer model setup + # PyTorch transformer model setup optim_config = optim.LambConfig(lr=learning_rate) if use_lamb else optim.AdamConfig(lr=learning_rate) model, model_desc, loss_fn, batcher_fn, train_data, _, _ = _load_pytorch_transformer_model(device) - trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=loss_fn, options=orttrainer.ORTTrainerOptions(trainer_opts)) + trainer = orttrainer.ORTTrainer( + model, model_desc, optim_config, loss_fn=loss_fn, options=orttrainer.ORTTrainerOptions(trainer_opts) + ) # load checkpoint into trainer - checkpoint_file_name = 'checkpoint*.ortcp' + checkpoint_file_name = "checkpoint*.ortcp" checkpoint_files = glob.glob(os.path.join(checkpoint_dir, checkpoint_file_name)) trainer.load_checkpoint(*checkpoint_files) @@ -118,7 +137,10 @@ def create_orttrainer_and_load_checkpoint(device, trainer_opts, checkpoint_dir, return trainer.state_dict(), model -def create_orttrainer_and_load_checkpoint_bart(device, trainer_opts, checkpoint_dir, use_lamb=True, seed=1, learning_rate=0.1): + +def create_orttrainer_and_load_checkpoint_bart( + device, trainer_opts, checkpoint_dir, use_lamb=True, seed=1, learning_rate=0.1 +): """Instantiate and load checkpoint into trainer - Instantiates the ORTTrainer with given input trainer_opts configuration for a simple BART model @@ -135,79 +157,85 @@ def create_orttrainer_and_load_checkpoint_bart(device, trainer_opts, checkpoint_ trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, options=orttrainer.ORTTrainerOptions(trainer_opts)) # load checkpoint into trainer - checkpoint_file_name = 'checkpoint*.ortcp' + checkpoint_file_name = "checkpoint*.ortcp" checkpoint_files = glob.glob(os.path.join(checkpoint_dir, checkpoint_file_name)) trainer.load_checkpoint(*checkpoint_files) # run an eval step to innitialize the graph - src_tokens, prev_output_tokens, target = generate_random_input_from_bart_model_desc(model_desc, seed = seed) + src_tokens, prev_output_tokens, target = generate_random_input_from_bart_model_desc(model_desc, seed=seed) trainer.eval_step(src_tokens, prev_output_tokens, target) expected_state_dict = None - fname = os.path.join(checkpoint_dir, 'expected_state_dict.pkl') + fname = os.path.join(checkpoint_dir, "expected_state_dict.pkl") if os.path.isfile(fname): with open(fname, "rb") as f: expected_state_dict = pickle.load(f) return trainer.state_dict(), expected_state_dict, model + def create_initialized_orttrainer(device, trainer_opts, use_lamb=True, seed=1, learning_rate=1e-10): torch.manual_seed(seed) set_seed(seed) optim_config = optim.LambConfig(lr=learning_rate) if use_lamb else optim.AdamConfig(lr=learning_rate) model, model_desc, loss_fn, batcher_fn, train_data, _, _ = _load_pytorch_transformer_model(device) - trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=loss_fn, options=orttrainer.ORTTrainerOptions(trainer_opts)) + trainer = orttrainer.ORTTrainer( + model, model_desc, optim_config, loss_fn=loss_fn, options=orttrainer.ORTTrainerOptions(trainer_opts) + ) _train(trainer, train_data, batcher_fn) return trainer + def verify_model_state(trainer, expected_state_dict, is_mixedprecision): actual_model_state = trainer._training_session.get_model_state(include_mixed_precision_weights=True) for fp_or_mp, value in actual_model_state.items(): for weight_name in value: - assert weight_name.find('_view_') == -1 - assert len(expected_state_dict['fp32_param']) == len(actual_model_state['full_precision']), \ - "expected and actual should have same number of tensors" - for weight_name, tensor in expected_state_dict['fp32_param'].items(): - if not weight_name in actual_model_state['full_precision']: - assert '_view_' in weight_name, \ - "only zero shared weight may not match name" - weight_name = weight_name.split('_view_')[0] - assert_allclose(tensor, actual_model_state['full_precision'][weight_name]) + assert weight_name.find("_view_") == -1 + assert len(expected_state_dict["fp32_param"]) == len( + actual_model_state["full_precision"] + ), "expected and actual should have same number of tensors" + for weight_name, tensor in expected_state_dict["fp32_param"].items(): + if not weight_name in actual_model_state["full_precision"]: + assert "_view_" in weight_name, "only zero shared weight may not match name" + weight_name = weight_name.split("_view_")[0] + assert_allclose(tensor, actual_model_state["full_precision"][weight_name]) if is_mixedprecision: - assert 'mixed_precision' in actual_model_state.keys(), "missing 'mixed_precision' key in mixed precision run" - assert len(expected_state_dict['fp16_param']) == len(actual_model_state['mixed_precision']), \ - "expected and actual should have same number of tensors" - for weight_name, tensor in expected_state_dict['fp16_param'].items(): - weight_name = weight_name.split('_fp16')[0] - assert_allclose(tensor, actual_model_state['mixed_precision'][weight_name]) + assert "mixed_precision" in actual_model_state.keys(), "missing 'mixed_precision' key in mixed precision run" + assert len(expected_state_dict["fp16_param"]) == len( + actual_model_state["mixed_precision"] + ), "expected and actual should have same number of tensors" + for weight_name, tensor in expected_state_dict["fp16_param"].items(): + weight_name = weight_name.split("_fp16")[0] + assert_allclose(tensor, actual_model_state["mixed_precision"][weight_name]) + def verify_opt_state(trainer, expected_state_dict): actual_opt_state = trainer._training_session.get_optimizer_state() actual_opt_count = sum(len(v) for v in actual_opt_state.values()) - assert actual_opt_count == len(expected_state_dict['optimizer']) + assert actual_opt_count == len(expected_state_dict["optimizer"]) for weight_name in actual_opt_state: - assert weight_name.find('_view_') == -1 - for opt_name, expected_tensor in expected_state_dict['optimizer'].items(): + assert weight_name.find("_view_") == -1 + for opt_name, expected_tensor in expected_state_dict["optimizer"].items(): if opt_name == "Step": - actual_tensor = actual_opt_state['shared_optimizer_state']['Step'] + actual_tensor = actual_opt_state["shared_optimizer_state"]["Step"] else: - if opt_name.startswith('Moment_'): - prefix = opt_name[:len("Moment_0")] - weight_name = opt_name[len("Moment_0_"):] + if opt_name.startswith("Moment_"): + prefix = opt_name[: len("Moment_0")] + weight_name = opt_name[len("Moment_0_") :] if not weight_name in actual_opt_state: - assert '_view_' in weight_name, \ - "only zero shared weight may not match name" - weight_name = weight_name.split('_view_')[0] - elif opt_name.startswith('Update_Count_'): + assert "_view_" in weight_name, "only zero shared weight may not match name" + weight_name = weight_name.split("_view_")[0] + elif opt_name.startswith("Update_Count_"): prefix = "Update_Count" - weight_name = opt_name[len(prefix + 1):] + weight_name = opt_name[len(prefix + 1) :] actual_tensor = actual_opt_state[weight_name][prefix] assert_allclose(actual_tensor, expected_tensor, atol=global_fp16_fp32_atol) + def verify_part_info(trainer, expected_state_dict, is_mixedprecision, is_zero_run): part_info = trainer._training_session.get_partition_info_map() for weight_name, weight_info in part_info.items(): @@ -222,69 +250,82 @@ def verify_part_info(trainer, expected_state_dict, is_mixedprecision, is_zero_ru if is_zero_run: assert len(value) > 0, "original_dim should not be empty in zero run" if is_mixedprecision: - assert_array_equal(part_info[weight_name]['original_dim'], expected_state_dict['fp16_param'][weight_name + '_fp16'].shape) + assert_array_equal( + part_info[weight_name]["original_dim"], + expected_state_dict["fp16_param"][weight_name + "_fp16"].shape, + ) else: - assert_array_equal(part_info[weight_name]['original_dim'], expected_state_dict['fp32_param'][weight_name].shape) + assert_array_equal( + part_info[weight_name]["original_dim"], expected_state_dict["fp32_param"][weight_name].shape + ) + def split_state_dict(state_dict): """Given a flat state dictionary, split it into optimizer, fp32_param, fp16_param hierarchical dictionary and return""" - optimizer_keys = ['Moment_1_', 'Moment_2_', 'Update_Count_', 'Step'] - split_sd = {'optimizer': {}, 'fp32_param': {}, 'fp16_param': {}} + optimizer_keys = ["Moment_1_", "Moment_2_", "Update_Count_", "Step"] + split_sd = {"optimizer": {}, "fp32_param": {}, "fp16_param": {}} for k, v in state_dict.items(): - mode = 'fp32_param' + mode = "fp32_param" for optim_key in optimizer_keys: if k.startswith(optim_key): - mode = 'optimizer' + mode = "optimizer" break - if k.endswith('_fp16'): - mode = 'fp16_param' + if k.endswith("_fp16"): + mode = "fp16_param" split_sd[mode][k] = v return split_sd + def _split_name(name): """Splits given state name (model or optimizer state name) into the param_name, optimizer_key, view_num and the fp16_key""" - name_split = name.split('_view_') + name_split = name.split("_view_") view_num = None - if(len(name_split) > 1): + if len(name_split) > 1: view_num = int(name_split[1]) - optimizer_key = '' - fp16_key = '' - if name_split[0].startswith('Moment_1'): - optimizer_key = 'Moment_1_' - elif name_split[0].startswith('Moment_2'): - optimizer_key = 'Moment_2_' - elif name_split[0].startswith('Update_Count'): - optimizer_key = 'Update_Count_' - elif name_split[0].endswith('_fp16'): - fp16_key = '_fp16' + optimizer_key = "" + fp16_key = "" + if name_split[0].startswith("Moment_1"): + optimizer_key = "Moment_1_" + elif name_split[0].startswith("Moment_2"): + optimizer_key = "Moment_2_" + elif name_split[0].startswith("Update_Count"): + optimizer_key = "Update_Count_" + elif name_split[0].endswith("_fp16"): + fp16_key = "_fp16" param_name = name_split[0] - if optimizer_key != '': + if optimizer_key != "": param_name = param_name.split(optimizer_key)[1] - param_name = param_name.split('_fp16')[0] + param_name = param_name.split("_fp16")[0] return param_name, optimizer_key, view_num, fp16_key -def aggregate_states(checkpoint_dir, filename_prefix='state_dict', state_dict_key_name='state_dict'): + +def aggregate_states(checkpoint_dir, filename_prefix="state_dict", state_dict_key_name="state_dict"): """Aggregate state dictionaries saved in the checkpoint_dir as pickle files with name filename_prefix_world_rank.pkl""" aggregated_states = {} - num_states = len(glob.glob1(checkpoint_dir,"{}*".format(filename_prefix))) + num_states = len(glob.glob1(checkpoint_dir, "{}*".format(filename_prefix))) for rank in range(num_states): rank_state_dict = None - with open(os.path.join(checkpoint_dir, '{}_{}.pkl'.format(filename_prefix, rank)), 'rb') as f: + with open(os.path.join(checkpoint_dir, "{}_{}.pkl".format(filename_prefix, rank)), "rb") as f: rank_state_dict = pickle.load(f) # if state_dict_key_name is None, then the rank_state_dictis the loaded object if state_dict_key_name: # if it has a name, index into the loaded object to extract the rank_state_dict - rank_state_dict = rank_state_dict['{}_{}'.format(state_dict_key_name, rank)] + rank_state_dict = rank_state_dict["{}_{}".format(state_dict_key_name, rank)] - checkpoint._aggregate_model_states(rank_state_dict, {}, aggregated_states, rank_state_dict['trainer_options']['mixed_precision']) + checkpoint._aggregate_model_states( + rank_state_dict, {}, aggregated_states, rank_state_dict["trainer_options"]["mixed_precision"] + ) checkpoint._aggregate_optimizer_states(rank_state_dict, {}, aggregated_states) - + return aggregated_states -def create_orttrainer_and_save_checkpoint(device, trainer_opts, checkpoint_dir, state_dict_key_name='state_dict', use_lamb=True, seed=1, learning_rate=0.1): + +def create_orttrainer_and_save_checkpoint( + device, trainer_opts, checkpoint_dir, state_dict_key_name="state_dict", use_lamb=True, seed=1, learning_rate=0.1 +): torch.manual_seed(seed) set_seed(seed) @@ -293,8 +334,14 @@ def create_orttrainer_and_save_checkpoint(device, trainer_opts, checkpoint_dir, model, model_desc, loss_fn, batcher_fn, train_data, _, _ = _load_pytorch_transformer_model(device) trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=loss_fn, options=ort_trainer_opts) - if 'distributed' in trainer_opts: - train_data = next(islice(_chunkify(train_data, trainer_opts['distributed']['world_size']), trainer_opts['distributed']['world_rank'], None)) + if "distributed" in trainer_opts: + train_data = next( + islice( + _chunkify(train_data, trainer_opts["distributed"]["world_size"]), + trainer_opts["distributed"]["world_rank"], + None, + ) + ) # run train steps _train(trainer, train_data, batcher_fn) @@ -306,7 +353,10 @@ def create_orttrainer_and_save_checkpoint(device, trainer_opts, checkpoint_dir, else: _save(trainer, checkpoint_dir, state_dict_key_name) -def create_orttrainer_and_save_checkpoint_bart(device, trainer_opts, checkpoint_dir, state_dict_key_name='state_dict', use_lamb=True, seed=1, learning_rate=0.1): + +def create_orttrainer_and_save_checkpoint_bart( + device, trainer_opts, checkpoint_dir, state_dict_key_name="state_dict", use_lamb=True, seed=1, learning_rate=0.1 +): """Instantiate trainer and save checkpoint for BART. - Instantiates the ORTTrainer with given input trainer_opts configuration for a simple BART model @@ -328,7 +378,7 @@ def create_orttrainer_and_save_checkpoint_bart(device, trainer_opts, checkpoint_ trainer.load_state_dict(dummy_init_state) # run an eval step to innitialize the graph - src_tokens, prev_output_tokens, target = generate_random_input_from_bart_model_desc(model_desc, seed = seed) + src_tokens, prev_output_tokens, target = generate_random_input_from_bart_model_desc(model_desc, seed=seed) trainer.eval_step(src_tokens, prev_output_tokens, target) # save current model parameters as a checkpoint @@ -337,21 +387,24 @@ def create_orttrainer_and_save_checkpoint_bart(device, trainer_opts, checkpoint_ _save(trainer, checkpoint_dir, state_dict_key_name, world_rank=ort_trainer_opts.distributed.world_rank) # save the initial complete model and optimizer states if ort_trainer_opts.distributed.world_rank == 0: - init_state['model'] = {'full_precision': dict()} + init_state["model"] = {"full_precision": dict()} for initializer in model.graph.initializer: - init_state['model']['full_precision'][initializer.name] = numpy_helper.to_array(initializer) - with open(os.path.join(checkpoint_dir, 'expected_state_dict.pkl'), "wb") as f: + init_state["model"]["full_precision"][initializer.name] = numpy_helper.to_array(initializer) + with open(os.path.join(checkpoint_dir, "expected_state_dict.pkl"), "wb") as f: pickle.dump(init_state, f) else: _save(trainer, checkpoint_dir, state_dict_key_name) + def load_model_optim_state_and_eval(device, trainer_opts, use_lamb=True, seed=1, learning_rate=0.1): torch.manual_seed(seed) set_seed(seed) optim_config = optim.LambConfig(lr=learning_rate) if use_lamb else optim.AdamConfig(lr=learning_rate) model, model_desc, loss_fn, batcher_fn, train_data, _, _ = _load_pytorch_transformer_model(device) - trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=loss_fn, options=orttrainer.ORTTrainerOptions(trainer_opts)) + trainer = orttrainer.ORTTrainer( + model, model_desc, optim_config, loss_fn=loss_fn, options=orttrainer.ORTTrainerOptions(trainer_opts) + ) # load dummy state dummy_init_state = generate_dummy_optim_state(model, optim_config) @@ -362,15 +415,16 @@ def load_model_optim_state_and_eval(device, trainer_opts, use_lamb=True, seed=1, trainer.eval_step(data, targets) optimizer_state_dict = trainer.state_dict() - del optimizer_state_dict['model'] + del optimizer_state_dict["model"] return dummy_init_state, optimizer_state_dict + def assert_all_states_close(checkpoint_dir, state_dict_key_name, state_dict_post_checkpoint, pytorch_model=None): """Extract previously saved state dictionary from pickle file and compare against state_dict_post_checkpoint for all states""" state = None - with open(os.path.join(checkpoint_dir, '{}.pkl'.format(state_dict_key_name)), 'rb') as f: + with open(os.path.join(checkpoint_dir, "{}.pkl".format(state_dict_key_name)), "rb") as f: state = pickle.load(f) state_dict_pre_checkpoint = state[state_dict_key_name] diff --git a/orttraining/orttraining/test/python/checkpoint/orttraining_test_backend_api.py b/orttraining/orttraining/test/python/checkpoint/orttraining_test_backend_api.py index 255b632954685..6ca2e574fbfc4 100644 --- a/orttraining/orttraining/test/python/checkpoint/orttraining_test_backend_api.py +++ b/orttraining/orttraining/test/python/checkpoint/orttraining_test_backend_api.py @@ -10,99 +10,97 @@ import argparse import sys + sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) -from _test_helpers import _train, distributed_setup, create_initialized_orttrainer, \ - split_state_dict, global_fp16_fp32_atol, verify_model_state, verify_opt_state, verify_part_info +from _test_helpers import ( + _train, + distributed_setup, + create_initialized_orttrainer, + split_state_dict, + global_fp16_fp32_atol, + verify_model_state, + verify_opt_state, + verify_part_info, +) + -def test_single_node_full_precision_lamb(device = 'cuda', checkpoint_dir=''): - opts_dict = {'device' : {'id' : device}, - 'debug' : {'deterministic_compute': True}} +def test_single_node_full_precision_lamb(device="cuda", checkpoint_dir=""): + opts_dict = {"device": {"id": device}, "debug": {"deterministic_compute": True}} is_mixedprecision = False is_zero_run = False - + trainer = create_initialized_orttrainer(device, opts_dict, True) expected_state_dict = trainer._training_session.get_state() expected_state_dict = split_state_dict(expected_state_dict) verify_model_state(trainer, expected_state_dict, is_mixedprecision) - + verify_opt_state(trainer, expected_state_dict) verify_part_info(trainer, expected_state_dict, is_mixedprecision, is_zero_run) + @distributed_setup def test_distributed_zero_full_precision_lamb(world_rank, world_size, device, checkpoint_dir): is_mixedprecision = False is_zero_run = True opts_dict = { - 'device' : {'id' : device}, - 'mixed_precision': - { - 'enabled': is_mixedprecision - }, - 'distributed' : - { - 'world_rank' : world_rank, - 'world_size' : world_size, - 'allreduce_post_accumulation' : True, - 'deepspeed_zero_optimization': - { - 'stage': 1 - } - }, - 'debug' : {'deterministic_compute': True} - } - + "device": {"id": device}, + "mixed_precision": {"enabled": is_mixedprecision}, + "distributed": { + "world_rank": world_rank, + "world_size": world_size, + "allreduce_post_accumulation": True, + "deepspeed_zero_optimization": {"stage": 1}, + }, + "debug": {"deterministic_compute": True}, + } + trainer = create_initialized_orttrainer(device, opts_dict, True) expected_state_dict = trainer._training_session.get_state() expected_state_dict = split_state_dict(expected_state_dict) verify_model_state(trainer, expected_state_dict, is_mixedprecision) - + verify_opt_state(trainer, expected_state_dict) verify_part_info(trainer, expected_state_dict, is_mixedprecision, is_zero_run) + @distributed_setup def test_distributed_zero_mixed_precision_lamb(world_rank, world_size, device, checkpoint_dir): is_mixedprecision = True is_zero_run = True opts_dict = { - 'device' : {'id' : device}, - 'mixed_precision': - { - 'enabled': is_mixedprecision - }, - 'distributed' : - { - 'world_rank' : world_rank, - 'world_size' : world_size, - 'allreduce_post_accumulation' : True, - 'deepspeed_zero_optimization': - { - 'stage': 1 - } - }, - 'debug' : {'deterministic_compute': True} - } - + "device": {"id": device}, + "mixed_precision": {"enabled": is_mixedprecision}, + "distributed": { + "world_rank": world_rank, + "world_size": world_size, + "allreduce_post_accumulation": True, + "deepspeed_zero_optimization": {"stage": 1}, + }, + "debug": {"deterministic_compute": True}, + } + trainer = create_initialized_orttrainer(device, opts_dict, True) expected_state_dict = trainer._training_session.get_state() expected_state_dict = split_state_dict(expected_state_dict) verify_model_state(trainer, expected_state_dict, is_mixedprecision) - + verify_opt_state(trainer, expected_state_dict) verify_part_info(trainer, expected_state_dict, is_mixedprecision, is_zero_run) + # To run single node test locally, from build directory # python3 checkpoint/orttraining_test_backend_api.py -# test_single_node_full_precision_lamb() +# test_single_node_full_precision_lamb() # To run distributed test locally, from build directory # mpirun -n 4 -x NCCL_DEBUG=INFO python3 checkpoint/orttraining_test_backend_api.py @@ -110,12 +108,14 @@ def test_distributed_zero_mixed_precision_lamb(world_rank, world_size, device, c # test_distributed_zero_mixed_precision_lamb(checkpoint_dir='') function_map = { - 'test_single_node_full_precision_lamb': test_single_node_full_precision_lamb, - 'test_distributed_zero_full_precision_lamb': test_distributed_zero_full_precision_lamb, - 'test_distributed_zero_mixed_precision_lamb': test_distributed_zero_mixed_precision_lamb + "test_single_node_full_precision_lamb": test_single_node_full_precision_lamb, + "test_distributed_zero_full_precision_lamb": test_distributed_zero_full_precision_lamb, + "test_distributed_zero_mixed_precision_lamb": test_distributed_zero_mixed_precision_lamb, } -parser = argparse.ArgumentParser(description='Test saved states of trainers to loaded states') -parser.add_argument('--scenario', choices=function_map.keys(), help='training scenario to test saved and loaded states', required=True) -parser.add_argument('--checkpoint_dir', help='path to the saved states directory', required=True) +parser = argparse.ArgumentParser(description="Test saved states of trainers to loaded states") +parser.add_argument( + "--scenario", choices=function_map.keys(), help="training scenario to test saved and loaded states", required=True +) +parser.add_argument("--checkpoint_dir", help="path to the saved states directory", required=True) args = parser.parse_args() -function_map[args.scenario](checkpoint_dir=args.checkpoint_dir) \ No newline at end of file +function_map[args.scenario](checkpoint_dir=args.checkpoint_dir) diff --git a/orttraining/orttraining/test/python/checkpoint/orttraining_test_checkpoint_aggregation.py b/orttraining/orttraining/test/python/checkpoint/orttraining_test_checkpoint_aggregation.py index dee7da9c86890..da0f5557b4e59 100644 --- a/orttraining/orttraining/test/python/checkpoint/orttraining_test_checkpoint_aggregation.py +++ b/orttraining/orttraining/test/python/checkpoint/orttraining_test_checkpoint_aggregation.py @@ -17,18 +17,25 @@ import torch.distributed as dist import sys + sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) import onnxruntime from onnxruntime.training import checkpoint -from _test_helpers import distributed_setup, create_orttrainer_and_load_checkpoint, create_orttrainer_and_load_checkpoint_bart, aggregate_states +from _test_helpers import ( + distributed_setup, + create_orttrainer_and_load_checkpoint, + create_orttrainer_and_load_checkpoint_bart, + aggregate_states, +) from _test_commons import assert_all_states_close_ort def test_zero_aggregation(checkpoint_dir, loaded_state_dict, is_mixedprecision): # get aggregated state dict independently - aggregate_state_dict_from_checkpoint = \ - checkpoint.aggregate_checkpoints(glob.glob(os.path.join(checkpoint_dir, "checkpoint*.ortcp")), pytorch_format=False) + aggregate_state_dict_from_checkpoint = checkpoint.aggregate_checkpoints( + glob.glob(os.path.join(checkpoint_dir, "checkpoint*.ortcp")), pytorch_format=False + ) # verify loaded state and aggregated states match: assert_all_states_close_ort(loaded_state_dict, aggregate_state_dict_from_checkpoint) @@ -37,180 +44,179 @@ def test_zero_aggregation(checkpoint_dir, loaded_state_dict, is_mixedprecision): aggregate_state_dict_from_test = aggregate_states(checkpoint_dir) # compare state dictionaries between the manually aggregated state dictionary with the aggregated state dictionary from the ORTTrainer - assert_all_states_close_ort(aggregate_state_dict_from_test, aggregate_state_dict_from_checkpoint, reshape_states=True) + assert_all_states_close_ort( + aggregate_state_dict_from_test, aggregate_state_dict_from_checkpoint, reshape_states=True + ) def test_megatron_aggregation(checkpoint_dir, loaded_state_dict, expected_state_dict, is_mixedprecision): # get aggregated state dict independently - aggregate_state_dict_from_checkpoint = \ - checkpoint.aggregate_checkpoints(glob.glob(os.path.join(checkpoint_dir, "checkpoint*.ortcp")), pytorch_format=False) + aggregate_state_dict_from_checkpoint = checkpoint.aggregate_checkpoints( + glob.glob(os.path.join(checkpoint_dir, "checkpoint*.ortcp")), pytorch_format=False + ) # verify loaded state and aggregated states match: assert_all_states_close_ort(loaded_state_dict, aggregate_state_dict_from_checkpoint) - #compare with expected state dict + # compare with expected state dict assert_all_states_close_ort(expected_state_dict, loaded_state_dict) -def test_aggregation_from_distributed_zero_full_precision_adam(device='cuda', checkpoint_dir='checkpoint_dir/distributed_zero/full_precision/adam/'): - opts = {'device': {'id': device}, - 'debug': {'deterministic_compute': True}} + +def test_aggregation_from_distributed_zero_full_precision_adam( + device="cuda", checkpoint_dir="checkpoint_dir/distributed_zero/full_precision/adam/" +): + opts = {"device": {"id": device}, "debug": {"deterministic_compute": True}} # extract state dictionaries to compare loaded_state_dict, model = create_orttrainer_and_load_checkpoint(device, opts, checkpoint_dir, use_lamb=False) test_zero_aggregation(checkpoint_dir, loaded_state_dict, is_mixedprecision=False) -def test_aggregation_from_distributed_zero_mixed_precision_adam(device='cuda', checkpoint_dir='checkpoint_dir/distributed_zero/mixed_precision/adam/'): - opts = { - 'device': {'id': device}, - 'mixed_precision': - { - 'enabled': True - }, - 'debug': {'deterministic_compute': True} - } +def test_aggregation_from_distributed_zero_mixed_precision_adam( + device="cuda", checkpoint_dir="checkpoint_dir/distributed_zero/mixed_precision/adam/" +): + opts = {"device": {"id": device}, "mixed_precision": {"enabled": True}, "debug": {"deterministic_compute": True}} # extract state dictionaries to compare loaded_state_dict, model = create_orttrainer_and_load_checkpoint(device, opts, checkpoint_dir, use_lamb=False) test_zero_aggregation(checkpoint_dir, loaded_state_dict, is_mixedprecision=True) -def test_aggregation_from_distributed_zero_full_precision_lamb(device='cuda', checkpoint_dir='checkpoint_dir/distributed_zero/full_precision/lamb/'): - opts = {'device': {'id': device}, - 'debug': {'deterministic_compute': True}} +def test_aggregation_from_distributed_zero_full_precision_lamb( + device="cuda", checkpoint_dir="checkpoint_dir/distributed_zero/full_precision/lamb/" +): + opts = {"device": {"id": device}, "debug": {"deterministic_compute": True}} # extract state dictionaries to compare loaded_state_dict, model = create_orttrainer_and_load_checkpoint(device, opts, checkpoint_dir, use_lamb=True) test_zero_aggregation(checkpoint_dir, loaded_state_dict, is_mixedprecision=False) -def test_aggregation_from_distributed_zero_mixed_precision_lamb(device='cuda', checkpoint_dir='checkpoint_dir/distributed_zero/mixed_precision/lamb/'): - opts = { - 'device': {'id': device}, - 'mixed_precision': - { - 'enabled': True - }, - 'debug': {'deterministic_compute': True} - } +def test_aggregation_from_distributed_zero_mixed_precision_lamb( + device="cuda", checkpoint_dir="checkpoint_dir/distributed_zero/mixed_precision/lamb/" +): + opts = {"device": {"id": device}, "mixed_precision": {"enabled": True}, "debug": {"deterministic_compute": True}} # extract state dictionaries to compare loaded_state_dict, model = create_orttrainer_and_load_checkpoint(device, opts, checkpoint_dir, use_lamb=True) test_zero_aggregation(checkpoint_dir, loaded_state_dict, is_mixedprecision=True) -def test_aggregation_from_distributed_megatron_full_precision_adam(device='cuda', checkpoint_dir='checkpoint_dir/distributed_megatron/full_precision/adam/'): - opts = {'device': {'id': device}, - 'debug': {'deterministic_compute': True}} +def test_aggregation_from_distributed_megatron_full_precision_adam( + device="cuda", checkpoint_dir="checkpoint_dir/distributed_megatron/full_precision/adam/" +): + opts = {"device": {"id": device}, "debug": {"deterministic_compute": True}} # extract state dictionaries to compare - loaded_state_dict, expected_state_dict, model = create_orttrainer_and_load_checkpoint_bart(device, opts, checkpoint_dir, use_lamb=False) + loaded_state_dict, expected_state_dict, model = create_orttrainer_and_load_checkpoint_bart( + device, opts, checkpoint_dir, use_lamb=False + ) test_megatron_aggregation(checkpoint_dir, loaded_state_dict, expected_state_dict, is_mixedprecision=False) -def test_aggregation_from_distributed_megatron_mixed_precision_adam(device='cuda', checkpoint_dir='checkpoint_dir/distributed_megatron/mixed_precision/adam/'): - opts = { - 'device': {'id': device}, - 'mixed_precision': - { - 'enabled': True - }, - 'debug': {'deterministic_compute': True} - } +def test_aggregation_from_distributed_megatron_mixed_precision_adam( + device="cuda", checkpoint_dir="checkpoint_dir/distributed_megatron/mixed_precision/adam/" +): + opts = {"device": {"id": device}, "mixed_precision": {"enabled": True}, "debug": {"deterministic_compute": True}} # extract state dictionaries to compare - loaded_state_dict, expected_state_dict, model = create_orttrainer_and_load_checkpoint_bart(device, opts, checkpoint_dir, use_lamb=False) + loaded_state_dict, expected_state_dict, model = create_orttrainer_and_load_checkpoint_bart( + device, opts, checkpoint_dir, use_lamb=False + ) test_megatron_aggregation(checkpoint_dir, loaded_state_dict, expected_state_dict, is_mixedprecision=True) -def test_aggregation_from_distributed_megatron_full_precision_lamb(device='cuda', checkpoint_dir='checkpoint_dir/distributed_megatron/full_precision/lamb/'): - opts = {'device': {'id': device}, - 'debug': {'deterministic_compute': True}} +def test_aggregation_from_distributed_megatron_full_precision_lamb( + device="cuda", checkpoint_dir="checkpoint_dir/distributed_megatron/full_precision/lamb/" +): + opts = {"device": {"id": device}, "debug": {"deterministic_compute": True}} # extract state dictionaries to compare - loaded_state_dict, expected_state_dict, model = create_orttrainer_and_load_checkpoint_bart(device, opts, checkpoint_dir, use_lamb=True) + loaded_state_dict, expected_state_dict, model = create_orttrainer_and_load_checkpoint_bart( + device, opts, checkpoint_dir, use_lamb=True + ) test_megatron_aggregation(checkpoint_dir, loaded_state_dict, expected_state_dict, is_mixedprecision=False) -def test_aggregation_from_distributed_megatron_mixed_precision_lamb(device='cuda', checkpoint_dir='checkpoint_dir/distributed_megatron/mixed_precision/lamb/'): - opts = { - 'device': {'id': device}, - 'mixed_precision': - { - 'enabled': True - }, - 'debug': {'deterministic_compute': True} - } +def test_aggregation_from_distributed_megatron_mixed_precision_lamb( + device="cuda", checkpoint_dir="checkpoint_dir/distributed_megatron/mixed_precision/lamb/" +): + opts = {"device": {"id": device}, "mixed_precision": {"enabled": True}, "debug": {"deterministic_compute": True}} # extract state dictionaries to compare - loaded_state_dict, expected_state_dict, model = create_orttrainer_and_load_checkpoint_bart(device, opts, checkpoint_dir, use_lamb=True) + loaded_state_dict, expected_state_dict, model = create_orttrainer_and_load_checkpoint_bart( + device, opts, checkpoint_dir, use_lamb=True + ) test_megatron_aggregation(checkpoint_dir, loaded_state_dict, expected_state_dict, is_mixedprecision=True) -def test_aggregation_from_distributed_zero_megatron_full_precision_adam(device='cuda', checkpoint_dir='checkpoint_dir/distributed_zero_megatron/full_precision/adam/'): - opts = {'device': {'id': device}, - 'debug': {'deterministic_compute': True}} + +def test_aggregation_from_distributed_zero_megatron_full_precision_adam( + device="cuda", checkpoint_dir="checkpoint_dir/distributed_zero_megatron/full_precision/adam/" +): + opts = {"device": {"id": device}, "debug": {"deterministic_compute": True}} # extract state dictionaries to compare - loaded_state_dict, expected_state_dict, model = create_orttrainer_and_load_checkpoint_bart(device, opts, checkpoint_dir, use_lamb=False) + loaded_state_dict, expected_state_dict, model = create_orttrainer_and_load_checkpoint_bart( + device, opts, checkpoint_dir, use_lamb=False + ) test_megatron_aggregation(checkpoint_dir, loaded_state_dict, expected_state_dict, is_mixedprecision=False) -def test_aggregation_from_distributed_zero_megatron_mixed_precision_adam(device='cuda', checkpoint_dir='checkpoint_dir/distributed_zero_megatron/mixed_precision/adam/'): - opts = { - 'device': {'id': device}, - 'mixed_precision': - { - 'enabled': True - }, - 'debug': {'deterministic_compute': True} - } +def test_aggregation_from_distributed_zero_megatron_mixed_precision_adam( + device="cuda", checkpoint_dir="checkpoint_dir/distributed_zero_megatron/mixed_precision/adam/" +): + opts = {"device": {"id": device}, "mixed_precision": {"enabled": True}, "debug": {"deterministic_compute": True}} # extract state dictionaries to compare - loaded_state_dict, expected_state_dict, model = create_orttrainer_and_load_checkpoint_bart(device, opts, checkpoint_dir, use_lamb=False) + loaded_state_dict, expected_state_dict, model = create_orttrainer_and_load_checkpoint_bart( + device, opts, checkpoint_dir, use_lamb=False + ) test_megatron_aggregation(checkpoint_dir, loaded_state_dict, expected_state_dict, is_mixedprecision=True) -def test_aggregation_from_distributed_zero_megatron_full_precision_lamb(device='cuda', checkpoint_dir='checkpoint_dir/distributed_zero_megatron/full_precision/lamb/'): - opts = {'device': {'id': device}, - 'debug': {'deterministic_compute': True}} +def test_aggregation_from_distributed_zero_megatron_full_precision_lamb( + device="cuda", checkpoint_dir="checkpoint_dir/distributed_zero_megatron/full_precision/lamb/" +): + opts = {"device": {"id": device}, "debug": {"deterministic_compute": True}} # extract state dictionaries to compare - loaded_state_dict, expected_state_dict, model = create_orttrainer_and_load_checkpoint_bart(device, opts, checkpoint_dir, use_lamb=True) + loaded_state_dict, expected_state_dict, model = create_orttrainer_and_load_checkpoint_bart( + device, opts, checkpoint_dir, use_lamb=True + ) test_megatron_aggregation(checkpoint_dir, loaded_state_dict, expected_state_dict, is_mixedprecision=False) -def test_aggregation_from_distributed_zero_megatron_mixed_precision_lamb(device='cuda', checkpoint_dir='checkpoint_dir/distributed_zero_megatron/mixed_precision/lamb/'): - opts = { - 'device': {'id': device}, - 'mixed_precision': - { - 'enabled': True - }, - 'debug': {'deterministic_compute': True} - } +def test_aggregation_from_distributed_zero_megatron_mixed_precision_lamb( + device="cuda", checkpoint_dir="checkpoint_dir/distributed_zero_megatron/mixed_precision/lamb/" +): + opts = {"device": {"id": device}, "mixed_precision": {"enabled": True}, "debug": {"deterministic_compute": True}} # extract state dictionaries to compare - loaded_state_dict, expected_state_dict, model = create_orttrainer_and_load_checkpoint_bart(device, opts, checkpoint_dir, use_lamb=True) + loaded_state_dict, expected_state_dict, model = create_orttrainer_and_load_checkpoint_bart( + device, opts, checkpoint_dir, use_lamb=True + ) test_megatron_aggregation(checkpoint_dir, loaded_state_dict, expected_state_dict, is_mixedprecision=True) function_map = { # all config to single node config - 'test_aggregation_from_distributed_zero_full_precision_adam': test_aggregation_from_distributed_zero_full_precision_adam, - 'test_aggregation_from_distributed_zero_mixed_precision_adam': test_aggregation_from_distributed_zero_mixed_precision_adam, - 'test_aggregation_from_distributed_zero_mixed_precision_lamb': test_aggregation_from_distributed_zero_mixed_precision_lamb, - 'test_aggregation_from_distributed_zero_full_precision_lamb': test_aggregation_from_distributed_zero_full_precision_lamb, - 'test_aggregation_from_distributed_megatron_full_precision_adam': test_aggregation_from_distributed_megatron_full_precision_adam, - 'test_aggregation_from_distributed_megatron_mixed_precision_adam': test_aggregation_from_distributed_megatron_mixed_precision_adam, - 'test_aggregation_from_distributed_megatron_mixed_precision_lamb': test_aggregation_from_distributed_megatron_mixed_precision_lamb, - 'test_aggregation_from_distributed_megatron_full_precision_lamb': test_aggregation_from_distributed_megatron_full_precision_lamb, - 'test_aggregation_from_distributed_zero_megatron_full_precision_adam': test_aggregation_from_distributed_zero_megatron_full_precision_adam, - 'test_aggregation_from_distributed_zero_megatron_mixed_precision_adam': test_aggregation_from_distributed_zero_megatron_mixed_precision_adam, - 'test_aggregation_from_distributed_zero_megatron_mixed_precision_lamb': test_aggregation_from_distributed_zero_megatron_mixed_precision_lamb, - 'test_aggregation_from_distributed_zero_megatron_full_precision_lamb': test_aggregation_from_distributed_zero_megatron_full_precision_lamb + "test_aggregation_from_distributed_zero_full_precision_adam": test_aggregation_from_distributed_zero_full_precision_adam, + "test_aggregation_from_distributed_zero_mixed_precision_adam": test_aggregation_from_distributed_zero_mixed_precision_adam, + "test_aggregation_from_distributed_zero_mixed_precision_lamb": test_aggregation_from_distributed_zero_mixed_precision_lamb, + "test_aggregation_from_distributed_zero_full_precision_lamb": test_aggregation_from_distributed_zero_full_precision_lamb, + "test_aggregation_from_distributed_megatron_full_precision_adam": test_aggregation_from_distributed_megatron_full_precision_adam, + "test_aggregation_from_distributed_megatron_mixed_precision_adam": test_aggregation_from_distributed_megatron_mixed_precision_adam, + "test_aggregation_from_distributed_megatron_mixed_precision_lamb": test_aggregation_from_distributed_megatron_mixed_precision_lamb, + "test_aggregation_from_distributed_megatron_full_precision_lamb": test_aggregation_from_distributed_megatron_full_precision_lamb, + "test_aggregation_from_distributed_zero_megatron_full_precision_adam": test_aggregation_from_distributed_zero_megatron_full_precision_adam, + "test_aggregation_from_distributed_zero_megatron_mixed_precision_adam": test_aggregation_from_distributed_zero_megatron_mixed_precision_adam, + "test_aggregation_from_distributed_zero_megatron_mixed_precision_lamb": test_aggregation_from_distributed_zero_megatron_mixed_precision_lamb, + "test_aggregation_from_distributed_zero_megatron_full_precision_lamb": test_aggregation_from_distributed_zero_megatron_full_precision_lamb, } -parser = argparse.ArgumentParser(description='Test aggregation of states for Zero-1') -parser.add_argument('--scenario', choices=function_map.keys(), help='training scenario to test saved and loaded states', required=True) -parser.add_argument('--checkpoint_dir', help='path to the saved states directory', required=True) +parser = argparse.ArgumentParser(description="Test aggregation of states for Zero-1") +parser.add_argument( + "--scenario", choices=function_map.keys(), help="training scenario to test saved and loaded states", required=True +) +parser.add_argument("--checkpoint_dir", help="path to the saved states directory", required=True) args = parser.parse_args() function_map[args.scenario](checkpoint_dir=args.checkpoint_dir) diff --git a/orttraining/orttraining/test/python/checkpoint/orttraining_test_load_checkpoint.py b/orttraining/orttraining/test/python/checkpoint/orttraining_test_load_checkpoint.py index ffcc6eea69642..44fbd33b3696e 100644 --- a/orttraining/orttraining/test/python/checkpoint/orttraining_test_load_checkpoint.py +++ b/orttraining/orttraining/test/python/checkpoint/orttraining_test_load_checkpoint.py @@ -15,120 +15,104 @@ import torch.distributed as dist import sys + sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) import onnxruntime from onnxruntime.training import checkpoint -from _test_helpers import distributed_setup, create_orttrainer_and_load_checkpoint, create_orttrainer_and_load_checkpoint_bart, aggregate_states, assert_all_states_close, save_ort_ckpt +from _test_helpers import ( + distributed_setup, + create_orttrainer_and_load_checkpoint, + create_orttrainer_and_load_checkpoint_bart, + aggregate_states, + assert_all_states_close, + save_ort_ckpt, +) from _test_commons import assert_all_states_close_ort, assert_all_states_close_pytorch -def test_load_from_single_node_full_precision_into_single_node_full_precision(checkpoint_dir, device = 'cuda'): - opts = {'device' : {'id' : device}, - 'debug' : {'deterministic_compute': True}} + +def test_load_from_single_node_full_precision_into_single_node_full_precision(checkpoint_dir, device="cuda"): + opts = {"device": {"id": device}, "debug": {"deterministic_compute": True}} # extract state dictionaries to compare state_dict_post_checkpoint, model = create_orttrainer_and_load_checkpoint(device, opts, checkpoint_dir) # compare all states - assert_all_states_close(checkpoint_dir, 'state_dict', state_dict_post_checkpoint, model) + assert_all_states_close(checkpoint_dir, "state_dict", state_dict_post_checkpoint, model) -def test_load_from_single_node_mixed_precision_into_single_node_full_precision(checkpoint_dir, device = 'cuda'): - opts = {'device' : {'id' : device}, - 'debug' : {'deterministic_compute': True}} + +def test_load_from_single_node_mixed_precision_into_single_node_full_precision(checkpoint_dir, device="cuda"): + opts = {"device": {"id": device}, "debug": {"deterministic_compute": True}} # extract state dictionaries to compare state_dict_post_checkpoint, model = create_orttrainer_and_load_checkpoint(device, opts, checkpoint_dir) # compare all states - assert_all_states_close(checkpoint_dir, 'state_dict', state_dict_post_checkpoint, model) + assert_all_states_close(checkpoint_dir, "state_dict", state_dict_post_checkpoint, model) -def test_load_from_single_node_mixed_precision_into_single_node_mixed_precision(checkpoint_dir, device = 'cuda'): - opts = { - 'device' : {'id' : device}, - 'mixed_precision': - { - 'enabled': True - }, - 'debug' : {'deterministic_compute': True} - } + +def test_load_from_single_node_mixed_precision_into_single_node_mixed_precision(checkpoint_dir, device="cuda"): + opts = {"device": {"id": device}, "mixed_precision": {"enabled": True}, "debug": {"deterministic_compute": True}} # extract state dictionaries to compare state_dict_post_checkpoint, model = create_orttrainer_and_load_checkpoint(device, opts, checkpoint_dir) # compare all states - assert_all_states_close(checkpoint_dir, 'state_dict', state_dict_post_checkpoint, model) + assert_all_states_close(checkpoint_dir, "state_dict", state_dict_post_checkpoint, model) -def test_load_from_single_node_full_precision_into_single_node_mixed_precision(checkpoint_dir, device = 'cuda'): - opts = { - 'device' : {'id' : device}, - 'mixed_precision': - { - 'enabled': True - }, - 'debug' : {'deterministic_compute': True} - } + +def test_load_from_single_node_full_precision_into_single_node_mixed_precision(checkpoint_dir, device="cuda"): + opts = {"device": {"id": device}, "mixed_precision": {"enabled": True}, "debug": {"deterministic_compute": True}} # extract state dictionaries to compare state_dict_post_checkpoint, model = create_orttrainer_and_load_checkpoint(device, opts, checkpoint_dir) # compare all states - assert_all_states_close(checkpoint_dir, 'state_dict', state_dict_post_checkpoint, model) + assert_all_states_close(checkpoint_dir, "state_dict", state_dict_post_checkpoint, model) + -def test_load_from_data_parallelism_full_precision_into_single_node_full_precision(checkpoint_dir, device = 'cuda'): - opts = {'device' : {'id' : device}, - 'debug' : {'deterministic_compute': True}} +def test_load_from_data_parallelism_full_precision_into_single_node_full_precision(checkpoint_dir, device="cuda"): + opts = {"device": {"id": device}, "debug": {"deterministic_compute": True}} # extract state dictionaries to compare state_dict_post_checkpoint, model = create_orttrainer_and_load_checkpoint(device, opts, checkpoint_dir) # compare all states - assert_all_states_close(checkpoint_dir, 'state_dict', state_dict_post_checkpoint, model) + assert_all_states_close(checkpoint_dir, "state_dict", state_dict_post_checkpoint, model) -def test_load_from_data_parallelism_mixed_precision_into_single_node_full_precision(checkpoint_dir, device = 'cuda'): - opts = {'device' : {'id' : device}, - 'debug' : {'deterministic_compute': True}} + +def test_load_from_data_parallelism_mixed_precision_into_single_node_full_precision(checkpoint_dir, device="cuda"): + opts = {"device": {"id": device}, "debug": {"deterministic_compute": True}} # extract state dictionaries to compare state_dict_post_checkpoint, model = create_orttrainer_and_load_checkpoint(device, opts, checkpoint_dir) # compare all states - assert_all_states_close(checkpoint_dir, 'state_dict', state_dict_post_checkpoint, model) + assert_all_states_close(checkpoint_dir, "state_dict", state_dict_post_checkpoint, model) -def test_load_from_data_parallelism_mixed_precision_into_single_node_mixed_precision(checkpoint_dir, device = 'cuda'): - opts = { - 'device' : {'id' : device}, - 'mixed_precision': - { - 'enabled': True - }, - 'debug' : {'deterministic_compute': True} - } + +def test_load_from_data_parallelism_mixed_precision_into_single_node_mixed_precision(checkpoint_dir, device="cuda"): + opts = {"device": {"id": device}, "mixed_precision": {"enabled": True}, "debug": {"deterministic_compute": True}} # extract state dictionaries to compare state_dict_post_checkpoint, model = create_orttrainer_and_load_checkpoint(device, opts, checkpoint_dir) # compare all states - assert_all_states_close(checkpoint_dir, 'state_dict', state_dict_post_checkpoint, model) + assert_all_states_close(checkpoint_dir, "state_dict", state_dict_post_checkpoint, model) -def test_load_from_data_parallelism_full_precision_into_single_node_mixed_precision(checkpoint_dir, device = 'cuda'): - opts = { - 'device' : {'id' : device}, - 'mixed_precision': - { - 'enabled': True - }, - 'debug' : {'deterministic_compute': True} - } + +def test_load_from_data_parallelism_full_precision_into_single_node_mixed_precision(checkpoint_dir, device="cuda"): + opts = {"device": {"id": device}, "mixed_precision": {"enabled": True}, "debug": {"deterministic_compute": True}} # extract state dictionaries to compare state_dict_post_checkpoint, model = create_orttrainer_and_load_checkpoint(device, opts, checkpoint_dir) # compare all states - assert_all_states_close(checkpoint_dir, 'state_dict', state_dict_post_checkpoint, model) + assert_all_states_close(checkpoint_dir, "state_dict", state_dict_post_checkpoint, model) + -def test_load_from_distributed_zero_full_precision_into_single_node_full_precision(checkpoint_dir, device = 'cuda'): - opts = {'device' : {'id' : device}, - 'debug' : {'deterministic_compute': True}} +def test_load_from_distributed_zero_full_precision_into_single_node_full_precision(checkpoint_dir, device="cuda"): + opts = {"device": {"id": device}, "debug": {"deterministic_compute": True}} # extract state dictionaries to compare state_dict_post_checkpoint, model = create_orttrainer_and_load_checkpoint(device, opts, checkpoint_dir) @@ -140,13 +124,13 @@ def test_load_from_distributed_zero_full_precision_into_single_node_full_precisi assert_all_states_close_ort(aggregated_state_dict, state_dict_post_checkpoint, reshape_states=True) # aggregate checkpoints previously saved and load it into the pytorch model for comparison - checkpoint_files = glob.glob(os.path.join(checkpoint_dir, 'checkpoint*.ortcp')) + checkpoint_files = glob.glob(os.path.join(checkpoint_dir, "checkpoint*.ortcp")) agg_state_dict = checkpoint.aggregate_checkpoints(checkpoint_files, pytorch_format=True) assert_all_states_close_pytorch(agg_state_dict, model) -def test_load_from_distributed_zero_mixed_precision_into_single_node_full_precision(checkpoint_dir, device = 'cuda'): - opts = {'device' : {'id' : device}, - 'debug' : {'deterministic_compute': True}} + +def test_load_from_distributed_zero_mixed_precision_into_single_node_full_precision(checkpoint_dir, device="cuda"): + opts = {"device": {"id": device}, "debug": {"deterministic_compute": True}} # extract state dictionaries to compare state_dict_post_checkpoint, model = create_orttrainer_and_load_checkpoint(device, opts, checkpoint_dir) @@ -158,19 +142,13 @@ def test_load_from_distributed_zero_mixed_precision_into_single_node_full_precis assert_all_states_close_ort(aggregated_state_dict, state_dict_post_checkpoint, reshape_states=True) # aggregate checkpoints previously saved and load it into the pytorch model for comparison - checkpoint_files = glob.glob(os.path.join(checkpoint_dir, 'checkpoint*.ortcp')) + checkpoint_files = glob.glob(os.path.join(checkpoint_dir, "checkpoint*.ortcp")) agg_state_dict = checkpoint.aggregate_checkpoints(checkpoint_files, pytorch_format=True) assert_all_states_close_pytorch(agg_state_dict, model) -def test_load_from_distributed_zero_mixed_precision_into_single_node_mixed_precision(checkpoint_dir, device = 'cuda'): - opts = { - 'device' : {'id' : device}, - 'mixed_precision': - { - 'enabled': True - }, - 'debug' : {'deterministic_compute': True} - } + +def test_load_from_distributed_zero_mixed_precision_into_single_node_mixed_precision(checkpoint_dir, device="cuda"): + opts = {"device": {"id": device}, "mixed_precision": {"enabled": True}, "debug": {"deterministic_compute": True}} # extract state dictionaries to compare state_dict_post_checkpoint, model = create_orttrainer_and_load_checkpoint(device, opts, checkpoint_dir) @@ -182,19 +160,13 @@ def test_load_from_distributed_zero_mixed_precision_into_single_node_mixed_preci assert_all_states_close_ort(aggregated_state_dict, state_dict_post_checkpoint, reshape_states=True) # aggregate checkpoints previously saved and load it into the pytorch model for comparison - checkpoint_files = glob.glob(os.path.join(checkpoint_dir, 'checkpoint*.ortcp')) + checkpoint_files = glob.glob(os.path.join(checkpoint_dir, "checkpoint*.ortcp")) agg_state_dict = checkpoint.aggregate_checkpoints(checkpoint_files, pytorch_format=True) assert_all_states_close_pytorch(agg_state_dict, model) -def test_load_from_distributed_zero_full_precision_into_single_node_mixed_precision(checkpoint_dir, device = 'cuda'): - opts = { - 'device' : {'id' : device}, - 'mixed_precision': - { - 'enabled': True - }, - 'debug' : {'deterministic_compute': True} - } + +def test_load_from_distributed_zero_full_precision_into_single_node_mixed_precision(checkpoint_dir, device="cuda"): + opts = {"device": {"id": device}, "mixed_precision": {"enabled": True}, "debug": {"deterministic_compute": True}} # extract state dictionaries to compare state_dict_post_checkpoint, model = create_orttrainer_and_load_checkpoint(device, opts, checkpoint_dir) @@ -206,288 +178,277 @@ def test_load_from_distributed_zero_full_precision_into_single_node_mixed_precis assert_all_states_close_ort(aggregated_state_dict, state_dict_post_checkpoint, reshape_states=True) # aggregate checkpoints previously saved and load it into the pytorch model for comparison - checkpoint_files = glob.glob(os.path.join(checkpoint_dir, 'checkpoint*.ortcp')) + checkpoint_files = glob.glob(os.path.join(checkpoint_dir, "checkpoint*.ortcp")) agg_state_dict = checkpoint.aggregate_checkpoints(checkpoint_files, pytorch_format=True) assert_all_states_close_pytorch(agg_state_dict, model) -def test_load_from_megatron_to_non_model_parallel_node(state_dict_post_checkpoint, expected_state_dict, model, checkpoint_dir): + +def test_load_from_megatron_to_non_model_parallel_node( + state_dict_post_checkpoint, expected_state_dict, model, checkpoint_dir +): # compare the expected dictionary with the aggregated state dictionary from the ORTTrainer assert_all_states_close_ort(expected_state_dict, state_dict_post_checkpoint) - # TODO: aggregate checkpoints previously saved and load it into the pytorch model for comparison, + # TODO: aggregate checkpoints previously saved and load it into the pytorch model for comparison, # need to add support to add the bart pytorch model to unit tests instead of current onnx model # checkpoint_files = glob.glob(os.path.join(checkpoint_dir, 'checkpoint*.ortcp')) # agg_state_dict = checkpoint.aggregate_checkpoints(checkpoint_files, pytorch_format=True) # assert_all_states_close_pytorch(agg_state_dict, model) -def test_load_from_distributed_megatron_full_precision_into_single_node_full_precision(checkpoint_dir, device = 'cuda'): - opts = {'device' : {'id' : device}, - 'debug' : {'deterministic_compute': True}} + +def test_load_from_distributed_megatron_full_precision_into_single_node_full_precision(checkpoint_dir, device="cuda"): + opts = {"device": {"id": device}, "debug": {"deterministic_compute": True}} # extract state dictionaries to compare - state_dict_post_checkpoint, expected_state_dict, model = create_orttrainer_and_load_checkpoint_bart(device, opts, checkpoint_dir) - test_load_from_megatron_to_non_model_parallel_node(state_dict_post_checkpoint, expected_state_dict, model, checkpoint_dir) + state_dict_post_checkpoint, expected_state_dict, model = create_orttrainer_and_load_checkpoint_bart( + device, opts, checkpoint_dir + ) + test_load_from_megatron_to_non_model_parallel_node( + state_dict_post_checkpoint, expected_state_dict, model, checkpoint_dir + ) -def test_load_from_distributed_megatron_mixed_precision_into_single_node_full_precision(checkpoint_dir, device = 'cuda'): - opts = {'device' : {'id' : device}, - 'debug' : {'deterministic_compute': True}} + +def test_load_from_distributed_megatron_mixed_precision_into_single_node_full_precision(checkpoint_dir, device="cuda"): + opts = {"device": {"id": device}, "debug": {"deterministic_compute": True}} # extract state dictionaries to compare - state_dict_post_checkpoint, expected_state_dict, model = create_orttrainer_and_load_checkpoint_bart(device, opts, checkpoint_dir) - test_load_from_megatron_to_non_model_parallel_node(state_dict_post_checkpoint, expected_state_dict, model, checkpoint_dir) + state_dict_post_checkpoint, expected_state_dict, model = create_orttrainer_and_load_checkpoint_bart( + device, opts, checkpoint_dir + ) + test_load_from_megatron_to_non_model_parallel_node( + state_dict_post_checkpoint, expected_state_dict, model, checkpoint_dir + ) -def test_load_from_distributed_megatron_mixed_precision_into_single_node_mixed_precision(checkpoint_dir, device = 'cuda'): - opts = { - 'device' : {'id' : device}, - 'mixed_precision': - { - 'enabled': True - }, - 'debug' : {'deterministic_compute': True} - } + +def test_load_from_distributed_megatron_mixed_precision_into_single_node_mixed_precision(checkpoint_dir, device="cuda"): + opts = {"device": {"id": device}, "mixed_precision": {"enabled": True}, "debug": {"deterministic_compute": True}} # extract state dictionaries to compare - state_dict_post_checkpoint, expected_state_dict, model = create_orttrainer_and_load_checkpoint_bart(device, opts, checkpoint_dir) - test_load_from_megatron_to_non_model_parallel_node(state_dict_post_checkpoint, expected_state_dict, model, checkpoint_dir) + state_dict_post_checkpoint, expected_state_dict, model = create_orttrainer_and_load_checkpoint_bart( + device, opts, checkpoint_dir + ) + test_load_from_megatron_to_non_model_parallel_node( + state_dict_post_checkpoint, expected_state_dict, model, checkpoint_dir + ) -def test_load_from_distributed_megatron_full_precision_into_single_node_mixed_precision(checkpoint_dir, device = 'cuda'): - opts = { - 'device' : {'id' : device}, - 'mixed_precision': - { - 'enabled': True - }, - 'debug' : {'deterministic_compute': True} - } + +def test_load_from_distributed_megatron_full_precision_into_single_node_mixed_precision(checkpoint_dir, device="cuda"): + opts = {"device": {"id": device}, "mixed_precision": {"enabled": True}, "debug": {"deterministic_compute": True}} # extract state dictionaries to compare - state_dict_post_checkpoint, expected_state_dict, model = create_orttrainer_and_load_checkpoint_bart(device, opts, checkpoint_dir) - test_load_from_megatron_to_non_model_parallel_node(state_dict_post_checkpoint, expected_state_dict, model, checkpoint_dir) + state_dict_post_checkpoint, expected_state_dict, model = create_orttrainer_and_load_checkpoint_bart( + device, opts, checkpoint_dir + ) + test_load_from_megatron_to_non_model_parallel_node( + state_dict_post_checkpoint, expected_state_dict, model, checkpoint_dir + ) + -def test_load_from_distributed_zero_megatron_full_precision_into_single_node_full_precision(checkpoint_dir, device = 'cuda'): - opts = {'device' : {'id' : device}, - 'debug' : {'deterministic_compute': True}} +def test_load_from_distributed_zero_megatron_full_precision_into_single_node_full_precision( + checkpoint_dir, device="cuda" +): + opts = {"device": {"id": device}, "debug": {"deterministic_compute": True}} # extract state dictionaries to compare - state_dict_post_checkpoint, expected_state_dict, model = create_orttrainer_and_load_checkpoint_bart(device, opts, checkpoint_dir) - test_load_from_megatron_to_non_model_parallel_node(state_dict_post_checkpoint, expected_state_dict, model, checkpoint_dir) + state_dict_post_checkpoint, expected_state_dict, model = create_orttrainer_and_load_checkpoint_bart( + device, opts, checkpoint_dir + ) + test_load_from_megatron_to_non_model_parallel_node( + state_dict_post_checkpoint, expected_state_dict, model, checkpoint_dir + ) -def test_load_from_distributed_zero_megatron_mixed_precision_into_single_node_full_precision(checkpoint_dir, device = 'cuda'): - opts = {'device' : {'id' : device}, - 'debug' : {'deterministic_compute': True}} + +def test_load_from_distributed_zero_megatron_mixed_precision_into_single_node_full_precision( + checkpoint_dir, device="cuda" +): + opts = {"device": {"id": device}, "debug": {"deterministic_compute": True}} # extract state dictionaries to compare - state_dict_post_checkpoint, expected_state_dict, model = create_orttrainer_and_load_checkpoint_bart(device, opts, checkpoint_dir) - test_load_from_megatron_to_non_model_parallel_node(state_dict_post_checkpoint, expected_state_dict, model, checkpoint_dir) + state_dict_post_checkpoint, expected_state_dict, model = create_orttrainer_and_load_checkpoint_bart( + device, opts, checkpoint_dir + ) + test_load_from_megatron_to_non_model_parallel_node( + state_dict_post_checkpoint, expected_state_dict, model, checkpoint_dir + ) -def test_load_from_distributed_zero_megatron_mixed_precision_into_single_node_mixed_precision(checkpoint_dir, device = 'cuda'): - opts = { - 'device' : {'id' : device}, - 'mixed_precision': - { - 'enabled': True - }, - 'debug' : {'deterministic_compute': True} - } + +def test_load_from_distributed_zero_megatron_mixed_precision_into_single_node_mixed_precision( + checkpoint_dir, device="cuda" +): + opts = {"device": {"id": device}, "mixed_precision": {"enabled": True}, "debug": {"deterministic_compute": True}} # extract state dictionaries to compare - state_dict_post_checkpoint, expected_state_dict, model = create_orttrainer_and_load_checkpoint_bart(device, opts, checkpoint_dir) - test_load_from_megatron_to_non_model_parallel_node(state_dict_post_checkpoint, expected_state_dict, model, checkpoint_dir) + state_dict_post_checkpoint, expected_state_dict, model = create_orttrainer_and_load_checkpoint_bart( + device, opts, checkpoint_dir + ) + test_load_from_megatron_to_non_model_parallel_node( + state_dict_post_checkpoint, expected_state_dict, model, checkpoint_dir + ) -def test_load_from_distributed_zero_megatron_full_precision_into_single_node_mixed_precision(checkpoint_dir, device = 'cuda'): - opts = { - 'device' : {'id' : device}, - 'mixed_precision': - { - 'enabled': True - }, - 'debug' : {'deterministic_compute': True} - } + +def test_load_from_distributed_zero_megatron_full_precision_into_single_node_mixed_precision( + checkpoint_dir, device="cuda" +): + opts = {"device": {"id": device}, "mixed_precision": {"enabled": True}, "debug": {"deterministic_compute": True}} # extract state dictionaries to compare - state_dict_post_checkpoint, expected_state_dict, model = create_orttrainer_and_load_checkpoint_bart(device, opts, checkpoint_dir) - test_load_from_megatron_to_non_model_parallel_node(state_dict_post_checkpoint, expected_state_dict, model, checkpoint_dir) + state_dict_post_checkpoint, expected_state_dict, model = create_orttrainer_and_load_checkpoint_bart( + device, opts, checkpoint_dir + ) + test_load_from_megatron_to_non_model_parallel_node( + state_dict_post_checkpoint, expected_state_dict, model, checkpoint_dir + ) + @distributed_setup -def test_load_from_single_node_full_precision_into_data_parallelism_full_precision(world_rank, world_size, device, checkpoint_dir): +def test_load_from_single_node_full_precision_into_data_parallelism_full_precision( + world_rank, world_size, device, checkpoint_dir +): opts = { - 'device' : {'id' : device}, - 'distributed' : - { - 'world_rank' : world_rank, - 'world_size' : world_size, - 'allreduce_post_accumulation' : True - }, - 'debug' : {'deterministic_compute': True} - } + "device": {"id": device}, + "distributed": {"world_rank": world_rank, "world_size": world_size, "allreduce_post_accumulation": True}, + "debug": {"deterministic_compute": True}, + } # extract state dictionaries to compare state_dict_post_checkpoint, model = create_orttrainer_and_load_checkpoint(device, opts, checkpoint_dir) # compare all states - assert_all_states_close(checkpoint_dir, 'state_dict', state_dict_post_checkpoint, model) + assert_all_states_close(checkpoint_dir, "state_dict", state_dict_post_checkpoint, model) + @distributed_setup -def test_load_from_single_node_mixed_precision_into_data_parallelism_full_precision(world_rank, world_size, device, checkpoint_dir): +def test_load_from_single_node_mixed_precision_into_data_parallelism_full_precision( + world_rank, world_size, device, checkpoint_dir +): opts = { - 'device' : {'id' : device}, - 'distributed' : - { - 'world_rank' : world_rank, - 'world_size' : world_size, - 'allreduce_post_accumulation' : True - }, - 'debug' : {'deterministic_compute': True} - } + "device": {"id": device}, + "distributed": {"world_rank": world_rank, "world_size": world_size, "allreduce_post_accumulation": True}, + "debug": {"deterministic_compute": True}, + } # extract state dictionaries to compare state_dict_post_checkpoint, model = create_orttrainer_and_load_checkpoint(device, opts, checkpoint_dir) # compare all states - assert_all_states_close(checkpoint_dir, 'state_dict', state_dict_post_checkpoint, model) + assert_all_states_close(checkpoint_dir, "state_dict", state_dict_post_checkpoint, model) + @distributed_setup -def test_load_from_single_node_mixed_precision_into_data_parallelism_mixed_precision(world_rank, world_size, device, checkpoint_dir): +def test_load_from_single_node_mixed_precision_into_data_parallelism_mixed_precision( + world_rank, world_size, device, checkpoint_dir +): opts = { - 'device' : {'id' : device}, - 'mixed_precision': - { - 'enabled': True - }, - 'distributed' : - { - 'world_rank' : world_rank, - 'world_size' : world_size, - 'allreduce_post_accumulation' : True - }, - 'debug' : {'deterministic_compute': True} - } + "device": {"id": device}, + "mixed_precision": {"enabled": True}, + "distributed": {"world_rank": world_rank, "world_size": world_size, "allreduce_post_accumulation": True}, + "debug": {"deterministic_compute": True}, + } # extract state dictionaries to compare state_dict_post_checkpoint, model = create_orttrainer_and_load_checkpoint(device, opts, checkpoint_dir) # compare all states - assert_all_states_close(checkpoint_dir, 'state_dict', state_dict_post_checkpoint, model) + assert_all_states_close(checkpoint_dir, "state_dict", state_dict_post_checkpoint, model) + @distributed_setup -def test_load_from_single_node_full_precision_into_data_parallelism_mixed_precision(world_rank, world_size, device, checkpoint_dir): +def test_load_from_single_node_full_precision_into_data_parallelism_mixed_precision( + world_rank, world_size, device, checkpoint_dir +): opts = { - 'device' : {'id' : device}, - 'mixed_precision': - { - 'enabled': True - }, - 'distributed' : - { - 'world_rank' : world_rank, - 'world_size' : world_size, - 'allreduce_post_accumulation' : True - }, - 'debug' : {'deterministic_compute': True} - } + "device": {"id": device}, + "mixed_precision": {"enabled": True}, + "distributed": {"world_rank": world_rank, "world_size": world_size, "allreduce_post_accumulation": True}, + "debug": {"deterministic_compute": True}, + } # extract state dictionaries to compare state_dict_post_checkpoint, model = create_orttrainer_and_load_checkpoint(device, opts, checkpoint_dir) # compare all states - assert_all_states_close(checkpoint_dir, 'state_dict', state_dict_post_checkpoint, model) + assert_all_states_close(checkpoint_dir, "state_dict", state_dict_post_checkpoint, model) + @distributed_setup -def test_load_from_data_parallelism_full_precision_into_data_parallelism_full_precision(world_rank, world_size, device, checkpoint_dir): +def test_load_from_data_parallelism_full_precision_into_data_parallelism_full_precision( + world_rank, world_size, device, checkpoint_dir +): opts = { - 'device' : {'id' : device}, - 'distributed' : - { - 'world_rank' : world_rank, - 'world_size' : world_size, - 'allreduce_post_accumulation' : True - }, - 'debug' : {'deterministic_compute': True} - } + "device": {"id": device}, + "distributed": {"world_rank": world_rank, "world_size": world_size, "allreduce_post_accumulation": True}, + "debug": {"deterministic_compute": True}, + } # extract state dictionaries to compare state_dict_post_checkpoint, model = create_orttrainer_and_load_checkpoint(device, opts, checkpoint_dir) # compare all states - assert_all_states_close(checkpoint_dir, 'state_dict', state_dict_post_checkpoint, model) + assert_all_states_close(checkpoint_dir, "state_dict", state_dict_post_checkpoint, model) + @distributed_setup -def test_load_from_data_parallelism_mixed_precision_into_data_parallelism_full_precision(world_rank, world_size, device, checkpoint_dir): +def test_load_from_data_parallelism_mixed_precision_into_data_parallelism_full_precision( + world_rank, world_size, device, checkpoint_dir +): opts = { - 'device' : {'id' : device}, - 'distributed' : - { - 'world_rank' : world_rank, - 'world_size' : world_size, - 'allreduce_post_accumulation' : True - }, - 'debug' : {'deterministic_compute': True} - } + "device": {"id": device}, + "distributed": {"world_rank": world_rank, "world_size": world_size, "allreduce_post_accumulation": True}, + "debug": {"deterministic_compute": True}, + } # extract state dictionaries to compare state_dict_post_checkpoint, model = create_orttrainer_and_load_checkpoint(device, opts, checkpoint_dir) # compare all states - assert_all_states_close(checkpoint_dir, 'state_dict', state_dict_post_checkpoint, model) + assert_all_states_close(checkpoint_dir, "state_dict", state_dict_post_checkpoint, model) + @distributed_setup -def test_load_from_data_parallelism_mixed_precision_into_data_parallelism_mixed_precision(world_rank, world_size, device, checkpoint_dir): +def test_load_from_data_parallelism_mixed_precision_into_data_parallelism_mixed_precision( + world_rank, world_size, device, checkpoint_dir +): opts = { - 'device' : {'id' : device}, - 'mixed_precision': - { - 'enabled': True - }, - 'distributed' : - { - 'world_rank' : world_rank, - 'world_size' : world_size, - 'allreduce_post_accumulation' : True - }, - 'debug' : {'deterministic_compute': True} - } + "device": {"id": device}, + "mixed_precision": {"enabled": True}, + "distributed": {"world_rank": world_rank, "world_size": world_size, "allreduce_post_accumulation": True}, + "debug": {"deterministic_compute": True}, + } # extract state dictionaries to compare state_dict_post_checkpoint, model = create_orttrainer_and_load_checkpoint(device, opts, checkpoint_dir) # compare all states - assert_all_states_close(checkpoint_dir, 'state_dict', state_dict_post_checkpoint, model) + assert_all_states_close(checkpoint_dir, "state_dict", state_dict_post_checkpoint, model) + @distributed_setup -def test_load_from_data_parallelism_full_precision_into_data_parallelism_mixed_precision(world_rank, world_size, device, checkpoint_dir): +def test_load_from_data_parallelism_full_precision_into_data_parallelism_mixed_precision( + world_rank, world_size, device, checkpoint_dir +): opts = { - 'device' : {'id' : device}, - 'mixed_precision': - { - 'enabled': True - }, - 'distributed' : - { - 'world_rank' : world_rank, - 'world_size' : world_size, - 'allreduce_post_accumulation' : True - }, - 'debug' : {'deterministic_compute': True} - } + "device": {"id": device}, + "mixed_precision": {"enabled": True}, + "distributed": {"world_rank": world_rank, "world_size": world_size, "allreduce_post_accumulation": True}, + "debug": {"deterministic_compute": True}, + } # extract state dictionaries to compare state_dict_post_checkpoint, model = create_orttrainer_and_load_checkpoint(device, opts, checkpoint_dir) # compare all states - assert_all_states_close(checkpoint_dir, 'state_dict', state_dict_post_checkpoint, model) + assert_all_states_close(checkpoint_dir, "state_dict", state_dict_post_checkpoint, model) + @distributed_setup -def test_load_from_distributed_zero_full_precision_into_data_parallelism_full_precision(world_rank, world_size, device, checkpoint_dir): +def test_load_from_distributed_zero_full_precision_into_data_parallelism_full_precision( + world_rank, world_size, device, checkpoint_dir +): opts = { - 'device' : {'id' : device}, - 'distributed' : - { - 'world_rank' : world_rank, - 'world_size' : world_size, - 'allreduce_post_accumulation' : True - }, - 'debug' : {'deterministic_compute': True} - } + "device": {"id": device}, + "distributed": {"world_rank": world_rank, "world_size": world_size, "allreduce_post_accumulation": True}, + "debug": {"deterministic_compute": True}, + } # extract state dictionaries to compare state_dict_post_checkpoint, model = create_orttrainer_and_load_checkpoint(device, opts, checkpoint_dir) @@ -499,22 +460,20 @@ def test_load_from_distributed_zero_full_precision_into_data_parallelism_full_pr assert_all_states_close_ort(aggregated_state_dict, state_dict_post_checkpoint, reshape_states=True) # aggregate checkpoints previously saved and load it into the pytorch model for comparison - checkpoint_files = glob.glob(os.path.join(checkpoint_dir, 'checkpoint*.ortcp')) + checkpoint_files = glob.glob(os.path.join(checkpoint_dir, "checkpoint*.ortcp")) agg_state_dict = checkpoint.aggregate_checkpoints(checkpoint_files, pytorch_format=True) assert_all_states_close_pytorch(agg_state_dict, model) + @distributed_setup -def test_load_from_distributed_zero_mixed_precision_into_data_parallelism_full_precision(world_rank, world_size, device, checkpoint_dir): +def test_load_from_distributed_zero_mixed_precision_into_data_parallelism_full_precision( + world_rank, world_size, device, checkpoint_dir +): opts = { - 'device' : {'id' : device}, - 'distributed' : - { - 'world_rank' : world_rank, - 'world_size' : world_size, - 'allreduce_post_accumulation' : True - }, - 'debug' : {'deterministic_compute': True} - } + "device": {"id": device}, + "distributed": {"world_rank": world_rank, "world_size": world_size, "allreduce_post_accumulation": True}, + "debug": {"deterministic_compute": True}, + } # extract state dictionaries to compare state_dict_post_checkpoint, model = create_orttrainer_and_load_checkpoint(device, opts, checkpoint_dir) @@ -526,26 +485,21 @@ def test_load_from_distributed_zero_mixed_precision_into_data_parallelism_full_p assert_all_states_close_ort(aggregated_state_dict, state_dict_post_checkpoint, reshape_states=True) # aggregate checkpoints previously saved and load it into the pytorch model for comparison - checkpoint_files = glob.glob(os.path.join(checkpoint_dir, 'checkpoint*.ortcp')) + checkpoint_files = glob.glob(os.path.join(checkpoint_dir, "checkpoint*.ortcp")) agg_state_dict = checkpoint.aggregate_checkpoints(checkpoint_files, pytorch_format=True) assert_all_states_close_pytorch(agg_state_dict, model) + @distributed_setup -def test_load_from_distributed_zero_mixed_precision_into_data_parallelism_mixed_precision(world_rank, world_size, device, checkpoint_dir): +def test_load_from_distributed_zero_mixed_precision_into_data_parallelism_mixed_precision( + world_rank, world_size, device, checkpoint_dir +): opts = { - 'device' : {'id' : device}, - 'mixed_precision': - { - 'enabled': True - }, - 'distributed' : - { - 'world_rank' : world_rank, - 'world_size' : world_size, - 'allreduce_post_accumulation' : True - }, - 'debug' : {'deterministic_compute': True} - } + "device": {"id": device}, + "mixed_precision": {"enabled": True}, + "distributed": {"world_rank": world_rank, "world_size": world_size, "allreduce_post_accumulation": True}, + "debug": {"deterministic_compute": True}, + } # extract state dictionaries to compare state_dict_post_checkpoint, model = create_orttrainer_and_load_checkpoint(device, opts, checkpoint_dir) @@ -557,26 +511,21 @@ def test_load_from_distributed_zero_mixed_precision_into_data_parallelism_mixed_ assert_all_states_close_ort(aggregated_state_dict, state_dict_post_checkpoint, reshape_states=True) # aggregate checkpoints previously saved and load it into the pytorch model for comparison - checkpoint_files = glob.glob(os.path.join(checkpoint_dir, 'checkpoint*.ortcp')) + checkpoint_files = glob.glob(os.path.join(checkpoint_dir, "checkpoint*.ortcp")) agg_state_dict = checkpoint.aggregate_checkpoints(checkpoint_files, pytorch_format=True) assert_all_states_close_pytorch(agg_state_dict, model) + @distributed_setup -def test_load_from_distributed_zero_full_precision_into_data_parallelism_mixed_precision(world_rank, world_size, device, checkpoint_dir): +def test_load_from_distributed_zero_full_precision_into_data_parallelism_mixed_precision( + world_rank, world_size, device, checkpoint_dir +): opts = { - 'device' : {'id' : device}, - 'mixed_precision': - { - 'enabled': True - }, - 'distributed' : - { - 'world_rank' : world_rank, - 'world_size' : world_size, - 'allreduce_post_accumulation' : True - }, - 'debug' : {'deterministic_compute': True} - } + "device": {"id": device}, + "mixed_precision": {"enabled": True}, + "distributed": {"world_rank": world_rank, "world_size": world_size, "allreduce_post_accumulation": True}, + "debug": {"deterministic_compute": True}, + } # extract state dictionaries to compare state_dict_post_checkpoint, model = create_orttrainer_and_load_checkpoint(device, opts, checkpoint_dir) @@ -588,575 +537,572 @@ def test_load_from_distributed_zero_full_precision_into_data_parallelism_mixed_p assert_all_states_close_ort(aggregated_state_dict, state_dict_post_checkpoint, reshape_states=True) # aggregate checkpoints previously saved and load it into the pytorch model for comparison - checkpoint_files = glob.glob(os.path.join(checkpoint_dir, 'checkpoint*.ortcp')) + checkpoint_files = glob.glob(os.path.join(checkpoint_dir, "checkpoint*.ortcp")) agg_state_dict = checkpoint.aggregate_checkpoints(checkpoint_files, pytorch_format=True) assert_all_states_close_pytorch(agg_state_dict, model) @distributed_setup -def test_load_from_distributed_megatron_full_precision_into_data_parallelism_full_precision(world_rank, world_size, device, checkpoint_dir): +def test_load_from_distributed_megatron_full_precision_into_data_parallelism_full_precision( + world_rank, world_size, device, checkpoint_dir +): opts = { - 'device' : {'id' : device}, - 'distributed' : - { - 'world_rank' : world_rank, - 'world_size' : world_size, - 'allreduce_post_accumulation' : True - }, - 'debug' : {'deterministic_compute': True} - } + "device": {"id": device}, + "distributed": {"world_rank": world_rank, "world_size": world_size, "allreduce_post_accumulation": True}, + "debug": {"deterministic_compute": True}, + } # extract state dictionaries to compare - state_dict_post_checkpoint, expected_state_dict, model = create_orttrainer_and_load_checkpoint_bart(device, opts, checkpoint_dir) - test_load_from_megatron_to_non_model_parallel_node(state_dict_post_checkpoint, expected_state_dict, model, checkpoint_dir) + state_dict_post_checkpoint, expected_state_dict, model = create_orttrainer_and_load_checkpoint_bart( + device, opts, checkpoint_dir + ) + test_load_from_megatron_to_non_model_parallel_node( + state_dict_post_checkpoint, expected_state_dict, model, checkpoint_dir + ) + @distributed_setup -def test_load_from_distributed_megatron_mixed_precision_into_data_parallelism_full_precision(world_rank, world_size, device, checkpoint_dir): +def test_load_from_distributed_megatron_mixed_precision_into_data_parallelism_full_precision( + world_rank, world_size, device, checkpoint_dir +): opts = { - 'device' : {'id' : device}, - 'distributed' : - { - 'world_rank' : world_rank, - 'world_size' : world_size, - 'allreduce_post_accumulation' : True - }, - 'debug' : {'deterministic_compute': True} - } + "device": {"id": device}, + "distributed": {"world_rank": world_rank, "world_size": world_size, "allreduce_post_accumulation": True}, + "debug": {"deterministic_compute": True}, + } # extract state dictionaries to compare - state_dict_post_checkpoint, expected_state_dict, model = create_orttrainer_and_load_checkpoint_bart(device, opts, checkpoint_dir) - test_load_from_megatron_to_non_model_parallel_node(state_dict_post_checkpoint, expected_state_dict, model, checkpoint_dir) + state_dict_post_checkpoint, expected_state_dict, model = create_orttrainer_and_load_checkpoint_bart( + device, opts, checkpoint_dir + ) + test_load_from_megatron_to_non_model_parallel_node( + state_dict_post_checkpoint, expected_state_dict, model, checkpoint_dir + ) + @distributed_setup -def test_load_from_distributed_megatron_mixed_precision_into_data_parallelism_mixed_precision(world_rank, world_size, device, checkpoint_dir): +def test_load_from_distributed_megatron_mixed_precision_into_data_parallelism_mixed_precision( + world_rank, world_size, device, checkpoint_dir +): opts = { - 'device' : {'id' : device}, - 'mixed_precision': - { - 'enabled': True - }, - 'distributed' : - { - 'world_rank' : world_rank, - 'world_size' : world_size, - 'allreduce_post_accumulation' : True - }, - 'debug' : {'deterministic_compute': True} - } + "device": {"id": device}, + "mixed_precision": {"enabled": True}, + "distributed": {"world_rank": world_rank, "world_size": world_size, "allreduce_post_accumulation": True}, + "debug": {"deterministic_compute": True}, + } # extract state dictionaries to compare - state_dict_post_checkpoint, expected_state_dict, model = create_orttrainer_and_load_checkpoint_bart(device, opts, checkpoint_dir) - test_load_from_megatron_to_non_model_parallel_node(state_dict_post_checkpoint, expected_state_dict, model, checkpoint_dir) + state_dict_post_checkpoint, expected_state_dict, model = create_orttrainer_and_load_checkpoint_bart( + device, opts, checkpoint_dir + ) + test_load_from_megatron_to_non_model_parallel_node( + state_dict_post_checkpoint, expected_state_dict, model, checkpoint_dir + ) + @distributed_setup -def test_load_from_distributed_megatron_full_precision_into_data_parallelism_mixed_precision(world_rank, world_size, device, checkpoint_dir): +def test_load_from_distributed_megatron_full_precision_into_data_parallelism_mixed_precision( + world_rank, world_size, device, checkpoint_dir +): opts = { - 'device' : {'id' : device}, - 'mixed_precision': - { - 'enabled': True - }, - 'distributed' : - { - 'world_rank' : world_rank, - 'world_size' : world_size, - 'allreduce_post_accumulation' : True - }, - 'debug' : {'deterministic_compute': True} - } + "device": {"id": device}, + "mixed_precision": {"enabled": True}, + "distributed": {"world_rank": world_rank, "world_size": world_size, "allreduce_post_accumulation": True}, + "debug": {"deterministic_compute": True}, + } # extract state dictionaries to compare - state_dict_post_checkpoint, expected_state_dict, model = create_orttrainer_and_load_checkpoint_bart(device, opts, checkpoint_dir) - test_load_from_megatron_to_non_model_parallel_node(state_dict_post_checkpoint, expected_state_dict, model, checkpoint_dir) + state_dict_post_checkpoint, expected_state_dict, model = create_orttrainer_and_load_checkpoint_bart( + device, opts, checkpoint_dir + ) + test_load_from_megatron_to_non_model_parallel_node( + state_dict_post_checkpoint, expected_state_dict, model, checkpoint_dir + ) + @distributed_setup -def test_load_from_distributed_zero_megatron_full_precision_into_data_parallelism_full_precision(world_rank, world_size, device, checkpoint_dir): +def test_load_from_distributed_zero_megatron_full_precision_into_data_parallelism_full_precision( + world_rank, world_size, device, checkpoint_dir +): opts = { - 'device' : {'id' : device}, - 'distributed' : - { - 'world_rank' : world_rank, - 'world_size' : world_size, - 'allreduce_post_accumulation' : True - }, - 'debug' : {'deterministic_compute': True} - } + "device": {"id": device}, + "distributed": {"world_rank": world_rank, "world_size": world_size, "allreduce_post_accumulation": True}, + "debug": {"deterministic_compute": True}, + } # extract state dictionaries to compare - state_dict_post_checkpoint, expected_state_dict, model = create_orttrainer_and_load_checkpoint_bart(device, opts, checkpoint_dir) - test_load_from_megatron_to_non_model_parallel_node(state_dict_post_checkpoint, expected_state_dict, model, checkpoint_dir) + state_dict_post_checkpoint, expected_state_dict, model = create_orttrainer_and_load_checkpoint_bart( + device, opts, checkpoint_dir + ) + test_load_from_megatron_to_non_model_parallel_node( + state_dict_post_checkpoint, expected_state_dict, model, checkpoint_dir + ) + @distributed_setup -def test_load_from_distributed_zero_megatron_mixed_precision_into_data_parallelism_full_precision(world_rank, world_size, device, checkpoint_dir): +def test_load_from_distributed_zero_megatron_mixed_precision_into_data_parallelism_full_precision( + world_rank, world_size, device, checkpoint_dir +): opts = { - 'device' : {'id' : device}, - 'distributed' : - { - 'world_rank' : world_rank, - 'world_size' : world_size, - 'allreduce_post_accumulation' : True - }, - 'debug' : {'deterministic_compute': True} - } + "device": {"id": device}, + "distributed": {"world_rank": world_rank, "world_size": world_size, "allreduce_post_accumulation": True}, + "debug": {"deterministic_compute": True}, + } # extract state dictionaries to compare - state_dict_post_checkpoint, expected_state_dict, model = create_orttrainer_and_load_checkpoint_bart(device, opts, checkpoint_dir) - test_load_from_megatron_to_non_model_parallel_node(state_dict_post_checkpoint, expected_state_dict, model, checkpoint_dir) + state_dict_post_checkpoint, expected_state_dict, model = create_orttrainer_and_load_checkpoint_bart( + device, opts, checkpoint_dir + ) + test_load_from_megatron_to_non_model_parallel_node( + state_dict_post_checkpoint, expected_state_dict, model, checkpoint_dir + ) + @distributed_setup -def test_load_from_distributed_zero_megatron_mixed_precision_into_data_parallelism_mixed_precision(world_rank, world_size, device, checkpoint_dir): +def test_load_from_distributed_zero_megatron_mixed_precision_into_data_parallelism_mixed_precision( + world_rank, world_size, device, checkpoint_dir +): opts = { - 'device' : {'id' : device}, - 'mixed_precision': - { - 'enabled': True - }, - 'distributed' : - { - 'world_rank' : world_rank, - 'world_size' : world_size, - 'allreduce_post_accumulation' : True - }, - 'debug' : {'deterministic_compute': True} - } + "device": {"id": device}, + "mixed_precision": {"enabled": True}, + "distributed": {"world_rank": world_rank, "world_size": world_size, "allreduce_post_accumulation": True}, + "debug": {"deterministic_compute": True}, + } # extract state dictionaries to compare - state_dict_post_checkpoint, expected_state_dict, model = create_orttrainer_and_load_checkpoint_bart(device, opts, checkpoint_dir) - test_load_from_megatron_to_non_model_parallel_node(state_dict_post_checkpoint, expected_state_dict, model, checkpoint_dir) + state_dict_post_checkpoint, expected_state_dict, model = create_orttrainer_and_load_checkpoint_bart( + device, opts, checkpoint_dir + ) + test_load_from_megatron_to_non_model_parallel_node( + state_dict_post_checkpoint, expected_state_dict, model, checkpoint_dir + ) + @distributed_setup -def test_load_from_distributed_zero_megatron_full_precision_into_data_parallelism_mixed_precision(world_rank, world_size, device, checkpoint_dir): +def test_load_from_distributed_zero_megatron_full_precision_into_data_parallelism_mixed_precision( + world_rank, world_size, device, checkpoint_dir +): opts = { - 'device' : {'id' : device}, - 'mixed_precision': - { - 'enabled': True - }, - 'distributed' : - { - 'world_rank' : world_rank, - 'world_size' : world_size, - 'allreduce_post_accumulation' : True - }, - 'debug' : {'deterministic_compute': True} - } + "device": {"id": device}, + "mixed_precision": {"enabled": True}, + "distributed": {"world_rank": world_rank, "world_size": world_size, "allreduce_post_accumulation": True}, + "debug": {"deterministic_compute": True}, + } # extract state dictionaries to compare - state_dict_post_checkpoint, expected_state_dict, model = create_orttrainer_and_load_checkpoint_bart(device, opts, checkpoint_dir) - test_load_from_megatron_to_non_model_parallel_node(state_dict_post_checkpoint, expected_state_dict, model, checkpoint_dir) + state_dict_post_checkpoint, expected_state_dict, model = create_orttrainer_and_load_checkpoint_bart( + device, opts, checkpoint_dir + ) + test_load_from_megatron_to_non_model_parallel_node( + state_dict_post_checkpoint, expected_state_dict, model, checkpoint_dir + ) + @distributed_setup -def test_load_from_single_node_full_precision_into_distributed_zero_full_precision(world_rank, world_size, device, checkpoint_dir): +def test_load_from_single_node_full_precision_into_distributed_zero_full_precision( + world_rank, world_size, device, checkpoint_dir +): opts = { - 'device' : {'id' : device}, - 'distributed' : - { - 'world_rank' : world_rank, - 'world_size' : world_size, - 'allreduce_post_accumulation' : True, - 'deepspeed_zero_optimization': - { - 'stage': 1 - } - }, - 'debug' : {'deterministic_compute': True} - } + "device": {"id": device}, + "distributed": { + "world_rank": world_rank, + "world_size": world_size, + "allreduce_post_accumulation": True, + "deepspeed_zero_optimization": {"stage": 1}, + }, + "debug": {"deterministic_compute": True}, + } # extract state dictionaries to compare state_dict_post_checkpoint, _ = create_orttrainer_and_load_checkpoint(device, opts, checkpoint_dir) state = None - with open(os.path.join(checkpoint_dir, 'state_dict.pkl'), 'rb') as f: + with open(os.path.join(checkpoint_dir, "state_dict.pkl"), "rb") as f: state = pickle.load(f) - state_dict_pre_checkpoint = state['state_dict'] + state_dict_pre_checkpoint = state["state_dict"] # To compare state dictionary from a single node trainer to the state dictionary from a zero run: # - Save the state dictionaries for each rank for the zero run in a pickle file (distributed_state_world_rank.pkl) # - On rank 0, manually load each state dictionary and aggregate all of them into a single state dictionary. # - Compare the aggregated state dictionary against the state dictionary previously saved from the single node run. - with open(os.path.join(checkpoint_dir, 'distributed_state_'+str(world_rank)+'.pkl'), "wb") as f: + with open(os.path.join(checkpoint_dir, "distributed_state_" + str(world_rank) + ".pkl"), "wb") as f: pickle.dump(state_dict_post_checkpoint, f) dist.barrier() if world_rank == 0: # manually aggregate states from the previously saved state dictionary in a pickle file - aggregated_state_dict = aggregate_states(checkpoint_dir, filename_prefix='distributed_state', state_dict_key_name=None) + aggregated_state_dict = aggregate_states( + checkpoint_dir, filename_prefix="distributed_state", state_dict_key_name=None + ) # compare the manually aggregated state dictionary with the single node state dictionary that was previously saved in a pickle file assert_all_states_close_ort(aggregated_state_dict, state_dict_pre_checkpoint, reshape_states=True) dist.barrier() - os.remove(os.path.join(checkpoint_dir, 'distributed_state_'+str(world_rank)+'.pkl')) + os.remove(os.path.join(checkpoint_dir, "distributed_state_" + str(world_rank) + ".pkl")) + @distributed_setup -def test_load_from_single_node_mixed_precision_into_distributed_zero_full_precision(world_rank, world_size, device, checkpoint_dir): +def test_load_from_single_node_mixed_precision_into_distributed_zero_full_precision( + world_rank, world_size, device, checkpoint_dir +): opts = { - 'device' : {'id' : device}, - 'distributed' : - { - 'world_rank' : world_rank, - 'world_size' : world_size, - 'allreduce_post_accumulation' : True, - 'deepspeed_zero_optimization': - { - 'stage': 1 - } - }, - 'debug' : {'deterministic_compute': True} - } + "device": {"id": device}, + "distributed": { + "world_rank": world_rank, + "world_size": world_size, + "allreduce_post_accumulation": True, + "deepspeed_zero_optimization": {"stage": 1}, + }, + "debug": {"deterministic_compute": True}, + } # extract state dictionaries to compare state_dict_post_checkpoint, _ = create_orttrainer_and_load_checkpoint(device, opts, checkpoint_dir) state = None - with open(os.path.join(checkpoint_dir, 'state_dict.pkl'), 'rb') as f: + with open(os.path.join(checkpoint_dir, "state_dict.pkl"), "rb") as f: state = pickle.load(f) - state_dict_pre_checkpoint = state['state_dict'] + state_dict_pre_checkpoint = state["state_dict"] # To compare state dictionary from a single node trainer to the state dictionary from a zero run: # - Save the state dictionaries for each rank for the zero run in a pickle file (distributed_state_world_rank.pkl) # - On rank 0, manually load each state dictionary and aggregate all of them into a single state dictionary. # - Compare the aggregated state dictionary against the state dictionary previously saved from the single node run. - with open(os.path.join(checkpoint_dir, 'distributed_state_'+str(world_rank)+'.pkl'), "wb") as f: + with open(os.path.join(checkpoint_dir, "distributed_state_" + str(world_rank) + ".pkl"), "wb") as f: pickle.dump(state_dict_post_checkpoint, f) dist.barrier() if world_rank == 0: # manually aggregate states from the previously saved state dictionary in a pickle file - aggregated_state_dict = aggregate_states(checkpoint_dir, filename_prefix='distributed_state', state_dict_key_name=None) + aggregated_state_dict = aggregate_states( + checkpoint_dir, filename_prefix="distributed_state", state_dict_key_name=None + ) # compare the manually aggregated state dictionary with the single node state dictionary that was previously saved in a pickle file assert_all_states_close_ort(aggregated_state_dict, state_dict_pre_checkpoint, reshape_states=True) dist.barrier() - os.remove(os.path.join(checkpoint_dir, 'distributed_state_'+str(world_rank)+'.pkl')) + os.remove(os.path.join(checkpoint_dir, "distributed_state_" + str(world_rank) + ".pkl")) + @distributed_setup -def test_load_from_single_node_mixed_precision_into_distributed_zero_mixed_precision(world_rank, world_size, device, checkpoint_dir): +def test_load_from_single_node_mixed_precision_into_distributed_zero_mixed_precision( + world_rank, world_size, device, checkpoint_dir +): opts = { - 'device' : {'id' : device}, - 'mixed_precision': - { - 'enabled': True - }, - 'distributed' : - { - 'world_rank' : world_rank, - 'world_size' : world_size, - 'allreduce_post_accumulation' : True, - 'deepspeed_zero_optimization': - { - 'stage': 1 - } - }, - 'debug' : {'deterministic_compute': True} - } + "device": {"id": device}, + "mixed_precision": {"enabled": True}, + "distributed": { + "world_rank": world_rank, + "world_size": world_size, + "allreduce_post_accumulation": True, + "deepspeed_zero_optimization": {"stage": 1}, + }, + "debug": {"deterministic_compute": True}, + } # extract state dictionaries to compare state_dict_post_checkpoint, _ = create_orttrainer_and_load_checkpoint(device, opts, checkpoint_dir) state = None - with open(os.path.join(checkpoint_dir, 'state_dict.pkl'), 'rb') as f: + with open(os.path.join(checkpoint_dir, "state_dict.pkl"), "rb") as f: state = pickle.load(f) - state_dict_pre_checkpoint = state['state_dict'] + state_dict_pre_checkpoint = state["state_dict"] # To compare state dictionary from a single node trainer to the state dictionary from a zero run: # - Save the state dictionaries for each rank for the zero run in a pickle file (distributed_state_world_rank.pkl) # - On rank 0, manually load each state dictionary and aggregate all of them into a single state dictionary. # - Compare the aggregated state dictionary against the state dictionary previously saved from the single node run. - with open(os.path.join(checkpoint_dir, 'distributed_state_'+str(world_rank)+'.pkl'), "wb") as f: + with open(os.path.join(checkpoint_dir, "distributed_state_" + str(world_rank) + ".pkl"), "wb") as f: pickle.dump(state_dict_post_checkpoint, f) dist.barrier() if world_rank == 0: # manually aggregate states from the previously saved state dictionary in a pickle file - aggregated_state_dict = aggregate_states(checkpoint_dir, filename_prefix='distributed_state', state_dict_key_name=None) + aggregated_state_dict = aggregate_states( + checkpoint_dir, filename_prefix="distributed_state", state_dict_key_name=None + ) # compare the manually aggregated state dictionary with the single node state dictionary that was previously saved in a pickle file assert_all_states_close_ort(aggregated_state_dict, state_dict_pre_checkpoint, reshape_states=True) dist.barrier() - os.remove(os.path.join(checkpoint_dir, 'distributed_state_'+str(world_rank)+'.pkl')) + os.remove(os.path.join(checkpoint_dir, "distributed_state_" + str(world_rank) + ".pkl")) + @distributed_setup -def test_load_from_single_node_full_precision_into_distributed_zero_mixed_precision(world_rank, world_size, device, checkpoint_dir): +def test_load_from_single_node_full_precision_into_distributed_zero_mixed_precision( + world_rank, world_size, device, checkpoint_dir +): opts = { - 'device' : {'id' : device}, - 'mixed_precision': - { - 'enabled': True - }, - 'distributed' : - { - 'world_rank' : world_rank, - 'world_size' : world_size, - 'allreduce_post_accumulation' : True, - 'deepspeed_zero_optimization': - { - 'stage': 1 - } - }, - 'debug' : {'deterministic_compute': True} - } + "device": {"id": device}, + "mixed_precision": {"enabled": True}, + "distributed": { + "world_rank": world_rank, + "world_size": world_size, + "allreduce_post_accumulation": True, + "deepspeed_zero_optimization": {"stage": 1}, + }, + "debug": {"deterministic_compute": True}, + } # extract state dictionaries to compare state_dict_post_checkpoint, _ = create_orttrainer_and_load_checkpoint(device, opts, checkpoint_dir) state = None - with open(os.path.join(checkpoint_dir, 'state_dict.pkl'), 'rb') as f: + with open(os.path.join(checkpoint_dir, "state_dict.pkl"), "rb") as f: state = pickle.load(f) - state_dict_pre_checkpoint = state['state_dict'] + state_dict_pre_checkpoint = state["state_dict"] # To compare state dictionary from a single node trainer to the state dictionary from a zero run: # - Save the state dictionaries for each rank for the zero run in a pickle file (distributed_state_world_rank.pkl) # - On rank 0, manually load each state dictionary and aggregate all of them into a single state dictionary. # - Compare the aggregated state dictionary against the state dictionary previously saved from the single node run. - with open(os.path.join(checkpoint_dir, 'distributed_state_'+str(world_rank)+'.pkl'), "wb") as f: + with open(os.path.join(checkpoint_dir, "distributed_state_" + str(world_rank) + ".pkl"), "wb") as f: pickle.dump(state_dict_post_checkpoint, f) dist.barrier() if world_rank == 0: # manually aggregate states from the previously saved state dictionary in a pickle file - aggregated_state_dict = aggregate_states(checkpoint_dir, filename_prefix='distributed_state', state_dict_key_name=None) + aggregated_state_dict = aggregate_states( + checkpoint_dir, filename_prefix="distributed_state", state_dict_key_name=None + ) # compare the manually aggregated state dictionary with the single node state dictionary that was previously saved in a pickle file assert_all_states_close_ort(aggregated_state_dict, state_dict_pre_checkpoint, reshape_states=True) dist.barrier() - os.remove(os.path.join(checkpoint_dir, 'distributed_state_'+str(world_rank)+'.pkl')) + os.remove(os.path.join(checkpoint_dir, "distributed_state_" + str(world_rank) + ".pkl")) + @distributed_setup -def test_load_from_data_parallelism_full_precision_into_distributed_zero_full_precision(world_rank, world_size, device, checkpoint_dir): +def test_load_from_data_parallelism_full_precision_into_distributed_zero_full_precision( + world_rank, world_size, device, checkpoint_dir +): opts = { - 'device' : {'id' : device}, - 'distributed' : - { - 'world_rank' : world_rank, - 'world_size' : world_size, - 'allreduce_post_accumulation' : True, - 'deepspeed_zero_optimization': - { - 'stage': 1 - } - }, - 'debug' : {'deterministic_compute': True} - } + "device": {"id": device}, + "distributed": { + "world_rank": world_rank, + "world_size": world_size, + "allreduce_post_accumulation": True, + "deepspeed_zero_optimization": {"stage": 1}, + }, + "debug": {"deterministic_compute": True}, + } # extract state dictionaries to compare state_dict_post_checkpoint, _ = create_orttrainer_and_load_checkpoint(device, opts, checkpoint_dir) state = None - with open(os.path.join(checkpoint_dir, 'state_dict.pkl'), 'rb') as f: + with open(os.path.join(checkpoint_dir, "state_dict.pkl"), "rb") as f: state = pickle.load(f) - state_dict_pre_checkpoint = state['state_dict'] + state_dict_pre_checkpoint = state["state_dict"] # To compare state dictionary from a data parallel node trainer to the state dictionary from a zero run: # - Save the state dictionaries for each rank for the zero run in a pickle file (distributed_state_world_rank.pkl) # - On rank 0, manually load each state dictionary and aggregate all of them into a single state dictionary. # - Compare the aggregated state dictionary against the state dictionary previously saved from the data parallel node run. - with open(os.path.join(checkpoint_dir, 'distributed_state_'+str(world_rank)+'.pkl'), "wb") as f: + with open(os.path.join(checkpoint_dir, "distributed_state_" + str(world_rank) + ".pkl"), "wb") as f: pickle.dump(state_dict_post_checkpoint, f) dist.barrier() if world_rank == 0: # manually aggregate states from the previously saved state dictionary in a pickle file - aggregated_state_dict = aggregate_states(checkpoint_dir, filename_prefix='distributed_state', state_dict_key_name=None) + aggregated_state_dict = aggregate_states( + checkpoint_dir, filename_prefix="distributed_state", state_dict_key_name=None + ) # compare the manually aggregated state dictionary with the data parallel state dictionary that was previously saved in a pickle file assert_all_states_close_ort(aggregated_state_dict, state_dict_pre_checkpoint, reshape_states=True) dist.barrier() - os.remove(os.path.join(checkpoint_dir, 'distributed_state_'+str(world_rank)+'.pkl')) + os.remove(os.path.join(checkpoint_dir, "distributed_state_" + str(world_rank) + ".pkl")) + @distributed_setup -def test_load_from_data_parallelism_mixed_precision_into_distributed_zero_full_precision(world_rank, world_size, device, checkpoint_dir): +def test_load_from_data_parallelism_mixed_precision_into_distributed_zero_full_precision( + world_rank, world_size, device, checkpoint_dir +): opts = { - 'device' : {'id' : device}, - 'distributed' : - { - 'world_rank' : world_rank, - 'world_size' : world_size, - 'allreduce_post_accumulation' : True, - 'deepspeed_zero_optimization': - { - 'stage': 1 - } - }, - 'debug' : {'deterministic_compute': True} - } + "device": {"id": device}, + "distributed": { + "world_rank": world_rank, + "world_size": world_size, + "allreduce_post_accumulation": True, + "deepspeed_zero_optimization": {"stage": 1}, + }, + "debug": {"deterministic_compute": True}, + } # extract state dictionaries to compare state_dict_post_checkpoint, _ = create_orttrainer_and_load_checkpoint(device, opts, checkpoint_dir) state = None - with open(os.path.join(checkpoint_dir, 'state_dict.pkl'), 'rb') as f: + with open(os.path.join(checkpoint_dir, "state_dict.pkl"), "rb") as f: state = pickle.load(f) - state_dict_pre_checkpoint = state['state_dict'] + state_dict_pre_checkpoint = state["state_dict"] # To compare state dictionary from a data parallel node trainer to the state dictionary from a zero run: # - Save the state dictionaries for each rank for the zero run in a pickle file (distributed_state_world_rank.pkl) # - On rank 0, manually load each state dictionary and aggregate all of them into a single state dictionary. # - Compare the aggregated state dictionary against the state dictionary previously saved from the data parallel node run. - with open(os.path.join(checkpoint_dir, 'distributed_state_'+str(world_rank)+'.pkl'), "wb") as f: + with open(os.path.join(checkpoint_dir, "distributed_state_" + str(world_rank) + ".pkl"), "wb") as f: pickle.dump(state_dict_post_checkpoint, f) dist.barrier() if world_rank == 0: # manually aggregate states from the previously saved state dictionary in a pickle file - aggregated_state_dict = aggregate_states(checkpoint_dir, filename_prefix='distributed_state', state_dict_key_name=None) + aggregated_state_dict = aggregate_states( + checkpoint_dir, filename_prefix="distributed_state", state_dict_key_name=None + ) # compare the manually aggregated state dictionary with the data parallel state dictionary that was previously saved in a pickle file assert_all_states_close_ort(aggregated_state_dict, state_dict_pre_checkpoint, reshape_states=True) dist.barrier() - os.remove(os.path.join(checkpoint_dir, 'distributed_state_'+str(world_rank)+'.pkl')) + os.remove(os.path.join(checkpoint_dir, "distributed_state_" + str(world_rank) + ".pkl")) + @distributed_setup -def test_load_from_data_parallelism_mixed_precision_into_distributed_zero_mixed_precision(world_rank, world_size, device, checkpoint_dir): +def test_load_from_data_parallelism_mixed_precision_into_distributed_zero_mixed_precision( + world_rank, world_size, device, checkpoint_dir +): opts = { - 'device' : {'id' : device}, - 'mixed_precision': - { - 'enabled': True - }, - 'distributed' : - { - 'world_rank' : world_rank, - 'world_size' : world_size, - 'allreduce_post_accumulation' : True, - 'deepspeed_zero_optimization': - { - 'stage': 1 - } - }, - 'debug' : {'deterministic_compute': True} - } + "device": {"id": device}, + "mixed_precision": {"enabled": True}, + "distributed": { + "world_rank": world_rank, + "world_size": world_size, + "allreduce_post_accumulation": True, + "deepspeed_zero_optimization": {"stage": 1}, + }, + "debug": {"deterministic_compute": True}, + } # extract state dictionaries to compare state_dict_post_checkpoint, _ = create_orttrainer_and_load_checkpoint(device, opts, checkpoint_dir) state = None - with open(os.path.join(checkpoint_dir, 'state_dict.pkl'), 'rb') as f: + with open(os.path.join(checkpoint_dir, "state_dict.pkl"), "rb") as f: state = pickle.load(f) - state_dict_pre_checkpoint = state['state_dict'] + state_dict_pre_checkpoint = state["state_dict"] # To compare state dictionary from a data parallel node trainer to the state dictionary from a zero run: # - Save the state dictionaries for each rank for the zero run in a pickle file (distributed_state_world_rank.pkl) # - On rank 0, manually load each state dictionary and aggregate all of them into a single state dictionary. # - Compare the aggregated state dictionary against the state dictionary previously saved from the data parallel node run. - with open(os.path.join(checkpoint_dir, 'distributed_state_'+str(world_rank)+'.pkl'), "wb") as f: + with open(os.path.join(checkpoint_dir, "distributed_state_" + str(world_rank) + ".pkl"), "wb") as f: pickle.dump(state_dict_post_checkpoint, f) dist.barrier() if world_rank == 0: # manually aggregate states from the previously saved state dictionary in a pickle file - aggregated_state_dict = aggregate_states(checkpoint_dir, filename_prefix='distributed_state', state_dict_key_name=None) + aggregated_state_dict = aggregate_states( + checkpoint_dir, filename_prefix="distributed_state", state_dict_key_name=None + ) # compare the manually aggregated state dictionary with the data parallel state dictionary that was previously saved in a pickle file assert_all_states_close_ort(aggregated_state_dict, state_dict_pre_checkpoint, reshape_states=True) dist.barrier() - os.remove(os.path.join(checkpoint_dir, 'distributed_state_'+str(world_rank)+'.pkl')) + os.remove(os.path.join(checkpoint_dir, "distributed_state_" + str(world_rank) + ".pkl")) + @distributed_setup -def test_load_from_data_parallelism_full_precision_into_distributed_zero_mixed_precision(world_rank, world_size, device, checkpoint_dir): +def test_load_from_data_parallelism_full_precision_into_distributed_zero_mixed_precision( + world_rank, world_size, device, checkpoint_dir +): opts = { - 'device' : {'id' : device}, - 'mixed_precision': - { - 'enabled': True - }, - 'distributed' : - { - 'world_rank' : world_rank, - 'world_size' : world_size, - 'allreduce_post_accumulation' : True, - 'deepspeed_zero_optimization': - { - 'stage': 1 - } - }, - 'debug' : {'deterministic_compute': True} - } + "device": {"id": device}, + "mixed_precision": {"enabled": True}, + "distributed": { + "world_rank": world_rank, + "world_size": world_size, + "allreduce_post_accumulation": True, + "deepspeed_zero_optimization": {"stage": 1}, + }, + "debug": {"deterministic_compute": True}, + } # extract state dictionaries to compare state_dict_post_checkpoint, _ = create_orttrainer_and_load_checkpoint(device, opts, checkpoint_dir) state = None - with open(os.path.join(checkpoint_dir, 'state_dict.pkl'), 'rb') as f: + with open(os.path.join(checkpoint_dir, "state_dict.pkl"), "rb") as f: state = pickle.load(f) - state_dict_pre_checkpoint = state['state_dict'] + state_dict_pre_checkpoint = state["state_dict"] # To compare state dictionary from a data parallel node trainer to the state dictionary from a zero run: # - Save the state dictionaries for each rank for the zero run in a pickle file (distributed_state_world_rank.pkl) # - On rank 0, manually load each state dictionary and aggregate all of them into a single state dictionary. # - Compare the aggregated state dictionary against the state dictionary previously saved from the data parallel node run. - with open(os.path.join(checkpoint_dir, 'distributed_state_'+str(world_rank)+'.pkl'), "wb") as f: + with open(os.path.join(checkpoint_dir, "distributed_state_" + str(world_rank) + ".pkl"), "wb") as f: pickle.dump(state_dict_post_checkpoint, f) dist.barrier() if world_rank == 0: # manually aggregate states from the previously saved state dictionary in a pickle file - aggregated_state_dict = aggregate_states(checkpoint_dir, filename_prefix='distributed_state', state_dict_key_name=None) + aggregated_state_dict = aggregate_states( + checkpoint_dir, filename_prefix="distributed_state", state_dict_key_name=None + ) # compare the manually aggregated state dictionary with the data parallel state dictionary that was previously saved in a pickle file assert_all_states_close_ort(aggregated_state_dict, state_dict_pre_checkpoint, reshape_states=True) dist.barrier() - os.remove(os.path.join(checkpoint_dir, 'distributed_state_'+str(world_rank)+'.pkl')) + os.remove(os.path.join(checkpoint_dir, "distributed_state_" + str(world_rank) + ".pkl")) + @distributed_setup -def test_load_from_distributed_zero_full_precision_into_distributed_zero_full_precision(world_rank, world_size, device, checkpoint_dir): +def test_load_from_distributed_zero_full_precision_into_distributed_zero_full_precision( + world_rank, world_size, device, checkpoint_dir +): opts = { - 'device' : {'id' : device}, - 'distributed' : - { - 'world_rank' : world_rank, - 'world_size' : world_size, - 'allreduce_post_accumulation' : True, - 'deepspeed_zero_optimization': - { - 'stage': 1 - } - }, - 'debug' : {'deterministic_compute': True} - } + "device": {"id": device}, + "distributed": { + "world_rank": world_rank, + "world_size": world_size, + "allreduce_post_accumulation": True, + "deepspeed_zero_optimization": {"stage": 1}, + }, + "debug": {"deterministic_compute": True}, + } # extract state dictionaries to compare state_dict_post_checkpoint, _ = create_orttrainer_and_load_checkpoint(device, opts, checkpoint_dir) state = None - with open(os.path.join(checkpoint_dir, 'state_dict_'+str(world_rank)+'.pkl'), 'rb') as f: + with open(os.path.join(checkpoint_dir, "state_dict_" + str(world_rank) + ".pkl"), "rb") as f: state = pickle.load(f) - state_dict_pre_checkpoint = state['state_dict_'+str(world_rank)] + state_dict_pre_checkpoint = state["state_dict_" + str(world_rank)] # compare all states for each rank independently assert_all_states_close_ort(state_dict_pre_checkpoint, state_dict_post_checkpoint) + @distributed_setup -def test_load_from_distributed_zero_mixed_precision_into_distributed_zero_full_precision(world_rank, world_size, device, checkpoint_dir): +def test_load_from_distributed_zero_mixed_precision_into_distributed_zero_full_precision( + world_rank, world_size, device, checkpoint_dir +): opts = { - 'device' : {'id' : device}, - 'distributed' : - { - 'world_rank' : world_rank, - 'world_size' : world_size, - 'allreduce_post_accumulation' : True, - 'deepspeed_zero_optimization': - { - 'stage': 1 - } - }, - 'debug' : {'deterministic_compute': True} - } + "device": {"id": device}, + "distributed": { + "world_rank": world_rank, + "world_size": world_size, + "allreduce_post_accumulation": True, + "deepspeed_zero_optimization": {"stage": 1}, + }, + "debug": {"deterministic_compute": True}, + } # extract state dictionaries to compare state_dict_post_checkpoint, _ = create_orttrainer_and_load_checkpoint(device, opts, checkpoint_dir) @@ -1172,75 +1118,69 @@ def test_load_from_distributed_zero_mixed_precision_into_distributed_zero_full_p # Which is why the need to compare the aggregated state dictionary (which returns a single state dictionary with all model and optimizer states) # as opposed to comparing the state dictionary for each rank independently. - with open(os.path.join(checkpoint_dir, 'distributed_state_'+str(world_rank)+'.pkl'), "wb") as f: + with open(os.path.join(checkpoint_dir, "distributed_state_" + str(world_rank) + ".pkl"), "wb") as f: pickle.dump(state_dict_post_checkpoint, f) dist.barrier() if world_rank == 0: # manually aggregate the states for the current full precision zero trainer - checkpoint_files = glob.glob(os.path.join(checkpoint_dir, 'checkpoint*.ortcp')) + checkpoint_files = glob.glob(os.path.join(checkpoint_dir, "checkpoint*.ortcp")) aggregated_state_dict1 = checkpoint.aggregate_checkpoints(checkpoint_files, pytorch_format=False) # aggregate checkpoints from the previous mixed precision zero trainer - aggregated_state_dict2 = aggregate_states(checkpoint_dir, filename_prefix='distributed_state', state_dict_key_name=None) + aggregated_state_dict2 = aggregate_states( + checkpoint_dir, filename_prefix="distributed_state", state_dict_key_name=None + ) # compare the two state dictionaries assert_all_states_close_ort(aggregated_state_dict2, aggregated_state_dict1, reshape_states=True) dist.barrier() - os.remove(os.path.join(checkpoint_dir, 'distributed_state_'+str(world_rank)+'.pkl')) + os.remove(os.path.join(checkpoint_dir, "distributed_state_" + str(world_rank) + ".pkl")) + @distributed_setup -def test_load_from_distributed_zero_mixed_precision_into_distributed_zero_mixed_precision(world_rank, world_size, device, checkpoint_dir): +def test_load_from_distributed_zero_mixed_precision_into_distributed_zero_mixed_precision( + world_rank, world_size, device, checkpoint_dir +): opts = { - 'device' : {'id' : device}, - 'mixed_precision': - { - 'enabled': True - }, - 'distributed' : - { - 'world_rank' : world_rank, - 'world_size' : world_size, - 'allreduce_post_accumulation' : True, - 'deepspeed_zero_optimization': - { - 'stage': 1 - } - }, - 'debug' : {'deterministic_compute': True} - } + "device": {"id": device}, + "mixed_precision": {"enabled": True}, + "distributed": { + "world_rank": world_rank, + "world_size": world_size, + "allreduce_post_accumulation": True, + "deepspeed_zero_optimization": {"stage": 1}, + }, + "debug": {"deterministic_compute": True}, + } # extract state dictionaries to compare state_dict_post_checkpoint, _ = create_orttrainer_and_load_checkpoint(device, opts, checkpoint_dir) state = None - with open(os.path.join(checkpoint_dir, 'state_dict_'+str(world_rank)+'.pkl'), 'rb') as f: + with open(os.path.join(checkpoint_dir, "state_dict_" + str(world_rank) + ".pkl"), "rb") as f: state = pickle.load(f) - state_dict_pre_checkpoint = state['state_dict_'+str(world_rank)] + state_dict_pre_checkpoint = state["state_dict_" + str(world_rank)] # compare all states for each rank independently assert_all_states_close_ort(state_dict_pre_checkpoint, state_dict_post_checkpoint) + @distributed_setup -def test_load_from_distributed_zero_full_precision_into_distributed_zero_mixed_precision(world_rank, world_size, device, checkpoint_dir): +def test_load_from_distributed_zero_full_precision_into_distributed_zero_mixed_precision( + world_rank, world_size, device, checkpoint_dir +): opts = { - 'device' : {'id' : device}, - 'mixed_precision': - { - 'enabled': True - }, - 'distributed' : - { - 'world_rank' : world_rank, - 'world_size' : world_size, - 'allreduce_post_accumulation' : True, - 'deepspeed_zero_optimization': - { - 'stage': 1 - } - }, - 'debug' : {'deterministic_compute': True} - } + "device": {"id": device}, + "mixed_precision": {"enabled": True}, + "distributed": { + "world_rank": world_rank, + "world_size": world_size, + "allreduce_post_accumulation": True, + "deepspeed_zero_optimization": {"stage": 1}, + }, + "debug": {"deterministic_compute": True}, + } # extract state dictionaries to compare state_dict_post_checkpoint, _ = create_orttrainer_and_load_checkpoint(device, opts, checkpoint_dir) @@ -1256,42 +1196,45 @@ def test_load_from_distributed_zero_full_precision_into_distributed_zero_mixed_p # Which is why the need to compare the aggregated state dictionary (which returns a single state dictionary with all model and optimizer states) # as opposed to comparing the state dictionary for each rank independently. - with open(os.path.join(checkpoint_dir, 'distributed_state_'+str(world_rank)+'.pkl'), "wb") as f: + with open(os.path.join(checkpoint_dir, "distributed_state_" + str(world_rank) + ".pkl"), "wb") as f: pickle.dump(state_dict_post_checkpoint, f) dist.barrier() if world_rank == 0: # manually aggregate the states for the current full precision zero trainer - checkpoint_files = glob.glob(os.path.join(checkpoint_dir, 'checkpoint*.ortcp')) + checkpoint_files = glob.glob(os.path.join(checkpoint_dir, "checkpoint*.ortcp")) aggregated_state_dict1 = checkpoint.aggregate_checkpoints(checkpoint_files, pytorch_format=False) # aggregate checkpoints from the previous mixed precision zero trainer - aggregated_state_dict2 = aggregate_states(checkpoint_dir, filename_prefix='distributed_state', state_dict_key_name=None) + aggregated_state_dict2 = aggregate_states( + checkpoint_dir, filename_prefix="distributed_state", state_dict_key_name=None + ) # compare the two state dictionaries assert_all_states_close_ort(aggregated_state_dict2, aggregated_state_dict1, reshape_states=True) dist.barrier() - os.remove(os.path.join(checkpoint_dir, 'distributed_state_'+str(world_rank)+'.pkl')) + os.remove(os.path.join(checkpoint_dir, "distributed_state_" + str(world_rank) + ".pkl")) + @distributed_setup -def test_load_from_distributed_megatron_full_precision_into_distributed_zero_full_precision(world_rank, world_size, device, checkpoint_dir): +def test_load_from_distributed_megatron_full_precision_into_distributed_zero_full_precision( + world_rank, world_size, device, checkpoint_dir +): opts = { - 'device' : {'id' : device}, - 'distributed' : - { - 'world_rank' : world_rank, - 'world_size' : world_size, - 'allreduce_post_accumulation' : True, - 'deepspeed_zero_optimization': - { - 'stage': 1 - } - }, - 'debug' : {'deterministic_compute': True} - } + "device": {"id": device}, + "distributed": { + "world_rank": world_rank, + "world_size": world_size, + "allreduce_post_accumulation": True, + "deepspeed_zero_optimization": {"stage": 1}, + }, + "debug": {"deterministic_compute": True}, + } # extract state dictionaries to compare - state_dict_post_checkpoint, expected_state_dict, model = create_orttrainer_and_load_checkpoint_bart(device, opts, checkpoint_dir) + state_dict_post_checkpoint, expected_state_dict, model = create_orttrainer_and_load_checkpoint_bart( + device, opts, checkpoint_dir + ) # To compare state dictionary between distributed megatron and distributed zero node trainers: # - Save the state dictionaries for each rank for the current zero run in a pickle file (distributed_state_world_rank.pkl) @@ -1299,39 +1242,42 @@ def test_load_from_distributed_megatron_full_precision_into_distributed_zero_ful # - Aggregate the checkpoint files from the previous zero run checkpoint files into a single state dictionary. # - Compare the aggregated state dictionary from the current run against the aggregated state dictionary from the previous run. # This is needed because of difference in model-parallel config causing different sharding of model and optimizer states - with open(os.path.join(checkpoint_dir, 'distributed_state_'+str(world_rank)+'.pkl'), "wb") as f: + with open(os.path.join(checkpoint_dir, "distributed_state_" + str(world_rank) + ".pkl"), "wb") as f: pickle.dump(state_dict_post_checkpoint, f) dist.barrier() if world_rank == 0: # manually aggregate the states for the current full precision zero trainer - aggregated_state_dict_loaded = aggregate_states(checkpoint_dir, filename_prefix='distributed_state', state_dict_key_name=None) + aggregated_state_dict_loaded = aggregate_states( + checkpoint_dir, filename_prefix="distributed_state", state_dict_key_name=None + ) # compare the two state dictionaries assert_all_states_close_ort(expected_state_dict, aggregated_state_dict_loaded, reshape_states=True) dist.barrier() - os.remove(os.path.join(checkpoint_dir, 'distributed_state_'+str(world_rank)+'.pkl')) + os.remove(os.path.join(checkpoint_dir, "distributed_state_" + str(world_rank) + ".pkl")) + @distributed_setup -def test_load_from_distributed_megatron_mixed_precision_into_distributed_zero_full_precision(world_rank, world_size, device, checkpoint_dir): +def test_load_from_distributed_megatron_mixed_precision_into_distributed_zero_full_precision( + world_rank, world_size, device, checkpoint_dir +): opts = { - 'device' : {'id' : device}, - 'distributed' : - { - 'world_rank' : world_rank, - 'world_size' : world_size, - 'allreduce_post_accumulation' : True, - 'deepspeed_zero_optimization': - { - 'stage': 1 - } - }, - 'debug' : {'deterministic_compute': True} - } + "device": {"id": device}, + "distributed": { + "world_rank": world_rank, + "world_size": world_size, + "allreduce_post_accumulation": True, + "deepspeed_zero_optimization": {"stage": 1}, + }, + "debug": {"deterministic_compute": True}, + } # extract state dictionaries to compare - state_dict_post_checkpoint, expected_state_dict, model = create_orttrainer_and_load_checkpoint_bart(device, opts, checkpoint_dir) + state_dict_post_checkpoint, expected_state_dict, model = create_orttrainer_and_load_checkpoint_bart( + device, opts, checkpoint_dir + ) # To compare state dictionary between distributed megatron and distributed zero node trainers: # - Save the state dictionaries for each rank for the current zero run in a pickle file (distributed_state_world_rank.pkl) @@ -1339,43 +1285,43 @@ def test_load_from_distributed_megatron_mixed_precision_into_distributed_zero_fu # - Aggregate the checkpoint files from the previous zero run checkpoint files into a single state dictionary. # - Compare the aggregated state dictionary from the current run against the aggregated state dictionary from the previous run. # This is needed because of difference in model-parallel config causing different sharding of model and optimizer states - with open(os.path.join(checkpoint_dir, 'distributed_state_'+str(world_rank)+'.pkl'), "wb") as f: + with open(os.path.join(checkpoint_dir, "distributed_state_" + str(world_rank) + ".pkl"), "wb") as f: pickle.dump(state_dict_post_checkpoint, f) dist.barrier() if world_rank == 0: # manually aggregate the states for the current full precision zero trainer - aggregated_state_dict_loaded = aggregate_states(checkpoint_dir, filename_prefix='distributed_state', state_dict_key_name=None) + aggregated_state_dict_loaded = aggregate_states( + checkpoint_dir, filename_prefix="distributed_state", state_dict_key_name=None + ) # compare the two state dictionaries assert_all_states_close_ort(expected_state_dict, aggregated_state_dict_loaded, reshape_states=True) dist.barrier() - os.remove(os.path.join(checkpoint_dir, 'distributed_state_'+str(world_rank)+'.pkl')) + os.remove(os.path.join(checkpoint_dir, "distributed_state_" + str(world_rank) + ".pkl")) + @distributed_setup -def test_load_from_distributed_megatron_mixed_precision_into_distributed_zero_mixed_precision(world_rank, world_size, device, checkpoint_dir): +def test_load_from_distributed_megatron_mixed_precision_into_distributed_zero_mixed_precision( + world_rank, world_size, device, checkpoint_dir +): opts = { - 'device' : {'id' : device}, - 'mixed_precision': - { - 'enabled': True - }, - 'distributed' : - { - 'world_rank' : world_rank, - 'world_size' : world_size, - 'allreduce_post_accumulation' : True, - 'deepspeed_zero_optimization': - { - 'stage': 1 - } - }, - 'debug' : {'deterministic_compute': True} - } + "device": {"id": device}, + "mixed_precision": {"enabled": True}, + "distributed": { + "world_rank": world_rank, + "world_size": world_size, + "allreduce_post_accumulation": True, + "deepspeed_zero_optimization": {"stage": 1}, + }, + "debug": {"deterministic_compute": True}, + } # extract state dictionaries to compare - state_dict_post_checkpoint, expected_state_dict, model = create_orttrainer_and_load_checkpoint_bart(device, opts, checkpoint_dir) + state_dict_post_checkpoint, expected_state_dict, model = create_orttrainer_and_load_checkpoint_bart( + device, opts, checkpoint_dir + ) # To compare state dictionary between distributed megatron and distributed zero node trainers: # - Save the state dictionaries for each rank for the current zero run in a pickle file (distributed_state_world_rank.pkl) @@ -1383,43 +1329,43 @@ def test_load_from_distributed_megatron_mixed_precision_into_distributed_zero_mi # - Aggregate the checkpoint files from the previous zero run checkpoint files into a single state dictionary. # - Compare the aggregated state dictionary from the current run against the aggregated state dictionary from the previous run. # This is needed because of difference in model-parallel config causing different sharding of model and optimizer states - with open(os.path.join(checkpoint_dir, 'distributed_state_'+str(world_rank)+'.pkl'), "wb") as f: + with open(os.path.join(checkpoint_dir, "distributed_state_" + str(world_rank) + ".pkl"), "wb") as f: pickle.dump(state_dict_post_checkpoint, f) dist.barrier() if world_rank == 0: # manually aggregate the states for the current full precision zero trainer - aggregated_state_dict_loaded = aggregate_states(checkpoint_dir, filename_prefix='distributed_state', state_dict_key_name=None) + aggregated_state_dict_loaded = aggregate_states( + checkpoint_dir, filename_prefix="distributed_state", state_dict_key_name=None + ) # compare the two state dictionaries assert_all_states_close_ort(expected_state_dict, aggregated_state_dict_loaded, reshape_states=True) dist.barrier() - os.remove(os.path.join(checkpoint_dir, 'distributed_state_'+str(world_rank)+'.pkl')) + os.remove(os.path.join(checkpoint_dir, "distributed_state_" + str(world_rank) + ".pkl")) + @distributed_setup -def test_load_from_distributed_megatron_full_precision_into_distributed_zero_mixed_precision(world_rank, world_size, device, checkpoint_dir): +def test_load_from_distributed_megatron_full_precision_into_distributed_zero_mixed_precision( + world_rank, world_size, device, checkpoint_dir +): opts = { - 'device' : {'id' : device}, - 'mixed_precision': - { - 'enabled': True - }, - 'distributed' : - { - 'world_rank' : world_rank, - 'world_size' : world_size, - 'allreduce_post_accumulation' : True, - 'deepspeed_zero_optimization': - { - 'stage': 1 - } - }, - 'debug' : {'deterministic_compute': True} - } + "device": {"id": device}, + "mixed_precision": {"enabled": True}, + "distributed": { + "world_rank": world_rank, + "world_size": world_size, + "allreduce_post_accumulation": True, + "deepspeed_zero_optimization": {"stage": 1}, + }, + "debug": {"deterministic_compute": True}, + } # extract state dictionaries to compare - state_dict_post_checkpoint, expected_state_dict, model = create_orttrainer_and_load_checkpoint_bart(device, opts, checkpoint_dir) + state_dict_post_checkpoint, expected_state_dict, model = create_orttrainer_and_load_checkpoint_bart( + device, opts, checkpoint_dir + ) # To compare state dictionary between distributed megatron and distributed zero node trainers: # - Save the state dictionaries for each rank for the current zero run in a pickle file (distributed_state_world_rank.pkl) @@ -1427,39 +1373,42 @@ def test_load_from_distributed_megatron_full_precision_into_distributed_zero_mix # - Aggregate the checkpoint files from the previous zero run checkpoint files into a single state dictionary. # - Compare the aggregated state dictionary from the current run against the aggregated state dictionary from the previous run. # This is needed because of difference in model-parallel config causing different sharding of model and optimizer states - with open(os.path.join(checkpoint_dir, 'distributed_state_'+str(world_rank)+'.pkl'), "wb") as f: + with open(os.path.join(checkpoint_dir, "distributed_state_" + str(world_rank) + ".pkl"), "wb") as f: pickle.dump(state_dict_post_checkpoint, f) dist.barrier() if world_rank == 0: # manually aggregate the states for the current full precision zero trainer - aggregated_state_dict_loaded = aggregate_states(checkpoint_dir, filename_prefix='distributed_state', state_dict_key_name=None) + aggregated_state_dict_loaded = aggregate_states( + checkpoint_dir, filename_prefix="distributed_state", state_dict_key_name=None + ) # compare the two state dictionaries assert_all_states_close_ort(expected_state_dict, aggregated_state_dict_loaded, reshape_states=True) dist.barrier() - os.remove(os.path.join(checkpoint_dir, 'distributed_state_'+str(world_rank)+'.pkl')) + os.remove(os.path.join(checkpoint_dir, "distributed_state_" + str(world_rank) + ".pkl")) + @distributed_setup -def test_load_from_distributed_zero_megatron_full_precision_into_distributed_zero_full_precision(world_rank, world_size, device, checkpoint_dir): +def test_load_from_distributed_zero_megatron_full_precision_into_distributed_zero_full_precision( + world_rank, world_size, device, checkpoint_dir +): opts = { - 'device' : {'id' : device}, - 'distributed' : - { - 'world_rank' : world_rank, - 'world_size' : world_size, - 'allreduce_post_accumulation' : True, - 'deepspeed_zero_optimization': - { - 'stage': 1 - } - }, - 'debug' : {'deterministic_compute': True} - } + "device": {"id": device}, + "distributed": { + "world_rank": world_rank, + "world_size": world_size, + "allreduce_post_accumulation": True, + "deepspeed_zero_optimization": {"stage": 1}, + }, + "debug": {"deterministic_compute": True}, + } # extract state dictionaries to compare - state_dict_post_checkpoint, expected_state_dict, model = create_orttrainer_and_load_checkpoint_bart(device, opts, checkpoint_dir) + state_dict_post_checkpoint, expected_state_dict, model = create_orttrainer_and_load_checkpoint_bart( + device, opts, checkpoint_dir + ) # To compare state dictionary between distributed zero+megatron and distributed zero node trainers: # - Save the state dictionaries for each rank for the current zero run in a pickle file (distributed_state_world_rank.pkl) @@ -1467,39 +1416,42 @@ def test_load_from_distributed_zero_megatron_full_precision_into_distributed_zer # - Aggregate the checkpoint files from the previous zero run checkpoint files into a single state dictionary. # - Compare the aggregated state dictionary from the current run against the aggregated state dictionary from the previous run. # This is needed because of difference in model-parallel config causing different sharding of model and optimizer states - with open(os.path.join(checkpoint_dir, 'distributed_state_'+str(world_rank)+'.pkl'), "wb") as f: + with open(os.path.join(checkpoint_dir, "distributed_state_" + str(world_rank) + ".pkl"), "wb") as f: pickle.dump(state_dict_post_checkpoint, f) dist.barrier() if world_rank == 0: # manually aggregate the states for the current full precision zero trainer - aggregated_state_dict_loaded = aggregate_states(checkpoint_dir, filename_prefix='distributed_state', state_dict_key_name=None) + aggregated_state_dict_loaded = aggregate_states( + checkpoint_dir, filename_prefix="distributed_state", state_dict_key_name=None + ) # compare the two state dictionaries assert_all_states_close_ort(expected_state_dict, aggregated_state_dict_loaded, reshape_states=True) dist.barrier() - os.remove(os.path.join(checkpoint_dir, 'distributed_state_'+str(world_rank)+'.pkl')) + os.remove(os.path.join(checkpoint_dir, "distributed_state_" + str(world_rank) + ".pkl")) + @distributed_setup -def test_load_from_distributed_zero_megatron_mixed_precision_into_distributed_zero_full_precision(world_rank, world_size, device, checkpoint_dir): +def test_load_from_distributed_zero_megatron_mixed_precision_into_distributed_zero_full_precision( + world_rank, world_size, device, checkpoint_dir +): opts = { - 'device' : {'id' : device}, - 'distributed' : - { - 'world_rank' : world_rank, - 'world_size' : world_size, - 'allreduce_post_accumulation' : True, - 'deepspeed_zero_optimization': - { - 'stage': 1 - } - }, - 'debug' : {'deterministic_compute': True} - } + "device": {"id": device}, + "distributed": { + "world_rank": world_rank, + "world_size": world_size, + "allreduce_post_accumulation": True, + "deepspeed_zero_optimization": {"stage": 1}, + }, + "debug": {"deterministic_compute": True}, + } # extract state dictionaries to compare - state_dict_post_checkpoint, expected_state_dict, model = create_orttrainer_and_load_checkpoint_bart(device, opts, checkpoint_dir) + state_dict_post_checkpoint, expected_state_dict, model = create_orttrainer_and_load_checkpoint_bart( + device, opts, checkpoint_dir + ) # To compare state dictionary between distributed zero+megatron and distributed zero node trainers: # - Save the state dictionaries for each rank for the current zero run in a pickle file (distributed_state_world_rank.pkl) @@ -1507,43 +1459,43 @@ def test_load_from_distributed_zero_megatron_mixed_precision_into_distributed_ze # - Aggregate the checkpoint files from the previous zero run checkpoint files into a single state dictionary. # - Compare the aggregated state dictionary from the current run against the aggregated state dictionary from the previous run. # This is needed because of difference in model-parallel config causing different sharding of model and optimizer states - with open(os.path.join(checkpoint_dir, 'distributed_state_'+str(world_rank)+'.pkl'), "wb") as f: + with open(os.path.join(checkpoint_dir, "distributed_state_" + str(world_rank) + ".pkl"), "wb") as f: pickle.dump(state_dict_post_checkpoint, f) dist.barrier() if world_rank == 0: # manually aggregate the states for the current full precision zero trainer - aggregated_state_dict_loaded = aggregate_states(checkpoint_dir, filename_prefix='distributed_state', state_dict_key_name=None) + aggregated_state_dict_loaded = aggregate_states( + checkpoint_dir, filename_prefix="distributed_state", state_dict_key_name=None + ) # compare the two state dictionaries assert_all_states_close_ort(expected_state_dict, aggregated_state_dict_loaded, reshape_states=True) dist.barrier() - os.remove(os.path.join(checkpoint_dir, 'distributed_state_'+str(world_rank)+'.pkl')) + os.remove(os.path.join(checkpoint_dir, "distributed_state_" + str(world_rank) + ".pkl")) + @distributed_setup -def test_load_from_distributed_zero_megatron_mixed_precision_into_distributed_zero_mixed_precision(world_rank, world_size, device, checkpoint_dir): +def test_load_from_distributed_zero_megatron_mixed_precision_into_distributed_zero_mixed_precision( + world_rank, world_size, device, checkpoint_dir +): opts = { - 'device' : {'id' : device}, - 'mixed_precision': - { - 'enabled': True - }, - 'distributed' : - { - 'world_rank' : world_rank, - 'world_size' : world_size, - 'allreduce_post_accumulation' : True, - 'deepspeed_zero_optimization': - { - 'stage': 1 - } - }, - 'debug' : {'deterministic_compute': True} - } + "device": {"id": device}, + "mixed_precision": {"enabled": True}, + "distributed": { + "world_rank": world_rank, + "world_size": world_size, + "allreduce_post_accumulation": True, + "deepspeed_zero_optimization": {"stage": 1}, + }, + "debug": {"deterministic_compute": True}, + } # extract state dictionaries to compare - state_dict_post_checkpoint, expected_state_dict, model = create_orttrainer_and_load_checkpoint_bart(device, opts, checkpoint_dir) + state_dict_post_checkpoint, expected_state_dict, model = create_orttrainer_and_load_checkpoint_bart( + device, opts, checkpoint_dir + ) # To compare state dictionary between distributed zero+megatron and distributed zero node trainers: # - Save the state dictionaries for each rank for the current zero run in a pickle file (distributed_state_world_rank.pkl) @@ -1551,43 +1503,43 @@ def test_load_from_distributed_zero_megatron_mixed_precision_into_distributed_ze # - Aggregate the checkpoint files from the previous zero run checkpoint files into a single state dictionary. # - Compare the aggregated state dictionary from the current run against the aggregated state dictionary from the previous run. # This is needed because of difference in model-parallel config causing different sharding of model and optimizer states - with open(os.path.join(checkpoint_dir, 'distributed_state_'+str(world_rank)+'.pkl'), "wb") as f: + with open(os.path.join(checkpoint_dir, "distributed_state_" + str(world_rank) + ".pkl"), "wb") as f: pickle.dump(state_dict_post_checkpoint, f) dist.barrier() if world_rank == 0: # manually aggregate the states for the current full precision zero trainer - aggregated_state_dict_loaded = aggregate_states(checkpoint_dir, filename_prefix='distributed_state', state_dict_key_name=None) + aggregated_state_dict_loaded = aggregate_states( + checkpoint_dir, filename_prefix="distributed_state", state_dict_key_name=None + ) # compare the two state dictionaries assert_all_states_close_ort(expected_state_dict, aggregated_state_dict_loaded, reshape_states=True) dist.barrier() - os.remove(os.path.join(checkpoint_dir, 'distributed_state_'+str(world_rank)+'.pkl')) + os.remove(os.path.join(checkpoint_dir, "distributed_state_" + str(world_rank) + ".pkl")) + @distributed_setup -def test_load_from_distributed_zero_megatron_full_precision_into_distributed_zero_mixed_precision(world_rank, world_size, device, checkpoint_dir): +def test_load_from_distributed_zero_megatron_full_precision_into_distributed_zero_mixed_precision( + world_rank, world_size, device, checkpoint_dir +): opts = { - 'device' : {'id' : device}, - 'mixed_precision': - { - 'enabled': True - }, - 'distributed' : - { - 'world_rank' : world_rank, - 'world_size' : world_size, - 'allreduce_post_accumulation' : True, - 'deepspeed_zero_optimization': - { - 'stage': 1 - } - }, - 'debug' : {'deterministic_compute': True} - } + "device": {"id": device}, + "mixed_precision": {"enabled": True}, + "distributed": { + "world_rank": world_rank, + "world_size": world_size, + "allreduce_post_accumulation": True, + "deepspeed_zero_optimization": {"stage": 1}, + }, + "debug": {"deterministic_compute": True}, + } # extract state dictionaries to compare - state_dict_post_checkpoint, expected_state_dict, model = create_orttrainer_and_load_checkpoint_bart(device, opts, checkpoint_dir) + state_dict_post_checkpoint, expected_state_dict, model = create_orttrainer_and_load_checkpoint_bart( + device, opts, checkpoint_dir + ) # To compare state dictionary between distributed zero+megatron and distributed zero node trainers: # - Save the state dictionaries for each rank for the current zero run in a pickle file (distributed_state_world_rank.pkl) @@ -1595,1834 +1547,1836 @@ def test_load_from_distributed_zero_megatron_full_precision_into_distributed_zer # - Aggregate the checkpoint files from the previous zero run checkpoint files into a single state dictionary. # - Compare the aggregated state dictionary from the current run against the aggregated state dictionary from the previous run. # This is needed because of difference in model-parallel config causing different sharding of model and optimizer states - with open(os.path.join(checkpoint_dir, 'distributed_state_'+str(world_rank)+'.pkl'), "wb") as f: + with open(os.path.join(checkpoint_dir, "distributed_state_" + str(world_rank) + ".pkl"), "wb") as f: pickle.dump(state_dict_post_checkpoint, f) dist.barrier() if world_rank == 0: # manually aggregate the states for the current full precision zero trainer - aggregated_state_dict_loaded = aggregate_states(checkpoint_dir, filename_prefix='distributed_state', state_dict_key_name=None) + aggregated_state_dict_loaded = aggregate_states( + checkpoint_dir, filename_prefix="distributed_state", state_dict_key_name=None + ) # compare the two state dictionaries assert_all_states_close_ort(expected_state_dict, aggregated_state_dict_loaded, reshape_states=True) dist.barrier() - os.remove(os.path.join(checkpoint_dir, 'distributed_state_'+str(world_rank)+'.pkl')) + os.remove(os.path.join(checkpoint_dir, "distributed_state_" + str(world_rank) + ".pkl")) ########################################################################################################################################### # LOAD TO MEGATRON ########################################################################################################################################### + @distributed_setup -def test_load_from_single_node_full_precision_into_distributed_megatron_full_precision(world_rank, world_size, device, checkpoint_dir): +def test_load_from_single_node_full_precision_into_distributed_megatron_full_precision( + world_rank, world_size, device, checkpoint_dir +): opts = { - 'device' : {'id' : device}, - 'distributed' : - { - 'world_rank' : world_rank, - 'world_size' : world_size, - 'allreduce_post_accumulation' : True, - 'horizontal_parallel_size' : world_size - }, - 'debug' : {'deterministic_compute': True} - } + "device": {"id": device}, + "distributed": { + "world_rank": world_rank, + "world_size": world_size, + "allreduce_post_accumulation": True, + "horizontal_parallel_size": world_size, + }, + "debug": {"deterministic_compute": True}, + } # extract state dictionaries to compare state_dict_post_checkpoint, _, _ = create_orttrainer_and_load_checkpoint_bart(device, opts, checkpoint_dir) state = None - with open(os.path.join(checkpoint_dir, 'state_dict.pkl'), 'rb') as f: + with open(os.path.join(checkpoint_dir, "state_dict.pkl"), "rb") as f: state = pickle.load(f) - state_dict_pre_checkpoint = state['state_dict'] + state_dict_pre_checkpoint = state["state_dict"] # To compare state dictionary from a single node trainer to the state dictionary from a megatron run: # - Save the state dictionaries for each rank for the megatron run (distributed_state_world_rank.ort.pt) # - On rank 0, manually load each state dictionary and aggregate all of them into a single state dictionary. # - Compare the aggregated state dictionary against the state dictionary previously saved from the single node run. - filename = 'distributed_state_' + str(world_rank) + '.ort.pt' + filename = "distributed_state_" + str(world_rank) + ".ort.pt" filepath = os.path.join(checkpoint_dir, filename) save_ort_ckpt(state_dict_post_checkpoint, filepath) dist.barrier() if world_rank == 0: # manually aggregate the states for the current trainer - checkpoint_files = glob.glob(os.path.join(checkpoint_dir, 'distributed_state*.ort.pt')) + checkpoint_files = glob.glob(os.path.join(checkpoint_dir, "distributed_state*.ort.pt")) aggregated_state_dict = checkpoint.aggregate_checkpoints(checkpoint_files, pytorch_format=False) # compare the manually aggregated state dictionary with the expected single node state dictionary assert_all_states_close_ort(aggregated_state_dict, state_dict_pre_checkpoint) dist.barrier() - os.remove(os.path.join(checkpoint_dir, 'distributed_state_'+str(world_rank)+'.ort.pt')) + os.remove(os.path.join(checkpoint_dir, "distributed_state_" + str(world_rank) + ".ort.pt")) + @distributed_setup -def test_load_from_single_node_mixed_precision_into_distributed_megatron_full_precision(world_rank, world_size, device, checkpoint_dir): +def test_load_from_single_node_mixed_precision_into_distributed_megatron_full_precision( + world_rank, world_size, device, checkpoint_dir +): opts = { - 'device' : {'id' : device}, - 'distributed' : - { - 'world_rank' : world_rank, - 'world_size' : world_size, - 'allreduce_post_accumulation' : True, - 'horizontal_parallel_size' : world_size - }, - 'debug' : {'deterministic_compute': True} - } + "device": {"id": device}, + "distributed": { + "world_rank": world_rank, + "world_size": world_size, + "allreduce_post_accumulation": True, + "horizontal_parallel_size": world_size, + }, + "debug": {"deterministic_compute": True}, + } # extract state dictionaries to compare state_dict_post_checkpoint, _, _ = create_orttrainer_and_load_checkpoint_bart(device, opts, checkpoint_dir) state = None - with open(os.path.join(checkpoint_dir, 'state_dict.pkl'), 'rb') as f: + with open(os.path.join(checkpoint_dir, "state_dict.pkl"), "rb") as f: state = pickle.load(f) - state_dict_pre_checkpoint = state['state_dict'] + state_dict_pre_checkpoint = state["state_dict"] # To compare state dictionary from a single node trainer to the state dictionary from a megatron run: # - Save the state dictionaries for each rank for the megatron run (distributed_state_world_rank.ort.pt) # - On rank 0, manually load each state dictionary and aggregate all of them into a single state dictionary. # - Compare the aggregated state dictionary against the state dictionary previously saved from the single node run. - filename = 'distributed_state_' + str(world_rank) + '.ort.pt' + filename = "distributed_state_" + str(world_rank) + ".ort.pt" filepath = os.path.join(checkpoint_dir, filename) save_ort_ckpt(state_dict_post_checkpoint, filepath) dist.barrier() if world_rank == 0: # manually aggregate the states for the current trainer - checkpoint_files = glob.glob(os.path.join(checkpoint_dir, 'distributed_state*.ort.pt')) + checkpoint_files = glob.glob(os.path.join(checkpoint_dir, "distributed_state*.ort.pt")) aggregated_state_dict = checkpoint.aggregate_checkpoints(checkpoint_files, pytorch_format=False) # compare the manually aggregated state dictionary with the expected single node state dictionary assert_all_states_close_ort(aggregated_state_dict, state_dict_pre_checkpoint) dist.barrier() - os.remove(os.path.join(checkpoint_dir, 'distributed_state_'+str(world_rank)+'.ort.pt')) + os.remove(os.path.join(checkpoint_dir, "distributed_state_" + str(world_rank) + ".ort.pt")) + @distributed_setup -def test_load_from_single_node_mixed_precision_into_distributed_megatron_mixed_precision(world_rank, world_size, device, checkpoint_dir): +def test_load_from_single_node_mixed_precision_into_distributed_megatron_mixed_precision( + world_rank, world_size, device, checkpoint_dir +): opts = { - 'device' : {'id' : device}, - 'mixed_precision': - { - 'enabled': True - }, - 'distributed' : - { - 'world_rank' : world_rank, - 'world_size' : world_size, - 'allreduce_post_accumulation' : True, - 'horizontal_parallel_size' : world_size - }, - 'debug' : {'deterministic_compute': True} - } + "device": {"id": device}, + "mixed_precision": {"enabled": True}, + "distributed": { + "world_rank": world_rank, + "world_size": world_size, + "allreduce_post_accumulation": True, + "horizontal_parallel_size": world_size, + }, + "debug": {"deterministic_compute": True}, + } # extract state dictionaries to compare state_dict_post_checkpoint, _, _ = create_orttrainer_and_load_checkpoint_bart(device, opts, checkpoint_dir) state = None - with open(os.path.join(checkpoint_dir, 'state_dict.pkl'), 'rb') as f: + with open(os.path.join(checkpoint_dir, "state_dict.pkl"), "rb") as f: state = pickle.load(f) - state_dict_pre_checkpoint = state['state_dict'] + state_dict_pre_checkpoint = state["state_dict"] # To compare state dictionary from a single node trainer to the state dictionary from a megatron run: # - Save the state dictionaries for each rank for the megatron run (distributed_state_world_rank.ort.pt) # - On rank 0, manually load each state dictionary and aggregate all of them into a single state dictionary. # - Compare the aggregated state dictionary against the state dictionary previously saved from the single node run. - filename = 'distributed_state_' + str(world_rank) + '.ort.pt' + filename = "distributed_state_" + str(world_rank) + ".ort.pt" filepath = os.path.join(checkpoint_dir, filename) save_ort_ckpt(state_dict_post_checkpoint, filepath) dist.barrier() if world_rank == 0: # manually aggregate the states for the current trainer - checkpoint_files = glob.glob(os.path.join(checkpoint_dir, 'distributed_state*.ort.pt')) + checkpoint_files = glob.glob(os.path.join(checkpoint_dir, "distributed_state*.ort.pt")) aggregated_state_dict = checkpoint.aggregate_checkpoints(checkpoint_files, pytorch_format=False) # compare the manually aggregated state dictionary with the expected single node state dictionary assert_all_states_close_ort(aggregated_state_dict, state_dict_pre_checkpoint) dist.barrier() - os.remove(os.path.join(checkpoint_dir, 'distributed_state_'+str(world_rank)+'.ort.pt')) + os.remove(os.path.join(checkpoint_dir, "distributed_state_" + str(world_rank) + ".ort.pt")) + @distributed_setup -def test_load_from_single_node_full_precision_into_distributed_megatron_mixed_precision(world_rank, world_size, device, checkpoint_dir): +def test_load_from_single_node_full_precision_into_distributed_megatron_mixed_precision( + world_rank, world_size, device, checkpoint_dir +): opts = { - 'device' : {'id' : device}, - 'mixed_precision': - { - 'enabled': True - }, - 'distributed' : - { - 'world_rank' : world_rank, - 'world_size' : world_size, - 'allreduce_post_accumulation' : True, - 'horizontal_parallel_size' : world_size - }, - 'debug' : {'deterministic_compute': True} - } + "device": {"id": device}, + "mixed_precision": {"enabled": True}, + "distributed": { + "world_rank": world_rank, + "world_size": world_size, + "allreduce_post_accumulation": True, + "horizontal_parallel_size": world_size, + }, + "debug": {"deterministic_compute": True}, + } # extract state dictionaries to compare state_dict_post_checkpoint, _, _ = create_orttrainer_and_load_checkpoint_bart(device, opts, checkpoint_dir) state = None - with open(os.path.join(checkpoint_dir, 'state_dict.pkl'), 'rb') as f: + with open(os.path.join(checkpoint_dir, "state_dict.pkl"), "rb") as f: state = pickle.load(f) - state_dict_pre_checkpoint = state['state_dict'] + state_dict_pre_checkpoint = state["state_dict"] # To compare state dictionary from a single node trainer to the state dictionary from a megatron run: # - Save the state dictionaries for each rank for the megatron run (distributed_state_world_rank.ort.pt) # - On rank 0, manually load each state dictionary and aggregate all of them into a single state dictionary. # - Compare the aggregated state dictionary against the state dictionary previously saved from the single node run. - filename = 'distributed_state_' + str(world_rank) + '.ort.pt' + filename = "distributed_state_" + str(world_rank) + ".ort.pt" filepath = os.path.join(checkpoint_dir, filename) save_ort_ckpt(state_dict_post_checkpoint, filepath) dist.barrier() if world_rank == 0: # manually aggregate the states for the current trainer - checkpoint_files = glob.glob(os.path.join(checkpoint_dir, 'distributed_state*.ort.pt')) + checkpoint_files = glob.glob(os.path.join(checkpoint_dir, "distributed_state*.ort.pt")) aggregated_state_dict = checkpoint.aggregate_checkpoints(checkpoint_files, pytorch_format=False) # compare the manually aggregated state dictionary with the expected single node state dictionary assert_all_states_close_ort(aggregated_state_dict, state_dict_pre_checkpoint) dist.barrier() - os.remove(os.path.join(checkpoint_dir, 'distributed_state_'+str(world_rank)+'.ort.pt')) + os.remove(os.path.join(checkpoint_dir, "distributed_state_" + str(world_rank) + ".ort.pt")) + @distributed_setup -def test_load_from_data_parallelism_full_precision_into_distributed_megatron_full_precision(world_rank, world_size, device, checkpoint_dir): +def test_load_from_data_parallelism_full_precision_into_distributed_megatron_full_precision( + world_rank, world_size, device, checkpoint_dir +): opts = { - 'device' : {'id' : device}, - 'distributed' : - { - 'world_rank' : world_rank, - 'world_size' : world_size, - 'allreduce_post_accumulation' : True, - 'horizontal_parallel_size' : world_size - }, - 'debug' : {'deterministic_compute': True} - } + "device": {"id": device}, + "distributed": { + "world_rank": world_rank, + "world_size": world_size, + "allreduce_post_accumulation": True, + "horizontal_parallel_size": world_size, + }, + "debug": {"deterministic_compute": True}, + } # extract state dictionaries to compare state_dict_post_checkpoint, _, _ = create_orttrainer_and_load_checkpoint_bart(device, opts, checkpoint_dir) state = None - with open(os.path.join(checkpoint_dir, 'state_dict.pkl'), 'rb') as f: + with open(os.path.join(checkpoint_dir, "state_dict.pkl"), "rb") as f: state = pickle.load(f) - state_dict_pre_checkpoint = state['state_dict'] + state_dict_pre_checkpoint = state["state_dict"] # To compare state dictionary from a data parallel trainer to the state dictionary from a megatron run: # - Save the state dictionaries for each rank for the megatron run (distributed_state_world_rank.ort.pt) # - On rank 0, manually load each state dictionary and aggregate all of them into a single state dictionary. # - Compare the aggregated state dictionary against the state dictionary previously saved from the data parallel node run.. - filename = 'distributed_state_' + str(world_rank) + '.ort.pt' + filename = "distributed_state_" + str(world_rank) + ".ort.pt" filepath = os.path.join(checkpoint_dir, filename) save_ort_ckpt(state_dict_post_checkpoint, filepath) dist.barrier() if world_rank == 0: # manually aggregate the states for the current trainer - checkpoint_files = glob.glob(os.path.join(checkpoint_dir, 'distributed_state*.ort.pt')) + checkpoint_files = glob.glob(os.path.join(checkpoint_dir, "distributed_state*.ort.pt")) aggregated_state_dict = checkpoint.aggregate_checkpoints(checkpoint_files, pytorch_format=False) # compare the manually aggregated state dictionary with the expected data parallel state dictionary assert_all_states_close_ort(aggregated_state_dict, state_dict_pre_checkpoint) dist.barrier() - os.remove(os.path.join(checkpoint_dir, 'distributed_state_'+str(world_rank)+'.ort.pt')) + os.remove(os.path.join(checkpoint_dir, "distributed_state_" + str(world_rank) + ".ort.pt")) + @distributed_setup -def test_load_from_data_parallelism_mixed_precision_into_distributed_megatron_full_precision(world_rank, world_size, device, checkpoint_dir): +def test_load_from_data_parallelism_mixed_precision_into_distributed_megatron_full_precision( + world_rank, world_size, device, checkpoint_dir +): opts = { - 'device' : {'id' : device}, - 'distributed' : - { - 'world_rank' : world_rank, - 'world_size' : world_size, - 'allreduce_post_accumulation' : True, - 'horizontal_parallel_size' : world_size - }, - 'debug' : {'deterministic_compute': True} - } + "device": {"id": device}, + "distributed": { + "world_rank": world_rank, + "world_size": world_size, + "allreduce_post_accumulation": True, + "horizontal_parallel_size": world_size, + }, + "debug": {"deterministic_compute": True}, + } # extract state dictionaries to compare state_dict_post_checkpoint, _, _ = create_orttrainer_and_load_checkpoint_bart(device, opts, checkpoint_dir) state = None - with open(os.path.join(checkpoint_dir, 'state_dict.pkl'), 'rb') as f: + with open(os.path.join(checkpoint_dir, "state_dict.pkl"), "rb") as f: state = pickle.load(f) - state_dict_pre_checkpoint = state['state_dict'] + state_dict_pre_checkpoint = state["state_dict"] # To compare state dictionary from a data parallel trainer to the state dictionary from a megatron run: # - Save the state dictionaries for each rank for the megatron run (distributed_state_world_rank.ort.pt) # - On rank 0, manually load each state dictionary and aggregate all of them into a single state dictionary. # - Compare the aggregated state dictionary against the state dictionary previously saved from the data parallel node run. - filename = 'distributed_state_' + str(world_rank) + '.ort.pt' + filename = "distributed_state_" + str(world_rank) + ".ort.pt" filepath = os.path.join(checkpoint_dir, filename) save_ort_ckpt(state_dict_post_checkpoint, filepath) dist.barrier() if world_rank == 0: # manually aggregate the states for the current trainer - checkpoint_files = glob.glob(os.path.join(checkpoint_dir, 'distributed_state*.ort.pt')) + checkpoint_files = glob.glob(os.path.join(checkpoint_dir, "distributed_state*.ort.pt")) aggregated_state_dict = checkpoint.aggregate_checkpoints(checkpoint_files, pytorch_format=False) # compare the manually aggregated state dictionary with the expected data parallel state dictionary assert_all_states_close_ort(aggregated_state_dict, state_dict_pre_checkpoint) dist.barrier() - os.remove(os.path.join(checkpoint_dir, 'distributed_state_'+str(world_rank)+'.ort.pt')) + os.remove(os.path.join(checkpoint_dir, "distributed_state_" + str(world_rank) + ".ort.pt")) + @distributed_setup -def test_load_from_data_parallelism_mixed_precision_into_distributed_megatron_mixed_precision(world_rank, world_size, device, checkpoint_dir): +def test_load_from_data_parallelism_mixed_precision_into_distributed_megatron_mixed_precision( + world_rank, world_size, device, checkpoint_dir +): opts = { - 'device' : {'id' : device}, - 'mixed_precision': - { - 'enabled': True - }, - 'distributed' : - { - 'world_rank' : world_rank, - 'world_size' : world_size, - 'allreduce_post_accumulation' : True, - 'horizontal_parallel_size' : world_size - }, - 'debug' : {'deterministic_compute': True} - } + "device": {"id": device}, + "mixed_precision": {"enabled": True}, + "distributed": { + "world_rank": world_rank, + "world_size": world_size, + "allreduce_post_accumulation": True, + "horizontal_parallel_size": world_size, + }, + "debug": {"deterministic_compute": True}, + } # extract state dictionaries to compare state_dict_post_checkpoint, _, _ = create_orttrainer_and_load_checkpoint_bart(device, opts, checkpoint_dir) state = None - with open(os.path.join(checkpoint_dir, 'state_dict.pkl'), 'rb') as f: + with open(os.path.join(checkpoint_dir, "state_dict.pkl"), "rb") as f: state = pickle.load(f) - state_dict_pre_checkpoint = state['state_dict'] + state_dict_pre_checkpoint = state["state_dict"] # To compare state dictionary from a data parallel trainer to the state dictionary from a megatron run: # - Save the state dictionaries for each rank for the megatron run (distributed_state_world_rank.ort.pt) # - On rank 0, manually load each state dictionary and aggregate all of them into a single state dictionary. # - Compare the aggregated state dictionary against the state dictionary previously saved from the data parallel run. - filename = 'distributed_state_' + str(world_rank) + '.ort.pt' + filename = "distributed_state_" + str(world_rank) + ".ort.pt" filepath = os.path.join(checkpoint_dir, filename) save_ort_ckpt(state_dict_post_checkpoint, filepath) dist.barrier() if world_rank == 0: # manually aggregate the states for the current trainer - checkpoint_files = glob.glob(os.path.join(checkpoint_dir, 'distributed_state*.ort.pt')) + checkpoint_files = glob.glob(os.path.join(checkpoint_dir, "distributed_state*.ort.pt")) aggregated_state_dict = checkpoint.aggregate_checkpoints(checkpoint_files, pytorch_format=False) # compare the manually aggregated state dictionary with the expected data parallel state dictionary assert_all_states_close_ort(aggregated_state_dict, state_dict_pre_checkpoint) dist.barrier() - os.remove(os.path.join(checkpoint_dir, 'distributed_state_'+str(world_rank)+'.ort.pt')) + os.remove(os.path.join(checkpoint_dir, "distributed_state_" + str(world_rank) + ".ort.pt")) + @distributed_setup -def test_load_from_data_parallelism_full_precision_into_distributed_megatron_mixed_precision(world_rank, world_size, device, checkpoint_dir): +def test_load_from_data_parallelism_full_precision_into_distributed_megatron_mixed_precision( + world_rank, world_size, device, checkpoint_dir +): opts = { - 'device' : {'id' : device}, - 'mixed_precision': - { - 'enabled': True - }, - 'distributed' : - { - 'world_rank' : world_rank, - 'world_size' : world_size, - 'allreduce_post_accumulation' : True, - 'horizontal_parallel_size' : world_size - }, - 'debug' : {'deterministic_compute': True} - } + "device": {"id": device}, + "mixed_precision": {"enabled": True}, + "distributed": { + "world_rank": world_rank, + "world_size": world_size, + "allreduce_post_accumulation": True, + "horizontal_parallel_size": world_size, + }, + "debug": {"deterministic_compute": True}, + } # extract state dictionaries to compare state_dict_post_checkpoint, _, _ = create_orttrainer_and_load_checkpoint_bart(device, opts, checkpoint_dir) state = None - with open(os.path.join(checkpoint_dir, 'state_dict.pkl'), 'rb') as f: + with open(os.path.join(checkpoint_dir, "state_dict.pkl"), "rb") as f: state = pickle.load(f) - state_dict_pre_checkpoint = state['state_dict'] + state_dict_pre_checkpoint = state["state_dict"] # To compare state dictionary from a data parallel trainer to the state dictionary from a megatron run: # - Save the state dictionaries for each rank for the megatron run (distributed_state_world_rank.ort.pt) # - On rank 0, manually load each state dictionary and aggregate all of them into a single state dictionary. # - Compare the aggregated state dictionary against the state dictionary previously saved from the data parallel run. - filename = 'distributed_state_' + str(world_rank) + '.ort.pt' + filename = "distributed_state_" + str(world_rank) + ".ort.pt" filepath = os.path.join(checkpoint_dir, filename) save_ort_ckpt(state_dict_post_checkpoint, filepath) dist.barrier() if world_rank == 0: # manually aggregate the states for the current trainer - checkpoint_files = glob.glob(os.path.join(checkpoint_dir, 'distributed_state*.ort.pt')) + checkpoint_files = glob.glob(os.path.join(checkpoint_dir, "distributed_state*.ort.pt")) aggregated_state_dict = checkpoint.aggregate_checkpoints(checkpoint_files, pytorch_format=False) # compare the manually aggregated state dictionary with the expected data parallel state dictionary assert_all_states_close_ort(aggregated_state_dict, state_dict_pre_checkpoint) dist.barrier() - os.remove(os.path.join(checkpoint_dir, 'distributed_state_'+str(world_rank)+'.ort.pt')) + os.remove(os.path.join(checkpoint_dir, "distributed_state_" + str(world_rank) + ".ort.pt")) + @distributed_setup -def test_load_from_distributed_zero_full_precision_into_distributed_megatron_full_precision(world_rank, world_size, device, checkpoint_dir): +def test_load_from_distributed_zero_full_precision_into_distributed_megatron_full_precision( + world_rank, world_size, device, checkpoint_dir +): opts = { - 'device' : {'id' : device}, - 'distributed' : - { - 'world_rank' : world_rank, - 'world_size' : world_size, - 'allreduce_post_accumulation' : True, - 'horizontal_parallel_size' : world_size - }, - 'debug' : {'deterministic_compute': True} - } + "device": {"id": device}, + "distributed": { + "world_rank": world_rank, + "world_size": world_size, + "allreduce_post_accumulation": True, + "horizontal_parallel_size": world_size, + }, + "debug": {"deterministic_compute": True}, + } # extract state dictionaries to compare - state_dict_post_checkpoint, expected_state_dict, model = create_orttrainer_and_load_checkpoint_bart(device, opts, checkpoint_dir) + state_dict_post_checkpoint, expected_state_dict, model = create_orttrainer_and_load_checkpoint_bart( + device, opts, checkpoint_dir + ) # To compare state dictionary between distributed megatron and distributed zero node trainers: # - Save the state dictionaries for each rank for the megatron run (distributed_state_world_rank.ort.pt) # - On rank 0, manually load each state dictionary and aggregate all of them into a single state dictionary. # - Compare the aggregated state dictionary from the current run against the expected state dictionary from the previous run. # This is needed because of difference in model-parallel config causing different sharding of model and optimizer states - filename = 'distributed_state_' + str(world_rank) + '.ort.pt' + filename = "distributed_state_" + str(world_rank) + ".ort.pt" filepath = os.path.join(checkpoint_dir, filename) save_ort_ckpt(state_dict_post_checkpoint, filepath) dist.barrier() if world_rank == 0: # manually aggregate the states for the current full precision zero trainer - checkpoint_files = glob.glob(os.path.join(checkpoint_dir, 'distributed_state*.ort.pt')) + checkpoint_files = glob.glob(os.path.join(checkpoint_dir, "distributed_state*.ort.pt")) aggregated_state_dict_loaded = checkpoint.aggregate_checkpoints(checkpoint_files, pytorch_format=False) # compare the two state dictionaries assert_all_states_close_ort(expected_state_dict, aggregated_state_dict_loaded, reshape_states=True) dist.barrier() - os.remove(os.path.join(checkpoint_dir, 'distributed_state_'+str(world_rank)+'.ort.pt')) + os.remove(os.path.join(checkpoint_dir, "distributed_state_" + str(world_rank) + ".ort.pt")) + @distributed_setup -def test_load_from_distributed_zero_mixed_precision_into_distributed_megatron_full_precision(world_rank, world_size, device, checkpoint_dir): +def test_load_from_distributed_zero_mixed_precision_into_distributed_megatron_full_precision( + world_rank, world_size, device, checkpoint_dir +): opts = { - 'device' : {'id' : device}, - 'distributed' : - { - 'world_rank' : world_rank, - 'world_size' : world_size, - 'allreduce_post_accumulation' : True, - 'horizontal_parallel_size' : world_size - }, - 'debug' : {'deterministic_compute': True} - } + "device": {"id": device}, + "distributed": { + "world_rank": world_rank, + "world_size": world_size, + "allreduce_post_accumulation": True, + "horizontal_parallel_size": world_size, + }, + "debug": {"deterministic_compute": True}, + } # extract state dictionaries to compare - state_dict_post_checkpoint, expected_state_dict, model = create_orttrainer_and_load_checkpoint_bart(device, opts, checkpoint_dir) + state_dict_post_checkpoint, expected_state_dict, model = create_orttrainer_and_load_checkpoint_bart( + device, opts, checkpoint_dir + ) # To compare state dictionary between distributed megatron and distributed zero node trainers: # - Save the state dictionaries for each rank for the megatron run (distributed_state_world_rank.ort.pt) # - On rank 0, manually load each state dictionary and aggregate all of them into a single state dictionary. # - Compare the aggregated state dictionary from the current run against the expected state dictionary from the previous run. # This is needed because of difference in model-parallel config causing different sharding of model and optimizer states - filename = 'distributed_state_' + str(world_rank) + '.ort.pt' + filename = "distributed_state_" + str(world_rank) + ".ort.pt" filepath = os.path.join(checkpoint_dir, filename) save_ort_ckpt(state_dict_post_checkpoint, filepath) dist.barrier() if world_rank == 0: # manually aggregate the states for the current trainer - checkpoint_files = glob.glob(os.path.join(checkpoint_dir, 'distributed_state*.ort.pt')) + checkpoint_files = glob.glob(os.path.join(checkpoint_dir, "distributed_state*.ort.pt")) aggregated_state_dict_loaded = checkpoint.aggregate_checkpoints(checkpoint_files, pytorch_format=False) # compare the two state dictionaries assert_all_states_close_ort(expected_state_dict, aggregated_state_dict_loaded, reshape_states=True) dist.barrier() - os.remove(os.path.join(checkpoint_dir, 'distributed_state_'+str(world_rank)+'.ort.pt')) + os.remove(os.path.join(checkpoint_dir, "distributed_state_" + str(world_rank) + ".ort.pt")) + @distributed_setup -def test_load_from_distributed_zero_mixed_precision_into_distributed_megatron_mixed_precision(world_rank, world_size, device, checkpoint_dir): +def test_load_from_distributed_zero_mixed_precision_into_distributed_megatron_mixed_precision( + world_rank, world_size, device, checkpoint_dir +): opts = { - 'device' : {'id' : device}, - 'mixed_precision': - { - 'enabled': True - }, - 'distributed' : - { - 'world_rank' : world_rank, - 'world_size' : world_size, - 'allreduce_post_accumulation' : True, - 'horizontal_parallel_size' : world_size - }, - 'debug' : {'deterministic_compute': True} - } + "device": {"id": device}, + "mixed_precision": {"enabled": True}, + "distributed": { + "world_rank": world_rank, + "world_size": world_size, + "allreduce_post_accumulation": True, + "horizontal_parallel_size": world_size, + }, + "debug": {"deterministic_compute": True}, + } # extract state dictionaries to compare - state_dict_post_checkpoint, expected_state_dict, model = create_orttrainer_and_load_checkpoint_bart(device, opts, checkpoint_dir) + state_dict_post_checkpoint, expected_state_dict, model = create_orttrainer_and_load_checkpoint_bart( + device, opts, checkpoint_dir + ) # To compare state dictionary between distributed megatron and distributed zero node trainers: # - Save the state dictionaries for each rank for the megatron run (distributed_state_world_rank.ort.pt) # - On rank 0, manually load each state dictionary and aggregate all of them into a single state dictionary. # - Compare the aggregated state dictionary from the current run against the expected state dictionary from the previous run. # This is needed because of difference in model-parallel config causing different sharding of model and optimizer states - filename = 'distributed_state_' + str(world_rank) + '.ort.pt' + filename = "distributed_state_" + str(world_rank) + ".ort.pt" filepath = os.path.join(checkpoint_dir, filename) save_ort_ckpt(state_dict_post_checkpoint, filepath) dist.barrier() if world_rank == 0: # manually aggregate the states for the current trainer - checkpoint_files = glob.glob(os.path.join(checkpoint_dir, 'distributed_state*.ort.pt')) + checkpoint_files = glob.glob(os.path.join(checkpoint_dir, "distributed_state*.ort.pt")) aggregated_state_dict_loaded = checkpoint.aggregate_checkpoints(checkpoint_files, pytorch_format=False) # compare the two state dictionaries assert_all_states_close_ort(expected_state_dict, aggregated_state_dict_loaded, reshape_states=True) dist.barrier() - os.remove(os.path.join(checkpoint_dir, 'distributed_state_'+str(world_rank)+'.ort.pt')) + os.remove(os.path.join(checkpoint_dir, "distributed_state_" + str(world_rank) + ".ort.pt")) + @distributed_setup -def test_load_from_distributed_zero_full_precision_into_distributed_megatron_mixed_precision(world_rank, world_size, device, checkpoint_dir): +def test_load_from_distributed_zero_full_precision_into_distributed_megatron_mixed_precision( + world_rank, world_size, device, checkpoint_dir +): opts = { - 'device' : {'id' : device}, - 'mixed_precision': - { - 'enabled': True - }, - 'distributed' : - { - 'world_rank' : world_rank, - 'world_size' : world_size, - 'allreduce_post_accumulation' : True, - 'horizontal_parallel_size' : world_size - }, - 'debug' : {'deterministic_compute': True} - } + "device": {"id": device}, + "mixed_precision": {"enabled": True}, + "distributed": { + "world_rank": world_rank, + "world_size": world_size, + "allreduce_post_accumulation": True, + "horizontal_parallel_size": world_size, + }, + "debug": {"deterministic_compute": True}, + } # extract state dictionaries to compare - state_dict_post_checkpoint, expected_state_dict, model = create_orttrainer_and_load_checkpoint_bart(device, opts, checkpoint_dir) + state_dict_post_checkpoint, expected_state_dict, model = create_orttrainer_and_load_checkpoint_bart( + device, opts, checkpoint_dir + ) # To compare state dictionary between distributed megatron and distributed zero node trainers: # - Save the state dictionaries for each rank for the megatron run (distributed_state_world_rank.ort.pt) # - On rank 0, manually load each state dictionary and aggregate all of them into a single state dictionary. # - Compare the aggregated state dictionary from the current run against the expected state dictionary from the previous run. # This is needed because of difference in model-parallel config causing different sharding of model and optimizer states - filename = 'distributed_state_' + str(world_rank) + '.ort.pt' + filename = "distributed_state_" + str(world_rank) + ".ort.pt" filepath = os.path.join(checkpoint_dir, filename) save_ort_ckpt(state_dict_post_checkpoint, filepath) dist.barrier() if world_rank == 0: # manually aggregate the states for the current trainer - checkpoint_files = glob.glob(os.path.join(checkpoint_dir, 'distributed_state*.ort.pt')) + checkpoint_files = glob.glob(os.path.join(checkpoint_dir, "distributed_state*.ort.pt")) aggregated_state_dict_loaded = checkpoint.aggregate_checkpoints(checkpoint_files, pytorch_format=False) # compare the two state dictionaries assert_all_states_close_ort(expected_state_dict, aggregated_state_dict_loaded, reshape_states=True) dist.barrier() - os.remove(os.path.join(checkpoint_dir, 'distributed_state_'+str(world_rank)+'.ort.pt')) + os.remove(os.path.join(checkpoint_dir, "distributed_state_" + str(world_rank) + ".ort.pt")) + @distributed_setup -def test_load_from_distributed_megatron_full_precision_into_distributed_megatron_full_precision(world_rank, world_size, device, checkpoint_dir): +def test_load_from_distributed_megatron_full_precision_into_distributed_megatron_full_precision( + world_rank, world_size, device, checkpoint_dir +): opts = { - 'device' : {'id' : device}, - 'distributed' : - { - 'world_rank' : world_rank, - 'world_size' : world_size, - 'allreduce_post_accumulation' : True, - 'horizontal_parallel_size' : world_size - }, - 'debug' : {'deterministic_compute': True} - } + "device": {"id": device}, + "distributed": { + "world_rank": world_rank, + "world_size": world_size, + "allreduce_post_accumulation": True, + "horizontal_parallel_size": world_size, + }, + "debug": {"deterministic_compute": True}, + } # extract state dictionaries to compare - state_dict_post_checkpoint, _ , _ = create_orttrainer_and_load_checkpoint_bart(device, opts, checkpoint_dir) + state_dict_post_checkpoint, _, _ = create_orttrainer_and_load_checkpoint_bart(device, opts, checkpoint_dir) state = None - with open(os.path.join(checkpoint_dir, 'state_dict_'+str(world_rank)+'.pkl'), 'rb') as f: + with open(os.path.join(checkpoint_dir, "state_dict_" + str(world_rank) + ".pkl"), "rb") as f: state = pickle.load(f) - state_dict_pre_checkpoint = state['state_dict_'+str(world_rank)] + state_dict_pre_checkpoint = state["state_dict_" + str(world_rank)] # compare all states for each rank independently assert_all_states_close_ort(state_dict_pre_checkpoint, state_dict_post_checkpoint) + @distributed_setup -def test_load_from_distributed_megatron_mixed_precision_into_distributed_megatron_full_precision(world_rank, world_size, device, checkpoint_dir): +def test_load_from_distributed_megatron_mixed_precision_into_distributed_megatron_full_precision( + world_rank, world_size, device, checkpoint_dir +): opts = { - 'device' : {'id' : device}, - 'distributed' : - { - 'world_rank' : world_rank, - 'world_size' : world_size, - 'allreduce_post_accumulation' : True, - 'horizontal_parallel_size' : world_size - }, - 'debug' : {'deterministic_compute': True} - } + "device": {"id": device}, + "distributed": { + "world_rank": world_rank, + "world_size": world_size, + "allreduce_post_accumulation": True, + "horizontal_parallel_size": world_size, + }, + "debug": {"deterministic_compute": True}, + } # extract state dictionaries to compare - state_dict_post_checkpoint, expected_state_dict, model = create_orttrainer_and_load_checkpoint_bart(device, opts, checkpoint_dir) + state_dict_post_checkpoint, expected_state_dict, model = create_orttrainer_and_load_checkpoint_bart( + device, opts, checkpoint_dir + ) # To compare state dictionary between distributed megatron and distributed megatron node trainers: # - Save the state dictionaries for each rank for the megatron run (distributed_state_world_rank.ort.pt) # - On rank 0, manually load each state dictionary and aggregate all of them into a single state dictionary. # - Compare the aggregated state dictionary from the current run against the expected state dictionary from the previous run. # This is needed because of difference in model-parallel config causing different sharding of model and optimizer states - filename = 'distributed_state_' + str(world_rank) + '.ort.pt' + filename = "distributed_state_" + str(world_rank) + ".ort.pt" filepath = os.path.join(checkpoint_dir, filename) save_ort_ckpt(state_dict_post_checkpoint, filepath) dist.barrier() if world_rank == 0: # manually aggregate the states for the current trainer - checkpoint_files = glob.glob(os.path.join(checkpoint_dir, 'distributed_state*.ort.pt')) + checkpoint_files = glob.glob(os.path.join(checkpoint_dir, "distributed_state*.ort.pt")) aggregated_state_dict_loaded = checkpoint.aggregate_checkpoints(checkpoint_files, pytorch_format=False) # compare the two state dictionaries assert_all_states_close_ort(expected_state_dict, aggregated_state_dict_loaded, reshape_states=True) dist.barrier() - os.remove(os.path.join(checkpoint_dir, 'distributed_state_'+str(world_rank)+'.ort.pt')) + os.remove(os.path.join(checkpoint_dir, "distributed_state_" + str(world_rank) + ".ort.pt")) + @distributed_setup -def test_load_from_distributed_megatron_mixed_precision_into_distributed_megatron_mixed_precision(world_rank, world_size, device, checkpoint_dir): +def test_load_from_distributed_megatron_mixed_precision_into_distributed_megatron_mixed_precision( + world_rank, world_size, device, checkpoint_dir +): opts = { - 'device' : {'id' : device}, - 'mixed_precision': - { - 'enabled': True - }, - 'distributed' : - { - 'world_rank' : world_rank, - 'world_size' : world_size, - 'allreduce_post_accumulation' : True, - 'horizontal_parallel_size' : world_size - }, - 'debug' : {'deterministic_compute': True} - } + "device": {"id": device}, + "mixed_precision": {"enabled": True}, + "distributed": { + "world_rank": world_rank, + "world_size": world_size, + "allreduce_post_accumulation": True, + "horizontal_parallel_size": world_size, + }, + "debug": {"deterministic_compute": True}, + } # extract state dictionaries to compare - state_dict_post_checkpoint, _ , _ = create_orttrainer_and_load_checkpoint_bart(device, opts, checkpoint_dir) + state_dict_post_checkpoint, _, _ = create_orttrainer_and_load_checkpoint_bart(device, opts, checkpoint_dir) state = None - with open(os.path.join(checkpoint_dir, 'state_dict_'+str(world_rank)+'.pkl'), 'rb') as f: + with open(os.path.join(checkpoint_dir, "state_dict_" + str(world_rank) + ".pkl"), "rb") as f: state = pickle.load(f) - state_dict_pre_checkpoint = state['state_dict_'+str(world_rank)] + state_dict_pre_checkpoint = state["state_dict_" + str(world_rank)] # compare all states for each rank independently assert_all_states_close_ort(state_dict_pre_checkpoint, state_dict_post_checkpoint) + @distributed_setup -def test_load_from_distributed_megatron_full_precision_into_distributed_megatron_mixed_precision(world_rank, world_size, device, checkpoint_dir): +def test_load_from_distributed_megatron_full_precision_into_distributed_megatron_mixed_precision( + world_rank, world_size, device, checkpoint_dir +): opts = { - 'device' : {'id' : device}, - 'mixed_precision': - { - 'enabled': True - }, - 'distributed' : - { - 'world_rank' : world_rank, - 'world_size' : world_size, - 'allreduce_post_accumulation' : True, - 'horizontal_parallel_size' : world_size - }, - 'debug' : {'deterministic_compute': True} - } + "device": {"id": device}, + "mixed_precision": {"enabled": True}, + "distributed": { + "world_rank": world_rank, + "world_size": world_size, + "allreduce_post_accumulation": True, + "horizontal_parallel_size": world_size, + }, + "debug": {"deterministic_compute": True}, + } # extract state dictionaries to compare - state_dict_post_checkpoint, expected_state_dict, model = create_orttrainer_and_load_checkpoint_bart(device, opts, checkpoint_dir) + state_dict_post_checkpoint, expected_state_dict, model = create_orttrainer_and_load_checkpoint_bart( + device, opts, checkpoint_dir + ) # To compare state dictionary between distributed megatron and distributed megatron node trainers: # - Save the state dictionaries for each rank for the megatron run (distributed_state_world_rank.ort.pt) # - On rank 0, manually load each state dictionary and aggregate all of them into a single state dictionary. # - Compare the aggregated state dictionary from the current run against the expected state dictionary from the previous run. # This is needed because of difference in model-parallel config causing different sharding of model and optimizer states - filename = 'distributed_state_' + str(world_rank) + '.ort.pt' + filename = "distributed_state_" + str(world_rank) + ".ort.pt" filepath = os.path.join(checkpoint_dir, filename) save_ort_ckpt(state_dict_post_checkpoint, filepath) dist.barrier() if world_rank == 0: # manually aggregate the states for the current trainer - checkpoint_files = glob.glob(os.path.join(checkpoint_dir, 'distributed_state*.ort.pt')) + checkpoint_files = glob.glob(os.path.join(checkpoint_dir, "distributed_state*.ort.pt")) aggregated_state_dict_loaded = checkpoint.aggregate_checkpoints(checkpoint_files, pytorch_format=False) # compare the two state dictionaries assert_all_states_close_ort(expected_state_dict, aggregated_state_dict_loaded, reshape_states=True) dist.barrier() - os.remove(os.path.join(checkpoint_dir, 'distributed_state_'+str(world_rank)+'.ort.pt')) + os.remove(os.path.join(checkpoint_dir, "distributed_state_" + str(world_rank) + ".ort.pt")) + @distributed_setup -def test_load_from_distributed_zero_megatron_full_precision_into_distributed_megatron_full_precision(world_rank, world_size, device, checkpoint_dir): +def test_load_from_distributed_zero_megatron_full_precision_into_distributed_megatron_full_precision( + world_rank, world_size, device, checkpoint_dir +): opts = { - 'device' : {'id' : device}, - 'distributed' : - { - 'world_rank' : world_rank, - 'world_size' : world_size, - 'allreduce_post_accumulation' : True, - 'horizontal_parallel_size' : world_size - }, - 'debug' : {'deterministic_compute': True} - } + "device": {"id": device}, + "distributed": { + "world_rank": world_rank, + "world_size": world_size, + "allreduce_post_accumulation": True, + "horizontal_parallel_size": world_size, + }, + "debug": {"deterministic_compute": True}, + } # extract state dictionaries to compare - state_dict_post_checkpoint, expected_state_dict, model = create_orttrainer_and_load_checkpoint_bart(device, opts, checkpoint_dir) + state_dict_post_checkpoint, expected_state_dict, model = create_orttrainer_and_load_checkpoint_bart( + device, opts, checkpoint_dir + ) # To compare state dictionary between distributed megatron and distributed zero+megatron node trainers: # - Save the state dictionaries for each rank for the megatron run (distributed_state_world_rank.ort.pt) # - On rank 0, manually load each state dictionary and aggregate all of them into a single state dictionary. # - Compare the aggregated state dictionary from the current run against the expected state dictionary from the previous run. # This is needed because of difference in model-parallel config causing different sharding of model and optimizer states - filename = 'distributed_state_' + str(world_rank) + '.ort.pt' + filename = "distributed_state_" + str(world_rank) + ".ort.pt" filepath = os.path.join(checkpoint_dir, filename) save_ort_ckpt(state_dict_post_checkpoint, filepath) dist.barrier() if world_rank == 0: # manually aggregate the states for the current trainer - checkpoint_files = glob.glob(os.path.join(checkpoint_dir, 'distributed_state*.ort.pt')) + checkpoint_files = glob.glob(os.path.join(checkpoint_dir, "distributed_state*.ort.pt")) aggregated_state_dict_loaded = checkpoint.aggregate_checkpoints(checkpoint_files, pytorch_format=False) # compare the two state dictionaries assert_all_states_close_ort(expected_state_dict, aggregated_state_dict_loaded, reshape_states=True) dist.barrier() - os.remove(os.path.join(checkpoint_dir, 'distributed_state_'+str(world_rank)+'.ort.pt')) + os.remove(os.path.join(checkpoint_dir, "distributed_state_" + str(world_rank) + ".ort.pt")) + @distributed_setup -def test_load_from_distributed_zero_megatron_mixed_precision_into_distributed_megatron_full_precision(world_rank, world_size, device, checkpoint_dir): +def test_load_from_distributed_zero_megatron_mixed_precision_into_distributed_megatron_full_precision( + world_rank, world_size, device, checkpoint_dir +): opts = { - 'device' : {'id' : device}, - 'distributed' : - { - 'world_rank' : world_rank, - 'world_size' : world_size, - 'allreduce_post_accumulation' : True, - 'horizontal_parallel_size' : world_size - }, - 'debug' : {'deterministic_compute': True} - } + "device": {"id": device}, + "distributed": { + "world_rank": world_rank, + "world_size": world_size, + "allreduce_post_accumulation": True, + "horizontal_parallel_size": world_size, + }, + "debug": {"deterministic_compute": True}, + } # extract state dictionaries to compare - state_dict_post_checkpoint, expected_state_dict, model = create_orttrainer_and_load_checkpoint_bart(device, opts, checkpoint_dir) + state_dict_post_checkpoint, expected_state_dict, model = create_orttrainer_and_load_checkpoint_bart( + device, opts, checkpoint_dir + ) # To compare state dictionary between distributed megatron and distributed zero+megatron node trainers : # - Save the state dictionaries for each rank for the megatron run (distributed_state_world_rank.ort.pt) # - On rank 0, manually load each state dictionary and aggregate all of them into a single state dictionary. # - Compare the aggregated state dictionary from the current run against the expected state dictionary from the previous run. # This is needed because of difference in model-parallel config causing different sharding of model and optimizer states - filename = 'distributed_state_' + str(world_rank) + '.ort.pt' + filename = "distributed_state_" + str(world_rank) + ".ort.pt" filepath = os.path.join(checkpoint_dir, filename) save_ort_ckpt(state_dict_post_checkpoint, filepath) dist.barrier() if world_rank == 0: # manually aggregate the states for the current trainer - checkpoint_files = glob.glob(os.path.join(checkpoint_dir, 'distributed_state*.ort.pt')) + checkpoint_files = glob.glob(os.path.join(checkpoint_dir, "distributed_state*.ort.pt")) aggregated_state_dict_loaded = checkpoint.aggregate_checkpoints(checkpoint_files, pytorch_format=False) # compare the two state dictionaries assert_all_states_close_ort(expected_state_dict, aggregated_state_dict_loaded, reshape_states=True) dist.barrier() - os.remove(os.path.join(checkpoint_dir, 'distributed_state_'+str(world_rank)+'.ort.pt')) + os.remove(os.path.join(checkpoint_dir, "distributed_state_" + str(world_rank) + ".ort.pt")) + @distributed_setup -def test_load_from_distributed_zero_megatron_mixed_precision_into_distributed_megatron_mixed_precision(world_rank, world_size, device, checkpoint_dir): +def test_load_from_distributed_zero_megatron_mixed_precision_into_distributed_megatron_mixed_precision( + world_rank, world_size, device, checkpoint_dir +): opts = { - 'device' : {'id' : device}, - 'mixed_precision': - { - 'enabled': True - }, - 'distributed' : - { - 'world_rank' : world_rank, - 'world_size' : world_size, - 'allreduce_post_accumulation' : True, - 'horizontal_parallel_size' : world_size - }, - 'debug' : {'deterministic_compute': True} - } + "device": {"id": device}, + "mixed_precision": {"enabled": True}, + "distributed": { + "world_rank": world_rank, + "world_size": world_size, + "allreduce_post_accumulation": True, + "horizontal_parallel_size": world_size, + }, + "debug": {"deterministic_compute": True}, + } # extract state dictionaries to compare - state_dict_post_checkpoint, expected_state_dict, model = create_orttrainer_and_load_checkpoint_bart(device, opts, checkpoint_dir) + state_dict_post_checkpoint, expected_state_dict, model = create_orttrainer_and_load_checkpoint_bart( + device, opts, checkpoint_dir + ) # To compare state dictionary between distributed megatron and distributed zero+megatron node trainers: # - Save the state dictionaries for each rank for the megatron run (distributed_state_world_rank.ort.pt) # - On rank 0, manually load each state dictionary and aggregate all of them into a single state dictionary. # - Compare the aggregated state dictionary from the current run against the expected state dictionary from the previous run. # This is needed because of difference in model-parallel config causing different sharding of model and optimizer states - filename = 'distributed_state_' + str(world_rank) + '.ort.pt' + filename = "distributed_state_" + str(world_rank) + ".ort.pt" filepath = os.path.join(checkpoint_dir, filename) save_ort_ckpt(state_dict_post_checkpoint, filepath) dist.barrier() if world_rank == 0: # manually aggregate the states for the current trainer - checkpoint_files = glob.glob(os.path.join(checkpoint_dir, 'distributed_state*.ort.pt')) + checkpoint_files = glob.glob(os.path.join(checkpoint_dir, "distributed_state*.ort.pt")) aggregated_state_dict_loaded = checkpoint.aggregate_checkpoints(checkpoint_files, pytorch_format=False) # compare the two state dictionaries assert_all_states_close_ort(expected_state_dict, aggregated_state_dict_loaded, reshape_states=True) dist.barrier() - os.remove(os.path.join(checkpoint_dir, 'distributed_state_'+str(world_rank)+'.ort.pt')) + os.remove(os.path.join(checkpoint_dir, "distributed_state_" + str(world_rank) + ".ort.pt")) + @distributed_setup -def test_load_from_distributed_zero_megatron_full_precision_into_distributed_megatron_mixed_precision(world_rank, world_size, device, checkpoint_dir): +def test_load_from_distributed_zero_megatron_full_precision_into_distributed_megatron_mixed_precision( + world_rank, world_size, device, checkpoint_dir +): opts = { - 'device' : {'id' : device}, - 'mixed_precision': - { - 'enabled': True - }, - 'distributed' : - { - 'world_rank' : world_rank, - 'world_size' : world_size, - 'allreduce_post_accumulation' : True, - 'horizontal_parallel_size' : world_size - }, - 'debug' : {'deterministic_compute': True} - } + "device": {"id": device}, + "mixed_precision": {"enabled": True}, + "distributed": { + "world_rank": world_rank, + "world_size": world_size, + "allreduce_post_accumulation": True, + "horizontal_parallel_size": world_size, + }, + "debug": {"deterministic_compute": True}, + } # extract state dictionaries to compare - state_dict_post_checkpoint, expected_state_dict, model = create_orttrainer_and_load_checkpoint_bart(device, opts, checkpoint_dir) + state_dict_post_checkpoint, expected_state_dict, model = create_orttrainer_and_load_checkpoint_bart( + device, opts, checkpoint_dir + ) # To compare state dictionary between distributed megatron and distributed zero+megatron node trainers: # - Save the state dictionaries for each rank for the megatron run (distributed_state_world_rank.ort.pt) # - On rank 0, manually load each state dictionary and aggregate all of them into a single state dictionary. # - Compare the aggregated state dictionary from the current run against the expected state dictionary from the previous run. # This is needed because of difference in model-parallel config causing different sharding of model and optimizer states - filename = 'distributed_state_' + str(world_rank) + '.ort.pt' + filename = "distributed_state_" + str(world_rank) + ".ort.pt" filepath = os.path.join(checkpoint_dir, filename) save_ort_ckpt(state_dict_post_checkpoint, filepath) dist.barrier() if world_rank == 0: # manually aggregate the states for the current trainer - checkpoint_files = glob.glob(os.path.join(checkpoint_dir, 'distributed_state*.ort.pt')) + checkpoint_files = glob.glob(os.path.join(checkpoint_dir, "distributed_state*.ort.pt")) aggregated_state_dict_loaded = checkpoint.aggregate_checkpoints(checkpoint_files, pytorch_format=False) # compare the two state dictionaries assert_all_states_close_ort(expected_state_dict, aggregated_state_dict_loaded, reshape_states=True) dist.barrier() - os.remove(os.path.join(checkpoint_dir, 'distributed_state_'+str(world_rank)+'.ort.pt')) + os.remove(os.path.join(checkpoint_dir, "distributed_state_" + str(world_rank) + ".ort.pt")) + ########################################################################################################################################### # LOAD TO ZERO+MEGATRON ########################################################################################################################################### + @distributed_setup -def test_load_from_single_node_full_precision_into_distributed_zero_megatron_full_precision(world_rank, world_size, device, checkpoint_dir): +def test_load_from_single_node_full_precision_into_distributed_zero_megatron_full_precision( + world_rank, world_size, device, checkpoint_dir +): opts = { - 'device' : {'id' : device}, - 'distributed' : - { - 'world_rank' : world_rank, - 'world_size' : world_size, - 'allreduce_post_accumulation' : True, - 'horizontal_parallel_size' : int(world_size/2), - 'deepspeed_zero_optimization': - { - 'stage': 1 - } - }, - 'debug' : {'deterministic_compute': True} - } + "device": {"id": device}, + "distributed": { + "world_rank": world_rank, + "world_size": world_size, + "allreduce_post_accumulation": True, + "horizontal_parallel_size": int(world_size / 2), + "deepspeed_zero_optimization": {"stage": 1}, + }, + "debug": {"deterministic_compute": True}, + } # extract state dictionaries to compare state_dict_post_checkpoint, _, _ = create_orttrainer_and_load_checkpoint_bart(device, opts, checkpoint_dir) state = None - with open(os.path.join(checkpoint_dir, 'state_dict.pkl'), 'rb') as f: + with open(os.path.join(checkpoint_dir, "state_dict.pkl"), "rb") as f: state = pickle.load(f) - state_dict_pre_checkpoint = state['state_dict'] + state_dict_pre_checkpoint = state["state_dict"] # To compare state dictionary from a single node trainer to the state dictionary from a zero+megatron run: # - Save the state dictionaries for each rank for the zero+megatron run (distributed_state_world_rank.ort.pt) # - On rank 0, manually load each state dictionary and aggregate all of them into a single state dictionary. # - Compare the aggregated state dictionary against the state dictionary previously saved from the single node run. - filename = 'distributed_state_' + str(world_rank) + '.ort.pt' + filename = "distributed_state_" + str(world_rank) + ".ort.pt" filepath = os.path.join(checkpoint_dir, filename) save_ort_ckpt(state_dict_post_checkpoint, filepath) dist.barrier() if world_rank == 0: # manually aggregate the states for the current full precision zero trainer - checkpoint_files = glob.glob(os.path.join(checkpoint_dir, 'distributed_state*.ort.pt')) + checkpoint_files = glob.glob(os.path.join(checkpoint_dir, "distributed_state*.ort.pt")) aggregated_state_dict = checkpoint.aggregate_checkpoints(checkpoint_files, pytorch_format=False) # compare the manually aggregated state dictionary with the expected single node state dictionary assert_all_states_close_ort(aggregated_state_dict, state_dict_pre_checkpoint) dist.barrier() - os.remove(os.path.join(checkpoint_dir, 'distributed_state_'+str(world_rank)+'.ort.pt')) + os.remove(os.path.join(checkpoint_dir, "distributed_state_" + str(world_rank) + ".ort.pt")) + @distributed_setup -def test_load_from_single_node_mixed_precision_into_distributed_zero_megatron_full_precision(world_rank, world_size, device, checkpoint_dir): +def test_load_from_single_node_mixed_precision_into_distributed_zero_megatron_full_precision( + world_rank, world_size, device, checkpoint_dir +): opts = { - 'device' : {'id' : device}, - 'distributed' : - { - 'world_rank' : world_rank, - 'world_size' : world_size, - 'allreduce_post_accumulation' : True, - 'horizontal_parallel_size' : int(world_size/2), - 'deepspeed_zero_optimization': - { - 'stage': 1 - } - }, - 'debug' : {'deterministic_compute': True} - } + "device": {"id": device}, + "distributed": { + "world_rank": world_rank, + "world_size": world_size, + "allreduce_post_accumulation": True, + "horizontal_parallel_size": int(world_size / 2), + "deepspeed_zero_optimization": {"stage": 1}, + }, + "debug": {"deterministic_compute": True}, + } # extract state dictionaries to compare state_dict_post_checkpoint, _, _ = create_orttrainer_and_load_checkpoint_bart(device, opts, checkpoint_dir) state = None - with open(os.path.join(checkpoint_dir, 'state_dict.pkl'), 'rb') as f: + with open(os.path.join(checkpoint_dir, "state_dict.pkl"), "rb") as f: state = pickle.load(f) - state_dict_pre_checkpoint = state['state_dict'] + state_dict_pre_checkpoint = state["state_dict"] # To compare state dictionary from a single node trainer to the state dictionary from a zero+megatron run: # - Save the state dictionaries for each rank for the zero+megatron run (distributed_state_world_rank.ort.pt) # - On rank 0, manually load each state dictionary and aggregate all of them into a single state dictionary. # - Compare the aggregated state dictionary against the state dictionary previously saved from the single node run. - filename = 'distributed_state_' + str(world_rank) + '.ort.pt' + filename = "distributed_state_" + str(world_rank) + ".ort.pt" filepath = os.path.join(checkpoint_dir, filename) save_ort_ckpt(state_dict_post_checkpoint, filepath) dist.barrier() if world_rank == 0: # manually aggregate the states for the current full precision zero trainer - checkpoint_files = glob.glob(os.path.join(checkpoint_dir, 'distributed_state*.ort.pt')) + checkpoint_files = glob.glob(os.path.join(checkpoint_dir, "distributed_state*.ort.pt")) aggregated_state_dict = checkpoint.aggregate_checkpoints(checkpoint_files, pytorch_format=False) # compare the manually aggregated state dictionary with the expected single node state dictionary assert_all_states_close_ort(aggregated_state_dict, state_dict_pre_checkpoint) dist.barrier() - os.remove(os.path.join(checkpoint_dir, 'distributed_state_'+str(world_rank)+'.ort.pt')) + os.remove(os.path.join(checkpoint_dir, "distributed_state_" + str(world_rank) + ".ort.pt")) + @distributed_setup -def test_load_from_single_node_mixed_precision_into_distributed_zero_megatron_mixed_precision(world_rank, world_size, device, checkpoint_dir): +def test_load_from_single_node_mixed_precision_into_distributed_zero_megatron_mixed_precision( + world_rank, world_size, device, checkpoint_dir +): opts = { - 'device' : {'id' : device}, - 'mixed_precision': - { - 'enabled': True - }, - 'distributed' : - { - 'world_rank' : world_rank, - 'world_size' : world_size, - 'allreduce_post_accumulation' : True, - 'horizontal_parallel_size' : int(world_size/2), - 'deepspeed_zero_optimization': - { - 'stage': 1 - } - }, - 'debug' : {'deterministic_compute': True} - } + "device": {"id": device}, + "mixed_precision": {"enabled": True}, + "distributed": { + "world_rank": world_rank, + "world_size": world_size, + "allreduce_post_accumulation": True, + "horizontal_parallel_size": int(world_size / 2), + "deepspeed_zero_optimization": {"stage": 1}, + }, + "debug": {"deterministic_compute": True}, + } # extract state dictionaries to compare state_dict_post_checkpoint, _, _ = create_orttrainer_and_load_checkpoint_bart(device, opts, checkpoint_dir) state = None - with open(os.path.join(checkpoint_dir, 'state_dict.pkl'), 'rb') as f: + with open(os.path.join(checkpoint_dir, "state_dict.pkl"), "rb") as f: state = pickle.load(f) - state_dict_pre_checkpoint = state['state_dict'] + state_dict_pre_checkpoint = state["state_dict"] # To compare state dictionary from a single node trainer to the state dictionary from a zero+megatron run: # - Save the state dictionaries for each rank for the zero+megatron run (distributed_state_world_rank.ort.pt) # - On rank 0, manually load each state dictionary and aggregate all of them into a single state dictionary. # - Compare the aggregated state dictionary against the state dictionary previously saved from the single node run. - filename = 'distributed_state_' + str(world_rank) + '.ort.pt' + filename = "distributed_state_" + str(world_rank) + ".ort.pt" filepath = os.path.join(checkpoint_dir, filename) save_ort_ckpt(state_dict_post_checkpoint, filepath) dist.barrier() if world_rank == 0: # manually aggregate the states for the current full precision zero trainer - checkpoint_files = glob.glob(os.path.join(checkpoint_dir, 'distributed_state*.ort.pt')) + checkpoint_files = glob.glob(os.path.join(checkpoint_dir, "distributed_state*.ort.pt")) aggregated_state_dict = checkpoint.aggregate_checkpoints(checkpoint_files, pytorch_format=False) # compare the manually aggregated state dictionary with the expected single node state dictionary assert_all_states_close_ort(aggregated_state_dict, state_dict_pre_checkpoint) dist.barrier() - os.remove(os.path.join(checkpoint_dir, 'distributed_state_'+str(world_rank)+'.ort.pt')) + os.remove(os.path.join(checkpoint_dir, "distributed_state_" + str(world_rank) + ".ort.pt")) + @distributed_setup -def test_load_from_single_node_full_precision_into_distributed_zero_megatron_mixed_precision(world_rank, world_size, device, checkpoint_dir): +def test_load_from_single_node_full_precision_into_distributed_zero_megatron_mixed_precision( + world_rank, world_size, device, checkpoint_dir +): opts = { - 'device' : {'id' : device}, - 'mixed_precision': - { - 'enabled': True - }, - 'distributed' : - { - 'world_rank' : world_rank, - 'world_size' : world_size, - 'allreduce_post_accumulation' : True, - 'horizontal_parallel_size' : int(world_size/2), - 'deepspeed_zero_optimization': - { - 'stage': 1 - } - }, - 'debug' : {'deterministic_compute': True} - } + "device": {"id": device}, + "mixed_precision": {"enabled": True}, + "distributed": { + "world_rank": world_rank, + "world_size": world_size, + "allreduce_post_accumulation": True, + "horizontal_parallel_size": int(world_size / 2), + "deepspeed_zero_optimization": {"stage": 1}, + }, + "debug": {"deterministic_compute": True}, + } # extract state dictionaries to compare state_dict_post_checkpoint, _, _ = create_orttrainer_and_load_checkpoint_bart(device, opts, checkpoint_dir) state = None - with open(os.path.join(checkpoint_dir, 'state_dict.pkl'), 'rb') as f: + with open(os.path.join(checkpoint_dir, "state_dict.pkl"), "rb") as f: state = pickle.load(f) - state_dict_pre_checkpoint = state['state_dict'] + state_dict_pre_checkpoint = state["state_dict"] # To compare state dictionary from a single node trainer to the state dictionary from a zero+megatron run: # - Save the state dictionaries for each rank for the zero+megatron run (distributed_state_world_rank.ort.pt) # - On rank 0, manually load each state dictionary and aggregate all of them into a single state dictionary. # - Compare the aggregated state dictionary against the state dictionary previously saved from the single node run. - filename = 'distributed_state_' + str(world_rank) + '.ort.pt' + filename = "distributed_state_" + str(world_rank) + ".ort.pt" filepath = os.path.join(checkpoint_dir, filename) save_ort_ckpt(state_dict_post_checkpoint, filepath) dist.barrier() if world_rank == 0: # manually aggregate the states for the current full precision zero trainer - checkpoint_files = glob.glob(os.path.join(checkpoint_dir, 'distributed_state*.ort.pt')) + checkpoint_files = glob.glob(os.path.join(checkpoint_dir, "distributed_state*.ort.pt")) aggregated_state_dict = checkpoint.aggregate_checkpoints(checkpoint_files, pytorch_format=False) # compare the manually aggregated state dictionary with the expected single node state dictionary assert_all_states_close_ort(aggregated_state_dict, state_dict_pre_checkpoint) dist.barrier() - os.remove(os.path.join(checkpoint_dir, 'distributed_state_'+str(world_rank)+'.ort.pt')) + os.remove(os.path.join(checkpoint_dir, "distributed_state_" + str(world_rank) + ".ort.pt")) + @distributed_setup -def test_load_from_data_parallelism_full_precision_into_distributed_zero_megatron_full_precision(world_rank, world_size, device, checkpoint_dir): +def test_load_from_data_parallelism_full_precision_into_distributed_zero_megatron_full_precision( + world_rank, world_size, device, checkpoint_dir +): opts = { - 'device' : {'id' : device}, - 'distributed' : - { - 'world_rank' : world_rank, - 'world_size' : world_size, - 'allreduce_post_accumulation' : True, - 'horizontal_parallel_size' : int(world_size/2), - 'deepspeed_zero_optimization': - { - 'stage': 1 - } - }, - 'debug' : {'deterministic_compute': True} - } + "device": {"id": device}, + "distributed": { + "world_rank": world_rank, + "world_size": world_size, + "allreduce_post_accumulation": True, + "horizontal_parallel_size": int(world_size / 2), + "deepspeed_zero_optimization": {"stage": 1}, + }, + "debug": {"deterministic_compute": True}, + } # extract state dictionaries to compare state_dict_post_checkpoint, _, _ = create_orttrainer_and_load_checkpoint_bart(device, opts, checkpoint_dir) state = None - with open(os.path.join(checkpoint_dir, 'state_dict.pkl'), 'rb') as f: + with open(os.path.join(checkpoint_dir, "state_dict.pkl"), "rb") as f: state = pickle.load(f) - state_dict_pre_checkpoint = state['state_dict'] + state_dict_pre_checkpoint = state["state_dict"] # To compare state dictionary from a data parallel trainer to the state dictionary from a zero+megatron run: # - Save the state dictionaries for each rank for the zero+megatron run (distributed_state_world_rank.ort.pt) # - On rank 0, manually load each state dictionary and aggregate all of them into a single state dictionary. # - Compare the aggregated state dictionary against the state dictionary previously saved from the data parallel node run.. - filename = 'distributed_state_' + str(world_rank) + '.ort.pt' + filename = "distributed_state_" + str(world_rank) + ".ort.pt" filepath = os.path.join(checkpoint_dir, filename) save_ort_ckpt(state_dict_post_checkpoint, filepath) dist.barrier() if world_rank == 0: # manually aggregate the states for the current full precision zero trainer - checkpoint_files = glob.glob(os.path.join(checkpoint_dir, 'distributed_state*.ort.pt')) + checkpoint_files = glob.glob(os.path.join(checkpoint_dir, "distributed_state*.ort.pt")) aggregated_state_dict = checkpoint.aggregate_checkpoints(checkpoint_files, pytorch_format=False) # compare the manually aggregated state dictionary with the expected data parallel state dictionary assert_all_states_close_ort(aggregated_state_dict, state_dict_pre_checkpoint) dist.barrier() - os.remove(os.path.join(checkpoint_dir, 'distributed_state_'+str(world_rank)+'.ort.pt')) + os.remove(os.path.join(checkpoint_dir, "distributed_state_" + str(world_rank) + ".ort.pt")) + @distributed_setup -def test_load_from_data_parallelism_mixed_precision_into_distributed_zero_megatron_full_precision(world_rank, world_size, device, checkpoint_dir): +def test_load_from_data_parallelism_mixed_precision_into_distributed_zero_megatron_full_precision( + world_rank, world_size, device, checkpoint_dir +): opts = { - 'device' : {'id' : device}, - 'distributed' : - { - 'world_rank' : world_rank, - 'world_size' : world_size, - 'allreduce_post_accumulation' : True, - 'horizontal_parallel_size' : int(world_size/2), - 'deepspeed_zero_optimization': - { - 'stage': 1 - } - }, - 'debug' : {'deterministic_compute': True} - } + "device": {"id": device}, + "distributed": { + "world_rank": world_rank, + "world_size": world_size, + "allreduce_post_accumulation": True, + "horizontal_parallel_size": int(world_size / 2), + "deepspeed_zero_optimization": {"stage": 1}, + }, + "debug": {"deterministic_compute": True}, + } # extract state dictionaries to compare state_dict_post_checkpoint, _, _ = create_orttrainer_and_load_checkpoint_bart(device, opts, checkpoint_dir) state = None - with open(os.path.join(checkpoint_dir, 'state_dict.pkl'), 'rb') as f: + with open(os.path.join(checkpoint_dir, "state_dict.pkl"), "rb") as f: state = pickle.load(f) - state_dict_pre_checkpoint = state['state_dict'] + state_dict_pre_checkpoint = state["state_dict"] # To compare state dictionary from a data parallel trainer to the state dictionary from a zero+megatron run: # - Save the state dictionaries for each rank for the zero+megatron run (distributed_state_world_rank.ort.pt) # - On rank 0, manually load each state dictionary and aggregate all of them into a single state dictionary. # - Compare the aggregated state dictionary against the state dictionary previously saved from the data parallel node run. - filename = 'distributed_state_' + str(world_rank) + '.ort.pt' + filename = "distributed_state_" + str(world_rank) + ".ort.pt" filepath = os.path.join(checkpoint_dir, filename) save_ort_ckpt(state_dict_post_checkpoint, filepath) dist.barrier() if world_rank == 0: # manually aggregate the states for the current full precision zero trainer - checkpoint_files = glob.glob(os.path.join(checkpoint_dir, 'distributed_state*.ort.pt')) + checkpoint_files = glob.glob(os.path.join(checkpoint_dir, "distributed_state*.ort.pt")) aggregated_state_dict = checkpoint.aggregate_checkpoints(checkpoint_files, pytorch_format=False) # compare the manually aggregated state dictionary with the expected data parallel state dictionary assert_all_states_close_ort(aggregated_state_dict, state_dict_pre_checkpoint) dist.barrier() - os.remove(os.path.join(checkpoint_dir, 'distributed_state_'+str(world_rank)+'.ort.pt')) + os.remove(os.path.join(checkpoint_dir, "distributed_state_" + str(world_rank) + ".ort.pt")) + @distributed_setup -def test_load_from_data_parallelism_mixed_precision_into_distributed_zero_megatron_mixed_precision(world_rank, world_size, device, checkpoint_dir): +def test_load_from_data_parallelism_mixed_precision_into_distributed_zero_megatron_mixed_precision( + world_rank, world_size, device, checkpoint_dir +): opts = { - 'device' : {'id' : device}, - 'mixed_precision': - { - 'enabled': True - }, - 'distributed' : - { - 'world_rank' : world_rank, - 'world_size' : world_size, - 'allreduce_post_accumulation' : True, - 'horizontal_parallel_size' : int(world_size/2), - 'deepspeed_zero_optimization': - { - 'stage': 1 - } - }, - 'debug' : {'deterministic_compute': True} - } + "device": {"id": device}, + "mixed_precision": {"enabled": True}, + "distributed": { + "world_rank": world_rank, + "world_size": world_size, + "allreduce_post_accumulation": True, + "horizontal_parallel_size": int(world_size / 2), + "deepspeed_zero_optimization": {"stage": 1}, + }, + "debug": {"deterministic_compute": True}, + } # extract state dictionaries to compare state_dict_post_checkpoint, _, _ = create_orttrainer_and_load_checkpoint_bart(device, opts, checkpoint_dir) state = None - with open(os.path.join(checkpoint_dir, 'state_dict.pkl'), 'rb') as f: + with open(os.path.join(checkpoint_dir, "state_dict.pkl"), "rb") as f: state = pickle.load(f) - state_dict_pre_checkpoint = state['state_dict'] + state_dict_pre_checkpoint = state["state_dict"] # To compare state dictionary from a data parallel trainer to the state dictionary from a zero+megatron run: # - Save the state dictionaries for each rank for the zero+megatron run (distributed_state_world_rank.ort.pt) # - On rank 0, manually load each state dictionary and aggregate all of them into a single state dictionary. # - Compare the aggregated state dictionary against the state dictionary previously saved from the data parallel run. - filename = 'distributed_state_' + str(world_rank) + '.ort.pt' + filename = "distributed_state_" + str(world_rank) + ".ort.pt" filepath = os.path.join(checkpoint_dir, filename) save_ort_ckpt(state_dict_post_checkpoint, filepath) dist.barrier() if world_rank == 0: # manually aggregate the states for the current full precision zero trainer - checkpoint_files = glob.glob(os.path.join(checkpoint_dir, 'distributed_state*.ort.pt')) + checkpoint_files = glob.glob(os.path.join(checkpoint_dir, "distributed_state*.ort.pt")) aggregated_state_dict = checkpoint.aggregate_checkpoints(checkpoint_files, pytorch_format=False) # compare the manually aggregated state dictionary with the expected data parallel state dictionary assert_all_states_close_ort(aggregated_state_dict, state_dict_pre_checkpoint) dist.barrier() - os.remove(os.path.join(checkpoint_dir, 'distributed_state_'+str(world_rank)+'.ort.pt')) + os.remove(os.path.join(checkpoint_dir, "distributed_state_" + str(world_rank) + ".ort.pt")) + @distributed_setup -def test_load_from_data_parallelism_full_precision_into_distributed_zero_megatron_mixed_precision(world_rank, world_size, device, checkpoint_dir): +def test_load_from_data_parallelism_full_precision_into_distributed_zero_megatron_mixed_precision( + world_rank, world_size, device, checkpoint_dir +): opts = { - 'device' : {'id' : device}, - 'mixed_precision': - { - 'enabled': True - }, - 'distributed' : - { - 'world_rank' : world_rank, - 'world_size' : world_size, - 'allreduce_post_accumulation' : True, - 'horizontal_parallel_size' : int(world_size/2), - 'deepspeed_zero_optimization': - { - 'stage': 1 - } - }, - 'debug' : {'deterministic_compute': True} - } + "device": {"id": device}, + "mixed_precision": {"enabled": True}, + "distributed": { + "world_rank": world_rank, + "world_size": world_size, + "allreduce_post_accumulation": True, + "horizontal_parallel_size": int(world_size / 2), + "deepspeed_zero_optimization": {"stage": 1}, + }, + "debug": {"deterministic_compute": True}, + } # extract state dictionaries to compare state_dict_post_checkpoint, _, _ = create_orttrainer_and_load_checkpoint_bart(device, opts, checkpoint_dir) state = None - with open(os.path.join(checkpoint_dir, 'state_dict.pkl'), 'rb') as f: + with open(os.path.join(checkpoint_dir, "state_dict.pkl"), "rb") as f: state = pickle.load(f) - state_dict_pre_checkpoint = state['state_dict'] + state_dict_pre_checkpoint = state["state_dict"] # To compare state dictionary from a data parallel trainer to the state dictionary from a zero+megatron run: # - Save the state dictionaries for each rank for the zero+megatron run (distributed_state_world_rank.ort.pt) # - On rank 0, manually load each state dictionary and aggregate all of them into a single state dictionary. # - Compare the aggregated state dictionary against the state dictionary previously saved from the data parallel run. - filename = 'distributed_state_' + str(world_rank) + '.ort.pt' + filename = "distributed_state_" + str(world_rank) + ".ort.pt" filepath = os.path.join(checkpoint_dir, filename) save_ort_ckpt(state_dict_post_checkpoint, filepath) dist.barrier() if world_rank == 0: # manually aggregate the states for the current full precision zero trainer - checkpoint_files = glob.glob(os.path.join(checkpoint_dir, 'distributed_state*.ort.pt')) + checkpoint_files = glob.glob(os.path.join(checkpoint_dir, "distributed_state*.ort.pt")) aggregated_state_dict = checkpoint.aggregate_checkpoints(checkpoint_files, pytorch_format=False) # compare the manually aggregated state dictionary with the expected data parallel state dictionary assert_all_states_close_ort(aggregated_state_dict, state_dict_pre_checkpoint) dist.barrier() - os.remove(os.path.join(checkpoint_dir, 'distributed_state_'+str(world_rank)+'.ort.pt')) + os.remove(os.path.join(checkpoint_dir, "distributed_state_" + str(world_rank) + ".ort.pt")) + @distributed_setup -def test_load_from_distributed_zero_full_precision_into_distributed_zero_megatron_full_precision(world_rank, world_size, device, checkpoint_dir): +def test_load_from_distributed_zero_full_precision_into_distributed_zero_megatron_full_precision( + world_rank, world_size, device, checkpoint_dir +): opts = { - 'device' : {'id' : device}, - 'distributed' : - { - 'world_rank' : world_rank, - 'world_size' : world_size, - 'allreduce_post_accumulation' : True, - 'horizontal_parallel_size' : int(world_size/2), - 'deepspeed_zero_optimization': - { - 'stage': 1 - } - }, - 'debug' : {'deterministic_compute': True} - } + "device": {"id": device}, + "distributed": { + "world_rank": world_rank, + "world_size": world_size, + "allreduce_post_accumulation": True, + "horizontal_parallel_size": int(world_size / 2), + "deepspeed_zero_optimization": {"stage": 1}, + }, + "debug": {"deterministic_compute": True}, + } # extract state dictionaries to compare - state_dict_post_checkpoint, expected_state_dict, model = create_orttrainer_and_load_checkpoint_bart(device, opts, checkpoint_dir) + state_dict_post_checkpoint, expected_state_dict, model = create_orttrainer_and_load_checkpoint_bart( + device, opts, checkpoint_dir + ) # To compare state dictionary between distributed zero+megatron and distributed zero node trainers: # - Save the state dictionaries for each rank for the zero+megatron run (distributed_state_world_rank.ort.pt) # - On rank 0, manually load each state dictionary and aggregate all of them into a single state dictionary. # - Compare the aggregated state dictionary from the current run against the expected state dictionary from the previous run. # This is needed because of difference in model-parallel config causing different sharding of model and optimizer states - filename = 'distributed_state_' + str(world_rank) + '.ort.pt' + filename = "distributed_state_" + str(world_rank) + ".ort.pt" filepath = os.path.join(checkpoint_dir, filename) save_ort_ckpt(state_dict_post_checkpoint, filepath) dist.barrier() if world_rank == 0: # manually aggregate the states for the current full precision zero trainer - checkpoint_files = glob.glob(os.path.join(checkpoint_dir, 'distributed_state*.ort.pt')) + checkpoint_files = glob.glob(os.path.join(checkpoint_dir, "distributed_state*.ort.pt")) aggregated_state_dict_loaded = checkpoint.aggregate_checkpoints(checkpoint_files, pytorch_format=False) # compare the two state dictionaries assert_all_states_close_ort(expected_state_dict, aggregated_state_dict_loaded, reshape_states=True) dist.barrier() - os.remove(os.path.join(checkpoint_dir, 'distributed_state_'+str(world_rank)+'.ort.pt')) + os.remove(os.path.join(checkpoint_dir, "distributed_state_" + str(world_rank) + ".ort.pt")) + @distributed_setup -def test_load_from_distributed_zero_mixed_precision_into_distributed_zero_megatron_full_precision(world_rank, world_size, device, checkpoint_dir): +def test_load_from_distributed_zero_mixed_precision_into_distributed_zero_megatron_full_precision( + world_rank, world_size, device, checkpoint_dir +): opts = { - 'device' : {'id' : device}, - 'distributed' : - { - 'world_rank' : world_rank, - 'world_size' : world_size, - 'allreduce_post_accumulation' : True, - 'horizontal_parallel_size' : int(world_size/2), - 'deepspeed_zero_optimization': - { - 'stage': 1 - } - }, - 'debug' : {'deterministic_compute': True} - } + "device": {"id": device}, + "distributed": { + "world_rank": world_rank, + "world_size": world_size, + "allreduce_post_accumulation": True, + "horizontal_parallel_size": int(world_size / 2), + "deepspeed_zero_optimization": {"stage": 1}, + }, + "debug": {"deterministic_compute": True}, + } # extract state dictionaries to compare - state_dict_post_checkpoint, expected_state_dict, model = create_orttrainer_and_load_checkpoint_bart(device, opts, checkpoint_dir) + state_dict_post_checkpoint, expected_state_dict, model = create_orttrainer_and_load_checkpoint_bart( + device, opts, checkpoint_dir + ) # To compare state dictionary between distributed zero+megatron and distributed zero node trainers: # - Save the state dictionaries for each rank for the zero+megatron run (distributed_state_world_rank.ort.pt) # - On rank 0, manually load each state dictionary and aggregate all of them into a single state dictionary. # - Compare the aggregated state dictionary from the current run against the expected state dictionary from the previous run. # This is needed because of difference in model-parallel config causing different sharding of model and optimizer states - filename = 'distributed_state_' + str(world_rank) + '.ort.pt' + filename = "distributed_state_" + str(world_rank) + ".ort.pt" filepath = os.path.join(checkpoint_dir, filename) save_ort_ckpt(state_dict_post_checkpoint, filepath) dist.barrier() if world_rank == 0: # manually aggregate the states for the current full precision zero trainer - checkpoint_files = glob.glob(os.path.join(checkpoint_dir, 'distributed_state*.ort.pt')) + checkpoint_files = glob.glob(os.path.join(checkpoint_dir, "distributed_state*.ort.pt")) aggregated_state_dict_loaded = checkpoint.aggregate_checkpoints(checkpoint_files, pytorch_format=False) # compare the two state dictionaries assert_all_states_close_ort(expected_state_dict, aggregated_state_dict_loaded, reshape_states=True) dist.barrier() - os.remove(os.path.join(checkpoint_dir, 'distributed_state_'+str(world_rank)+'.ort.pt')) + os.remove(os.path.join(checkpoint_dir, "distributed_state_" + str(world_rank) + ".ort.pt")) + @distributed_setup -def test_load_from_distributed_zero_mixed_precision_into_distributed_zero_megatron_mixed_precision(world_rank, world_size, device, checkpoint_dir): +def test_load_from_distributed_zero_mixed_precision_into_distributed_zero_megatron_mixed_precision( + world_rank, world_size, device, checkpoint_dir +): opts = { - 'device' : {'id' : device}, - 'mixed_precision': - { - 'enabled': True - }, - 'distributed' : - { - 'world_rank' : world_rank, - 'world_size' : world_size, - 'allreduce_post_accumulation' : True, - 'horizontal_parallel_size' : int(world_size/2), - 'deepspeed_zero_optimization': - { - 'stage': 1 - } - }, - 'debug' : {'deterministic_compute': True} - } + "device": {"id": device}, + "mixed_precision": {"enabled": True}, + "distributed": { + "world_rank": world_rank, + "world_size": world_size, + "allreduce_post_accumulation": True, + "horizontal_parallel_size": int(world_size / 2), + "deepspeed_zero_optimization": {"stage": 1}, + }, + "debug": {"deterministic_compute": True}, + } # extract state dictionaries to compare - state_dict_post_checkpoint, expected_state_dict, model = create_orttrainer_and_load_checkpoint_bart(device, opts, checkpoint_dir) + state_dict_post_checkpoint, expected_state_dict, model = create_orttrainer_and_load_checkpoint_bart( + device, opts, checkpoint_dir + ) # To compare state dictionary between distributed zero+megatron and distributed zero node trainers: # - Save the state dictionaries for each rank for the zero+megatron run (distributed_state_world_rank.ort.pt) # - On rank 0, manually load each state dictionary and aggregate all of them into a single state dictionary. # - Compare the aggregated state dictionary from the current run against the expected state dictionary from the previous run. # This is needed because of difference in model-parallel config causing different sharding of model and optimizer states - filename = 'distributed_state_' + str(world_rank) + '.ort.pt' + filename = "distributed_state_" + str(world_rank) + ".ort.pt" filepath = os.path.join(checkpoint_dir, filename) save_ort_ckpt(state_dict_post_checkpoint, filepath) dist.barrier() if world_rank == 0: # manually aggregate the states for the current full precision zero trainer - checkpoint_files = glob.glob(os.path.join(checkpoint_dir, 'distributed_state*.ort.pt')) + checkpoint_files = glob.glob(os.path.join(checkpoint_dir, "distributed_state*.ort.pt")) aggregated_state_dict_loaded = checkpoint.aggregate_checkpoints(checkpoint_files, pytorch_format=False) # compare the two state dictionaries assert_all_states_close_ort(expected_state_dict, aggregated_state_dict_loaded, reshape_states=True) dist.barrier() - os.remove(os.path.join(checkpoint_dir, 'distributed_state_'+str(world_rank)+'.ort.pt')) + os.remove(os.path.join(checkpoint_dir, "distributed_state_" + str(world_rank) + ".ort.pt")) + @distributed_setup -def test_load_from_distributed_zero_full_precision_into_distributed_zero_megatron_mixed_precision(world_rank, world_size, device, checkpoint_dir): +def test_load_from_distributed_zero_full_precision_into_distributed_zero_megatron_mixed_precision( + world_rank, world_size, device, checkpoint_dir +): opts = { - 'device' : {'id' : device}, - 'mixed_precision': - { - 'enabled': True - }, - 'distributed' : - { - 'world_rank' : world_rank, - 'world_size' : world_size, - 'allreduce_post_accumulation' : True, - 'horizontal_parallel_size' : int(world_size/2), - 'deepspeed_zero_optimization': - { - 'stage': 1 - } - }, - 'debug' : {'deterministic_compute': True} - } + "device": {"id": device}, + "mixed_precision": {"enabled": True}, + "distributed": { + "world_rank": world_rank, + "world_size": world_size, + "allreduce_post_accumulation": True, + "horizontal_parallel_size": int(world_size / 2), + "deepspeed_zero_optimization": {"stage": 1}, + }, + "debug": {"deterministic_compute": True}, + } # extract state dictionaries to compare - state_dict_post_checkpoint, expected_state_dict, model = create_orttrainer_and_load_checkpoint_bart(device, opts, checkpoint_dir) + state_dict_post_checkpoint, expected_state_dict, model = create_orttrainer_and_load_checkpoint_bart( + device, opts, checkpoint_dir + ) # To compare state dictionary between distributed zero+megatron and distributed zero node trainers: # - Save the state dictionaries for each rank for the zero+megatron run (distributed_state_world_rank.ort.pt) # - On rank 0, manually load each state dictionary and aggregate all of them into a single state dictionary. # - Compare the aggregated state dictionary from the current run against the expected state dictionary from the previous run. # This is needed because of difference in model-parallel config causing different sharding of model and optimizer states - filename = 'distributed_state_' + str(world_rank) + '.ort.pt' + filename = "distributed_state_" + str(world_rank) + ".ort.pt" filepath = os.path.join(checkpoint_dir, filename) save_ort_ckpt(state_dict_post_checkpoint, filepath) dist.barrier() if world_rank == 0: # manually aggregate the states for the current full precision zero trainer - checkpoint_files = glob.glob(os.path.join(checkpoint_dir, 'distributed_state*.ort.pt')) + checkpoint_files = glob.glob(os.path.join(checkpoint_dir, "distributed_state*.ort.pt")) aggregated_state_dict_loaded = checkpoint.aggregate_checkpoints(checkpoint_files, pytorch_format=False) # compare the two state dictionaries assert_all_states_close_ort(expected_state_dict, aggregated_state_dict_loaded, reshape_states=True) dist.barrier() - os.remove(os.path.join(checkpoint_dir, 'distributed_state_'+str(world_rank)+'.ort.pt')) + os.remove(os.path.join(checkpoint_dir, "distributed_state_" + str(world_rank) + ".ort.pt")) + @distributed_setup -def test_load_from_distributed_megatron_full_precision_into_distributed_zero_megatron_full_precision(world_rank, world_size, device, checkpoint_dir): +def test_load_from_distributed_megatron_full_precision_into_distributed_zero_megatron_full_precision( + world_rank, world_size, device, checkpoint_dir +): opts = { - 'device' : {'id' : device}, - 'distributed' : - { - 'world_rank' : world_rank, - 'world_size' : world_size, - 'allreduce_post_accumulation' : True, - 'horizontal_parallel_size' : int(world_size/2), - 'deepspeed_zero_optimization': - { - 'stage': 1 - } - }, - 'debug' : {'deterministic_compute': True} - } + "device": {"id": device}, + "distributed": { + "world_rank": world_rank, + "world_size": world_size, + "allreduce_post_accumulation": True, + "horizontal_parallel_size": int(world_size / 2), + "deepspeed_zero_optimization": {"stage": 1}, + }, + "debug": {"deterministic_compute": True}, + } # extract state dictionaries to compare - state_dict_post_checkpoint, expected_state_dict, model = create_orttrainer_and_load_checkpoint_bart(device, opts, checkpoint_dir) + state_dict_post_checkpoint, expected_state_dict, model = create_orttrainer_and_load_checkpoint_bart( + device, opts, checkpoint_dir + ) # To compare state dictionary between distributed zero+megatron and distributed megatron node trainers: # - Save the state dictionaries for each rank for the zero+megatron run (distributed_state_world_rank.ort.pt) # - On rank 0, manually load each state dictionary and aggregate all of them into a single state dictionary. # - Compare the aggregated state dictionary from the current run against the expected state dictionary from the previous run. # This is needed because of difference in model-parallel config causing different sharding of model and optimizer states - filename = 'distributed_state_' + str(world_rank) + '.ort.pt' + filename = "distributed_state_" + str(world_rank) + ".ort.pt" filepath = os.path.join(checkpoint_dir, filename) save_ort_ckpt(state_dict_post_checkpoint, filepath) dist.barrier() if world_rank == 0: # manually aggregate the states for the current trainer - checkpoint_files = glob.glob(os.path.join(checkpoint_dir, 'distributed_state*.ort.pt')) + checkpoint_files = glob.glob(os.path.join(checkpoint_dir, "distributed_state*.ort.pt")) aggregated_state_dict_loaded = checkpoint.aggregate_checkpoints(checkpoint_files, pytorch_format=False) # compare the two state dictionaries assert_all_states_close_ort(expected_state_dict, aggregated_state_dict_loaded, reshape_states=True) dist.barrier() - os.remove(os.path.join(checkpoint_dir, 'distributed_state_'+str(world_rank)+'.ort.pt')) + os.remove(os.path.join(checkpoint_dir, "distributed_state_" + str(world_rank) + ".ort.pt")) + @distributed_setup -def test_load_from_distributed_megatron_mixed_precision_into_distributed_zero_megatron_full_precision(world_rank, world_size, device, checkpoint_dir): +def test_load_from_distributed_megatron_mixed_precision_into_distributed_zero_megatron_full_precision( + world_rank, world_size, device, checkpoint_dir +): opts = { - 'device' : {'id' : device}, - 'distributed' : - { - 'world_rank' : world_rank, - 'world_size' : world_size, - 'allreduce_post_accumulation' : True, - 'horizontal_parallel_size' : int(world_size/2), - 'deepspeed_zero_optimization': - { - 'stage': 1 - } - }, - 'debug' : {'deterministic_compute': True} - } + "device": {"id": device}, + "distributed": { + "world_rank": world_rank, + "world_size": world_size, + "allreduce_post_accumulation": True, + "horizontal_parallel_size": int(world_size / 2), + "deepspeed_zero_optimization": {"stage": 1}, + }, + "debug": {"deterministic_compute": True}, + } # extract state dictionaries to compare - state_dict_post_checkpoint, expected_state_dict, model = create_orttrainer_and_load_checkpoint_bart(device, opts, checkpoint_dir) + state_dict_post_checkpoint, expected_state_dict, model = create_orttrainer_and_load_checkpoint_bart( + device, opts, checkpoint_dir + ) # To compare state dictionary between distributed zero+megatron and distributed megatron node trainers: # - Save the state dictionaries for each rank for the zero+megatron run (distributed_state_world_rank.ort.pt) # - On rank 0, manually load each state dictionary and aggregate all of them into a single state dictionary. # - Compare the aggregated state dictionary from the current run against the expected state dictionary from the previous run. # This is needed because of difference in model-parallel config causing different sharding of model and optimizer states - filename = 'distributed_state_' + str(world_rank) + '.ort.pt' + filename = "distributed_state_" + str(world_rank) + ".ort.pt" filepath = os.path.join(checkpoint_dir, filename) save_ort_ckpt(state_dict_post_checkpoint, filepath) dist.barrier() if world_rank == 0: # manually aggregate the states for the current trainer - checkpoint_files = glob.glob(os.path.join(checkpoint_dir, 'distributed_state*.ort.pt')) + checkpoint_files = glob.glob(os.path.join(checkpoint_dir, "distributed_state*.ort.pt")) aggregated_state_dict_loaded = checkpoint.aggregate_checkpoints(checkpoint_files, pytorch_format=False) # compare the two state dictionaries assert_all_states_close_ort(expected_state_dict, aggregated_state_dict_loaded, reshape_states=True) dist.barrier() - os.remove(os.path.join(checkpoint_dir, 'distributed_state_'+str(world_rank)+'.ort.pt')) + os.remove(os.path.join(checkpoint_dir, "distributed_state_" + str(world_rank) + ".ort.pt")) + @distributed_setup -def test_load_from_distributed_megatron_mixed_precision_into_distributed_zero_megatron_mixed_precision(world_rank, world_size, device, checkpoint_dir): +def test_load_from_distributed_megatron_mixed_precision_into_distributed_zero_megatron_mixed_precision( + world_rank, world_size, device, checkpoint_dir +): opts = { - 'device' : {'id' : device}, - 'mixed_precision': - { - 'enabled': True - }, - 'distributed' : - { - 'world_rank' : world_rank, - 'world_size' : world_size, - 'allreduce_post_accumulation' : True, - 'horizontal_parallel_size' : int(world_size/2), - 'deepspeed_zero_optimization': - { - 'stage': 1 - } - }, - 'debug' : {'deterministic_compute': True} - } + "device": {"id": device}, + "mixed_precision": {"enabled": True}, + "distributed": { + "world_rank": world_rank, + "world_size": world_size, + "allreduce_post_accumulation": True, + "horizontal_parallel_size": int(world_size / 2), + "deepspeed_zero_optimization": {"stage": 1}, + }, + "debug": {"deterministic_compute": True}, + } # extract state dictionaries to compare - state_dict_post_checkpoint, expected_state_dict, model = create_orttrainer_and_load_checkpoint_bart(device, opts, checkpoint_dir) + state_dict_post_checkpoint, expected_state_dict, model = create_orttrainer_and_load_checkpoint_bart( + device, opts, checkpoint_dir + ) # To compare state dictionary between distributed zero+megatron and distributed megatron node trainers: # - Save the state dictionaries for each rank for the zero+megatron run (distributed_state_world_rank.ort.pt) # - On rank 0, manually load each state dictionary and aggregate all of them into a single state dictionary. # - Compare the aggregated state dictionary from the current run against the expected state dictionary from the previous run. # This is needed because of difference in model-parallel config causing different sharding of model and optimizer states - filename = 'distributed_state_' + str(world_rank) + '.ort.pt' + filename = "distributed_state_" + str(world_rank) + ".ort.pt" filepath = os.path.join(checkpoint_dir, filename) save_ort_ckpt(state_dict_post_checkpoint, filepath) dist.barrier() if world_rank == 0: # manually aggregate the states for the current trainer - checkpoint_files = glob.glob(os.path.join(checkpoint_dir, 'distributed_state*.ort.pt')) + checkpoint_files = glob.glob(os.path.join(checkpoint_dir, "distributed_state*.ort.pt")) aggregated_state_dict_loaded = checkpoint.aggregate_checkpoints(checkpoint_files, pytorch_format=False) # compare the two state dictionaries assert_all_states_close_ort(expected_state_dict, aggregated_state_dict_loaded, reshape_states=True) dist.barrier() - os.remove(os.path.join(checkpoint_dir, 'distributed_state_'+str(world_rank)+'.ort.pt')) + os.remove(os.path.join(checkpoint_dir, "distributed_state_" + str(world_rank) + ".ort.pt")) + @distributed_setup -def test_load_from_distributed_megatron_full_precision_into_distributed_zero_megatron_mixed_precision(world_rank, world_size, device, checkpoint_dir): +def test_load_from_distributed_megatron_full_precision_into_distributed_zero_megatron_mixed_precision( + world_rank, world_size, device, checkpoint_dir +): opts = { - 'device' : {'id' : device}, - 'mixed_precision': - { - 'enabled': True - }, - 'distributed' : - { - 'world_rank' : world_rank, - 'world_size' : world_size, - 'allreduce_post_accumulation' : True, - 'horizontal_parallel_size' : int(world_size/2), - 'deepspeed_zero_optimization': - { - 'stage': 1 - } - }, - 'debug' : {'deterministic_compute': True} - } + "device": {"id": device}, + "mixed_precision": {"enabled": True}, + "distributed": { + "world_rank": world_rank, + "world_size": world_size, + "allreduce_post_accumulation": True, + "horizontal_parallel_size": int(world_size / 2), + "deepspeed_zero_optimization": {"stage": 1}, + }, + "debug": {"deterministic_compute": True}, + } # extract state dictionaries to compare - state_dict_post_checkpoint, expected_state_dict, model = create_orttrainer_and_load_checkpoint_bart(device, opts, checkpoint_dir) + state_dict_post_checkpoint, expected_state_dict, model = create_orttrainer_and_load_checkpoint_bart( + device, opts, checkpoint_dir + ) # To compare state dictionary between distributed zero+megatron and distributed megatron node trainers: # - Save the state dictionaries for each rank for the zero+megatron run (distributed_state_world_rank.ort.pt) # - On rank 0, manually load each state dictionary and aggregate all of them into a single state dictionary. # - Compare the aggregated state dictionary from the current run against the expected state dictionary from the previous run. # This is needed because of difference in model-parallel config causing different sharding of model and optimizer states - filename = 'distributed_state_' + str(world_rank) + '.ort.pt' + filename = "distributed_state_" + str(world_rank) + ".ort.pt" filepath = os.path.join(checkpoint_dir, filename) save_ort_ckpt(state_dict_post_checkpoint, filepath) dist.barrier() if world_rank == 0: # manually aggregate the states for the current trainer - checkpoint_files = glob.glob(os.path.join(checkpoint_dir, 'distributed_state*.ort.pt')) + checkpoint_files = glob.glob(os.path.join(checkpoint_dir, "distributed_state*.ort.pt")) aggregated_state_dict_loaded = checkpoint.aggregate_checkpoints(checkpoint_files, pytorch_format=False) # compare the two state dictionaries assert_all_states_close_ort(expected_state_dict, aggregated_state_dict_loaded, reshape_states=True) dist.barrier() - os.remove(os.path.join(checkpoint_dir, 'distributed_state_'+str(world_rank)+'.ort.pt')) + os.remove(os.path.join(checkpoint_dir, "distributed_state_" + str(world_rank) + ".ort.pt")) + @distributed_setup -def test_load_from_distributed_zero_megatron_full_precision_into_distributed_zero_megatron_full_precision(world_rank, world_size, device, checkpoint_dir): +def test_load_from_distributed_zero_megatron_full_precision_into_distributed_zero_megatron_full_precision( + world_rank, world_size, device, checkpoint_dir +): opts = { - 'device' : {'id' : device}, - 'distributed' : - { - 'world_rank' : world_rank, - 'world_size' : world_size, - 'allreduce_post_accumulation' : True, - 'horizontal_parallel_size' : int(world_size/2), - 'deepspeed_zero_optimization': - { - 'stage': 1 - } - }, - 'debug' : {'deterministic_compute': True} - } + "device": {"id": device}, + "distributed": { + "world_rank": world_rank, + "world_size": world_size, + "allreduce_post_accumulation": True, + "horizontal_parallel_size": int(world_size / 2), + "deepspeed_zero_optimization": {"stage": 1}, + }, + "debug": {"deterministic_compute": True}, + } # extract state dictionaries to compare - state_dict_post_checkpoint, _ , _ = create_orttrainer_and_load_checkpoint_bart(device, opts, checkpoint_dir) + state_dict_post_checkpoint, _, _ = create_orttrainer_and_load_checkpoint_bart(device, opts, checkpoint_dir) state = None - with open(os.path.join(checkpoint_dir, 'state_dict_'+str(world_rank)+'.pkl'), 'rb') as f: + with open(os.path.join(checkpoint_dir, "state_dict_" + str(world_rank) + ".pkl"), "rb") as f: state = pickle.load(f) - state_dict_pre_checkpoint = state['state_dict_'+str(world_rank)] + state_dict_pre_checkpoint = state["state_dict_" + str(world_rank)] # compare all states for each rank independently assert_all_states_close_ort(state_dict_pre_checkpoint, state_dict_post_checkpoint) + @distributed_setup -def test_load_from_distributed_zero_megatron_mixed_precision_into_distributed_zero_megatron_full_precision(world_rank, world_size, device, checkpoint_dir): +def test_load_from_distributed_zero_megatron_mixed_precision_into_distributed_zero_megatron_full_precision( + world_rank, world_size, device, checkpoint_dir +): opts = { - 'device' : {'id' : device}, - 'distributed' : - { - 'world_rank' : world_rank, - 'world_size' : world_size, - 'allreduce_post_accumulation' : True, - 'horizontal_parallel_size' : int(world_size/2), - 'deepspeed_zero_optimization': - { - 'stage': 1 - } - }, - 'debug' : {'deterministic_compute': True} - } + "device": {"id": device}, + "distributed": { + "world_rank": world_rank, + "world_size": world_size, + "allreduce_post_accumulation": True, + "horizontal_parallel_size": int(world_size / 2), + "deepspeed_zero_optimization": {"stage": 1}, + }, + "debug": {"deterministic_compute": True}, + } # extract state dictionaries to compare - state_dict_post_checkpoint, expected_state_dict, model = create_orttrainer_and_load_checkpoint_bart(device, opts, checkpoint_dir) + state_dict_post_checkpoint, expected_state_dict, model = create_orttrainer_and_load_checkpoint_bart( + device, opts, checkpoint_dir + ) # To compare state dictionary between distributed zero+megatron and distributed zero+megatron node trainers with different precisions: # - Save the state dictionaries for each rank for the zero+megatron run (distributed_state_world_rank.ort.pt) # - On rank 0, manually load each state dictionary and aggregate all of them into a single state dictionary. # - Compare the aggregated state dictionary from the current run against the expected state dictionary from the previous run. # This is needed because of difference in model-parallel config causing different sharding of model and optimizer states - filename = 'distributed_state_' + str(world_rank) + '.ort.pt' + filename = "distributed_state_" + str(world_rank) + ".ort.pt" filepath = os.path.join(checkpoint_dir, filename) save_ort_ckpt(state_dict_post_checkpoint, filepath) dist.barrier() if world_rank == 0: # manually aggregate the states for the current trainer - checkpoint_files = glob.glob(os.path.join(checkpoint_dir, 'distributed_state*.ort.pt')) + checkpoint_files = glob.glob(os.path.join(checkpoint_dir, "distributed_state*.ort.pt")) aggregated_state_dict_loaded = checkpoint.aggregate_checkpoints(checkpoint_files, pytorch_format=False) # compare the two state dictionaries assert_all_states_close_ort(expected_state_dict, aggregated_state_dict_loaded, reshape_states=True) dist.barrier() - os.remove(os.path.join(checkpoint_dir, 'distributed_state_'+str(world_rank)+'.ort.pt')) + os.remove(os.path.join(checkpoint_dir, "distributed_state_" + str(world_rank) + ".ort.pt")) + @distributed_setup -def test_load_from_distributed_zero_megatron_mixed_precision_into_distributed_zero_megatron_mixed_precision(world_rank, world_size, device, checkpoint_dir): +def test_load_from_distributed_zero_megatron_mixed_precision_into_distributed_zero_megatron_mixed_precision( + world_rank, world_size, device, checkpoint_dir +): opts = { - 'device' : {'id' : device}, - 'mixed_precision': - { - 'enabled': True - }, - 'distributed' : - { - 'world_rank' : world_rank, - 'world_size' : world_size, - 'allreduce_post_accumulation' : True, - 'horizontal_parallel_size' : int(world_size/2), - 'deepspeed_zero_optimization': - { - 'stage': 1 - } - }, - 'debug' : {'deterministic_compute': True} - } + "device": {"id": device}, + "mixed_precision": {"enabled": True}, + "distributed": { + "world_rank": world_rank, + "world_size": world_size, + "allreduce_post_accumulation": True, + "horizontal_parallel_size": int(world_size / 2), + "deepspeed_zero_optimization": {"stage": 1}, + }, + "debug": {"deterministic_compute": True}, + } # extract state dictionaries to compare - state_dict_post_checkpoint, _ , _ = create_orttrainer_and_load_checkpoint_bart(device, opts, checkpoint_dir) + state_dict_post_checkpoint, _, _ = create_orttrainer_and_load_checkpoint_bart(device, opts, checkpoint_dir) state = None - with open(os.path.join(checkpoint_dir, 'state_dict_'+str(world_rank)+'.pkl'), 'rb') as f: + with open(os.path.join(checkpoint_dir, "state_dict_" + str(world_rank) + ".pkl"), "rb") as f: state = pickle.load(f) - state_dict_pre_checkpoint = state['state_dict_'+str(world_rank)] + state_dict_pre_checkpoint = state["state_dict_" + str(world_rank)] # compare all states for each rank independently assert_all_states_close_ort(state_dict_pre_checkpoint, state_dict_post_checkpoint) + @distributed_setup -def test_load_from_distributed_zero_megatron_full_precision_into_distributed_zero_megatron_mixed_precision(world_rank, world_size, device, checkpoint_dir): +def test_load_from_distributed_zero_megatron_full_precision_into_distributed_zero_megatron_mixed_precision( + world_rank, world_size, device, checkpoint_dir +): opts = { - 'device' : {'id' : device}, - 'mixed_precision': - { - 'enabled': True - }, - 'distributed' : - { - 'world_rank' : world_rank, - 'world_size' : world_size, - 'allreduce_post_accumulation' : True, - 'horizontal_parallel_size' : int(world_size/2), - 'deepspeed_zero_optimization': - { - 'stage': 1 - } - }, - 'debug' : {'deterministic_compute': True} - } + "device": {"id": device}, + "mixed_precision": {"enabled": True}, + "distributed": { + "world_rank": world_rank, + "world_size": world_size, + "allreduce_post_accumulation": True, + "horizontal_parallel_size": int(world_size / 2), + "deepspeed_zero_optimization": {"stage": 1}, + }, + "debug": {"deterministic_compute": True}, + } # extract state dictionaries to compare - state_dict_post_checkpoint, expected_state_dict, model = create_orttrainer_and_load_checkpoint_bart(device, opts, checkpoint_dir) + state_dict_post_checkpoint, expected_state_dict, model = create_orttrainer_and_load_checkpoint_bart( + device, opts, checkpoint_dir + ) # To compare state dictionary between distributed zero+megatron and distributed zero+megatron node trainers with different precisions: # - Save the state dictionaries for each rank for the zero+megatron run (distributed_state_world_rank.ort.pt) # - On rank 0, manually load each state dictionary and aggregate all of them into a single state dictionary. # - Compare the aggregated state dictionary from the current run against the expected state dictionary from the previous run. # This is needed because of difference in model-parallel config causing different sharding of model and optimizer states - filename = 'distributed_state_' + str(world_rank) + '.ort.pt' + filename = "distributed_state_" + str(world_rank) + ".ort.pt" filepath = os.path.join(checkpoint_dir, filename) save_ort_ckpt(state_dict_post_checkpoint, filepath) dist.barrier() if world_rank == 0: # manually aggregate the states for the current trainer - checkpoint_files = glob.glob(os.path.join(checkpoint_dir, 'distributed_state*.ort.pt')) + checkpoint_files = glob.glob(os.path.join(checkpoint_dir, "distributed_state*.ort.pt")) aggregated_state_dict_loaded = checkpoint.aggregate_checkpoints(checkpoint_files, pytorch_format=False) # compare the two state dictionaries assert_all_states_close_ort(expected_state_dict, aggregated_state_dict_loaded, reshape_states=True) dist.barrier() - os.remove(os.path.join(checkpoint_dir, 'distributed_state_'+str(world_rank)+'.ort.pt')) + os.remove(os.path.join(checkpoint_dir, "distributed_state_" + str(world_rank) + ".ort.pt")) + function_map = { # all config to single node config - 'test_load_from_single_node_full_precision_into_single_node_full_precision': test_load_from_single_node_full_precision_into_single_node_full_precision, - 'test_load_from_single_node_mixed_precision_into_single_node_mixed_precision': test_load_from_single_node_mixed_precision_into_single_node_mixed_precision, - 'test_load_from_single_node_mixed_precision_into_single_node_full_precision': test_load_from_single_node_mixed_precision_into_single_node_full_precision, - 'test_load_from_single_node_full_precision_into_single_node_mixed_precision': test_load_from_single_node_full_precision_into_single_node_mixed_precision, - 'test_load_from_data_parallelism_full_precision_into_single_node_full_precision': test_load_from_data_parallelism_full_precision_into_single_node_full_precision, - 'test_load_from_data_parallelism_mixed_precision_into_single_node_full_precision': test_load_from_data_parallelism_mixed_precision_into_single_node_full_precision, - 'test_load_from_data_parallelism_mixed_precision_into_single_node_mixed_precision': test_load_from_data_parallelism_mixed_precision_into_single_node_mixed_precision, - 'test_load_from_data_parallelism_full_precision_into_single_node_mixed_precision': test_load_from_data_parallelism_full_precision_into_single_node_mixed_precision, - 'test_load_from_distributed_zero_full_precision_into_single_node_full_precision': test_load_from_distributed_zero_full_precision_into_single_node_full_precision, - 'test_load_from_distributed_zero_mixed_precision_into_single_node_full_precision': test_load_from_distributed_zero_mixed_precision_into_single_node_full_precision, - 'test_load_from_distributed_zero_mixed_precision_into_single_node_mixed_precision': test_load_from_distributed_zero_mixed_precision_into_single_node_mixed_precision, - 'test_load_from_distributed_zero_full_precision_into_single_node_mixed_precision': test_load_from_distributed_zero_full_precision_into_single_node_mixed_precision, - 'test_load_from_distributed_megatron_full_precision_into_single_node_full_precision': test_load_from_distributed_megatron_full_precision_into_single_node_full_precision, - 'test_load_from_distributed_megatron_mixed_precision_into_single_node_full_precision': test_load_from_distributed_megatron_mixed_precision_into_single_node_full_precision, - 'test_load_from_distributed_megatron_mixed_precision_into_single_node_mixed_precision': test_load_from_distributed_megatron_mixed_precision_into_single_node_mixed_precision, - 'test_load_from_distributed_megatron_full_precision_into_single_node_mixed_precision': test_load_from_distributed_megatron_full_precision_into_single_node_mixed_precision, - 'test_load_from_distributed_zero_megatron_full_precision_into_single_node_full_precision': test_load_from_distributed_zero_megatron_full_precision_into_single_node_full_precision, - 'test_load_from_distributed_zero_megatron_mixed_precision_into_single_node_full_precision': test_load_from_distributed_zero_megatron_mixed_precision_into_single_node_full_precision, - 'test_load_from_distributed_zero_megatron_mixed_precision_into_single_node_mixed_precision': test_load_from_distributed_zero_megatron_mixed_precision_into_single_node_mixed_precision, - 'test_load_from_distributed_zero_megatron_full_precision_into_single_node_mixed_precision': test_load_from_distributed_zero_megatron_full_precision_into_single_node_mixed_precision, - + "test_load_from_single_node_full_precision_into_single_node_full_precision": test_load_from_single_node_full_precision_into_single_node_full_precision, + "test_load_from_single_node_mixed_precision_into_single_node_mixed_precision": test_load_from_single_node_mixed_precision_into_single_node_mixed_precision, + "test_load_from_single_node_mixed_precision_into_single_node_full_precision": test_load_from_single_node_mixed_precision_into_single_node_full_precision, + "test_load_from_single_node_full_precision_into_single_node_mixed_precision": test_load_from_single_node_full_precision_into_single_node_mixed_precision, + "test_load_from_data_parallelism_full_precision_into_single_node_full_precision": test_load_from_data_parallelism_full_precision_into_single_node_full_precision, + "test_load_from_data_parallelism_mixed_precision_into_single_node_full_precision": test_load_from_data_parallelism_mixed_precision_into_single_node_full_precision, + "test_load_from_data_parallelism_mixed_precision_into_single_node_mixed_precision": test_load_from_data_parallelism_mixed_precision_into_single_node_mixed_precision, + "test_load_from_data_parallelism_full_precision_into_single_node_mixed_precision": test_load_from_data_parallelism_full_precision_into_single_node_mixed_precision, + "test_load_from_distributed_zero_full_precision_into_single_node_full_precision": test_load_from_distributed_zero_full_precision_into_single_node_full_precision, + "test_load_from_distributed_zero_mixed_precision_into_single_node_full_precision": test_load_from_distributed_zero_mixed_precision_into_single_node_full_precision, + "test_load_from_distributed_zero_mixed_precision_into_single_node_mixed_precision": test_load_from_distributed_zero_mixed_precision_into_single_node_mixed_precision, + "test_load_from_distributed_zero_full_precision_into_single_node_mixed_precision": test_load_from_distributed_zero_full_precision_into_single_node_mixed_precision, + "test_load_from_distributed_megatron_full_precision_into_single_node_full_precision": test_load_from_distributed_megatron_full_precision_into_single_node_full_precision, + "test_load_from_distributed_megatron_mixed_precision_into_single_node_full_precision": test_load_from_distributed_megatron_mixed_precision_into_single_node_full_precision, + "test_load_from_distributed_megatron_mixed_precision_into_single_node_mixed_precision": test_load_from_distributed_megatron_mixed_precision_into_single_node_mixed_precision, + "test_load_from_distributed_megatron_full_precision_into_single_node_mixed_precision": test_load_from_distributed_megatron_full_precision_into_single_node_mixed_precision, + "test_load_from_distributed_zero_megatron_full_precision_into_single_node_full_precision": test_load_from_distributed_zero_megatron_full_precision_into_single_node_full_precision, + "test_load_from_distributed_zero_megatron_mixed_precision_into_single_node_full_precision": test_load_from_distributed_zero_megatron_mixed_precision_into_single_node_full_precision, + "test_load_from_distributed_zero_megatron_mixed_precision_into_single_node_mixed_precision": test_load_from_distributed_zero_megatron_mixed_precision_into_single_node_mixed_precision, + "test_load_from_distributed_zero_megatron_full_precision_into_single_node_mixed_precision": test_load_from_distributed_zero_megatron_full_precision_into_single_node_mixed_precision, # all config to data parallel node config - 'test_load_from_single_node_full_precision_into_data_parallelism_full_precision': test_load_from_single_node_full_precision_into_data_parallelism_full_precision, - 'test_load_from_single_node_mixed_precision_into_data_parallelism_full_precision': test_load_from_single_node_mixed_precision_into_data_parallelism_full_precision, - 'test_load_from_single_node_mixed_precision_into_data_parallelism_mixed_precision': test_load_from_single_node_mixed_precision_into_data_parallelism_mixed_precision, - 'test_load_from_single_node_full_precision_into_data_parallelism_mixed_precision': test_load_from_single_node_full_precision_into_data_parallelism_mixed_precision, - 'test_load_from_data_parallelism_full_precision_into_data_parallelism_full_precision': test_load_from_data_parallelism_full_precision_into_data_parallelism_full_precision, - 'test_load_from_data_parallelism_mixed_precision_into_data_parallelism_full_precision': test_load_from_data_parallelism_mixed_precision_into_data_parallelism_full_precision, - 'test_load_from_data_parallelism_mixed_precision_into_data_parallelism_mixed_precision': test_load_from_data_parallelism_mixed_precision_into_data_parallelism_mixed_precision, - 'test_load_from_data_parallelism_full_precision_into_data_parallelism_mixed_precision': test_load_from_data_parallelism_full_precision_into_data_parallelism_mixed_precision, - 'test_load_from_distributed_zero_full_precision_into_data_parallelism_full_precision': test_load_from_distributed_zero_full_precision_into_data_parallelism_full_precision, - 'test_load_from_distributed_zero_mixed_precision_into_data_parallelism_full_precision': test_load_from_distributed_zero_mixed_precision_into_data_parallelism_full_precision, - 'test_load_from_distributed_zero_mixed_precision_into_data_parallelism_mixed_precision': test_load_from_distributed_zero_mixed_precision_into_data_parallelism_mixed_precision, - 'test_load_from_distributed_zero_full_precision_into_data_parallelism_mixed_precision': test_load_from_distributed_zero_full_precision_into_data_parallelism_mixed_precision, - 'test_load_from_distributed_megatron_full_precision_into_data_parallelism_full_precision': test_load_from_distributed_megatron_full_precision_into_data_parallelism_full_precision, - 'test_load_from_distributed_megatron_mixed_precision_into_data_parallelism_full_precision': test_load_from_distributed_megatron_mixed_precision_into_data_parallelism_full_precision, - 'test_load_from_distributed_megatron_mixed_precision_into_data_parallelism_mixed_precision': test_load_from_distributed_megatron_mixed_precision_into_data_parallelism_mixed_precision, - 'test_load_from_distributed_megatron_full_precision_into_data_parallelism_mixed_precision': test_load_from_distributed_megatron_full_precision_into_data_parallelism_mixed_precision, - 'test_load_from_distributed_zero_megatron_full_precision_into_data_parallelism_full_precision': test_load_from_distributed_zero_megatron_full_precision_into_data_parallelism_full_precision, - 'test_load_from_distributed_zero_megatron_mixed_precision_into_data_parallelism_full_precision': test_load_from_distributed_zero_megatron_mixed_precision_into_data_parallelism_full_precision, - 'test_load_from_distributed_zero_megatron_mixed_precision_into_data_parallelism_mixed_precision': test_load_from_distributed_zero_megatron_mixed_precision_into_data_parallelism_mixed_precision, - 'test_load_from_distributed_zero_megatron_full_precision_into_data_parallelism_mixed_precision': test_load_from_distributed_zero_megatron_full_precision_into_data_parallelism_mixed_precision, - + "test_load_from_single_node_full_precision_into_data_parallelism_full_precision": test_load_from_single_node_full_precision_into_data_parallelism_full_precision, + "test_load_from_single_node_mixed_precision_into_data_parallelism_full_precision": test_load_from_single_node_mixed_precision_into_data_parallelism_full_precision, + "test_load_from_single_node_mixed_precision_into_data_parallelism_mixed_precision": test_load_from_single_node_mixed_precision_into_data_parallelism_mixed_precision, + "test_load_from_single_node_full_precision_into_data_parallelism_mixed_precision": test_load_from_single_node_full_precision_into_data_parallelism_mixed_precision, + "test_load_from_data_parallelism_full_precision_into_data_parallelism_full_precision": test_load_from_data_parallelism_full_precision_into_data_parallelism_full_precision, + "test_load_from_data_parallelism_mixed_precision_into_data_parallelism_full_precision": test_load_from_data_parallelism_mixed_precision_into_data_parallelism_full_precision, + "test_load_from_data_parallelism_mixed_precision_into_data_parallelism_mixed_precision": test_load_from_data_parallelism_mixed_precision_into_data_parallelism_mixed_precision, + "test_load_from_data_parallelism_full_precision_into_data_parallelism_mixed_precision": test_load_from_data_parallelism_full_precision_into_data_parallelism_mixed_precision, + "test_load_from_distributed_zero_full_precision_into_data_parallelism_full_precision": test_load_from_distributed_zero_full_precision_into_data_parallelism_full_precision, + "test_load_from_distributed_zero_mixed_precision_into_data_parallelism_full_precision": test_load_from_distributed_zero_mixed_precision_into_data_parallelism_full_precision, + "test_load_from_distributed_zero_mixed_precision_into_data_parallelism_mixed_precision": test_load_from_distributed_zero_mixed_precision_into_data_parallelism_mixed_precision, + "test_load_from_distributed_zero_full_precision_into_data_parallelism_mixed_precision": test_load_from_distributed_zero_full_precision_into_data_parallelism_mixed_precision, + "test_load_from_distributed_megatron_full_precision_into_data_parallelism_full_precision": test_load_from_distributed_megatron_full_precision_into_data_parallelism_full_precision, + "test_load_from_distributed_megatron_mixed_precision_into_data_parallelism_full_precision": test_load_from_distributed_megatron_mixed_precision_into_data_parallelism_full_precision, + "test_load_from_distributed_megatron_mixed_precision_into_data_parallelism_mixed_precision": test_load_from_distributed_megatron_mixed_precision_into_data_parallelism_mixed_precision, + "test_load_from_distributed_megatron_full_precision_into_data_parallelism_mixed_precision": test_load_from_distributed_megatron_full_precision_into_data_parallelism_mixed_precision, + "test_load_from_distributed_zero_megatron_full_precision_into_data_parallelism_full_precision": test_load_from_distributed_zero_megatron_full_precision_into_data_parallelism_full_precision, + "test_load_from_distributed_zero_megatron_mixed_precision_into_data_parallelism_full_precision": test_load_from_distributed_zero_megatron_mixed_precision_into_data_parallelism_full_precision, + "test_load_from_distributed_zero_megatron_mixed_precision_into_data_parallelism_mixed_precision": test_load_from_distributed_zero_megatron_mixed_precision_into_data_parallelism_mixed_precision, + "test_load_from_distributed_zero_megatron_full_precision_into_data_parallelism_mixed_precision": test_load_from_distributed_zero_megatron_full_precision_into_data_parallelism_mixed_precision, # all config to distributed zero node config - 'test_load_from_single_node_full_precision_into_distributed_zero_full_precision': test_load_from_single_node_full_precision_into_distributed_zero_full_precision, - 'test_load_from_single_node_mixed_precision_into_distributed_zero_full_precision': test_load_from_single_node_mixed_precision_into_distributed_zero_full_precision, - 'test_load_from_single_node_mixed_precision_into_distributed_zero_mixed_precision': test_load_from_single_node_mixed_precision_into_distributed_zero_mixed_precision, - 'test_load_from_single_node_full_precision_into_distributed_zero_mixed_precision': test_load_from_single_node_full_precision_into_distributed_zero_mixed_precision, - 'test_load_from_data_parallelism_full_precision_into_distributed_zero_full_precision': test_load_from_data_parallelism_full_precision_into_distributed_zero_full_precision, - 'test_load_from_data_parallelism_mixed_precision_into_distributed_zero_full_precision': test_load_from_data_parallelism_mixed_precision_into_distributed_zero_full_precision, - 'test_load_from_data_parallelism_mixed_precision_into_distributed_zero_mixed_precision': test_load_from_data_parallelism_mixed_precision_into_distributed_zero_mixed_precision, - 'test_load_from_data_parallelism_full_precision_into_distributed_zero_mixed_precision': test_load_from_data_parallelism_full_precision_into_distributed_zero_mixed_precision, - 'test_load_from_distributed_zero_full_precision_into_distributed_zero_full_precision': test_load_from_distributed_zero_full_precision_into_distributed_zero_full_precision, - 'test_load_from_distributed_zero_mixed_precision_into_distributed_zero_full_precision': test_load_from_distributed_zero_mixed_precision_into_distributed_zero_full_precision, - 'test_load_from_distributed_zero_mixed_precision_into_distributed_zero_mixed_precision': test_load_from_distributed_zero_mixed_precision_into_distributed_zero_mixed_precision, - 'test_load_from_distributed_zero_full_precision_into_distributed_zero_mixed_precision': test_load_from_distributed_zero_full_precision_into_distributed_zero_mixed_precision, - 'test_load_from_distributed_megatron_full_precision_into_distributed_zero_full_precision': test_load_from_distributed_megatron_full_precision_into_distributed_zero_full_precision, - 'test_load_from_distributed_megatron_mixed_precision_into_distributed_zero_full_precision': test_load_from_distributed_megatron_mixed_precision_into_distributed_zero_full_precision, - 'test_load_from_distributed_megatron_mixed_precision_into_distributed_zero_mixed_precision': test_load_from_distributed_megatron_mixed_precision_into_distributed_zero_mixed_precision, - 'test_load_from_distributed_megatron_full_precision_into_distributed_zero_mixed_precision': test_load_from_distributed_megatron_full_precision_into_distributed_zero_mixed_precision, - 'test_load_from_distributed_zero_megatron_full_precision_into_distributed_zero_full_precision': test_load_from_distributed_zero_megatron_full_precision_into_distributed_zero_full_precision, - 'test_load_from_distributed_zero_megatron_mixed_precision_into_distributed_zero_full_precision': test_load_from_distributed_zero_megatron_mixed_precision_into_distributed_zero_full_precision, - 'test_load_from_distributed_zero_megatron_mixed_precision_into_distributed_zero_mixed_precision': test_load_from_distributed_zero_megatron_mixed_precision_into_distributed_zero_mixed_precision, - 'test_load_from_distributed_zero_megatron_full_precision_into_distributed_zero_mixed_precision': test_load_from_distributed_zero_megatron_full_precision_into_distributed_zero_mixed_precision, - + "test_load_from_single_node_full_precision_into_distributed_zero_full_precision": test_load_from_single_node_full_precision_into_distributed_zero_full_precision, + "test_load_from_single_node_mixed_precision_into_distributed_zero_full_precision": test_load_from_single_node_mixed_precision_into_distributed_zero_full_precision, + "test_load_from_single_node_mixed_precision_into_distributed_zero_mixed_precision": test_load_from_single_node_mixed_precision_into_distributed_zero_mixed_precision, + "test_load_from_single_node_full_precision_into_distributed_zero_mixed_precision": test_load_from_single_node_full_precision_into_distributed_zero_mixed_precision, + "test_load_from_data_parallelism_full_precision_into_distributed_zero_full_precision": test_load_from_data_parallelism_full_precision_into_distributed_zero_full_precision, + "test_load_from_data_parallelism_mixed_precision_into_distributed_zero_full_precision": test_load_from_data_parallelism_mixed_precision_into_distributed_zero_full_precision, + "test_load_from_data_parallelism_mixed_precision_into_distributed_zero_mixed_precision": test_load_from_data_parallelism_mixed_precision_into_distributed_zero_mixed_precision, + "test_load_from_data_parallelism_full_precision_into_distributed_zero_mixed_precision": test_load_from_data_parallelism_full_precision_into_distributed_zero_mixed_precision, + "test_load_from_distributed_zero_full_precision_into_distributed_zero_full_precision": test_load_from_distributed_zero_full_precision_into_distributed_zero_full_precision, + "test_load_from_distributed_zero_mixed_precision_into_distributed_zero_full_precision": test_load_from_distributed_zero_mixed_precision_into_distributed_zero_full_precision, + "test_load_from_distributed_zero_mixed_precision_into_distributed_zero_mixed_precision": test_load_from_distributed_zero_mixed_precision_into_distributed_zero_mixed_precision, + "test_load_from_distributed_zero_full_precision_into_distributed_zero_mixed_precision": test_load_from_distributed_zero_full_precision_into_distributed_zero_mixed_precision, + "test_load_from_distributed_megatron_full_precision_into_distributed_zero_full_precision": test_load_from_distributed_megatron_full_precision_into_distributed_zero_full_precision, + "test_load_from_distributed_megatron_mixed_precision_into_distributed_zero_full_precision": test_load_from_distributed_megatron_mixed_precision_into_distributed_zero_full_precision, + "test_load_from_distributed_megatron_mixed_precision_into_distributed_zero_mixed_precision": test_load_from_distributed_megatron_mixed_precision_into_distributed_zero_mixed_precision, + "test_load_from_distributed_megatron_full_precision_into_distributed_zero_mixed_precision": test_load_from_distributed_megatron_full_precision_into_distributed_zero_mixed_precision, + "test_load_from_distributed_zero_megatron_full_precision_into_distributed_zero_full_precision": test_load_from_distributed_zero_megatron_full_precision_into_distributed_zero_full_precision, + "test_load_from_distributed_zero_megatron_mixed_precision_into_distributed_zero_full_precision": test_load_from_distributed_zero_megatron_mixed_precision_into_distributed_zero_full_precision, + "test_load_from_distributed_zero_megatron_mixed_precision_into_distributed_zero_mixed_precision": test_load_from_distributed_zero_megatron_mixed_precision_into_distributed_zero_mixed_precision, + "test_load_from_distributed_zero_megatron_full_precision_into_distributed_zero_mixed_precision": test_load_from_distributed_zero_megatron_full_precision_into_distributed_zero_mixed_precision, # all config to distributed megatron node config - 'test_load_from_single_node_full_precision_into_distributed_megatron_full_precision': test_load_from_single_node_full_precision_into_distributed_megatron_full_precision, - 'test_load_from_single_node_mixed_precision_into_distributed_megatron_full_precision': test_load_from_single_node_mixed_precision_into_distributed_megatron_full_precision, - 'test_load_from_single_node_mixed_precision_into_distributed_megatron_mixed_precision': test_load_from_single_node_mixed_precision_into_distributed_megatron_mixed_precision, - 'test_load_from_single_node_full_precision_into_distributed_megatron_mixed_precision': test_load_from_single_node_full_precision_into_distributed_megatron_mixed_precision, - 'test_load_from_data_parallelism_full_precision_into_distributed_megatron_full_precision': test_load_from_data_parallelism_full_precision_into_distributed_megatron_full_precision, - 'test_load_from_data_parallelism_mixed_precision_into_distributed_megatron_full_precision': test_load_from_data_parallelism_mixed_precision_into_distributed_megatron_full_precision, - 'test_load_from_data_parallelism_mixed_precision_into_distributed_megatron_mixed_precision': test_load_from_data_parallelism_mixed_precision_into_distributed_megatron_mixed_precision, - 'test_load_from_data_parallelism_full_precision_into_distributed_megatron_mixed_precision': test_load_from_data_parallelism_full_precision_into_distributed_megatron_mixed_precision, - 'test_load_from_distributed_zero_full_precision_into_distributed_megatron_full_precision': test_load_from_distributed_zero_full_precision_into_distributed_megatron_full_precision, - 'test_load_from_distributed_zero_mixed_precision_into_distributed_megatron_full_precision': test_load_from_distributed_zero_mixed_precision_into_distributed_megatron_full_precision, - 'test_load_from_distributed_zero_mixed_precision_into_distributed_megatron_mixed_precision': test_load_from_distributed_zero_mixed_precision_into_distributed_megatron_mixed_precision, - 'test_load_from_distributed_zero_full_precision_into_distributed_megatron_mixed_precision': test_load_from_distributed_zero_full_precision_into_distributed_megatron_mixed_precision, - 'test_load_from_distributed_megatron_full_precision_into_distributed_megatron_full_precision': test_load_from_distributed_megatron_full_precision_into_distributed_megatron_full_precision, - 'test_load_from_distributed_megatron_mixed_precision_into_distributed_megatron_full_precision': test_load_from_distributed_megatron_mixed_precision_into_distributed_megatron_full_precision, - 'test_load_from_distributed_megatron_mixed_precision_into_distributed_megatron_mixed_precision': test_load_from_distributed_megatron_mixed_precision_into_distributed_megatron_mixed_precision, - 'test_load_from_distributed_megatron_full_precision_into_distributed_megatron_mixed_precision': test_load_from_distributed_megatron_full_precision_into_distributed_megatron_mixed_precision, - 'test_load_from_distributed_zero_megatron_full_precision_into_distributed_megatron_full_precision': test_load_from_distributed_zero_megatron_full_precision_into_distributed_megatron_full_precision, - 'test_load_from_distributed_zero_megatron_mixed_precision_into_distributed_megatron_full_precision': test_load_from_distributed_zero_megatron_mixed_precision_into_distributed_megatron_full_precision, - 'test_load_from_distributed_zero_megatron_mixed_precision_into_distributed_megatron_mixed_precision': test_load_from_distributed_zero_megatron_mixed_precision_into_distributed_megatron_mixed_precision, - 'test_load_from_distributed_zero_megatron_full_precision_into_distributed_megatron_mixed_precision': test_load_from_distributed_zero_megatron_full_precision_into_distributed_megatron_mixed_precision, - + "test_load_from_single_node_full_precision_into_distributed_megatron_full_precision": test_load_from_single_node_full_precision_into_distributed_megatron_full_precision, + "test_load_from_single_node_mixed_precision_into_distributed_megatron_full_precision": test_load_from_single_node_mixed_precision_into_distributed_megatron_full_precision, + "test_load_from_single_node_mixed_precision_into_distributed_megatron_mixed_precision": test_load_from_single_node_mixed_precision_into_distributed_megatron_mixed_precision, + "test_load_from_single_node_full_precision_into_distributed_megatron_mixed_precision": test_load_from_single_node_full_precision_into_distributed_megatron_mixed_precision, + "test_load_from_data_parallelism_full_precision_into_distributed_megatron_full_precision": test_load_from_data_parallelism_full_precision_into_distributed_megatron_full_precision, + "test_load_from_data_parallelism_mixed_precision_into_distributed_megatron_full_precision": test_load_from_data_parallelism_mixed_precision_into_distributed_megatron_full_precision, + "test_load_from_data_parallelism_mixed_precision_into_distributed_megatron_mixed_precision": test_load_from_data_parallelism_mixed_precision_into_distributed_megatron_mixed_precision, + "test_load_from_data_parallelism_full_precision_into_distributed_megatron_mixed_precision": test_load_from_data_parallelism_full_precision_into_distributed_megatron_mixed_precision, + "test_load_from_distributed_zero_full_precision_into_distributed_megatron_full_precision": test_load_from_distributed_zero_full_precision_into_distributed_megatron_full_precision, + "test_load_from_distributed_zero_mixed_precision_into_distributed_megatron_full_precision": test_load_from_distributed_zero_mixed_precision_into_distributed_megatron_full_precision, + "test_load_from_distributed_zero_mixed_precision_into_distributed_megatron_mixed_precision": test_load_from_distributed_zero_mixed_precision_into_distributed_megatron_mixed_precision, + "test_load_from_distributed_zero_full_precision_into_distributed_megatron_mixed_precision": test_load_from_distributed_zero_full_precision_into_distributed_megatron_mixed_precision, + "test_load_from_distributed_megatron_full_precision_into_distributed_megatron_full_precision": test_load_from_distributed_megatron_full_precision_into_distributed_megatron_full_precision, + "test_load_from_distributed_megatron_mixed_precision_into_distributed_megatron_full_precision": test_load_from_distributed_megatron_mixed_precision_into_distributed_megatron_full_precision, + "test_load_from_distributed_megatron_mixed_precision_into_distributed_megatron_mixed_precision": test_load_from_distributed_megatron_mixed_precision_into_distributed_megatron_mixed_precision, + "test_load_from_distributed_megatron_full_precision_into_distributed_megatron_mixed_precision": test_load_from_distributed_megatron_full_precision_into_distributed_megatron_mixed_precision, + "test_load_from_distributed_zero_megatron_full_precision_into_distributed_megatron_full_precision": test_load_from_distributed_zero_megatron_full_precision_into_distributed_megatron_full_precision, + "test_load_from_distributed_zero_megatron_mixed_precision_into_distributed_megatron_full_precision": test_load_from_distributed_zero_megatron_mixed_precision_into_distributed_megatron_full_precision, + "test_load_from_distributed_zero_megatron_mixed_precision_into_distributed_megatron_mixed_precision": test_load_from_distributed_zero_megatron_mixed_precision_into_distributed_megatron_mixed_precision, + "test_load_from_distributed_zero_megatron_full_precision_into_distributed_megatron_mixed_precision": test_load_from_distributed_zero_megatron_full_precision_into_distributed_megatron_mixed_precision, # all config to distributed zero + megatron node config - 'test_load_from_single_node_full_precision_into_distributed_zero_megatron_full_precision': test_load_from_single_node_full_precision_into_distributed_zero_megatron_full_precision, - 'test_load_from_single_node_mixed_precision_into_distributed_zero_megatron_full_precision': test_load_from_single_node_mixed_precision_into_distributed_zero_megatron_full_precision, - 'test_load_from_single_node_mixed_precision_into_distributed_zero_megatron_mixed_precision': test_load_from_single_node_mixed_precision_into_distributed_zero_megatron_mixed_precision, - 'test_load_from_single_node_full_precision_into_distributed_zero_megatron_mixed_precision': test_load_from_single_node_full_precision_into_distributed_zero_megatron_mixed_precision, - 'test_load_from_data_parallelism_full_precision_into_distributed_zero_megatron_full_precision': test_load_from_data_parallelism_full_precision_into_distributed_zero_megatron_full_precision, - 'test_load_from_data_parallelism_mixed_precision_into_distributed_zero_megatron_full_precision': test_load_from_data_parallelism_mixed_precision_into_distributed_zero_megatron_full_precision, - 'test_load_from_data_parallelism_mixed_precision_into_distributed_zero_megatron_mixed_precision': test_load_from_data_parallelism_mixed_precision_into_distributed_zero_megatron_mixed_precision, - 'test_load_from_data_parallelism_full_precision_into_distributed_zero_megatron_mixed_precision': test_load_from_data_parallelism_full_precision_into_distributed_zero_megatron_mixed_precision, - 'test_load_from_distributed_zero_full_precision_into_distributed_zero_megatron_full_precision': test_load_from_distributed_zero_full_precision_into_distributed_zero_megatron_full_precision, - 'test_load_from_distributed_zero_mixed_precision_into_distributed_zero_megatron_full_precision': test_load_from_distributed_zero_mixed_precision_into_distributed_zero_megatron_full_precision, - 'test_load_from_distributed_zero_mixed_precision_into_distributed_zero_megatron_mixed_precision': test_load_from_distributed_zero_mixed_precision_into_distributed_zero_megatron_mixed_precision, - 'test_load_from_distributed_zero_full_precision_into_distributed_zero_megatron_mixed_precision': test_load_from_distributed_zero_full_precision_into_distributed_zero_megatron_mixed_precision, - 'test_load_from_distributed_megatron_full_precision_into_distributed_zero_megatron_full_precision': test_load_from_distributed_megatron_full_precision_into_distributed_zero_megatron_full_precision, - 'test_load_from_distributed_megatron_mixed_precision_into_distributed_zero_megatron_full_precision': test_load_from_distributed_megatron_mixed_precision_into_distributed_zero_megatron_full_precision, - 'test_load_from_distributed_megatron_mixed_precision_into_distributed_zero_megatron_mixed_precision': test_load_from_distributed_megatron_mixed_precision_into_distributed_zero_megatron_mixed_precision, - 'test_load_from_distributed_megatron_full_precision_into_distributed_zero_megatron_mixed_precision': test_load_from_distributed_megatron_full_precision_into_distributed_zero_megatron_mixed_precision, - 'test_load_from_distributed_zero_megatron_full_precision_into_distributed_zero_megatron_full_precision': test_load_from_distributed_zero_megatron_full_precision_into_distributed_zero_megatron_full_precision, - 'test_load_from_distributed_zero_megatron_mixed_precision_into_distributed_zero_megatron_full_precision': test_load_from_distributed_zero_megatron_mixed_precision_into_distributed_zero_megatron_full_precision, - 'test_load_from_distributed_zero_megatron_mixed_precision_into_distributed_zero_megatron_mixed_precision': test_load_from_distributed_zero_megatron_mixed_precision_into_distributed_zero_megatron_mixed_precision, - 'test_load_from_distributed_zero_megatron_full_precision_into_distributed_zero_megatron_mixed_precision': test_load_from_distributed_zero_megatron_full_precision_into_distributed_zero_megatron_mixed_precision + "test_load_from_single_node_full_precision_into_distributed_zero_megatron_full_precision": test_load_from_single_node_full_precision_into_distributed_zero_megatron_full_precision, + "test_load_from_single_node_mixed_precision_into_distributed_zero_megatron_full_precision": test_load_from_single_node_mixed_precision_into_distributed_zero_megatron_full_precision, + "test_load_from_single_node_mixed_precision_into_distributed_zero_megatron_mixed_precision": test_load_from_single_node_mixed_precision_into_distributed_zero_megatron_mixed_precision, + "test_load_from_single_node_full_precision_into_distributed_zero_megatron_mixed_precision": test_load_from_single_node_full_precision_into_distributed_zero_megatron_mixed_precision, + "test_load_from_data_parallelism_full_precision_into_distributed_zero_megatron_full_precision": test_load_from_data_parallelism_full_precision_into_distributed_zero_megatron_full_precision, + "test_load_from_data_parallelism_mixed_precision_into_distributed_zero_megatron_full_precision": test_load_from_data_parallelism_mixed_precision_into_distributed_zero_megatron_full_precision, + "test_load_from_data_parallelism_mixed_precision_into_distributed_zero_megatron_mixed_precision": test_load_from_data_parallelism_mixed_precision_into_distributed_zero_megatron_mixed_precision, + "test_load_from_data_parallelism_full_precision_into_distributed_zero_megatron_mixed_precision": test_load_from_data_parallelism_full_precision_into_distributed_zero_megatron_mixed_precision, + "test_load_from_distributed_zero_full_precision_into_distributed_zero_megatron_full_precision": test_load_from_distributed_zero_full_precision_into_distributed_zero_megatron_full_precision, + "test_load_from_distributed_zero_mixed_precision_into_distributed_zero_megatron_full_precision": test_load_from_distributed_zero_mixed_precision_into_distributed_zero_megatron_full_precision, + "test_load_from_distributed_zero_mixed_precision_into_distributed_zero_megatron_mixed_precision": test_load_from_distributed_zero_mixed_precision_into_distributed_zero_megatron_mixed_precision, + "test_load_from_distributed_zero_full_precision_into_distributed_zero_megatron_mixed_precision": test_load_from_distributed_zero_full_precision_into_distributed_zero_megatron_mixed_precision, + "test_load_from_distributed_megatron_full_precision_into_distributed_zero_megatron_full_precision": test_load_from_distributed_megatron_full_precision_into_distributed_zero_megatron_full_precision, + "test_load_from_distributed_megatron_mixed_precision_into_distributed_zero_megatron_full_precision": test_load_from_distributed_megatron_mixed_precision_into_distributed_zero_megatron_full_precision, + "test_load_from_distributed_megatron_mixed_precision_into_distributed_zero_megatron_mixed_precision": test_load_from_distributed_megatron_mixed_precision_into_distributed_zero_megatron_mixed_precision, + "test_load_from_distributed_megatron_full_precision_into_distributed_zero_megatron_mixed_precision": test_load_from_distributed_megatron_full_precision_into_distributed_zero_megatron_mixed_precision, + "test_load_from_distributed_zero_megatron_full_precision_into_distributed_zero_megatron_full_precision": test_load_from_distributed_zero_megatron_full_precision_into_distributed_zero_megatron_full_precision, + "test_load_from_distributed_zero_megatron_mixed_precision_into_distributed_zero_megatron_full_precision": test_load_from_distributed_zero_megatron_mixed_precision_into_distributed_zero_megatron_full_precision, + "test_load_from_distributed_zero_megatron_mixed_precision_into_distributed_zero_megatron_mixed_precision": test_load_from_distributed_zero_megatron_mixed_precision_into_distributed_zero_megatron_mixed_precision, + "test_load_from_distributed_zero_megatron_full_precision_into_distributed_zero_megatron_mixed_precision": test_load_from_distributed_zero_megatron_full_precision_into_distributed_zero_megatron_mixed_precision, } -parser = argparse.ArgumentParser(description='Test saved states of trainers to loaded states') -parser.add_argument('--scenario', choices=function_map.keys(), help='training scenario to test saved and loaded states', required=True) -parser.add_argument('--checkpoint_dir', help='path to the saved states directory', required=True) +parser = argparse.ArgumentParser(description="Test saved states of trainers to loaded states") +parser.add_argument( + "--scenario", choices=function_map.keys(), help="training scenario to test saved and loaded states", required=True +) +parser.add_argument("--checkpoint_dir", help="path to the saved states directory", required=True) args = parser.parse_args() function_map[args.scenario](checkpoint_dir=args.checkpoint_dir) diff --git a/orttraining/orttraining/test/python/checkpoint/orttraining_test_load_optimizer_state.py b/orttraining/orttraining/test/python/checkpoint/orttraining_test_load_optimizer_state.py index 66580f1d6901f..2d9a54eda77bf 100644 --- a/orttraining/orttraining/test/python/checkpoint/orttraining_test_load_optimizer_state.py +++ b/orttraining/orttraining/test/python/checkpoint/orttraining_test_load_optimizer_state.py @@ -16,6 +16,7 @@ import torch.distributed as dist import sys + sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) import onnxruntime @@ -23,7 +24,8 @@ from _test_helpers import distributed_setup, load_model_optim_state_and_eval, aggregate_states from _test_commons import assert_all_states_close_ort -def verify_optimizer_state_match(device, opts, checkpoint_dir, world_rank, use_lamb=False): + +def verify_optimizer_state_match(device, opts, checkpoint_dir, world_rank, use_lamb=False): expected_optim_state, trainer_optim_state = load_model_optim_state_and_eval(device, opts, use_lamb) # verify optimizer states are matching by: @@ -32,113 +34,104 @@ def verify_optimizer_state_match(device, opts, checkpoint_dir, world_rank, use_ # - Comparing this aggregated state dictionary with the full dummy optimizer dictionary (expected_optim_state) # created by load_model_optim_state_and_eval - with open(os.path.join(checkpoint_dir, 'distributed_state_'+str(world_rank)+'.pkl'), "wb") as f: + with open(os.path.join(checkpoint_dir, "distributed_state_" + str(world_rank) + ".pkl"), "wb") as f: pickle.dump(trainer_optim_state, f) dist.barrier() if world_rank == 0: # aggregate states and compare - aggregated_state_dict = aggregate_states(checkpoint_dir, filename_prefix='distributed_state', state_dict_key_name=None) + aggregated_state_dict = aggregate_states( + checkpoint_dir, filename_prefix="distributed_state", state_dict_key_name=None + ) # compare all states assert_all_states_close_ort(aggregated_state_dict, expected_optim_state, reshape_states=True) dist.barrier() - os.remove(os.path.join(checkpoint_dir, 'distributed_state_'+str(world_rank)+'.pkl')) + os.remove(os.path.join(checkpoint_dir, "distributed_state_" + str(world_rank) + ".pkl")) @distributed_setup -def test_optim_load_to_distributed_zero_full_precision_adam(world_rank, world_size, device, checkpoint_dir = 'checkpoint_dir/distributed_zero/full_precision/adam/'): +def test_optim_load_to_distributed_zero_full_precision_adam( + world_rank, world_size, device, checkpoint_dir="checkpoint_dir/distributed_zero/full_precision/adam/" +): opts = { - 'device' : {'id' : device}, - 'distributed' : - { - 'world_rank' : world_rank, - 'world_size' : world_size, - 'allreduce_post_accumulation' : True, - 'deepspeed_zero_optimization': - { - 'stage': 1 - } - }, - 'debug' : {'deterministic_compute': True} - } - verify_optimizer_state_match(device, opts, checkpoint_dir, world_rank, use_lamb=False) + "device": {"id": device}, + "distributed": { + "world_rank": world_rank, + "world_size": world_size, + "allreduce_post_accumulation": True, + "deepspeed_zero_optimization": {"stage": 1}, + }, + "debug": {"deterministic_compute": True}, + } + verify_optimizer_state_match(device, opts, checkpoint_dir, world_rank, use_lamb=False) @distributed_setup -def test_optim_load_to_distributed_zero_mixed_precision_adam(world_rank, world_size, device, checkpoint_dir = 'checkpoint_dir/distributed_zero/mixed_precision/adam/'): +def test_optim_load_to_distributed_zero_mixed_precision_adam( + world_rank, world_size, device, checkpoint_dir="checkpoint_dir/distributed_zero/mixed_precision/adam/" +): opts = { - 'device' : {'id' : device}, - 'mixed_precision': - { - 'enabled': True - }, - 'distributed' : - { - 'world_rank' : world_rank, - 'world_size' : world_size, - 'allreduce_post_accumulation' : True, - 'deepspeed_zero_optimization': - { - 'stage': 1 - } - }, - 'debug' : {'deterministic_compute': True} - } - verify_optimizer_state_match(device, opts, checkpoint_dir, world_rank, use_lamb=False) + "device": {"id": device}, + "mixed_precision": {"enabled": True}, + "distributed": { + "world_rank": world_rank, + "world_size": world_size, + "allreduce_post_accumulation": True, + "deepspeed_zero_optimization": {"stage": 1}, + }, + "debug": {"deterministic_compute": True}, + } + verify_optimizer_state_match(device, opts, checkpoint_dir, world_rank, use_lamb=False) @distributed_setup -def test_optim_load_to_distributed_zero_full_precision_lamb(world_rank, world_size, device, checkpoint_dir = 'checkpoint_dir/distributed_zero/full_precision/lamb/'): +def test_optim_load_to_distributed_zero_full_precision_lamb( + world_rank, world_size, device, checkpoint_dir="checkpoint_dir/distributed_zero/full_precision/lamb/" +): opts = { - 'device' : {'id' : device}, - 'distributed' : - { - 'world_rank' : world_rank, - 'world_size' : world_size, - 'allreduce_post_accumulation' : True, - 'deepspeed_zero_optimization': - { - 'stage': 1 - } - }, - 'debug' : {'deterministic_compute': True} - } - verify_optimizer_state_match(device, opts, checkpoint_dir, world_rank, use_lamb=True) + "device": {"id": device}, + "distributed": { + "world_rank": world_rank, + "world_size": world_size, + "allreduce_post_accumulation": True, + "deepspeed_zero_optimization": {"stage": 1}, + }, + "debug": {"deterministic_compute": True}, + } + verify_optimizer_state_match(device, opts, checkpoint_dir, world_rank, use_lamb=True) + @distributed_setup -def test_optim_load_to_distributed_zero_mixed_precision_lamb(world_rank, world_size, device, checkpoint_dir = 'checkpoint_dir/distributed_zero/mixed_precision/lamb/'): +def test_optim_load_to_distributed_zero_mixed_precision_lamb( + world_rank, world_size, device, checkpoint_dir="checkpoint_dir/distributed_zero/mixed_precision/lamb/" +): opts = { - 'device' : {'id' : device}, - 'mixed_precision': - { - 'enabled': True - }, - 'distributed' : - { - 'world_rank' : world_rank, - 'world_size' : world_size, - 'allreduce_post_accumulation' : True, - 'deepspeed_zero_optimization': - { - 'stage': 1 - } - }, - 'debug' : {'deterministic_compute': True} - } - verify_optimizer_state_match(device, opts, checkpoint_dir, world_rank, use_lamb=True) + "device": {"id": device}, + "mixed_precision": {"enabled": True}, + "distributed": { + "world_rank": world_rank, + "world_size": world_size, + "allreduce_post_accumulation": True, + "deepspeed_zero_optimization": {"stage": 1}, + }, + "debug": {"deterministic_compute": True}, + } + verify_optimizer_state_match(device, opts, checkpoint_dir, world_rank, use_lamb=True) function_map = { # load to zero configs - 'test_optim_load_to_distributed_zero_full_precision_adam': test_optim_load_to_distributed_zero_full_precision_adam, - 'test_optim_load_to_distributed_zero_mixed_precision_adam': test_optim_load_to_distributed_zero_mixed_precision_adam, - 'test_optim_load_to_distributed_zero_mixed_precision_lamb': test_optim_load_to_distributed_zero_mixed_precision_lamb, - 'test_optim_load_to_distributed_zero_full_precision_lamb': test_optim_load_to_distributed_zero_full_precision_lamb + "test_optim_load_to_distributed_zero_full_precision_adam": test_optim_load_to_distributed_zero_full_precision_adam, + "test_optim_load_to_distributed_zero_mixed_precision_adam": test_optim_load_to_distributed_zero_mixed_precision_adam, + "test_optim_load_to_distributed_zero_mixed_precision_lamb": test_optim_load_to_distributed_zero_mixed_precision_lamb, + "test_optim_load_to_distributed_zero_full_precision_lamb": test_optim_load_to_distributed_zero_full_precision_lamb, } -parser = argparse.ArgumentParser(description='Test loading of initial optimizer state for Zero-1') -parser.add_argument('--scenario', choices=function_map.keys(), help='training scenario to test loaded states', required=True) -parser.add_argument('--checkpoint_dir', help='path to the saved states directory', required=True) +parser = argparse.ArgumentParser(description="Test loading of initial optimizer state for Zero-1") +parser.add_argument( + "--scenario", choices=function_map.keys(), help="training scenario to test loaded states", required=True +) +parser.add_argument("--checkpoint_dir", help="path to the saved states directory", required=True) args = parser.parse_args() function_map[args.scenario](checkpoint_dir=args.checkpoint_dir) diff --git a/orttraining/orttraining/test/python/checkpoint/orttraining_test_save_checkpoint.py b/orttraining/orttraining/test/python/checkpoint/orttraining_test_save_checkpoint.py index b00b77af81e13..7cd16ea48b0d7 100644 --- a/orttraining/orttraining/test/python/checkpoint/orttraining_test_save_checkpoint.py +++ b/orttraining/orttraining/test/python/checkpoint/orttraining_test_save_checkpoint.py @@ -9,405 +9,353 @@ import argparse import os import sys + sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) -from _test_helpers import distributed_setup, create_orttrainer_and_save_checkpoint, create_orttrainer_and_save_checkpoint_bart +from _test_helpers import ( + distributed_setup, + create_orttrainer_and_save_checkpoint, + create_orttrainer_and_save_checkpoint_bart, +) + -def single_node_full_precision(checkpoint_dir, device = 'cuda'): - opts = {'device' : {'id' : device}, - 'debug' : {'deterministic_compute': True}} +def single_node_full_precision(checkpoint_dir, device="cuda"): + opts = {"device": {"id": device}, "debug": {"deterministic_compute": True}} create_orttrainer_and_save_checkpoint(device, opts, checkpoint_dir) -def single_node_mixed_precision(checkpoint_dir, device = 'cuda'): - opts = { - 'device' : {'id' : device}, - 'mixed_precision': - { - 'enabled': True - }, - 'debug' : {'deterministic_compute': True} - } + +def single_node_mixed_precision(checkpoint_dir, device="cuda"): + opts = {"device": {"id": device}, "mixed_precision": {"enabled": True}, "debug": {"deterministic_compute": True}} create_orttrainer_and_save_checkpoint(device, opts, checkpoint_dir) -def single_node_full_precision_bart(checkpoint_dir, device = 'cuda'): - opts = {'device' : {'id' : device}, - 'debug' : {'deterministic_compute': True}} + +def single_node_full_precision_bart(checkpoint_dir, device="cuda"): + opts = {"device": {"id": device}, "debug": {"deterministic_compute": True}} create_orttrainer_and_save_checkpoint_bart(device, opts, checkpoint_dir) -def single_node_mixed_precision_bart(checkpoint_dir, device = 'cuda'): - opts = { - 'device' : {'id' : device}, - 'mixed_precision': - { - 'enabled': True - }, - 'debug' : {'deterministic_compute': True} - } + +def single_node_mixed_precision_bart(checkpoint_dir, device="cuda"): + opts = {"device": {"id": device}, "mixed_precision": {"enabled": True}, "debug": {"deterministic_compute": True}} create_orttrainer_and_save_checkpoint_bart(device, opts, checkpoint_dir) + @distributed_setup def data_parallelism_full_precision(world_rank, world_size, device, checkpoint_dir): opts = { - 'device' : {'id' : device}, - 'distributed' : - { - 'world_rank' : world_rank, - 'world_size' : world_size, - 'allreduce_post_accumulation' : True - }, - 'debug' : {'deterministic_compute': True} - } + "device": {"id": device}, + "distributed": {"world_rank": world_rank, "world_size": world_size, "allreduce_post_accumulation": True}, + "debug": {"deterministic_compute": True}, + } create_orttrainer_and_save_checkpoint(device, opts, checkpoint_dir if world_rank == 0 else None) + @distributed_setup def data_parallelism_mixed_precision(world_rank, world_size, device, checkpoint_dir): opts = { - 'device' : {'id' : device}, - 'mixed_precision': - { - 'enabled': True - }, - 'distributed' : - { - 'world_rank' : world_rank, - 'world_size' : world_size, - 'allreduce_post_accumulation' : True - }, - 'debug' : {'deterministic_compute': True} - } + "device": {"id": device}, + "mixed_precision": {"enabled": True}, + "distributed": {"world_rank": world_rank, "world_size": world_size, "allreduce_post_accumulation": True}, + "debug": {"deterministic_compute": True}, + } create_orttrainer_and_save_checkpoint(device, opts, checkpoint_dir if world_rank == 0 else None) + @distributed_setup def data_parallelism_full_precision_bart(world_rank, world_size, device, checkpoint_dir): opts = { - 'device' : {'id' : device}, - 'distributed' : - { - 'world_rank' : world_rank, - 'world_size' : world_size, - 'allreduce_post_accumulation' : True - }, - 'debug' : {'deterministic_compute': True} - } + "device": {"id": device}, + "distributed": {"world_rank": world_rank, "world_size": world_size, "allreduce_post_accumulation": True}, + "debug": {"deterministic_compute": True}, + } create_orttrainer_and_save_checkpoint_bart(device, opts, checkpoint_dir if world_rank == 0 else None) + @distributed_setup def data_parallelism_mixed_precision_bart(world_rank, world_size, device, checkpoint_dir): opts = { - 'device' : {'id' : device}, - 'mixed_precision': - { - 'enabled': True - }, - 'distributed' : - { - 'world_rank' : world_rank, - 'world_size' : world_size, - 'allreduce_post_accumulation' : True - }, - 'debug' : {'deterministic_compute': True} - } + "device": {"id": device}, + "mixed_precision": {"enabled": True}, + "distributed": {"world_rank": world_rank, "world_size": world_size, "allreduce_post_accumulation": True}, + "debug": {"deterministic_compute": True}, + } create_orttrainer_and_save_checkpoint_bart(device, opts, checkpoint_dir if world_rank == 0 else None) + @distributed_setup def distributed_zero_full_precision_adam(world_rank, world_size, device, checkpoint_dir): opts = { - 'device' : {'id' : device}, - 'distributed' : - { - 'world_rank' : world_rank, - 'world_size' : world_size, - 'allreduce_post_accumulation' : True, - 'deepspeed_zero_optimization': - { - 'stage': 1 - } - }, - 'debug' : {'deterministic_compute': True} - } - create_orttrainer_and_save_checkpoint(device, opts, checkpoint_dir, state_dict_key_name='state_dict_'+str(world_rank), use_lamb=False) + "device": {"id": device}, + "distributed": { + "world_rank": world_rank, + "world_size": world_size, + "allreduce_post_accumulation": True, + "deepspeed_zero_optimization": {"stage": 1}, + }, + "debug": {"deterministic_compute": True}, + } + create_orttrainer_and_save_checkpoint( + device, opts, checkpoint_dir, state_dict_key_name="state_dict_" + str(world_rank), use_lamb=False + ) + @distributed_setup def distributed_zero_mixed_precision_adam(world_rank, world_size, device, checkpoint_dir): opts = { - 'device' : {'id' : device}, - 'mixed_precision': - { - 'enabled': True - }, - 'distributed' : - { - 'world_rank' : world_rank, - 'world_size' : world_size, - 'allreduce_post_accumulation' : True, - 'deepspeed_zero_optimization': - { - 'stage': 1 - } - }, - 'debug' : {'deterministic_compute': True} - } - create_orttrainer_and_save_checkpoint(device, opts, checkpoint_dir, state_dict_key_name='state_dict_'+str(world_rank), use_lamb=False) + "device": {"id": device}, + "mixed_precision": {"enabled": True}, + "distributed": { + "world_rank": world_rank, + "world_size": world_size, + "allreduce_post_accumulation": True, + "deepspeed_zero_optimization": {"stage": 1}, + }, + "debug": {"deterministic_compute": True}, + } + create_orttrainer_and_save_checkpoint( + device, opts, checkpoint_dir, state_dict_key_name="state_dict_" + str(world_rank), use_lamb=False + ) + @distributed_setup def distributed_zero_full_precision_lamb(world_rank, world_size, device, checkpoint_dir): opts = { - 'device' : {'id' : device}, - 'distributed' : - { - 'world_rank' : world_rank, - 'world_size' : world_size, - 'allreduce_post_accumulation' : True, - 'deepspeed_zero_optimization': - { - 'stage': 1 - } - }, - 'debug' : {'deterministic_compute': True} - } - create_orttrainer_and_save_checkpoint(device, opts, checkpoint_dir, state_dict_key_name='state_dict_'+str(world_rank)) + "device": {"id": device}, + "distributed": { + "world_rank": world_rank, + "world_size": world_size, + "allreduce_post_accumulation": True, + "deepspeed_zero_optimization": {"stage": 1}, + }, + "debug": {"deterministic_compute": True}, + } + create_orttrainer_and_save_checkpoint( + device, opts, checkpoint_dir, state_dict_key_name="state_dict_" + str(world_rank) + ) + @distributed_setup def distributed_zero_mixed_precision_lamb(world_rank, world_size, device, checkpoint_dir): opts = { - 'device' : {'id' : device}, - 'mixed_precision': - { - 'enabled': True - }, - 'distributed' : - { - 'world_rank' : world_rank, - 'world_size' : world_size, - 'allreduce_post_accumulation' : True, - 'deepspeed_zero_optimization': - { - 'stage': 1 - } - }, - 'debug' : {'deterministic_compute': True} - } - create_orttrainer_and_save_checkpoint(device, opts, checkpoint_dir, state_dict_key_name='state_dict_'+str(world_rank)) + "device": {"id": device}, + "mixed_precision": {"enabled": True}, + "distributed": { + "world_rank": world_rank, + "world_size": world_size, + "allreduce_post_accumulation": True, + "deepspeed_zero_optimization": {"stage": 1}, + }, + "debug": {"deterministic_compute": True}, + } + create_orttrainer_and_save_checkpoint( + device, opts, checkpoint_dir, state_dict_key_name="state_dict_" + str(world_rank) + ) + @distributed_setup def distributed_zero_full_precision_lamb_bart(world_rank, world_size, device, checkpoint_dir): opts = { - 'device' : {'id' : device}, - 'distributed' : - { - 'world_rank' : world_rank, - 'world_size' : world_size, - 'allreduce_post_accumulation' : True, - 'deepspeed_zero_optimization': - { - 'stage': 1 - } - }, - 'debug' : {'deterministic_compute': True} - } - create_orttrainer_and_save_checkpoint_bart(device, opts, checkpoint_dir, state_dict_key_name='state_dict_'+str(world_rank)) + "device": {"id": device}, + "distributed": { + "world_rank": world_rank, + "world_size": world_size, + "allreduce_post_accumulation": True, + "deepspeed_zero_optimization": {"stage": 1}, + }, + "debug": {"deterministic_compute": True}, + } + create_orttrainer_and_save_checkpoint_bart( + device, opts, checkpoint_dir, state_dict_key_name="state_dict_" + str(world_rank) + ) + @distributed_setup def distributed_zero_mixed_precision_lamb_bart(world_rank, world_size, device, checkpoint_dir): opts = { - 'device' : {'id' : device}, - 'mixed_precision': - { - 'enabled': True - }, - 'distributed' : - { - 'world_rank' : world_rank, - 'world_size' : world_size, - 'allreduce_post_accumulation' : True, - 'deepspeed_zero_optimization': - { - 'stage': 1 - } - }, - 'debug' : {'deterministic_compute': True} - } - create_orttrainer_and_save_checkpoint_bart(device, opts, checkpoint_dir, state_dict_key_name='state_dict_'+str(world_rank)) + "device": {"id": device}, + "mixed_precision": {"enabled": True}, + "distributed": { + "world_rank": world_rank, + "world_size": world_size, + "allreduce_post_accumulation": True, + "deepspeed_zero_optimization": {"stage": 1}, + }, + "debug": {"deterministic_compute": True}, + } + create_orttrainer_and_save_checkpoint_bart( + device, opts, checkpoint_dir, state_dict_key_name="state_dict_" + str(world_rank) + ) @distributed_setup def distributed_megatron_full_precision_adam(world_rank, world_size, device, checkpoint_dir): opts = { - 'device' : {'id' : device}, - 'distributed' : - { - 'world_rank' : world_rank, - 'world_size' : world_size, - 'allreduce_post_accumulation' : True, - 'horizontal_parallel_size': world_size - }, - 'debug' : {'deterministic_compute': True} - } - create_orttrainer_and_save_checkpoint_bart(device, opts, checkpoint_dir, state_dict_key_name='state_dict_'+str(world_rank), use_lamb=False) + "device": {"id": device}, + "distributed": { + "world_rank": world_rank, + "world_size": world_size, + "allreduce_post_accumulation": True, + "horizontal_parallel_size": world_size, + }, + "debug": {"deterministic_compute": True}, + } + create_orttrainer_and_save_checkpoint_bart( + device, opts, checkpoint_dir, state_dict_key_name="state_dict_" + str(world_rank), use_lamb=False + ) + @distributed_setup def distributed_megatron_mixed_precision_adam(world_rank, world_size, device, checkpoint_dir): opts = { - 'device' : {'id' : device}, - 'mixed_precision': - { - 'enabled': True - }, - 'distributed' : - { - 'world_rank' : world_rank, - 'world_size' : world_size, - 'allreduce_post_accumulation' : True, - 'horizontal_parallel_size': world_size - }, - 'debug' : {'deterministic_compute': True} - } - create_orttrainer_and_save_checkpoint_bart(device, opts, checkpoint_dir, state_dict_key_name='state_dict_'+str(world_rank), use_lamb=False) + "device": {"id": device}, + "mixed_precision": {"enabled": True}, + "distributed": { + "world_rank": world_rank, + "world_size": world_size, + "allreduce_post_accumulation": True, + "horizontal_parallel_size": world_size, + }, + "debug": {"deterministic_compute": True}, + } + create_orttrainer_and_save_checkpoint_bart( + device, opts, checkpoint_dir, state_dict_key_name="state_dict_" + str(world_rank), use_lamb=False + ) + @distributed_setup def distributed_megatron_full_precision_lamb(world_rank, world_size, device, checkpoint_dir): opts = { - 'device' : {'id' : device}, - 'distributed' : - { - 'world_rank' : world_rank, - 'world_size' : world_size, - 'allreduce_post_accumulation' : True, - 'horizontal_parallel_size': world_size - }, - 'debug' : {'deterministic_compute': True} - } - create_orttrainer_and_save_checkpoint_bart(device, opts, checkpoint_dir, state_dict_key_name='state_dict_'+str(world_rank)) + "device": {"id": device}, + "distributed": { + "world_rank": world_rank, + "world_size": world_size, + "allreduce_post_accumulation": True, + "horizontal_parallel_size": world_size, + }, + "debug": {"deterministic_compute": True}, + } + create_orttrainer_and_save_checkpoint_bart( + device, opts, checkpoint_dir, state_dict_key_name="state_dict_" + str(world_rank) + ) + @distributed_setup def distributed_megatron_mixed_precision_lamb(world_rank, world_size, device, checkpoint_dir): opts = { - 'device' : {'id' : device}, - 'mixed_precision': - { - 'enabled': True - }, - 'distributed' : - { - 'world_rank' : world_rank, - 'world_size' : world_size, - 'allreduce_post_accumulation' : True, - 'horizontal_parallel_size': world_size - }, - 'debug' : {'deterministic_compute': True} - } - create_orttrainer_and_save_checkpoint_bart(device, opts, checkpoint_dir, state_dict_key_name='state_dict_'+str(world_rank)) + "device": {"id": device}, + "mixed_precision": {"enabled": True}, + "distributed": { + "world_rank": world_rank, + "world_size": world_size, + "allreduce_post_accumulation": True, + "horizontal_parallel_size": world_size, + }, + "debug": {"deterministic_compute": True}, + } + create_orttrainer_and_save_checkpoint_bart( + device, opts, checkpoint_dir, state_dict_key_name="state_dict_" + str(world_rank) + ) + @distributed_setup def distributed_zero_megatron_full_precision_adam(world_rank, world_size, device, checkpoint_dir): opts = { - 'device' : {'id' : device}, - 'distributed' : - { - 'world_rank' : world_rank, - 'world_size' : world_size, - 'allreduce_post_accumulation' : True, - 'deepspeed_zero_optimization': - { - 'stage': 1 - }, - 'horizontal_parallel_size': int(world_size/2) - }, - 'debug' : {'deterministic_compute': True} - } - create_orttrainer_and_save_checkpoint_bart(device, opts, checkpoint_dir, state_dict_key_name='state_dict_'+str(world_rank), use_lamb=False) + "device": {"id": device}, + "distributed": { + "world_rank": world_rank, + "world_size": world_size, + "allreduce_post_accumulation": True, + "deepspeed_zero_optimization": {"stage": 1}, + "horizontal_parallel_size": int(world_size / 2), + }, + "debug": {"deterministic_compute": True}, + } + create_orttrainer_and_save_checkpoint_bart( + device, opts, checkpoint_dir, state_dict_key_name="state_dict_" + str(world_rank), use_lamb=False + ) + @distributed_setup def distributed_zero_megatron_mixed_precision_adam(world_rank, world_size, device, checkpoint_dir): opts = { - 'device' : {'id' : device}, - 'mixed_precision': - { - 'enabled': True - }, - 'distributed' : - { - 'world_rank' : world_rank, - 'world_size' : world_size, - 'allreduce_post_accumulation' : True, - 'deepspeed_zero_optimization': - { - 'stage': 1 - }, - 'horizontal_parallel_size': int(world_size/2) - }, - 'debug' : {'deterministic_compute': True} - } - create_orttrainer_and_save_checkpoint_bart(device, opts, checkpoint_dir, state_dict_key_name='state_dict_'+str(world_rank), use_lamb=False) + "device": {"id": device}, + "mixed_precision": {"enabled": True}, + "distributed": { + "world_rank": world_rank, + "world_size": world_size, + "allreduce_post_accumulation": True, + "deepspeed_zero_optimization": {"stage": 1}, + "horizontal_parallel_size": int(world_size / 2), + }, + "debug": {"deterministic_compute": True}, + } + create_orttrainer_and_save_checkpoint_bart( + device, opts, checkpoint_dir, state_dict_key_name="state_dict_" + str(world_rank), use_lamb=False + ) + @distributed_setup def distributed_zero_megatron_full_precision_lamb(world_rank, world_size, device, checkpoint_dir): opts = { - 'device' : {'id' : device}, - 'distributed' : - { - 'world_rank' : world_rank, - 'world_size' : world_size, - 'allreduce_post_accumulation' : True, - 'deepspeed_zero_optimization': - { - 'stage': 1 - }, - 'horizontal_parallel_size': int(world_size/2) - }, - 'debug' : {'deterministic_compute': True} - } - create_orttrainer_and_save_checkpoint_bart(device, opts, checkpoint_dir, state_dict_key_name='state_dict_'+str(world_rank)) + "device": {"id": device}, + "distributed": { + "world_rank": world_rank, + "world_size": world_size, + "allreduce_post_accumulation": True, + "deepspeed_zero_optimization": {"stage": 1}, + "horizontal_parallel_size": int(world_size / 2), + }, + "debug": {"deterministic_compute": True}, + } + create_orttrainer_and_save_checkpoint_bart( + device, opts, checkpoint_dir, state_dict_key_name="state_dict_" + str(world_rank) + ) + @distributed_setup def distributed_zero_megatron_mixed_precision_lamb(world_rank, world_size, device, checkpoint_dir): opts = { - 'device' : {'id' : device}, - 'mixed_precision': - { - 'enabled': True - }, - 'distributed' : - { - 'world_rank' : world_rank, - 'world_size' : world_size, - 'allreduce_post_accumulation' : True, - 'deepspeed_zero_optimization': - { - 'stage': 1 - }, - 'horizontal_parallel_size': int(world_size/2) - }, - 'debug' : {'deterministic_compute': True} - } - create_orttrainer_and_save_checkpoint_bart(device, opts, checkpoint_dir, state_dict_key_name='state_dict_'+str(world_rank)) + "device": {"id": device}, + "mixed_precision": {"enabled": True}, + "distributed": { + "world_rank": world_rank, + "world_size": world_size, + "allreduce_post_accumulation": True, + "deepspeed_zero_optimization": {"stage": 1}, + "horizontal_parallel_size": int(world_size / 2), + }, + "debug": {"deterministic_compute": True}, + } + create_orttrainer_and_save_checkpoint_bart( + device, opts, checkpoint_dir, state_dict_key_name="state_dict_" + str(world_rank) + ) + function_map = { - 'single_node_full_precision': single_node_full_precision, - 'single_node_mixed_precision': single_node_mixed_precision, - 'single_node_full_precision_bart': single_node_full_precision_bart, - 'single_node_mixed_precision_bart': single_node_mixed_precision_bart, - 'data_parallelism_full_precision': data_parallelism_full_precision, - 'data_parallelism_mixed_precision': data_parallelism_mixed_precision, - 'data_parallelism_full_precision_bart': data_parallelism_full_precision_bart, - 'data_parallelism_mixed_precision_bart': data_parallelism_mixed_precision_bart, - 'distributed_zero_full_precision_adam': distributed_zero_full_precision_adam, - 'distributed_zero_mixed_precision_adam': distributed_zero_mixed_precision_adam, - 'distributed_zero_full_precision_lamb': distributed_zero_full_precision_lamb, - 'distributed_zero_mixed_precision_lamb': distributed_zero_mixed_precision_lamb, - 'distributed_zero_full_precision_lamb_bart': distributed_zero_full_precision_lamb_bart, - 'distributed_zero_mixed_precision_lamb_bart': distributed_zero_mixed_precision_lamb_bart, - 'distributed_megatron_full_precision_adam': distributed_megatron_full_precision_adam, - 'distributed_megatron_mixed_precision_adam': distributed_megatron_mixed_precision_adam, - 'distributed_megatron_full_precision_lamb': distributed_megatron_full_precision_lamb, - 'distributed_megatron_mixed_precision_lamb': distributed_megatron_mixed_precision_lamb, - 'distributed_zero_megatron_full_precision_adam': distributed_zero_megatron_full_precision_adam, - 'distributed_zero_megatron_mixed_precision_adam': distributed_zero_megatron_mixed_precision_adam, - 'distributed_zero_megatron_full_precision_lamb': distributed_zero_megatron_full_precision_lamb, - 'distributed_zero_megatron_mixed_precision_lamb': distributed_zero_megatron_mixed_precision_lamb + "single_node_full_precision": single_node_full_precision, + "single_node_mixed_precision": single_node_mixed_precision, + "single_node_full_precision_bart": single_node_full_precision_bart, + "single_node_mixed_precision_bart": single_node_mixed_precision_bart, + "data_parallelism_full_precision": data_parallelism_full_precision, + "data_parallelism_mixed_precision": data_parallelism_mixed_precision, + "data_parallelism_full_precision_bart": data_parallelism_full_precision_bart, + "data_parallelism_mixed_precision_bart": data_parallelism_mixed_precision_bart, + "distributed_zero_full_precision_adam": distributed_zero_full_precision_adam, + "distributed_zero_mixed_precision_adam": distributed_zero_mixed_precision_adam, + "distributed_zero_full_precision_lamb": distributed_zero_full_precision_lamb, + "distributed_zero_mixed_precision_lamb": distributed_zero_mixed_precision_lamb, + "distributed_zero_full_precision_lamb_bart": distributed_zero_full_precision_lamb_bart, + "distributed_zero_mixed_precision_lamb_bart": distributed_zero_mixed_precision_lamb_bart, + "distributed_megatron_full_precision_adam": distributed_megatron_full_precision_adam, + "distributed_megatron_mixed_precision_adam": distributed_megatron_mixed_precision_adam, + "distributed_megatron_full_precision_lamb": distributed_megatron_full_precision_lamb, + "distributed_megatron_mixed_precision_lamb": distributed_megatron_mixed_precision_lamb, + "distributed_zero_megatron_full_precision_adam": distributed_zero_megatron_full_precision_adam, + "distributed_zero_megatron_mixed_precision_adam": distributed_zero_megatron_mixed_precision_adam, + "distributed_zero_megatron_full_precision_lamb": distributed_zero_megatron_full_precision_lamb, + "distributed_zero_megatron_mixed_precision_lamb": distributed_zero_megatron_mixed_precision_lamb, } -parser = argparse.ArgumentParser(description='Save states of trainers') -parser.add_argument('--scenario', choices=function_map.keys(), help='training scenario to save states', required=True) -parser.add_argument('--checkpoint_dir', help='path to the directory where checkpoints can be saved', required=True) +parser = argparse.ArgumentParser(description="Save states of trainers") +parser.add_argument("--scenario", choices=function_map.keys(), help="training scenario to save states", required=True) +parser.add_argument("--checkpoint_dir", help="path to the directory where checkpoints can be saved", required=True) args = parser.parse_args() -function_map[args.scenario](checkpoint_dir = args.checkpoint_dir) +function_map[args.scenario](checkpoint_dir=args.checkpoint_dir) diff --git a/orttraining/orttraining/test/python/dhp_parallel/orttraining_test_parallel_train_simple_model.py b/orttraining/orttraining/test/python/dhp_parallel/orttraining_test_parallel_train_simple_model.py index a4ce3d8f1c4a7..0c4e031532bfd 100644 --- a/orttraining/orttraining/test/python/dhp_parallel/orttraining_test_parallel_train_simple_model.py +++ b/orttraining/orttraining/test/python/dhp_parallel/orttraining_test_parallel_train_simple_model.py @@ -2,6 +2,7 @@ # Otherwise, "import onnxruntime" may fail. import os import sys + sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) import onnxruntime @@ -16,6 +17,7 @@ torch.manual_seed(0) + class Mlp(nn.Module): def __init__(self, d_in, d_hidden, d_out): super(Mlp, self).__init__() @@ -47,7 +49,7 @@ def forward(self, x): y = torch.randint(0, d_out, (n,)) # Modeling. -loss = nn.CrossEntropyLoss(reduction='sum') +loss = nn.CrossEntropyLoss(reduction="sum") model = Mlp(d_in, d_hidden, d_out) @@ -63,44 +65,43 @@ def apply_loss(p, y): # Compute batch size for micro-batches. n_slice = int(n / num_pipeline_steps) -cuda_device = 'cuda:' + str(rank) +cuda_device = "cuda:" + str(rank) # Schema used when running the original batch. -schema = {'inputs': [('x', ['n', 'd_in']), ('target', ['n'])], 'outputs': [ - ('loss', [], True), ('output', ['n', d_out])]} +schema = { + "inputs": [("x", ["n", "d_in"]), ("target", ["n"])], + "outputs": [("loss", [], True), ("output", ["n", d_out])], +} # Actual schema used when running micro-batches. -pipeline_schema = {'x': [n_slice, d_in], 'target': [ - n_slice], 'output': [n_slice, d_out], 'loss': []} +pipeline_schema = {"x": [n_slice, d_in], "target": [n_slice], "output": [n_slice, d_out], "loss": []} # Describe which axis to slice along for each sliced tensor. -sliced_axes = {'x': 0, 'target': 0, 'output': 0} +sliced_axes = {"x": 0, "target": 0, "output": 0} adam_config = optim.AdamConfig(lr=0.1) # # Specify configuration for pipeline parallel training. -trainer_config = ORTTrainerOptions({ - 'batch': { - 'gradient_accumulation_steps': num_pipeline_steps - }, - 'device': { - 'id': cuda_device - }, - 'distributed': { - 'world_size': total_ranks, - 'world_rank': rank, - 'data_parallel_size': int(total_ranks / num_pipeline_stages), - 'horizontal_parallel_size': 1, - 'pipeline_parallel': { - 'pipeline_parallel_size': int(num_pipeline_stages), - 'num_pipeline_micro_batches': num_pipeline_steps, - 'sliced_schema': pipeline_schema, - 'sliced_axes': sliced_axes, - 'sliced_tensor_names': ['x', 'target', 'output'], - # Define pipeline stage partition by specifying cut points. - # 2-stage cut. It's a cut on tensor "12". - 'pipeline_cut_info_string': '12' +trainer_config = ORTTrainerOptions( + { + "batch": {"gradient_accumulation_steps": num_pipeline_steps}, + "device": {"id": cuda_device}, + "distributed": { + "world_size": total_ranks, + "world_rank": rank, + "data_parallel_size": int(total_ranks / num_pipeline_stages), + "horizontal_parallel_size": 1, + "pipeline_parallel": { + "pipeline_parallel_size": int(num_pipeline_stages), + "num_pipeline_micro_batches": num_pipeline_steps, + "sliced_schema": pipeline_schema, + "sliced_axes": sliced_axes, + "sliced_tensor_names": ["x", "target", "output"], + # Define pipeline stage partition by specifying cut points. + # 2-stage cut. It's a cut on tensor "12". + "pipeline_cut_info_string": "12", + }, + "allreduce_post_accumulation": True, }, - 'allreduce_post_accumulation': True } -}) +) trainer = ORTTrainer(model, schema, adam_config, apply_loss, trainer_config) @@ -119,4 +120,4 @@ def apply_loss(p, y): expected_loss_history = [0.8660, 1.1219, 1.6610, 1.2641, 1.0162] if rank in last_pipeline_stage_ranks: for result, expected in zip(loss_history, expected_loss_history): - assert torch.allclose(result.cpu(), torch.Tensor([expected], device='cpu'), 1e-03) + assert torch.allclose(result.cpu(), torch.Tensor([expected], device="cpu"), 1e-03) diff --git a/orttraining/orttraining/test/python/dhp_parallel/orttraining_test_parallel_train_simple_model_fp16.py b/orttraining/orttraining/test/python/dhp_parallel/orttraining_test_parallel_train_simple_model_fp16.py index 988c9b38b5b90..02d1073ea1672 100644 --- a/orttraining/orttraining/test/python/dhp_parallel/orttraining_test_parallel_train_simple_model_fp16.py +++ b/orttraining/orttraining/test/python/dhp_parallel/orttraining_test_parallel_train_simple_model_fp16.py @@ -2,6 +2,7 @@ # Otherwise, "import onnxruntime" may fail. import os import sys + sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) import onnxruntime @@ -16,6 +17,7 @@ torch.manual_seed(0) + class Mlp(nn.Module): def __init__(self, d_in, d_hidden, d_out): super(Mlp, self).__init__() @@ -47,7 +49,7 @@ def forward(self, x): y = torch.randint(0, d_out, (n,)) # Modeling. -loss = nn.CrossEntropyLoss(reduction='sum') +loss = nn.CrossEntropyLoss(reduction="sum") model = Mlp(d_in, d_hidden, d_out) @@ -63,48 +65,44 @@ def apply_loss(p, y): # Compute batch size for micro-batches. n_slice = int(n / num_pipeline_steps) -cuda_device = 'cuda:' + str(rank) +cuda_device = "cuda:" + str(rank) # Schema used when running the original batch. -schema = {'inputs': [('x', ['n', 'd_in']), ('target', ['n'])], 'outputs': [ - ('loss', [], True), ('output', ['n', d_out])]} +schema = { + "inputs": [("x", ["n", "d_in"]), ("target", ["n"])], + "outputs": [("loss", [], True), ("output", ["n", d_out])], +} # Actual schema used when running micro-batches. -pipeline_schema = {'x': [n_slice, d_in], 'target': [ - n_slice], 'output': [n_slice, d_out], 'loss': []} +pipeline_schema = {"x": [n_slice, d_in], "target": [n_slice], "output": [n_slice, d_out], "loss": []} # Describe which axis to slice along for each sliced tensor. -sliced_axes = {'x': 0, 'target': 0, 'output': 0} +sliced_axes = {"x": 0, "target": 0, "output": 0} adam_config = optim.AdamConfig(lr=0.1) # Specify configuration for pipeline parallel training. -trainer_config = ORTTrainerOptions({ - 'batch': { - 'gradient_accumulation_steps': num_pipeline_steps - }, - 'device': { - 'id': cuda_device - }, - 'mixed_precision': { - 'enabled': True, - 'loss_scaler': amp.DynamicLossScaler() - }, - 'distributed': { - 'world_size': total_ranks, - 'world_rank': rank, - 'data_parallel_size': int(total_ranks / num_pipeline_stages), - 'horizontal_parallel_size': 1, - 'pipeline_parallel': { - 'pipeline_parallel_size': int(num_pipeline_stages), - 'num_pipeline_micro_batches': num_pipeline_steps, - 'sliced_schema': pipeline_schema, - 'sliced_axes': sliced_axes, - 'sliced_tensor_names': ['x', 'target', 'output'], - # Define pipeline stage partition by specifying cut points. - # 2-stage cut. It's a cut on tensor "12". - 'pipeline_cut_info_string': '12' +trainer_config = ORTTrainerOptions( + { + "batch": {"gradient_accumulation_steps": num_pipeline_steps}, + "device": {"id": cuda_device}, + "mixed_precision": {"enabled": True, "loss_scaler": amp.DynamicLossScaler()}, + "distributed": { + "world_size": total_ranks, + "world_rank": rank, + "data_parallel_size": int(total_ranks / num_pipeline_stages), + "horizontal_parallel_size": 1, + "pipeline_parallel": { + "pipeline_parallel_size": int(num_pipeline_stages), + "num_pipeline_micro_batches": num_pipeline_steps, + "sliced_schema": pipeline_schema, + "sliced_axes": sliced_axes, + "sliced_tensor_names": ["x", "target", "output"], + # Define pipeline stage partition by specifying cut points. + # 2-stage cut. It's a cut on tensor "12". + "pipeline_cut_info_string": "12", + }, + "allreduce_post_accumulation": True, }, - 'allreduce_post_accumulation': True } -}) +) trainer = ORTTrainer(model, schema, adam_config, apply_loss, trainer_config) @@ -113,7 +111,7 @@ def apply_loss(p, y): l, p = trainer.train_step(x.to(cuda_device), y.to(cuda_device)) loss_history.append(l) -print('loss history: ', loss_history) +print("loss history: ", loss_history) # Valid ranks are [0, 1, 2, 3]. # [0, 2] forms the 2-stage pipeline in the 1st data parallel group. @@ -125,4 +123,4 @@ def apply_loss(p, y): expected_loss_history = [0.9420, 0.6608, 0.8944, 1.2279, 1.1173] if rank in last_pipeline_stage_ranks: for result, expected in zip(loss_history, expected_loss_history): - assert torch.allclose(result.cpu(), torch.Tensor([expected], device='cpu'), 1e-03) + assert torch.allclose(result.cpu(), torch.Tensor([expected], device="cpu"), 1e-03) diff --git a/orttraining/orttraining/test/python/launch_test.py b/orttraining/orttraining/test/python/launch_test.py index 1b354fb33bb39..d183f3189511c 100755 --- a/orttraining/orttraining/test/python/launch_test.py +++ b/orttraining/orttraining/test/python/launch_test.py @@ -10,30 +10,31 @@ import logging -logging.basicConfig( - format="%(asctime)s %(name)s [%(levelname)s] - %(message)s", - level=logging.DEBUG) +logging.basicConfig(format="%(asctime)s %(name)s [%(levelname)s] - %(message)s", level=logging.DEBUG) log = logging.getLogger("Build") + def parse_arguments(): parser = argparse.ArgumentParser() parser.add_argument( "--cmd_line_with_args", required=True, help="command line with arguments to be executed in a subprocess. \ - it expects a single string containing arguments separated by spaces.") + it expects a single string containing arguments separated by spaces.", + ) parser.add_argument("--cwd", help="working directory") # parser.add_argument("--env", help="env variables.") parser.add_argument("--env", help="env variables", nargs=2, action="append", default=[]) return parser.parse_args() + launch_args = parse_arguments() print("sys.executable: ", sys.executable) cmd_line_with_args = launch_args.cmd_line_with_args.split() for n, arg in enumerate(cmd_line_with_args): - if arg == 'python': + if arg == "python": cmd_line_with_args[n] = sys.executable run_subprocess(cmd_line_with_args, cwd=launch_args.cwd, env=dict(launch_args.env), log=log) diff --git a/orttraining/orttraining/test/python/onnxruntime_test_postprocess.py b/orttraining/orttraining/test/python/onnxruntime_test_postprocess.py index b10bfcd6b0e05..e91dd480b7f8c 100644 --- a/orttraining/orttraining/test/python/onnxruntime_test_postprocess.py +++ b/orttraining/orttraining/test/python/onnxruntime_test_postprocess.py @@ -21,22 +21,32 @@ torch.manual_seed(1) onnxruntime.set_seed(1) + class Test_PostPasses(unittest.TestCase): - def get_onnx_model(self, model, model_desc, inputs, device, - _enable_internal_postprocess=True, _extra_postprocess=None): - lr_desc = IODescription('Learning_Rate', [1,], torch.float32) - model = ORTTrainer(model, - None, - model_desc, - "LambOptimizer", - map_optimizer_attributes, - lr_desc, - device, - world_rank=0, - world_size=1, - _opset_version=14, - _enable_internal_postprocess=_enable_internal_postprocess, - _extra_postprocess=_extra_postprocess) + def get_onnx_model( + self, model, model_desc, inputs, device, _enable_internal_postprocess=True, _extra_postprocess=None + ): + lr_desc = IODescription( + "Learning_Rate", + [ + 1, + ], + torch.float32, + ) + model = ORTTrainer( + model, + None, + model_desc, + "LambOptimizer", + map_optimizer_attributes, + lr_desc, + device, + world_rank=0, + world_size=1, + _opset_version=14, + _enable_internal_postprocess=_enable_internal_postprocess, + _extra_postprocess=_extra_postprocess, + ) train_output = model.train_step(*inputs) return model.onnx_model_ @@ -89,13 +99,13 @@ def forward(self, x): model = LayerNormNet(target) input = torch.randn(20, 5, 10, 10, dtype=torch.float32).to(device) - input_desc = IODescription('input', [], "float32") - output0_desc = IODescription('output0', [], "float32") - output1_desc = IODescription('output1', [20, 5, 10, 10], "float32") + input_desc = IODescription("input", [], "float32") + output0_desc = IODescription("output0", [], "float32") + output1_desc = IODescription("output1", [20, 5, 10, 10], "float32") model_desc = ModelDescription([input_desc], [output0_desc, output1_desc]) - learning_rate = torch.tensor([1.0000000e+00]).to(device) - input_args=[input, learning_rate] + learning_rate = torch.tensor([1.0000000e00]).to(device) + input_args = [input, learning_rate] onnx_model = self.get_onnx_model(model, model_desc, input_args, device) @@ -127,13 +137,13 @@ def forward(self, x, x1): x = torch.randn(5, 3, 1, 2, dtype=torch.float32).to(device) x1 = torch.randn(5, 3, 5, 2, dtype=torch.float32).to(device) - input0_desc = IODescription('x', [5, 3, 1, 2], "float32") - input1_desc = IODescription('x1', [5, 3, 5, 2], "float32") - output0_desc = IODescription('output0', [], "float32") - output1_desc = IODescription('output1', [5, 3, 5, 2], "float32") + input0_desc = IODescription("x", [5, 3, 1, 2], "float32") + input1_desc = IODescription("x1", [5, 3, 5, 2], "float32") + output0_desc = IODescription("output0", [], "float32") + output1_desc = IODescription("output1", [5, 3, 5, 2], "float32") model_desc = ModelDescription([input0_desc, input1_desc], [output0_desc, output1_desc]) - learning_rate = torch.tensor([1.0000000e+00]).to(device) + learning_rate = torch.tensor([1.0000000e00]).to(device) input_args = [x, x1, learning_rate] onnx_model = self.get_onnx_model(model, model_desc, input_args, device) @@ -150,49 +160,63 @@ def test_bert(self): device = torch.device("cpu") model_tester = BertModelTest.BertModelTester(self) - config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels = model_tester.prepare_config_and_inputs() + ( + config, + input_ids, + token_type_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + ) = model_tester.prepare_config_and_inputs() model = BertForPreTraining(config=config) model.eval() - loss, prediction_scores, seq_relationship_score = model(input_ids, - attention_mask=input_mask, - token_type_ids=token_type_ids, - masked_lm_labels=token_labels, - next_sentence_label=sequence_labels) - - model_desc = ModelDescription([model_tester.input_ids_desc, - model_tester.attention_mask_desc, - model_tester.token_type_ids_desc, - model_tester.masked_lm_labels_desc, - model_tester.next_sentence_label_desc], - [model_tester.loss_desc, - model_tester.prediction_scores_desc, - model_tester.seq_relationship_scores_desc]) + loss, prediction_scores, seq_relationship_score = model( + input_ids, + attention_mask=input_mask, + token_type_ids=token_type_ids, + masked_lm_labels=token_labels, + next_sentence_label=sequence_labels, + ) + + model_desc = ModelDescription( + [ + model_tester.input_ids_desc, + model_tester.attention_mask_desc, + model_tester.token_type_ids_desc, + model_tester.masked_lm_labels_desc, + model_tester.next_sentence_label_desc, + ], + [model_tester.loss_desc, model_tester.prediction_scores_desc, model_tester.seq_relationship_scores_desc], + ) from collections import namedtuple - MyArgs = namedtuple("MyArgs", - "local_rank world_size max_steps learning_rate warmup_proportion batch_size seq_len") - args = MyArgs(local_rank=0, - world_size=1, - max_steps=100, - learning_rate=0.00001, - warmup_proportion=0.01, - batch_size=13, - seq_len=7) - + + MyArgs = namedtuple( + "MyArgs", "local_rank world_size max_steps learning_rate warmup_proportion batch_size seq_len" + ) + args = MyArgs( + local_rank=0, + world_size=1, + max_steps=100, + learning_rate=0.00001, + warmup_proportion=0.01, + batch_size=13, + seq_len=7, + ) + dataset_len = 100 - dataloader = create_ort_test_dataloader(model_desc.inputs_, - args.batch_size, - args.seq_len, - dataset_len, - device) - learning_rate = torch.tensor(1.0e+0, dtype=torch.float32).to(device) + dataloader = create_ort_test_dataloader(model_desc.inputs_, args.batch_size, args.seq_len, dataset_len, device) + learning_rate = torch.tensor(1.0e0, dtype=torch.float32).to(device) for b in dataloader: batch = b break - learning_rate = torch.tensor([1.00e+00]).to(device) - inputs = batch + [learning_rate,] + learning_rate = torch.tensor([1.00e00]).to(device) + inputs = batch + [ + learning_rate, + ] onnx_model = self.get_onnx_model(model, model_desc, inputs, device, _extra_postprocess=postprocess_model) @@ -233,7 +257,7 @@ def postpass_replace_first_add_with_sub(model): # Sub # | # (subgraph 3) - add_nodes = [n for n in model.graph.node if n.op_type == 'Add'] + add_nodes = [n for n in model.graph.node if n.op_type == "Add"] add_nodes[0].op_type = "Sub" class MultiAdd(nn.Module): @@ -258,17 +282,18 @@ def forward(self, x, x1): x = torch.randn(5, 5, 2, dtype=torch.float32).to(device) x1 = torch.randn(5, 5, 2, dtype=torch.float32).to(device) - input0_desc = IODescription('x', [5, 5, 2], "float32") - input1_desc = IODescription('x1', [5, 5, 2], "float32") - output0_desc = IODescription('output0', [], "float32") - output1_desc = IODescription('output1', [5, 5, 2], "float32") + input0_desc = IODescription("x", [5, 5, 2], "float32") + input1_desc = IODescription("x1", [5, 5, 2], "float32") + output0_desc = IODescription("output0", [], "float32") + output1_desc = IODescription("output1", [5, 5, 2], "float32") model_desc = ModelDescription([input0_desc, input1_desc], [output0_desc, output1_desc]) - learning_rate = torch.tensor([1.0000000e+00]).to(device) + learning_rate = torch.tensor([1.0000000e00]).to(device) input_args = [x, x1, learning_rate] - onnx_model = self.get_onnx_model(model, model_desc, input_args, device, - _extra_postprocess=postpass_replace_first_add_with_sub) + onnx_model = self.get_onnx_model( + model, model_desc, input_args, device, _extra_postprocess=postpass_replace_first_add_with_sub + ) # check that extra postpass is called, and called only once. add_nodes = self.find_nodes(onnx_model, "Add") @@ -276,17 +301,22 @@ def forward(self, x, x1): assert len(add_nodes) == 2 assert len(sub_nodes) == 1 - - unprocessed_onnx_model = self.get_onnx_model(model, model_desc, input_args, device, - _extra_postprocess=None, _enable_internal_postprocess=False) + unprocessed_onnx_model = self.get_onnx_model( + model, model_desc, input_args, device, _extra_postprocess=None, _enable_internal_postprocess=False + ) # check that the model is unchanged. add_nodes = self.find_nodes(unprocessed_onnx_model, "Add") sub_nodes = self.find_nodes(unprocessed_onnx_model, "Sub") assert len(add_nodes) == 3 assert len(sub_nodes) == 0 - processed_onnx_model = self.get_onnx_model(unprocessed_onnx_model, model_desc, input_args, device, - _extra_postprocess=postpass_replace_first_add_with_sub) + processed_onnx_model = self.get_onnx_model( + unprocessed_onnx_model, + model_desc, + input_args, + device, + _extra_postprocess=postpass_replace_first_add_with_sub, + ) # check that extra postpass is called, and called only once. add_nodes = self.find_nodes(processed_onnx_model, "Add") sub_nodes = self.find_nodes(processed_onnx_model, "Sub") @@ -294,5 +324,5 @@ def forward(self, x, x1): assert len(sub_nodes) == 1 -if __name__ == '__main__': +if __name__ == "__main__": unittest.main(module=__name__, buffer=True) diff --git a/orttraining/orttraining/test/python/onnxruntime_test_register_ep.py b/orttraining/orttraining/test/python/onnxruntime_test_register_ep.py index 7b9d87dcc6512..5f71125cff413 100644 --- a/orttraining/orttraining/test/python/onnxruntime_test_register_ep.py +++ b/orttraining/orttraining/test/python/onnxruntime_test_register_ep.py @@ -2,28 +2,28 @@ import onnxruntime_pybind11_state as C import os + class EPRegistrationTests(unittest.TestCase): - def get_test_execution_provider_path(self): - return os.path.join('.', 'libtest_execution_provider.so') - - def test_register_custom_eps(self): - C._register_provider_lib('TestExecutionProvider', self.get_test_execution_provider_path(), {'some_config':'val'}) - - assert 'TestExecutionProvider' in C.get_available_providers() - - this = os.path.dirname(__file__) - custom_op_model = os.path.join(this, "testdata", "custom_execution_provider_library", "test_model.onnx") - if not os.path.exists(custom_op_model): - raise FileNotFoundError("Unable to find '{0}'".format(custom_op_model)) - - session_options = C.get_default_session_options() - sess = C.InferenceSession(session_options, custom_op_model, True, True) - sess.initialize_session(['TestExecutionProvider'], - [{'device_id':'0'}], - set()) - print("Created session with customize execution provider successfully!") - - -if __name__ == '__main__': - unittest.main() + def get_test_execution_provider_path(self): + return os.path.join(".", "libtest_execution_provider.so") + + def test_register_custom_eps(self): + C._register_provider_lib( + "TestExecutionProvider", self.get_test_execution_provider_path(), {"some_config": "val"} + ) + + assert "TestExecutionProvider" in C.get_available_providers() + + this = os.path.dirname(__file__) + custom_op_model = os.path.join(this, "testdata", "custom_execution_provider_library", "test_model.onnx") + if not os.path.exists(custom_op_model): + raise FileNotFoundError("Unable to find '{0}'".format(custom_op_model)) + + session_options = C.get_default_session_options() + sess = C.InferenceSession(session_options, custom_op_model, True, True) + sess.initialize_session(["TestExecutionProvider"], [{"device_id": "0"}], set()) + print("Created session with customize execution provider successfully!") + +if __name__ == "__main__": + unittest.main() diff --git a/orttraining/orttraining/test/python/orttraining_distributed_tests.py b/orttraining/orttraining/test/python/orttraining_distributed_tests.py index f4d2da6a98f94..f8e1dca370a1d 100644 --- a/orttraining/orttraining/test/python/orttraining_distributed_tests.py +++ b/orttraining/orttraining/test/python/orttraining_distributed_tests.py @@ -9,9 +9,7 @@ import logging -logging.basicConfig( - format="%(asctime)s %(name)s [%(levelname)s] - %(message)s", - level=logging.DEBUG) +logging.basicConfig(format="%(asctime)s %(name)s [%(levelname)s] - %(message)s", level=logging.DEBUG) log = logging.getLogger("DistributedTests") @@ -22,28 +20,32 @@ def parse_arguments(): def run_checkpoint_tests(cwd, log): - log.debug('Running: Checkpoint tests') + log.debug("Running: Checkpoint tests") - command = [sys.executable, 'orttraining_test_checkpoint.py'] + command = [sys.executable, "orttraining_test_checkpoint.py"] run_subprocess(command, cwd=cwd, log=log).check_returncode() + def run_distributed_allreduce_tests(cwd, log): - log.debug('Running: distributed allreduce tests') + log.debug("Running: distributed allreduce tests") - command = [sys.executable, 'orttraining_test_allreduce.py'] + command = [sys.executable, "orttraining_test_allreduce.py"] run_subprocess(command, cwd=cwd, log=log).check_returncode() + def run_pipeline_parallel_tests(cwd, log): - log.debug('Running: pipeline parallel tests') + log.debug("Running: pipeline parallel tests") - command = [sys.executable, 'orttraining_test_dhp_parallel_tests.py'] + command = [sys.executable, "orttraining_test_dhp_parallel_tests.py"] run_subprocess(command, cwd=cwd, log=log).check_returncode() + def main(): import torch + ngpus = torch.cuda.device_count() if ngpus < 2: diff --git a/orttraining/orttraining/test/python/orttraining_ortmodule_distributed_tests.py b/orttraining/orttraining/test/python/orttraining_ortmodule_distributed_tests.py index b03af9099aa8f..4f778444b88f0 100644 --- a/orttraining/orttraining/test/python/orttraining_ortmodule_distributed_tests.py +++ b/orttraining/orttraining/test/python/orttraining_ortmodule_distributed_tests.py @@ -9,9 +9,7 @@ import logging -logging.basicConfig( - format="%(asctime)s %(name)s [%(levelname)s] - %(message)s", - level=logging.DEBUG) +logging.basicConfig(format="%(asctime)s %(name)s [%(levelname)s] - %(message)s", level=logging.DEBUG) log = logging.getLogger("ORTModuleDistributedTests") @@ -23,40 +21,56 @@ def parse_arguments(): def run_ortmodule_deepspeed_zero_stage_1_tests(cwd, log, data_dir): - log.debug('Running: ORTModule deepspeed zero stage 1 tests') + log.debug("Running: ORTModule deepspeed zero stage 1 tests") - command = ['deepspeed', 'orttraining_test_ortmodule_deepspeed_zero_stage_1.py', - '--deepspeed_config', 'orttraining_test_ortmodule_deepspeed_zero_stage_1_config.json'] + command = [ + "deepspeed", + "orttraining_test_ortmodule_deepspeed_zero_stage_1.py", + "--deepspeed_config", + "orttraining_test_ortmodule_deepspeed_zero_stage_1_config.json", + ] if data_dir: - command.extend(['--data-dir', data_dir]) + command.extend(["--data-dir", data_dir]) run_subprocess(command, cwd=cwd, log=log).check_returncode() + def run_pytorch_ddp_tests(cwd, log): - log.debug('Running: ORTModule Pytorch DDP tests') + log.debug("Running: ORTModule Pytorch DDP tests") - command = [sys.executable, 'orttraining_test_ortmodule_pytorch_ddp.py', '--use_ort_module'] + command = [sys.executable, "orttraining_test_ortmodule_pytorch_ddp.py", "--use_ort_module"] run_subprocess(command, cwd=cwd, log=log).check_returncode() + def run_ortmodule_deepspeed_pipeline_parallel_tests(cwd, log): - log.debug('Running: ORTModule deepspeed pipeline parallel tests') + log.debug("Running: ORTModule deepspeed pipeline parallel tests") - command = ['deepspeed', 'orttraining_test_ortmodule_deepspeed_pipeline_parallel.py', - '--deepspeed_config', 'orttraining_test_ortmodule_deepspeed_pipeline_parallel_config.json'] + command = [ + "deepspeed", + "orttraining_test_ortmodule_deepspeed_pipeline_parallel.py", + "--deepspeed_config", + "orttraining_test_ortmodule_deepspeed_pipeline_parallel_config.json", + ] run_subprocess(command, cwd=cwd, log=log).check_returncode() + def run_ortmodule_fairscale_sharded_optimizer_tests(cwd, log, data_dir): - log.debug('Running: ORTModule fairscale sharded optimizer tests') - command = ['python3', 'orttraining_test_ortmodule_fairscale_sharded_optimizer.py', - '--use_sharded_optimizer', '--use_ortmodule'] + log.debug("Running: ORTModule fairscale sharded optimizer tests") + command = [ + "python3", + "orttraining_test_ortmodule_fairscale_sharded_optimizer.py", + "--use_sharded_optimizer", + "--use_ortmodule", + ] if data_dir: - command.extend(['--data-dir', data_dir]) + command.extend(["--data-dir", data_dir]) run_subprocess(command, cwd=cwd, log=log).check_returncode() + def main(): args = parse_arguments() cwd = args.cwd diff --git a/orttraining/orttraining/test/python/orttraining_ortmodule_tests.py b/orttraining/orttraining/test/python/orttraining_ortmodule_tests.py index 3e6dcc1cfe643..3b03829cc25e8 100644 --- a/orttraining/orttraining/test/python/orttraining_ortmodule_tests.py +++ b/orttraining/orttraining/test/python/orttraining_ortmodule_tests.py @@ -9,9 +9,7 @@ import logging -logging.basicConfig( - format="%(asctime)s %(name)s [%(levelname)s] - %(message)s", - level=logging.DEBUG) +logging.basicConfig(format="%(asctime)s %(name)s [%(levelname)s] - %(message)s", level=logging.DEBUG) log = logging.getLogger("ORTModuleTests") @@ -20,103 +18,111 @@ def parse_arguments(): parser.add_argument("--cwd", help="Path to the current working directory") parser.add_argument("--mnist", help="Path to the mnist data directory", type=str, default=None) parser.add_argument("--bert_data", help="Path to the bert data directory", type=str, default=None) - parser.add_argument("--transformers_cache", help="Path to the transformers model cache directory", type=str, default=None) + parser.add_argument( + "--transformers_cache", help="Path to the transformers model cache directory", type=str, default=None + ) return parser.parse_args() def get_env_with_transformers_cache(transformers_cache): - return {'TRANSFORMERS_CACHE': transformers_cache} if transformers_cache else {} + return {"TRANSFORMERS_CACHE": transformers_cache} if transformers_cache else {} def run_ortmodule_api_tests(cwd, log, transformers_cache): - log.debug('Running: ORTModule-API tests') + log.debug("Running: ORTModule-API tests") env = get_env_with_transformers_cache(transformers_cache) - command = [sys.executable, '-m', 'pytest', '-sv', 'orttraining_test_ortmodule_api.py'] + command = [sys.executable, "-m", "pytest", "-sv", "orttraining_test_ortmodule_api.py"] run_subprocess(command, cwd=cwd, log=log, env=env).check_returncode() def run_ortmodule_ops_tests(cwd, log, transformers_cache): - log.debug('Running: ORTModule-OPS tests') + log.debug("Running: ORTModule-OPS tests") env = get_env_with_transformers_cache(transformers_cache) - command = [sys.executable, '-m', 'pytest', '-sv', 'orttraining_test_onnx_ops_ortmodule.py'] + command = [sys.executable, "-m", "pytest", "-sv", "orttraining_test_onnx_ops_ortmodule.py"] run_subprocess(command, cwd=cwd, log=log, env=env).check_returncode() def run_ortmodule_fallback_tests(cwd, log, transformers_cache): - log.debug('Running: ORTModule-API tests') + log.debug("Running: ORTModule-API tests") env = get_env_with_transformers_cache(transformers_cache) - command = [sys.executable, '-m', 'pytest', '-sv', 'orttraining_test_ortmodule_fallback.py'] + command = [sys.executable, "-m", "pytest", "-sv", "orttraining_test_ortmodule_fallback.py"] run_subprocess(command, cwd=cwd, log=log, env=env).check_returncode() + def run_ortmodule_poc_net(cwd, log, no_cuda, data_dir): - log.debug('Running: ORTModule POCNet for MNIST with --no-cuda arg {}.'.format(no_cuda)) + log.debug("Running: ORTModule POCNet for MNIST with --no-cuda arg {}.".format(no_cuda)) - command = [sys.executable, 'orttraining_test_ortmodule_poc.py'] + command = [sys.executable, "orttraining_test_ortmodule_poc.py"] if no_cuda: - command.extend(['--no-cuda', '--epochs', str(3)]) + command.extend(["--no-cuda", "--epochs", str(3)]) if data_dir: - command.extend(['--data-dir', data_dir]) + command.extend(["--data-dir", data_dir]) run_subprocess(command, cwd=cwd, log=log).check_returncode() def run_ortmodule_torch_lightning(cwd, log, data_dir): - log.debug('Running: ORTModule PyTorch Lightning sample .') + log.debug("Running: ORTModule PyTorch Lightning sample .") - command = [sys.executable, 'orttraining_test_ortmodule_torch_lightning_basic.py', '--train-steps=470', - '--epochs=2', '--batch-size=256'] + command = [ + sys.executable, + "orttraining_test_ortmodule_torch_lightning_basic.py", + "--train-steps=470", + "--epochs=2", + "--batch-size=256", + ] if data_dir: - command.extend(['--data-dir', data_dir]) + command.extend(["--data-dir", data_dir]) run_subprocess(command, cwd=cwd, log=log).check_returncode() def run_ortmodule_hf_bert_for_sequence_classification_from_pretrained(cwd, log, no_cuda, data_dir, transformers_cache): - log.debug('Running: ORTModule HuggingFace BERT for sequence classification with --no-cuda arg {}.'.format(no_cuda)) + log.debug("Running: ORTModule HuggingFace BERT for sequence classification with --no-cuda arg {}.".format(no_cuda)) env = get_env_with_transformers_cache(transformers_cache) - command = [sys.executable, 'orttraining_test_ortmodule_bert_classifier.py'] + command = [sys.executable, "orttraining_test_ortmodule_bert_classifier.py"] if no_cuda: - command.extend(['--no-cuda', '--epochs', str(3)]) + command.extend(["--no-cuda", "--epochs", str(3)]) if data_dir: - command.extend(['--data-dir', data_dir]) + command.extend(["--data-dir", data_dir]) run_subprocess(command, cwd=cwd, log=log, env=env).check_returncode() def run_ortmodule_custom_autograd_tests(cwd, log): - log.debug('Running: ORTModule-Custom AutoGrad Functions tests') + log.debug("Running: ORTModule-Custom AutoGrad Functions tests") - command = [sys.executable, '-m', 'pytest', '-sv', 'orttraining_test_ortmodule_autograd.py'] + command = [sys.executable, "-m", "pytest", "-sv", "orttraining_test_ortmodule_autograd.py"] run_subprocess(command, cwd=cwd, log=log).check_returncode() def run_ortmodule_hierarchical_ortmodule_tests(cwd, log): - log.debug('Running: ORTModule-Hierarchical model tests') + log.debug("Running: ORTModule-Hierarchical model tests") - command = [sys.executable, '-m', 'pytest', '-sv', 'orttraining_test_hierarchical_ortmodule.py'] + command = [sys.executable, "-m", "pytest", "-sv", "orttraining_test_hierarchical_ortmodule.py"] run_subprocess(command, cwd=cwd, log=log).check_returncode() def run_ortmodule_experimental_json_config_tests(cwd, log): - log.debug('Running: ORTModule Experimental Load Config tests') + log.debug("Running: ORTModule Experimental Load Config tests") - command = [sys.executable, '-m', 'pytest', '-sv', 'orttraining_test_ortmodule_experimental_json_config.py'] + command = [sys.executable, "-m", "pytest", "-sv", "orttraining_test_ortmodule_experimental_json_config.py"] run_subprocess(command, cwd=cwd, log=log).check_returncode() @@ -124,17 +130,15 @@ def run_ortmodule_experimental_json_config_tests(cwd, log): def run_experimental_gradient_graph_tests(cwd, log): log.debug("Running: Experimental Gradient Graph Export Tests") - command = [sys.executable, '-m', 'pytest', '-sv', - 'orttraining_test_experimental_gradient_graph.py'] + command = [sys.executable, "-m", "pytest", "-sv", "orttraining_test_experimental_gradient_graph.py"] run_subprocess(command, cwd=cwd, log=log).check_returncode() def run_data_sampler_tests(cwd, log): - log.debug('Running: Data sampler tests') + log.debug("Running: Data sampler tests") - command = [sys.executable, '-m', 'pytest', - '-sv', 'orttraining_test_sampler.py'] + command = [sys.executable, "-m", "pytest", "-sv", "orttraining_test_sampler.py"] run_subprocess(command, cwd=cwd, log=log).check_returncode() @@ -153,11 +157,13 @@ def main(): run_ortmodule_poc_net(cwd, log, no_cuda=True, data_dir=args.mnist) - run_ortmodule_hf_bert_for_sequence_classification_from_pretrained(cwd, log, no_cuda=False, - data_dir=args.bert_data, transformers_cache=args.transformers_cache) + run_ortmodule_hf_bert_for_sequence_classification_from_pretrained( + cwd, log, no_cuda=False, data_dir=args.bert_data, transformers_cache=args.transformers_cache + ) - run_ortmodule_hf_bert_for_sequence_classification_from_pretrained(cwd, log, no_cuda=True, - data_dir=args.bert_data, transformers_cache=args.transformers_cache) + run_ortmodule_hf_bert_for_sequence_classification_from_pretrained( + cwd, log, no_cuda=True, data_dir=args.bert_data, transformers_cache=args.transformers_cache + ) run_ortmodule_torch_lightning(cwd, log, args.mnist) diff --git a/orttraining/orttraining/test/python/orttraining_run_bert_pretrain.py b/orttraining/orttraining/test/python/orttraining_run_bert_pretrain.py index 337fd624ec5b6..a087a97da5a54 100644 --- a/orttraining/orttraining/test/python/orttraining_run_bert_pretrain.py +++ b/orttraining/orttraining/test/python/orttraining_run_bert_pretrain.py @@ -45,11 +45,12 @@ # terminate the training before the pipeline run hit its timeout. force_to_stop_max_steps = 2500 -logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', - datefmt='%m/%d/%Y %H:%M:%S', - level=logging.INFO) +logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO +) logger = logging.getLogger(__name__) + def get_rank(): if not dist.is_available(): return 0 @@ -57,27 +58,53 @@ def get_rank(): return 0 return dist.get_rank() + def is_main_process(args): - if hasattr(args, 'world_rank'): + if hasattr(args, "world_rank"): return args.world_rank in [-1, 0] else: return get_rank() == 0 + def bert_model_description(config): vocab_size = config.vocab_size new_model_desc = { - 'inputs': [ - ('input_ids', ['batch', 'max_seq_len_in_batch'],), - ('attention_mask', ['batch', 'max_seq_len_in_batch'],), - ('token_type_ids', ['batch', 'max_seq_len_in_batch'],), - ('masked_lm_labels', ['batch', 'max_seq_len_in_batch'],), - ('next_sentence_label', ['batch', ],) - ], - 'outputs': [ - ('loss', [], True), - ('prediction_scores', ['batch', 'max_seq_len_in_batch', vocab_size],), - ('seq_relationship_scores', ['batch', 2],) - ]} + "inputs": [ + ( + "input_ids", + ["batch", "max_seq_len_in_batch"], + ), + ( + "attention_mask", + ["batch", "max_seq_len_in_batch"], + ), + ( + "token_type_ids", + ["batch", "max_seq_len_in_batch"], + ), + ( + "masked_lm_labels", + ["batch", "max_seq_len_in_batch"], + ), + ( + "next_sentence_label", + [ + "batch", + ], + ), + ], + "outputs": [ + ("loss", [], True), + ( + "prediction_scores", + ["batch", "max_seq_len_in_batch", vocab_size], + ), + ( + "seq_relationship_scores", + ["batch", 2], + ), + ], + } return new_model_desc @@ -85,32 +112,40 @@ def create_pretraining_dataset(input_file, max_pred_length, args): train_data = pretraining_dataset(input_file=input_file, max_pred_length=max_pred_length) train_sampler = RandomSampler(train_data) - train_dataloader = DataLoader(train_data, sampler=train_sampler, - batch_size=args.train_batch_size * args.n_gpu, num_workers=0, - pin_memory=True) + train_dataloader = DataLoader( + train_data, sampler=train_sampler, batch_size=args.train_batch_size * args.n_gpu, num_workers=0, pin_memory=True + ) return train_dataloader, input_file class pretraining_dataset(Dataset): - def __init__(self, input_file, max_pred_length): logger.info("pretraining_dataset: %s, max_pred_length: %d", input_file, max_pred_length) self.input_file = input_file self.max_pred_length = max_pred_length f = h5py.File(input_file, "r") - keys = ['input_ids', 'input_mask', 'segment_ids', 'masked_lm_positions', 'masked_lm_ids', - 'next_sentence_labels'] + keys = [ + "input_ids", + "input_mask", + "segment_ids", + "masked_lm_positions", + "masked_lm_ids", + "next_sentence_labels", + ] self.inputs = [np.asarray(f[key][:]) for key in keys] f.close() def __len__(self): - 'Denotes the total number of samples' + "Denotes the total number of samples" return len(self.inputs[0]) def __getitem__(self, index): [input_ids, input_mask, segment_ids, masked_lm_positions, masked_lm_ids, next_sentence_labels] = [ - torch.from_numpy(input[index].astype(np.int64)) if indice < 5 else torch.from_numpy( - np.asarray(input[index].astype(np.int64))) for indice, input in enumerate(self.inputs)] + torch.from_numpy(input[index].astype(np.int64)) + if indice < 5 + else torch.from_numpy(np.asarray(input[index].astype(np.int64))) + for indice, input in enumerate(self.inputs) + ] # HF model use default ignore_index value (-100) for CrossEntropyLoss masked_lm_labels = torch.ones(input_ids.shape, dtype=torch.long) * -100 @@ -120,50 +155,47 @@ def __getitem__(self, index): if len(padded_mask_indices) != 0: index = padded_mask_indices[0].item() masked_lm_labels[masked_lm_positions[:index]] = masked_lm_ids[:index] - return [input_ids, segment_ids, input_mask, - masked_lm_labels, next_sentence_labels] + return [input_ids, segment_ids, input_mask, masked_lm_labels, next_sentence_labels] + import argparse + + def parse_arguments(): parser = argparse.ArgumentParser() # batch size test config parameters - parser.add_argument("--enable_mixed_precision", - default=False, - action='store_true', - help="Whether to use 16-bit float precision instead of 32-bit") - - parser.add_argument("--sequence_length", - default=512, - type=int, - help="The maximum total input sequence length after WordPiece tokenization. \n" - "Sequences longer than this will be truncated, and sequences shorter \n" - "than this will be padded.") - parser.add_argument("--max_predictions_per_seq", - default=80, - type=int, - help="The maximum total of masked tokens in input sequence") - parser.add_argument("--max_batch_size", - default=32, - type=int, - help="Total batch size for training.") - - parser.add_argument("--gelu_recompute", - default=False, - action='store_true') - - parser.add_argument("--attn_dropout_recompute", - default=False, - action='store_true') - - parser.add_argument("--transformer_layer_recompute", - default=False, - action='store_true') + parser.add_argument( + "--enable_mixed_precision", + default=False, + action="store_true", + help="Whether to use 16-bit float precision instead of 32-bit", + ) + + parser.add_argument( + "--sequence_length", + default=512, + type=int, + help="The maximum total input sequence length after WordPiece tokenization. \n" + "Sequences longer than this will be truncated, and sequences shorter \n" + "than this will be padded.", + ) + parser.add_argument( + "--max_predictions_per_seq", default=80, type=int, help="The maximum total of masked tokens in input sequence" + ) + parser.add_argument("--max_batch_size", default=32, type=int, help="Total batch size for training.") + + parser.add_argument("--gelu_recompute", default=False, action="store_true") + + parser.add_argument("--attn_dropout_recompute", default=False, action="store_true") + + parser.add_argument("--transformer_layer_recompute", default=False, action="store_true") args = parser.parse_args() return args + @dataclass class PretrainArguments: """ @@ -175,8 +207,11 @@ class PretrainArguments: ) bert_model: str = field( - default=None, metadata={"help": "Bert pre-trained model selected in the list: bert-base-uncased, \ - bert-large-uncased, bert-base-cased, bert-base-multilingual, bert-base-chinese."} + default=None, + metadata={ + "help": "Bert pre-trained model selected in the list: bert-base-uncased, \ + bert-large-uncased, bert-base-cased, bert-base-multilingual, bert-base-chinese." + }, ) output_dir: str = field( @@ -184,155 +219,103 @@ class PretrainArguments: ) cache_dir: str = field( - default='/tmp/bert_pretrain/', - metadata={"help": "The output directory where the model checkpoints will be written."} + default="/tmp/bert_pretrain/", + metadata={"help": "The output directory where the model checkpoints will be written."}, ) max_seq_length: Optional[int] = field( default=512, - metadata={"help": "The maximum total input sequence length after tokenization. Sequences longer \ - than this will be truncated, sequences shorter will be padded."} + metadata={ + "help": "The maximum total input sequence length after tokenization. Sequences longer \ + than this will be truncated, sequences shorter will be padded." + }, ) max_predictions_per_seq: Optional[int] = field( - default=80, - metadata={"help": "The maximum total of masked tokens in input sequence."} + default=80, metadata={"help": "The maximum total of masked tokens in input sequence."} ) - train_batch_size: Optional[int] = field( - default=32, - metadata={"help": "Batch size for training."} - ) + train_batch_size: Optional[int] = field(default=32, metadata={"help": "Batch size for training."}) - learning_rate: Optional[float] = field( - default=5e-5, - metadata={"help": "The initial learning rate for Lamb."} - ) + learning_rate: Optional[float] = field(default=5e-5, metadata={"help": "The initial learning rate for Lamb."}) num_train_epochs: Optional[float] = field( - default=3.0, - metadata={"help": "Total number of training epochs to perform."} + default=3.0, metadata={"help": "Total number of training epochs to perform."} ) - max_steps: Optional[float] = field( - default=1000, - metadata={"help": "Total number of training steps to perform."} - ) + max_steps: Optional[float] = field(default=1000, metadata={"help": "Total number of training steps to perform."}) warmup_proportion: Optional[float] = field( default=0.01, - metadata={"help": "Proportion of training to perform linear learning rate warmup for. \ - E.g., 0.1 = 10%% of training."} + metadata={ + "help": "Proportion of training to perform linear learning rate warmup for. \ + E.g., 0.1 = 10%% of training." + }, ) - local_rank: Optional[int] = field( - default=-1, - metadata={"help": "local_rank for distributed training on gpus."} - ) + local_rank: Optional[int] = field(default=-1, metadata={"help": "local_rank for distributed training on gpus."}) - world_rank: Optional[int] = field( - default=-1 - ) + world_rank: Optional[int] = field(default=-1) - world_size: Optional[int] = field( - default=1 - ) + world_size: Optional[int] = field(default=1) - seed: Optional[int] = field( - default=42, - metadata={"help": "random seed for initialization."} - ) + seed: Optional[int] = field(default=42, metadata={"help": "random seed for initialization."}) gradient_accumulation_steps: Optional[int] = field( - default=1, - metadata={"help": "Number of updates steps to accumualte before performing a backward/update pass."} + default=1, metadata={"help": "Number of updates steps to accumualte before performing a backward/update pass."} ) - fp16: bool = field( - default=False, - metadata={"help": "Whether to use 16-bit float precision instead of 32-bit."} - ) + fp16: bool = field(default=False, metadata={"help": "Whether to use 16-bit float precision instead of 32-bit."}) gelu_recompute: bool = field( - default=False, - metadata={"help": "Whether to enable recomputing Gelu activation output to save memory."} + default=False, metadata={"help": "Whether to enable recomputing Gelu activation output to save memory."} ) attn_dropout_recompute: bool = field( - default=False, - metadata={"help": "Whether to enable recomputing attention dropout to save memory."} + default=False, metadata={"help": "Whether to enable recomputing attention dropout to save memory."} ) transformer_layer_recompute: bool = field( - default=False, - metadata={"help": "Whether to enable recomputing transformer layerwise to save memory."} + default=False, metadata={"help": "Whether to enable recomputing transformer layerwise to save memory."} ) loss_scale: Optional[float] = field( - default=0.0, - metadata={"help": "Loss scaling, positive power of 2 values can improve fp16 convergence."} + default=0.0, metadata={"help": "Loss scaling, positive power of 2 values can improve fp16 convergence."} ) - deepspeed_zero_stage: Optional[int] = field( - default=0, - metadata={"help": "Deepspeed Zero Stage. 0 => disabled"} - ) + deepspeed_zero_stage: Optional[int] = field(default=0, metadata={"help": "Deepspeed Zero Stage. 0 => disabled"}) - log_freq: Optional[float] = field( - default=1.0, - metadata={"help": "frequency of logging loss."} - ) + log_freq: Optional[float] = field(default=1.0, metadata={"help": "frequency of logging loss."}) - checkpoint_activations: bool = field( - default=False, - metadata={"help": "Whether to use gradient checkpointing."} - ) + checkpoint_activations: bool = field(default=False, metadata={"help": "Whether to use gradient checkpointing."}) resume_from_checkpoint: bool = field( - default=False, - metadata={"help": "Whether to resume training from checkpoint."} + default=False, metadata={"help": "Whether to resume training from checkpoint."} ) - resume_step: Optional[int] = field( - default=-1, - metadata={"help": "Step to resume training from."} - ) + resume_step: Optional[int] = field(default=-1, metadata={"help": "Step to resume training from."}) num_steps_per_checkpoint: Optional[int] = field( - default=100, - metadata={"help": "Number of update steps until a model checkpoint is saved to disk."} + default=100, metadata={"help": "Number of update steps until a model checkpoint is saved to disk."} ) save_checkpoint: Optional[bool] = field( - default=False, - metadata={"help": "Enable for saving a model checkpoint to disk."} + default=False, metadata={"help": "Enable for saving a model checkpoint to disk."} ) - init_state_dict: Optional[dict] = field( - default=None, - metadata={"help": "State to load before training."} - ) + init_state_dict: Optional[dict] = field(default=None, metadata={"help": "State to load before training."}) - phase2: bool = field( - default=False, - metadata={"help": "Whether to train with seq len 512."} - ) + phase2: bool = field(default=False, metadata={"help": "Whether to train with seq len 512."}) allreduce_post_accumulation: bool = field( - default=False, - metadata={"help": "Whether to do allreduces during gradient accumulation steps."} + default=False, metadata={"help": "Whether to do allreduces during gradient accumulation steps."} ) allreduce_post_accumulation_fp16: bool = field( - default=False, - metadata={"help": "Whether to do fp16 allreduce post accumulation."} + default=False, metadata={"help": "Whether to do fp16 allreduce post accumulation."} ) - accumulate_into_fp16: bool = field( - default=False, - metadata={"help": "Whether to use fp16 gradient accumulators."} - ) + accumulate_into_fp16: bool = field(default=False, metadata={"help": "Whether to use fp16 gradient accumulators."}) phase1_end_step: Optional[int] = field( - default=7038, - metadata={"help": "Whether to use fp16 gradient accumulators."} + default=7038, metadata={"help": "Whether to use fp16 gradient accumulators."} ) tensorboard_dir: Optional[str] = field( @@ -340,14 +323,13 @@ class PretrainArguments: ) schedule: Optional[str] = field( - default='warmup_poly', + default="warmup_poly", ) # this argument is test specific. to run a full bert model will take too long to run. instead, we reduce # number of hidden layers so that it can show convergence to an extend to help detect any regression. force_num_hidden_layers: Optional[int] = field( - default=None, - metadata={"help": "Whether to use fp16 gradient accumulators."} + default=None, metadata={"help": "Whether to use fp16 gradient accumulators."} ) def to_json_string(self): @@ -367,7 +349,7 @@ def to_sanitized_dict(self) -> Dict[str, Any]: def setup_training(args): - assert (torch.cuda.is_available()) + assert torch.cuda.is_available() if args.local_rank == -1: args.local_rank = 0 @@ -379,11 +361,15 @@ def setup_training(args): args.n_gpu = 1 if args.gradient_accumulation_steps < 1: - raise ValueError("Invalid gradient_accumulation_steps parameter: {}, should be >= 1".format( - args.gradient_accumulation_steps)) + raise ValueError( + "Invalid gradient_accumulation_steps parameter: {}, should be >= 1".format(args.gradient_accumulation_steps) + ) if args.train_batch_size % args.gradient_accumulation_steps != 0: - raise ValueError("Invalid gradient_accumulation_steps parameter: {}, batch size {} should be divisible".format( - args.gradient_accumulation_steps, args.train_batch_size)) + raise ValueError( + "Invalid gradient_accumulation_steps parameter: {}, batch size {} should be divisible".format( + args.gradient_accumulation_steps, args.train_batch_size + ) + ) # args.train_batch_size is per global step (optimization step) batch size # now make it a per gpu batch size @@ -393,15 +379,16 @@ def setup_training(args): logger.info("setup_training: args.train_batch_size = %d", args.train_batch_size) return device, args + def setup_torch_distributed(world_rank, world_size): - os.environ['RANK'] = str(world_rank) - os.environ['WORLD_SIZE'] = str(world_size) - os.environ['MASTER_ADDR'] = 'localhost' - os.environ['MASTER_PORT'] = str('12345') - torch.distributed.init_process_group(backend='nccl', world_size=world_size, - rank=world_rank) + os.environ["RANK"] = str(world_rank) + os.environ["WORLD_SIZE"] = str(world_size) + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = str("12345") + torch.distributed.init_process_group(backend="nccl", world_size=world_size, rank=world_rank) return + def prepare_model(args, device): config = BertConfig.from_pretrained(args.bert_model, cache_dir=args.cache_dir) @@ -419,43 +406,57 @@ def prepare_model(args, device): loss_scaler = amp.DynamicLossScaler() if args.fp16 else None - options = orttrainer.ORTTrainerOptions({'batch': { - 'gradient_accumulation_steps': args.gradient_accumulation_steps}, - 'device': {'id': str(device)}, - 'mixed_precision': { - 'enabled': args.fp16, - 'loss_scaler': loss_scaler}, - 'graph_transformer': { - 'attn_dropout_recompute': args.attn_dropout_recompute, - 'gelu_recompute': args.gelu_recompute, - 'transformer_layer_recompute': args.transformer_layer_recompute, - }, - 'debug': {'deterministic_compute': True, }, - 'utils': { - 'grad_norm_clip': True}, - 'distributed': { - 'world_rank': max(0, args.local_rank), - 'world_size': args.world_size, - 'local_rank': max(0, args.local_rank), - 'allreduce_post_accumulation': args.allreduce_post_accumulation, - 'deepspeed_zero_optimization': {'stage': args.deepspeed_zero_stage}, - 'enable_adasum': False}, - 'lr_scheduler': lr_scheduler - }) + options = orttrainer.ORTTrainerOptions( + { + "batch": {"gradient_accumulation_steps": args.gradient_accumulation_steps}, + "device": {"id": str(device)}, + "mixed_precision": {"enabled": args.fp16, "loss_scaler": loss_scaler}, + "graph_transformer": { + "attn_dropout_recompute": args.attn_dropout_recompute, + "gelu_recompute": args.gelu_recompute, + "transformer_layer_recompute": args.transformer_layer_recompute, + }, + "debug": { + "deterministic_compute": True, + }, + "utils": {"grad_norm_clip": True}, + "distributed": { + "world_rank": max(0, args.local_rank), + "world_size": args.world_size, + "local_rank": max(0, args.local_rank), + "allreduce_post_accumulation": args.allreduce_post_accumulation, + "deepspeed_zero_optimization": {"stage": args.deepspeed_zero_stage}, + "enable_adasum": False, + }, + "lr_scheduler": lr_scheduler, + } + ) param_optimizer = list(model.named_parameters()) no_decay_keys = ["bias", "gamma", "beta", "LayerNorm"] - params = [{ - 'params': [n for n, p in param_optimizer if any(no_decay_key in n for no_decay_key in no_decay_keys)], - "alpha": 0.9, "beta": 0.999, "lambda": 0.0, "epsilon": 1e-6}, { - 'params': [n for n, p in param_optimizer if not any(no_decay_key in n for no_decay_key in no_decay_keys)], - "alpha": 0.9, "beta": 0.999, "lambda": 0.0, "epsilon": 1e-6}] + params = [ + { + "params": [n for n, p in param_optimizer if any(no_decay_key in n for no_decay_key in no_decay_keys)], + "alpha": 0.9, + "beta": 0.999, + "lambda": 0.0, + "epsilon": 1e-6, + }, + { + "params": [n for n, p in param_optimizer if not any(no_decay_key in n for no_decay_key in no_decay_keys)], + "alpha": 0.9, + "beta": 0.999, + "lambda": 0.0, + "epsilon": 1e-6, + }, + ] optim_config = optim.AdamConfig(params=params, lr=2e-5, do_bias_correction=True) model = orttrainer.ORTTrainer(model, model_desc, optim_config, options=options) return model + def get_data_file(f_id, world_rank, world_size, files): num_files = len(files) if world_size > num_files: @@ -500,25 +501,27 @@ def do_pretrain(args): pool = ProcessPoolExecutor(1) while True: files = [ - os.path.join(args.input_dir, f) for f in os.listdir(args.input_dir) - if os.path.isfile(os.path.join(args.input_dir, f)) and 'training' in f] + os.path.join(args.input_dir, f) + for f in os.listdir(args.input_dir) + if os.path.isfile(os.path.join(args.input_dir, f)) and "training" in f + ] files.sort() random.shuffle(files) f_id = 0 train_dataloader, data_file = create_pretraining_dataset( - get_data_file(f_id, args.world_rank, args.world_size, files), - args.max_predictions_per_seq, - args) + get_data_file(f_id, args.world_rank, args.world_size, files), args.max_predictions_per_seq, args + ) - for f_id in range(1 , len(files)): + for f_id in range(1, len(files)): logger.info("data file %s" % (data_file)) dataset_future = pool.submit( create_pretraining_dataset, get_data_file(f_id, args.world_rank, args.world_size, files), args.max_predictions_per_seq, - args) + args, + ) train_iter = tqdm(train_dataloader, desc="Iteration") if is_main_process(args) else train_dataloader for step, batch in enumerate(train_iter): @@ -526,7 +529,9 @@ def do_pretrain(args): batch = [t.to(device) for t in batch] input_ids, segment_ids, input_mask, masked_lm_labels, next_sentence_labels = batch - loss, _, _ = model.train_step(input_ids, input_mask, segment_ids, masked_lm_labels, next_sentence_labels) + loss, _, _ = model.train_step( + input_ids, input_mask, segment_ids, masked_lm_labels, next_sentence_labels + ) average_loss += loss.item() global_step = model._train_step_info.optimization_step @@ -535,13 +540,13 @@ def do_pretrain(args): divisor = args.log_freq * args.gradient_accumulation_steps if tb_writer: lr = model.options.lr_scheduler.get_last_lr()[0] - tb_writer.add_scalar('train/summary/scalar/Learning_Rate', lr, global_step) + tb_writer.add_scalar("train/summary/scalar/Learning_Rate", lr, global_step) if args.fp16: - tb_writer.add_scalar('train/summary/scalar/loss_scale_25', loss, global_step) + tb_writer.add_scalar("train/summary/scalar/loss_scale_25", loss, global_step) # TODO: ORTTrainer to expose all_finite # tb_writer.add_scalar('train/summary/scalar/all_fp16_gradients_finite_859', all_finite, global_step) - tb_writer.add_scalar('train/summary/total_loss', average_loss / divisor, global_step) - + tb_writer.add_scalar("train/summary/total_loss", average_loss / divisor, global_step) + print("Step:{} Average Loss = {}".format(global_step, average_loss / divisor)) if global_step >= args.max_steps or global_step >= force_to_stop_max_steps: @@ -550,7 +555,9 @@ def do_pretrain(args): if global_step >= args.max_steps: if args.save_checkpoint: - model.save_checkpoint(os.path.join(args.output_dir, 'checkpoint-{}.ortcp'.format(args.world_rank))) + model.save_checkpoint( + os.path.join(args.output_dir, "checkpoint-{}.ortcp".format(args.world_rank)) + ) final_loss = average_loss / (args.log_freq * args.gradient_accumulation_steps) return final_loss @@ -563,17 +570,17 @@ def do_pretrain(args): epoch += 1 -def generate_tensorboard_logdir(root_dir): +def generate_tensorboard_logdir(root_dir): current_date_time = datetime.datetime.today() - dt_string = current_date_time.strftime('BERT_pretrain_%y_%m_%d_%I_%M_%S') + dt_string = current_date_time.strftime("BERT_pretrain_%y_%m_%d_%I_%M_%S") return os.path.join(root_dir, dt_string) class ORTBertPretrainTest(unittest.TestCase): def setUp(self): - self.output_dir = '/bert_data/hf_data/test_out/bert_pretrain_results' - self.bert_model = 'bert-base-uncased' + self.output_dir = "/bert_data/hf_data/test_out/bert_pretrain_results" + self.bert_model = "bert-base-uncased" self.local_rank = -1 self.world_rank = -1 self.world_size = 1 @@ -581,18 +588,18 @@ def setUp(self): self.learning_rate = 5e-4 self.max_seq_length = 512 self.max_predictions_per_seq = 20 - self.input_dir = '/bert_data/hdf5_lower_case_1_seq_len_128_max_pred_20_masked_lm_prob_0.15_random_seed_12345_dupe_factor_5/books_wiki_en_corpus/train' + self.input_dir = "/bert_data/hdf5_lower_case_1_seq_len_128_max_pred_20_masked_lm_prob_0.15_random_seed_12345_dupe_factor_5/books_wiki_en_corpus/train" self.train_batch_size = 4096 self.gradient_accumulation_steps = 64 self.fp16 = True self.allreduce_post_accumulation = True - self.tensorboard_dir = '/bert_data/hf_data/test_out' + self.tensorboard_dir = "/bert_data/hf_data/test_out" def test_pretrain_throughput(self, process_args=None): if process_args.sequence_length == 128: - input_dir = '/bert_data/hdf5_lower_case_1_seq_len_128_max_pred_20_masked_lm_prob_0.15_random_seed_12345_dupe_factor_5/books_wiki_en_corpus/train' + input_dir = "/bert_data/hdf5_lower_case_1_seq_len_128_max_pred_20_masked_lm_prob_0.15_random_seed_12345_dupe_factor_5/books_wiki_en_corpus/train" else: - input_dir = '/bert_data/hdf5_lower_case_1_seq_len_512_max_pred_80_masked_lm_prob_0.15_random_seed_12345_dupe_factor_5/books_wiki_en_corpus/train' + input_dir = "/bert_data/hdf5_lower_case_1_seq_len_512_max_pred_80_masked_lm_prob_0.15_random_seed_12345_dupe_factor_5/books_wiki_en_corpus/train" print("process_args.enable_mixed_precision: ", process_args.enable_mixed_precision) print("process_args.sequence_length: ", process_args.sequence_length) @@ -604,8 +611,8 @@ def test_pretrain_throughput(self, process_args=None): args = PretrainArguments( input_dir=input_dir, - output_dir='/bert_data/hf_data/test_out/bert_pretrain_results', - bert_model='bert-large-uncased', + output_dir="/bert_data/hf_data/test_out/bert_pretrain_results", + bert_model="bert-large-uncased", local_rank=self.local_rank, world_rank=self.world_rank, world_size=self.world_size, @@ -642,31 +649,32 @@ def test_pretrain_convergence(self): fp16=self.fp16, allreduce_post_accumulation=self.allreduce_post_accumulation, force_num_hidden_layers=self.force_num_hidden_layers, - tensorboard_dir=generate_tensorboard_logdir('/bert_data/hf_data/test_out/')) + tensorboard_dir=generate_tensorboard_logdir("/bert_data/hf_data/test_out/"), + ) final_loss = do_pretrain(args) return final_loss - + def test_pretrain_zero(self): - assert self.world_size >0, "ZeRO test requires a distributed run." + assert self.world_size > 0, "ZeRO test requires a distributed run." setup_torch_distributed(self.world_rank, self.world_size) per_gpu_batch_size = 32 - optimization_batch_size = per_gpu_batch_size*self.world_size # set to disable grad accumulation - + optimization_batch_size = per_gpu_batch_size * self.world_size # set to disable grad accumulation + self.train_batch_size = optimization_batch_size self.gradient_accumulation_steps = 1 self.deepspeed_zero_stage = 1 self.force_num_hidden_layers = 2 self.max_seq_length = 32 - self.output_dir = './bert_pretrain_ckpt' - if self.world_rank == 0: + self.output_dir = "./bert_pretrain_ckpt" + if self.world_rank == 0: if os.path.isdir(self.output_dir): shutil.rmtree(self.output_dir) - os.makedirs(self.output_dir, exist_ok = True) - + os.makedirs(self.output_dir, exist_ok=True) + torch.distributed.barrier() - assert os.path.exists(self.output_dir) - + assert os.path.exists(self.output_dir) + # run a few optimization steps self.max_steps = 200 args = PretrainArguments( @@ -686,7 +694,8 @@ def test_pretrain_zero(self): allreduce_post_accumulation=self.allreduce_post_accumulation, force_num_hidden_layers=self.force_num_hidden_layers, deepspeed_zero_stage=self.deepspeed_zero_stage, - save_checkpoint=True) + save_checkpoint=True, + ) train_loss = do_pretrain(args) # ensure all workers reach this point before loading the checkpointed state @@ -694,7 +703,7 @@ def test_pretrain_zero(self): # on rank 0, load the trained state if args.world_rank == 0: - checkpoint_files = glob.glob(os.path.join(self.output_dir, 'checkpoint*.ortcp')) + checkpoint_files = glob.glob(os.path.join(self.output_dir, "checkpoint*.ortcp")) args.init_state_dict = aggregate_checkpoints(checkpoint_files, pytorch_format=True) torch.distributed.barrier() @@ -709,10 +718,11 @@ def test_pretrain_zero(self): if __name__ == "__main__": import sys + logger.warning("sys.argv: %s", sys.argv) # usage: # data parallel training - # mpirun -n 4 python orttraining_run_bert_pretrain.py + # mpirun -n 4 python orttraining_run_bert_pretrain.py # # single gpu: # python orttraining_run_bert_pretrain.py ORTBertPretrainTest.test_pretrain_throughput @@ -723,19 +733,24 @@ def test_pretrain_zero(self): # calling unpublished get_mpi_context_xxx to get rank/size numbers. try: # In case ORT is not built with MPI/NCCL, there are no get_mpi_context_xxx internal apis. - from onnxruntime.capi._pybind_state import get_mpi_context_local_rank, get_mpi_context_local_size,\ - get_mpi_context_world_rank, get_mpi_context_world_size + from onnxruntime.capi._pybind_state import ( + get_mpi_context_local_rank, + get_mpi_context_local_size, + get_mpi_context_world_rank, + get_mpi_context_world_size, + ) + has_get_mpi_context_internal_api = True except ImportError: has_get_mpi_context_internal_api = False pass if has_get_mpi_context_internal_api and get_mpi_context_world_size() > 1: world_size = get_mpi_context_world_size() - print('get_mpi_context_world_size(): ', world_size) + print("get_mpi_context_world_size(): ", world_size) local_rank = get_mpi_context_local_rank() if local_rank == 0: - print('================================================================> os.getpid() = ', os.getpid()) + print("================================================================> os.getpid() = ", os.getpid()) test = ORTBertPretrainTest() test.setUp() @@ -743,7 +758,7 @@ def test_pretrain_zero(self): test.world_rank = local_rank test.world_size = world_size - if len(sys.argv) >= 2 and sys.argv[1] == 'ORTBertPretrainTest.test_pretrain_zero': + if len(sys.argv) >= 2 and sys.argv[1] == "ORTBertPretrainTest.test_pretrain_zero": logger.info("running ORTBertPretrainTest.test_pretrain_zero()...") final_loss = test.test_pretrain_zero() logger.info("ORTBertPretrainTest.test_pretrain_zero() rank = %i final loss = %f", local_rank, final_loss) @@ -752,7 +767,7 @@ def test_pretrain_zero(self): else: test.assertGreater(final_loss, 11.0) logger.info("ORTBertPretrainTest.test_pretrain_zero() passed") - elif len(sys.argv) >= 2 and sys.argv[1] == 'ORTBertPretrainTest.test_pretrain_convergence': + elif len(sys.argv) >= 2 and sys.argv[1] == "ORTBertPretrainTest.test_pretrain_convergence": logger.info("running ORTBertPretrainTest.test_pretrain_convergence()...") test.max_steps = 200 test.force_num_hidden_layers = 8 @@ -782,12 +797,12 @@ def test_pretrain_zero(self): else: # unittest does not accept user defined arguments # we need to run this script with user defined arguments - if len(sys.argv) >= 2 and sys.argv[1] == 'ORTBertPretrainTest.test_pretrain_throughput': + if len(sys.argv) >= 2 and sys.argv[1] == "ORTBertPretrainTest.test_pretrain_throughput": run_test_pretrain_throughput, run_test_pretrain_convergence = True, False - sys.argv.remove('ORTBertPretrainTest.test_pretrain_throughput') - elif len(sys.argv) >= 2 and sys.argv[1] == 'ORTBertPretrainTest.test_pretrain_convergence': + sys.argv.remove("ORTBertPretrainTest.test_pretrain_throughput") + elif len(sys.argv) >= 2 and sys.argv[1] == "ORTBertPretrainTest.test_pretrain_convergence": run_test_pretrain_throughput, run_test_pretrain_convergence = False, True - sys.argv.remove('ORTBertPretrainTest.test_pretrain_convergence') + sys.argv.remove("ORTBertPretrainTest.test_pretrain_convergence") else: run_test_pretrain_throughput, run_test_pretrain_convergence = True, True process_args = parse_arguments() diff --git a/orttraining/orttraining/test/python/orttraining_run_frontend_batch_size_test.py b/orttraining/orttraining/test/python/orttraining_run_frontend_batch_size_test.py index b49fc3c95cc6d..db03a636d046e 100644 --- a/orttraining/orttraining/test/python/orttraining_run_frontend_batch_size_test.py +++ b/orttraining/orttraining/test/python/orttraining_run_frontend_batch_size_test.py @@ -11,7 +11,9 @@ "max_predictions_per_seq", "gelu_recompute", "attn_dropout_recompute", - "transformer_layer_recompute"]) + "transformer_layer_recompute", + ], +) configs = [ Config(True, 128, 46, 20, False, False, False), @@ -26,20 +28,28 @@ Config(True, 512, 15, 80, False, False, True), ] + def run_with_config(config): - print("##### testing name - {}-{} #####".format("fp16" if config.enable_mixed_precision else "fp32", - config.sequence_length)) + print( + "##### testing name - {}-{} #####".format( + "fp16" if config.enable_mixed_precision else "fp32", config.sequence_length + ) + ) print("gelu_recompute: ", config.gelu_recompute) print("attn_dropout_recompute: ", config.attn_dropout_recompute) print("transformer_layer_recompute: ", config.transformer_layer_recompute) cmds = [ sys.executable, - 'orttraining_run_bert_pretrain.py', + "orttraining_run_bert_pretrain.py", "ORTBertPretrainTest.test_pretrain_throughput", - "--sequence_length", str(config.sequence_length), - "--max_batch_size", str(config.max_batch_size), - "--max_predictions_per_seq", str(config.max_predictions_per_seq)] + "--sequence_length", + str(config.sequence_length), + "--max_batch_size", + str(config.max_batch_size), + "--max_predictions_per_seq", + str(config.max_predictions_per_seq), + ] if config.enable_mixed_precision: cmds.append("--enable_mixed_precision") if config.gelu_recompute: @@ -52,7 +62,6 @@ def run_with_config(config): # access to azure storage shared disk is much slower so we need a longer timeout. subprocess.run(cmds, timeout=1200).check_returncode() + for config in configs: run_with_config(config) - - diff --git a/orttraining/orttraining/test/python/orttraining_run_glue.py b/orttraining/orttraining/test/python/orttraining_run_glue.py index 3ab3c1f9409dc..a9b514599fb78 100644 --- a/orttraining/orttraining/test/python/orttraining_run_glue.py +++ b/orttraining/orttraining/test/python/orttraining_run_glue.py @@ -27,8 +27,13 @@ from onnxruntime.capi.ort_trainer import ORTTrainer, LossScaler, ModelDescription, IODescription try: - from onnxruntime.capi._pybind_state import get_mpi_context_local_rank, get_mpi_context_local_size,\ - get_mpi_context_world_rank, get_mpi_context_world_size + from onnxruntime.capi._pybind_state import ( + get_mpi_context_local_rank, + get_mpi_context_local_size, + get_mpi_context_world_rank, + get_mpi_context_world_size, + ) + has_get_mpi_context_internal_api = True except ImportError: has_get_mpi_context_internal_api = False @@ -41,21 +46,21 @@ logger = logging.getLogger(__name__) + def verify_old_and_new_api_are_equal(results_per_api): new_api_results = results_per_api[True] old_api_results = results_per_api[False] for key in new_api_results.keys(): assert_allclose(new_api_results[key], old_api_results[key]) + @dataclass class ModelArguments: """ Arguments pertaining to which model/config/tokenizer we are going to fine-tune from. """ - model_name_or_path: str = field( - metadata={"help": "model identifier from huggingface.co/models"} - ) + model_name_or_path: str = field(metadata={"help": "model identifier from huggingface.co/models"}) config_name: Optional[str] = field( default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} ) @@ -66,8 +71,8 @@ class ModelArguments: default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"} ) -class ORTGlueTest(unittest.TestCase): +class ORTGlueTest(unittest.TestCase): def setUp(self): # configurations not to be changed accoss tests self.max_seq_length = 128 @@ -80,7 +85,7 @@ def setUp(self): self.gradient_accumulation_steps = 1 self.data_dir = "/bert_data/hf_data/glue_data/" self.output_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "glue_test_output/") - self.cache_dir = '/tmp/glue/' + self.cache_dir = "/tmp/glue/" self.logging_steps = 10 def test_roberta_with_mrpc(self): @@ -89,9 +94,9 @@ def test_roberta_with_mrpc(self): expected_loss = 0.35 results = self.run_glue(model_name="roberta-base", task_name="MRPC", fp16=False) - assert(results['acc'] >= expected_acc) - assert(results['f1'] >= expected_f1) - assert(results['loss'] <= expected_loss) + assert results["acc"] >= expected_acc + assert results["f1"] >= expected_f1 + assert results["loss"] <= expected_loss def test_roberta_fp16_with_mrpc(self): expected_acc = 0.87 @@ -100,9 +105,9 @@ def test_roberta_fp16_with_mrpc(self): results = self.run_glue(model_name="roberta-base", task_name="MRPC", fp16=True) - assert(results['acc'] >= expected_acc) - assert(results['f1'] >= expected_f1) - assert(results['loss'] <= expected_loss) + assert results["acc"] >= expected_acc + assert results["f1"] >= expected_f1 + assert results["loss"] <= expected_loss def test_bert_with_mrpc(self): if self.local_rank == -1: @@ -117,9 +122,9 @@ def test_bert_with_mrpc(self): results = self.run_glue(model_name="bert-base-cased", task_name="MRPC", fp16=False) if self.local_rank in [-1, 0]: - assert(results['acc'] >= expected_acc) - assert(results['f1'] >= expected_f1) - assert(results['loss'] <= expected_loss) + assert results["acc"] >= expected_acc + assert results["f1"] >= expected_f1 + assert results["loss"] <= expected_loss def test_bert_fp16_with_mrpc(self): expected_acc = 0.84 @@ -128,28 +133,55 @@ def test_bert_fp16_with_mrpc(self): results = self.run_glue(model_name="bert-base-cased", task_name="MRPC", fp16=True) - assert(results['acc'] >= expected_acc) - assert(results['f1'] >= expected_f1) - assert(results['loss'] <= expected_loss) + assert results["acc"] >= expected_acc + assert results["f1"] >= expected_f1 + assert results["loss"] <= expected_loss def model_to_desc(self, model_name, model): - if model_name.startswith('bert') or model_name.startswith('xlnet'): + if model_name.startswith("bert") or model_name.startswith("xlnet"): model_desc = { - 'inputs': [ - ('input_ids', ['batch', 'max_seq_len_in_batch'],), - ('attention_mask', ['batch', 'max_seq_len_in_batch'],), - ('token_type_ids', ['batch', 'max_seq_len_in_batch'],), - ('labels', ['batch', ],)], - 'outputs': [('loss', [], True), - ('logits', ['batch', 2])]} - elif model_name.startswith('roberta'): + "inputs": [ + ( + "input_ids", + ["batch", "max_seq_len_in_batch"], + ), + ( + "attention_mask", + ["batch", "max_seq_len_in_batch"], + ), + ( + "token_type_ids", + ["batch", "max_seq_len_in_batch"], + ), + ( + "labels", + [ + "batch", + ], + ), + ], + "outputs": [("loss", [], True), ("logits", ["batch", 2])], + } + elif model_name.startswith("roberta"): model_desc = { - 'inputs': [ - ('input_ids', ['batch', 'max_seq_len_in_batch'],), - ('attention_mask', ['batch', 'max_seq_len_in_batch'],), - ('labels', ['batch', ],)], - 'outputs': [('loss', [], True), - ('logits', ['batch', 2])]} + "inputs": [ + ( + "input_ids", + ["batch", "max_seq_len_in_batch"], + ), + ( + "attention_mask", + ["batch", "max_seq_len_in_batch"], + ), + ( + "labels", + [ + "batch", + ], + ), + ], + "outputs": [("loss", [], True), ("logits", ["batch", 2])], + } else: raise RuntimeError("unsupported base model name {}.".format(model_name)) @@ -158,16 +190,22 @@ def model_to_desc(self, model_name, model): def run_glue(self, model_name, task_name, fp16): model_args = ModelArguments(model_name_or_path=model_name, cache_dir=self.cache_dir) data_args = GlueDataTrainingArguments( - task_name=task_name, data_dir=os.path.join(self.data_dir, task_name), - max_seq_length=self.max_seq_length) + task_name=task_name, data_dir=os.path.join(self.data_dir, task_name), max_seq_length=self.max_seq_length + ) training_args = TrainingArguments( - output_dir=os.path.join(self.output_dir, task_name), do_train=True, do_eval=True, + output_dir=os.path.join(self.output_dir, task_name), + do_train=True, + do_eval=True, per_gpu_train_batch_size=self.train_batch_size, - learning_rate=self.learning_rate, num_train_epochs=self.num_train_epochs, + learning_rate=self.learning_rate, + num_train_epochs=self.num_train_epochs, local_rank=self.local_rank, - overwrite_output_dir=self.overwrite_output_dir, gradient_accumulation_steps=self.gradient_accumulation_steps, - fp16=fp16, logging_steps=self.logging_steps) + overwrite_output_dir=self.overwrite_output_dir, + gradient_accumulation_steps=self.gradient_accumulation_steps, + fp16=fp16, + logging_steps=self.logging_steps, + ) # Setup logging logging.basicConfig( @@ -212,17 +250,9 @@ def run_glue(self, model_name, task_name, fp16): cache_dir=model_args.cache_dir, ) - train_dataset = ( - GlueDataset(data_args, tokenizer=tokenizer) - if training_args.do_train - else None - ) + train_dataset = GlueDataset(data_args, tokenizer=tokenizer) if training_args.do_train else None - eval_dataset = ( - GlueDataset(data_args, tokenizer=tokenizer, mode="dev") - if training_args.do_eval - else None - ) + eval_dataset = GlueDataset(data_args, tokenizer=tokenizer, mode="dev") if training_args.do_eval else None def compute_metrics(p: EvalPrediction) -> Dict: if output_mode == "classification": @@ -257,12 +287,13 @@ def compute_metrics(p: EvalPrediction) -> Dict: logger.info("***** Eval results {} *****".format(data_args.task_name)) for key, value in result.items(): - logger.info(" %s = %s", key, value) + logger.info(" %s = %s", key, value) results.update(result) return results + if __name__ == "__main__": if has_get_mpi_context_internal_api: local_rank = get_mpi_context_local_rank() @@ -278,12 +309,13 @@ def compute_metrics(p: EvalPrediction) -> Dict: # TrainingArguments._setup_devices will call torch.distributed.init_process_group(backend="nccl") # pytorch expects following environment settings (which would be set if launched with torch.distributed.launch). - os.environ['RANK'] = str(local_rank) - os.environ['WORLD_SIZE'] = str(world_size) - os.environ['MASTER_ADDR'] = '127.0.0.1' - os.environ['MASTER_PORT'] = '29500' + os.environ["RANK"] = str(local_rank) + os.environ["WORLD_SIZE"] = str(world_size) + os.environ["MASTER_ADDR"] = "127.0.0.1" + os.environ["MASTER_PORT"] = "29500" from onnxruntime.capi._pybind_state import set_cuda_device_id + set_cuda_device_id(local_rank) test = ORTGlueTest() diff --git a/orttraining/orttraining/test/python/orttraining_run_multiple_choice.py b/orttraining/orttraining/test/python/orttraining_run_multiple_choice.py index 256f4a59ed1d8..a4c069c683e1c 100644 --- a/orttraining/orttraining/test/python/orttraining_run_multiple_choice.py +++ b/orttraining/orttraining/test/python/orttraining_run_multiple_choice.py @@ -33,18 +33,18 @@ logger = logging.getLogger(__name__) + def simple_accuracy(preds, labels): return (preds == labels).mean() + @dataclass class ModelArguments: """ Arguments pertaining to which model/config/tokenizer we are going to fine-tune from. """ - model_name_or_path: str = field( - metadata={"help": "model identifier from huggingface.co/models"} - ) + model_name_or_path: str = field(metadata={"help": "model identifier from huggingface.co/models"}) config_name: Optional[str] = field( default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} ) @@ -55,6 +55,7 @@ class ModelArguments: default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"} ) + @dataclass class DataTrainingArguments: """ @@ -70,12 +71,10 @@ class DataTrainingArguments: "than this will be truncated, sequences shorter will be padded." }, ) - overwrite_cache: bool = field( - default=False, metadata={"help": "Overwrite the cached training and evaluation sets"} - ) + overwrite_cache: bool = field(default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}) -class ORTMultipleChoiceTest(unittest.TestCase): +class ORTMultipleChoiceTest(unittest.TestCase): def setUp(self): # configurations not to be changed accoss tests self.max_seq_length = 80 @@ -88,7 +87,7 @@ def setUp(self): self.gradient_accumulation_steps = 8 self.data_dir = "/bert_data/hf_data/swag/swagaf/data" self.output_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "multiple_choice_test_output/") - self.cache_dir = '/tmp/multiple_choice/' + self.cache_dir = "/tmp/multiple_choice/" self.logging_steps = 10 self.rtol = 2e-01 @@ -97,8 +96,8 @@ def test_bert_with_swag(self): expected_loss = 0.64 results = self.run_multiple_choice(model_name="bert-base-cased", task_name="swag", fp16=False) - assert(results['acc'] >= expected_acc) - assert(results['loss'] <= expected_loss) + assert results["acc"] >= expected_acc + assert results["loss"] <= expected_loss def test_bert_fp16_with_swag(self): # larger batch can be handled with mixed precision @@ -108,20 +107,29 @@ def test_bert_fp16_with_swag(self): expected_loss = 0.68 results = self.run_multiple_choice(model_name="bert-base-cased", task_name="swag", fp16=True) - assert(results['acc'] >= expected_acc) - assert(results['loss'] <= expected_loss) + assert results["acc"] >= expected_acc + assert results["loss"] <= expected_loss def run_multiple_choice(self, model_name, task_name, fp16): model_args = ModelArguments(model_name_or_path=model_name, cache_dir=self.cache_dir) - data_args = DataTrainingArguments(task_name=task_name, data_dir=self.data_dir, - max_seq_length=self.max_seq_length) + data_args = DataTrainingArguments( + task_name=task_name, data_dir=self.data_dir, max_seq_length=self.max_seq_length + ) - training_args = TrainingArguments(output_dir=os.path.join(self.output_dir, task_name), do_train=True, do_eval=True, + training_args = TrainingArguments( + output_dir=os.path.join(self.output_dir, task_name), + do_train=True, + do_eval=True, per_gpu_train_batch_size=self.train_batch_size, per_gpu_eval_batch_size=self.eval_batch_size, - learning_rate=self.learning_rate, num_train_epochs=self.num_train_epochs,local_rank=self.local_rank, - overwrite_output_dir=self.overwrite_output_dir, gradient_accumulation_steps=self.gradient_accumulation_steps, - fp16=fp16, logging_steps=self.logging_steps) + learning_rate=self.learning_rate, + num_train_epochs=self.num_train_epochs, + local_rank=self.local_rank, + overwrite_output_dir=self.overwrite_output_dir, + gradient_accumulation_steps=self.gradient_accumulation_steps, + fp16=fp16, + logging_steps=self.logging_steps, + ) # Setup logging logging.basicConfig( @@ -200,23 +208,46 @@ def compute_metrics(p: EvalPrediction) -> Dict: preds = np.argmax(p.predictions, axis=1) return {"acc": simple_accuracy(preds, p.label_ids)} - if model_name.startswith('bert'): + if model_name.startswith("bert"): model_desc = { - 'inputs': [ - ('input_ids', ['batch', num_labels, 'max_seq_len_in_batch'],), - ('attention_mask', ['batch', num_labels, 'max_seq_len_in_batch'],), - ('token_type_ids', ['batch', num_labels, 'max_seq_len_in_batch'],), - ('labels', ['batch', num_labels],)], - 'outputs': [('loss', [], True), - ('reshaped_logits', ['batch', num_labels])]} + "inputs": [ + ( + "input_ids", + ["batch", num_labels, "max_seq_len_in_batch"], + ), + ( + "attention_mask", + ["batch", num_labels, "max_seq_len_in_batch"], + ), + ( + "token_type_ids", + ["batch", num_labels, "max_seq_len_in_batch"], + ), + ( + "labels", + ["batch", num_labels], + ), + ], + "outputs": [("loss", [], True), ("reshaped_logits", ["batch", num_labels])], + } else: model_desc = { - 'inputs': [ - ('input_ids', ['batch', num_labels, 'max_seq_len_in_batch'],), - ('attention_mask', ['batch', num_labels, 'max_seq_len_in_batch'],), - ('labels', ['batch', num_labels],)], - 'outputs': [('loss', [], True), - ('reshaped_logits', ['batch', num_labels])]} + "inputs": [ + ( + "input_ids", + ["batch", num_labels, "max_seq_len_in_batch"], + ), + ( + "attention_mask", + ["batch", num_labels, "max_seq_len_in_batch"], + ), + ( + "labels", + ["batch", num_labels], + ), + ], + "outputs": [("loss", [], True), ("reshaped_logits", ["batch", num_labels])], + } # Initialize the ORTTrainer within ORTTransformerTrainer trainer = ORTTransformerTrainer( @@ -242,11 +273,12 @@ def compute_metrics(p: EvalPrediction) -> Dict: logger.info("***** Eval results {} *****".format(data_args.task_name)) for key, value in result.items(): - logger.info(" %s = %s", key, value) + logger.info(" %s = %s", key, value) results.update(result) return results + if __name__ == "__main__": unittest.main() diff --git a/orttraining/orttraining/test/python/orttraining_test_allreduce_adasum.py b/orttraining/orttraining/test/python/orttraining_test_allreduce_adasum.py index 75abc9f2aacbc..aa966b31a8d19 100644 --- a/orttraining/orttraining/test/python/orttraining_test_allreduce_adasum.py +++ b/orttraining/orttraining/test/python/orttraining_test_allreduce_adasum.py @@ -12,6 +12,7 @@ from _test_commons import _load_pytorch_transformer_model from onnxruntime import set_seed + def _run_adasum_tests(opts): # Common setup seed = 42 @@ -26,58 +27,59 @@ def _run_adasum_tests(opts): result = trainer.train_step(data, targets) assert result is not None + def test_single_precision_adasum_on_gpu(): # Common setup world_rank = get_mpi_context_world_rank() world_size = get_mpi_context_world_size() set_cuda_device_id(world_rank) - device = 'cuda:' + str(world_rank) - opts = orttrainer.ORTTrainerOptions({ - 'debug' : { - 'deterministic_compute': True - }, - 'device' : { - 'id' : device, - }, - 'distributed': { - 'world_rank' : world_rank, - 'world_size' : world_size, - 'enable_adasum': True, - } - - }) + device = "cuda:" + str(world_rank) + opts = orttrainer.ORTTrainerOptions( + { + "debug": {"deterministic_compute": True}, + "device": { + "id": device, + }, + "distributed": { + "world_rank": world_rank, + "world_size": world_size, + "enable_adasum": True, + }, + } + ) _run_adasum_tests(opts) + def test_half_precision_adasum_on_gpu(): # Common setup world_rank = get_mpi_context_world_rank() world_size = get_mpi_context_world_size() set_cuda_device_id(world_rank) - device = 'cuda:' + str(world_rank) - opts = orttrainer.ORTTrainerOptions({ - 'debug' : { - 'deterministic_compute': True - }, - 'device' : { - 'id' : device, - }, - 'mixed_precision': { - 'enabled': True - }, - 'distributed': { - 'world_rank' : world_rank, - 'world_size' : world_size, - 'enable_adasum': True, - } - - }) + device = "cuda:" + str(world_rank) + opts = orttrainer.ORTTrainerOptions( + { + "debug": {"deterministic_compute": True}, + "device": { + "id": device, + }, + "mixed_precision": {"enabled": True}, + "distributed": { + "world_rank": world_rank, + "world_size": world_size, + "enable_adasum": True, + }, + } + ) _run_adasum_tests(opts) + function_map = { - 'test_single_precision_adasum_on_gpu': test_single_precision_adasum_on_gpu, - 'test_half_precision_adasum_on_gpu': test_half_precision_adasum_on_gpu, + "test_single_precision_adasum_on_gpu": test_single_precision_adasum_on_gpu, + "test_half_precision_adasum_on_gpu": test_half_precision_adasum_on_gpu, } -parser = argparse.ArgumentParser(description='Test adasum allreduce') -parser.add_argument('--scenario', choices=function_map.keys(), help='training scenario to test adasum allreduce', required=True) +parser = argparse.ArgumentParser(description="Test adasum allreduce") +parser.add_argument( + "--scenario", choices=function_map.keys(), help="training scenario to test adasum allreduce", required=True +) args = parser.parse_args() function_map[args.scenario]() diff --git a/orttraining/orttraining/test/python/orttraining_test_bert_postprocess.py b/orttraining/orttraining/test/python/orttraining_test_bert_postprocess.py index bfcbaee289929..66de14dce6852 100644 --- a/orttraining/orttraining/test/python/orttraining_test_bert_postprocess.py +++ b/orttraining/orttraining/test/python/orttraining_test_bert_postprocess.py @@ -1,5 +1,6 @@ from orttraining_test_model_transform import add_name, fix_transpose, add_expand_shape from orttraining_test_layer_norm_transform import layer_norm_transform + def postprocess_model(model): add_name(model) diff --git a/orttraining/orttraining/test/python/orttraining_test_checkpoint.py b/orttraining/orttraining/test/python/orttraining_test_checkpoint.py index 1bbcd41221b4a..f737f3d4131b5 100644 --- a/orttraining/orttraining/test/python/orttraining_test_checkpoint.py +++ b/orttraining/orttraining/test/python/orttraining_test_checkpoint.py @@ -9,7 +9,7 @@ from checkpoint._test_helpers import makedir from _test_commons import _single_run, _distributed_run -checkpoint_dir = os.path.abspath('checkpoint/checkpoint_dir/') +checkpoint_dir = os.path.abspath("checkpoint/checkpoint_dir/") makedir(checkpoint_dir) # test workflow: @@ -51,116 +51,408 @@ # - Load all states and aggregate them into 1 state dictionary fpr both the configs. # - Compare this aggregated state dictionaries against one another. -save_checkpoint_file = os.path.join('checkpoint', 'orttraining_test_save_checkpoint.py') -load_checkpoint_file = os.path.join('checkpoint', 'orttraining_test_load_checkpoint.py') -aggregate_checkpoint_file = os.path.join('checkpoint', 'orttraining_test_checkpoint_aggregation.py') -optim_state_file = os.path.join('checkpoint', 'orttraining_test_load_optimizer_state.py') -backend_api_file = os.path.join('checkpoint', 'orttraining_test_backend_api.py') +save_checkpoint_file = os.path.join("checkpoint", "orttraining_test_save_checkpoint.py") +load_checkpoint_file = os.path.join("checkpoint", "orttraining_test_load_checkpoint.py") +aggregate_checkpoint_file = os.path.join("checkpoint", "orttraining_test_checkpoint_aggregation.py") +optim_state_file = os.path.join("checkpoint", "orttraining_test_load_optimizer_state.py") +backend_api_file = os.path.join("checkpoint", "orttraining_test_backend_api.py") -single_node_full_precision_path = os.path.join(checkpoint_dir, 'single_node', 'full_precision') -single_node_mixed_precision_path = os.path.join(checkpoint_dir, 'single_node', 'mixed_precision') -distributed_zero_full_precision_lamb_path = os.path.join(checkpoint_dir, 'distributed_zero', 'full_precision', 'lamb') -distributed_zero_mixed_precision_lamb_path = os.path.join(checkpoint_dir, 'distributed_zero', 'mixed_precision', 'lamb') +single_node_full_precision_path = os.path.join(checkpoint_dir, "single_node", "full_precision") +single_node_mixed_precision_path = os.path.join(checkpoint_dir, "single_node", "mixed_precision") +distributed_zero_full_precision_lamb_path = os.path.join(checkpoint_dir, "distributed_zero", "full_precision", "lamb") +distributed_zero_mixed_precision_lamb_path = os.path.join(checkpoint_dir, "distributed_zero", "mixed_precision", "lamb") # megatron saving and loading uses a different model -single_node_full_precision_bart_path = os.path.join(checkpoint_dir, 'bart', 'single_node', 'full_precision') -single_node_mixed_precision_bart_path = os.path.join(checkpoint_dir, 'bart', 'single_node', 'mixed_precision') -distributed_zero_full_precision_lamb_bart_path = os.path.join(checkpoint_dir, 'bart', 'distributed_zero', 'full_precision', 'lamb') -distributed_zero_mixed_precision_lamb_bart_path = os.path.join(checkpoint_dir, 'bart', 'distributed_zero', 'mixed_precision', 'lamb') -distributed_megatron_full_precision_lamb_path = os.path.join(checkpoint_dir, 'bart', 'distributed_megatron', 'full_precision', 'lamb') -distributed_megatron_mixed_precision_lamb_path = os.path.join(checkpoint_dir, 'bart', 'distributed_megatron', 'mixed_precision', 'lamb') -distributed_zero_megatron_full_precision_adam_path = os.path.join(checkpoint_dir, 'bart', 'distributed_zero_megatron', 'full_precision', 'adam') -distributed_zero_megatron_mixed_precision_adam_path = os.path.join(checkpoint_dir, 'bart', 'distributed_zero_megatron', 'mixed_precision', 'adam') -distributed_zero_megatron_full_precision_lamb_path = os.path.join(checkpoint_dir, 'bart', 'distributed_zero_megatron', 'full_precision', 'lamb') -distributed_zero_megatron_mixed_precision_lamb_path = os.path.join(checkpoint_dir, 'bart', 'distributed_zero_megatron', 'mixed_precision', 'lamb') +single_node_full_precision_bart_path = os.path.join(checkpoint_dir, "bart", "single_node", "full_precision") +single_node_mixed_precision_bart_path = os.path.join(checkpoint_dir, "bart", "single_node", "mixed_precision") +distributed_zero_full_precision_lamb_bart_path = os.path.join( + checkpoint_dir, "bart", "distributed_zero", "full_precision", "lamb" +) +distributed_zero_mixed_precision_lamb_bart_path = os.path.join( + checkpoint_dir, "bart", "distributed_zero", "mixed_precision", "lamb" +) +distributed_megatron_full_precision_lamb_path = os.path.join( + checkpoint_dir, "bart", "distributed_megatron", "full_precision", "lamb" +) +distributed_megatron_mixed_precision_lamb_path = os.path.join( + checkpoint_dir, "bart", "distributed_megatron", "mixed_precision", "lamb" +) +distributed_zero_megatron_full_precision_adam_path = os.path.join( + checkpoint_dir, "bart", "distributed_zero_megatron", "full_precision", "adam" +) +distributed_zero_megatron_mixed_precision_adam_path = os.path.join( + checkpoint_dir, "bart", "distributed_zero_megatron", "mixed_precision", "adam" +) +distributed_zero_megatron_full_precision_lamb_path = os.path.join( + checkpoint_dir, "bart", "distributed_zero_megatron", "full_precision", "lamb" +) +distributed_zero_megatron_mixed_precision_lamb_path = os.path.join( + checkpoint_dir, "bart", "distributed_zero_megatron", "mixed_precision", "lamb" +) # save all checkpoint files (pre-checkpoint) -_single_run(save_checkpoint_file, 'single_node_full_precision', single_node_full_precision_path) -_single_run(save_checkpoint_file, 'single_node_mixed_precision', single_node_mixed_precision_path) -_distributed_run(save_checkpoint_file, 'distributed_zero_full_precision_lamb', distributed_zero_full_precision_lamb_path) -_distributed_run(save_checkpoint_file, 'distributed_zero_mixed_precision_lamb', distributed_zero_mixed_precision_lamb_path) +_single_run(save_checkpoint_file, "single_node_full_precision", single_node_full_precision_path) +_single_run(save_checkpoint_file, "single_node_mixed_precision", single_node_mixed_precision_path) +_distributed_run( + save_checkpoint_file, "distributed_zero_full_precision_lamb", distributed_zero_full_precision_lamb_path +) +_distributed_run( + save_checkpoint_file, "distributed_zero_mixed_precision_lamb", distributed_zero_mixed_precision_lamb_path +) -_single_run(save_checkpoint_file, 'single_node_full_precision_bart', single_node_full_precision_bart_path) -_single_run(save_checkpoint_file, 'single_node_mixed_precision_bart', single_node_mixed_precision_bart_path) -_distributed_run(save_checkpoint_file, 'distributed_zero_full_precision_lamb_bart', distributed_zero_full_precision_lamb_bart_path) -_distributed_run(save_checkpoint_file, 'distributed_zero_mixed_precision_lamb_bart', distributed_zero_mixed_precision_lamb_bart_path) +_single_run(save_checkpoint_file, "single_node_full_precision_bart", single_node_full_precision_bart_path) +_single_run(save_checkpoint_file, "single_node_mixed_precision_bart", single_node_mixed_precision_bart_path) +_distributed_run( + save_checkpoint_file, "distributed_zero_full_precision_lamb_bart", distributed_zero_full_precision_lamb_bart_path +) +_distributed_run( + save_checkpoint_file, "distributed_zero_mixed_precision_lamb_bart", distributed_zero_mixed_precision_lamb_bart_path +) -_distributed_run(save_checkpoint_file, 'distributed_megatron_full_precision_lamb', distributed_megatron_full_precision_lamb_path) -_distributed_run(save_checkpoint_file, 'distributed_megatron_mixed_precision_lamb', distributed_megatron_mixed_precision_lamb_path) -_distributed_run(save_checkpoint_file, 'distributed_zero_megatron_full_precision_lamb', distributed_zero_megatron_full_precision_lamb_path) -_distributed_run(save_checkpoint_file, 'distributed_zero_megatron_mixed_precision_lamb', distributed_zero_megatron_mixed_precision_lamb_path) +_distributed_run( + save_checkpoint_file, "distributed_megatron_full_precision_lamb", distributed_megatron_full_precision_lamb_path +) +_distributed_run( + save_checkpoint_file, "distributed_megatron_mixed_precision_lamb", distributed_megatron_mixed_precision_lamb_path +) +_distributed_run( + save_checkpoint_file, + "distributed_zero_megatron_full_precision_lamb", + distributed_zero_megatron_full_precision_lamb_path, +) +_distributed_run( + save_checkpoint_file, + "distributed_zero_megatron_mixed_precision_lamb", + distributed_zero_megatron_mixed_precision_lamb_path, +) # load checkpoint files (post-checkpoint) # going to single node trainer -_single_run(load_checkpoint_file, 'test_load_from_single_node_full_precision_into_single_node_full_precision', single_node_full_precision_path) -_single_run(load_checkpoint_file, 'test_load_from_single_node_mixed_precision_into_single_node_full_precision', single_node_mixed_precision_path) -_single_run(load_checkpoint_file, 'test_load_from_single_node_mixed_precision_into_single_node_mixed_precision', single_node_mixed_precision_path) -_single_run(load_checkpoint_file, 'test_load_from_single_node_full_precision_into_single_node_mixed_precision', single_node_full_precision_path) -_single_run(load_checkpoint_file, 'test_load_from_distributed_zero_full_precision_into_single_node_full_precision', distributed_zero_full_precision_lamb_path) -_single_run(load_checkpoint_file, 'test_load_from_distributed_zero_mixed_precision_into_single_node_full_precision', distributed_zero_mixed_precision_lamb_path) -_single_run(load_checkpoint_file, 'test_load_from_distributed_zero_mixed_precision_into_single_node_mixed_precision', distributed_zero_mixed_precision_lamb_path) -_single_run(load_checkpoint_file, 'test_load_from_distributed_zero_full_precision_into_single_node_mixed_precision', distributed_zero_full_precision_lamb_path) -_single_run(load_checkpoint_file, 'test_load_from_distributed_megatron_full_precision_into_single_node_full_precision', distributed_megatron_full_precision_lamb_path) -_single_run(load_checkpoint_file, 'test_load_from_distributed_megatron_mixed_precision_into_single_node_full_precision', distributed_megatron_mixed_precision_lamb_path) -_single_run(load_checkpoint_file, 'test_load_from_distributed_megatron_mixed_precision_into_single_node_mixed_precision', distributed_megatron_mixed_precision_lamb_path) -_single_run(load_checkpoint_file, 'test_load_from_distributed_megatron_full_precision_into_single_node_mixed_precision', distributed_megatron_full_precision_lamb_path) -_single_run(load_checkpoint_file, 'test_load_from_distributed_zero_megatron_full_precision_into_single_node_full_precision', distributed_zero_megatron_full_precision_lamb_path) -_single_run(load_checkpoint_file, 'test_load_from_distributed_zero_megatron_mixed_precision_into_single_node_full_precision', distributed_zero_megatron_mixed_precision_lamb_path) -_single_run(load_checkpoint_file, 'test_load_from_distributed_zero_megatron_mixed_precision_into_single_node_mixed_precision', distributed_zero_megatron_mixed_precision_lamb_path) -_single_run(load_checkpoint_file, 'test_load_from_distributed_zero_megatron_full_precision_into_single_node_mixed_precision', distributed_zero_megatron_full_precision_lamb_path) +_single_run( + load_checkpoint_file, + "test_load_from_single_node_full_precision_into_single_node_full_precision", + single_node_full_precision_path, +) +_single_run( + load_checkpoint_file, + "test_load_from_single_node_mixed_precision_into_single_node_full_precision", + single_node_mixed_precision_path, +) +_single_run( + load_checkpoint_file, + "test_load_from_single_node_mixed_precision_into_single_node_mixed_precision", + single_node_mixed_precision_path, +) +_single_run( + load_checkpoint_file, + "test_load_from_single_node_full_precision_into_single_node_mixed_precision", + single_node_full_precision_path, +) +_single_run( + load_checkpoint_file, + "test_load_from_distributed_zero_full_precision_into_single_node_full_precision", + distributed_zero_full_precision_lamb_path, +) +_single_run( + load_checkpoint_file, + "test_load_from_distributed_zero_mixed_precision_into_single_node_full_precision", + distributed_zero_mixed_precision_lamb_path, +) +_single_run( + load_checkpoint_file, + "test_load_from_distributed_zero_mixed_precision_into_single_node_mixed_precision", + distributed_zero_mixed_precision_lamb_path, +) +_single_run( + load_checkpoint_file, + "test_load_from_distributed_zero_full_precision_into_single_node_mixed_precision", + distributed_zero_full_precision_lamb_path, +) +_single_run( + load_checkpoint_file, + "test_load_from_distributed_megatron_full_precision_into_single_node_full_precision", + distributed_megatron_full_precision_lamb_path, +) +_single_run( + load_checkpoint_file, + "test_load_from_distributed_megatron_mixed_precision_into_single_node_full_precision", + distributed_megatron_mixed_precision_lamb_path, +) +_single_run( + load_checkpoint_file, + "test_load_from_distributed_megatron_mixed_precision_into_single_node_mixed_precision", + distributed_megatron_mixed_precision_lamb_path, +) +_single_run( + load_checkpoint_file, + "test_load_from_distributed_megatron_full_precision_into_single_node_mixed_precision", + distributed_megatron_full_precision_lamb_path, +) +_single_run( + load_checkpoint_file, + "test_load_from_distributed_zero_megatron_full_precision_into_single_node_full_precision", + distributed_zero_megatron_full_precision_lamb_path, +) +_single_run( + load_checkpoint_file, + "test_load_from_distributed_zero_megatron_mixed_precision_into_single_node_full_precision", + distributed_zero_megatron_mixed_precision_lamb_path, +) +_single_run( + load_checkpoint_file, + "test_load_from_distributed_zero_megatron_mixed_precision_into_single_node_mixed_precision", + distributed_zero_megatron_mixed_precision_lamb_path, +) +_single_run( + load_checkpoint_file, + "test_load_from_distributed_zero_megatron_full_precision_into_single_node_mixed_precision", + distributed_zero_megatron_full_precision_lamb_path, +) # going to distributed zero trainer -_distributed_run(load_checkpoint_file, 'test_load_from_single_node_full_precision_into_distributed_zero_full_precision', single_node_full_precision_path) -_distributed_run(load_checkpoint_file, 'test_load_from_single_node_mixed_precision_into_distributed_zero_full_precision', single_node_mixed_precision_path) -_distributed_run(load_checkpoint_file, 'test_load_from_single_node_mixed_precision_into_distributed_zero_mixed_precision', single_node_mixed_precision_path) -_distributed_run(load_checkpoint_file, 'test_load_from_single_node_full_precision_into_distributed_zero_mixed_precision', single_node_full_precision_path) -_distributed_run(load_checkpoint_file, 'test_load_from_distributed_zero_full_precision_into_distributed_zero_full_precision', distributed_zero_full_precision_lamb_path) -_distributed_run(load_checkpoint_file, 'test_load_from_distributed_zero_mixed_precision_into_distributed_zero_full_precision', distributed_zero_mixed_precision_lamb_path) -_distributed_run(load_checkpoint_file, 'test_load_from_distributed_zero_mixed_precision_into_distributed_zero_mixed_precision', distributed_zero_mixed_precision_lamb_path) -_distributed_run(load_checkpoint_file, 'test_load_from_distributed_zero_full_precision_into_distributed_zero_mixed_precision', distributed_zero_full_precision_lamb_path) -_distributed_run(load_checkpoint_file, 'test_load_from_distributed_megatron_full_precision_into_distributed_zero_full_precision', distributed_megatron_full_precision_lamb_path) -_distributed_run(load_checkpoint_file, 'test_load_from_distributed_megatron_mixed_precision_into_distributed_zero_full_precision', distributed_megatron_mixed_precision_lamb_path) -_distributed_run(load_checkpoint_file, 'test_load_from_distributed_megatron_mixed_precision_into_distributed_zero_mixed_precision', distributed_megatron_mixed_precision_lamb_path) -_distributed_run(load_checkpoint_file, 'test_load_from_distributed_megatron_full_precision_into_distributed_zero_mixed_precision', distributed_megatron_full_precision_lamb_path) -_distributed_run(load_checkpoint_file, 'test_load_from_distributed_zero_megatron_full_precision_into_distributed_zero_full_precision', distributed_zero_megatron_full_precision_lamb_path) -_distributed_run(load_checkpoint_file, 'test_load_from_distributed_zero_megatron_mixed_precision_into_distributed_zero_full_precision', distributed_zero_megatron_mixed_precision_lamb_path) -_distributed_run(load_checkpoint_file, 'test_load_from_distributed_zero_megatron_mixed_precision_into_distributed_zero_mixed_precision', distributed_zero_megatron_mixed_precision_lamb_path) -_distributed_run(load_checkpoint_file, 'test_load_from_distributed_zero_megatron_full_precision_into_distributed_zero_mixed_precision', distributed_zero_megatron_full_precision_lamb_path) +_distributed_run( + load_checkpoint_file, + "test_load_from_single_node_full_precision_into_distributed_zero_full_precision", + single_node_full_precision_path, +) +_distributed_run( + load_checkpoint_file, + "test_load_from_single_node_mixed_precision_into_distributed_zero_full_precision", + single_node_mixed_precision_path, +) +_distributed_run( + load_checkpoint_file, + "test_load_from_single_node_mixed_precision_into_distributed_zero_mixed_precision", + single_node_mixed_precision_path, +) +_distributed_run( + load_checkpoint_file, + "test_load_from_single_node_full_precision_into_distributed_zero_mixed_precision", + single_node_full_precision_path, +) +_distributed_run( + load_checkpoint_file, + "test_load_from_distributed_zero_full_precision_into_distributed_zero_full_precision", + distributed_zero_full_precision_lamb_path, +) +_distributed_run( + load_checkpoint_file, + "test_load_from_distributed_zero_mixed_precision_into_distributed_zero_full_precision", + distributed_zero_mixed_precision_lamb_path, +) +_distributed_run( + load_checkpoint_file, + "test_load_from_distributed_zero_mixed_precision_into_distributed_zero_mixed_precision", + distributed_zero_mixed_precision_lamb_path, +) +_distributed_run( + load_checkpoint_file, + "test_load_from_distributed_zero_full_precision_into_distributed_zero_mixed_precision", + distributed_zero_full_precision_lamb_path, +) +_distributed_run( + load_checkpoint_file, + "test_load_from_distributed_megatron_full_precision_into_distributed_zero_full_precision", + distributed_megatron_full_precision_lamb_path, +) +_distributed_run( + load_checkpoint_file, + "test_load_from_distributed_megatron_mixed_precision_into_distributed_zero_full_precision", + distributed_megatron_mixed_precision_lamb_path, +) +_distributed_run( + load_checkpoint_file, + "test_load_from_distributed_megatron_mixed_precision_into_distributed_zero_mixed_precision", + distributed_megatron_mixed_precision_lamb_path, +) +_distributed_run( + load_checkpoint_file, + "test_load_from_distributed_megatron_full_precision_into_distributed_zero_mixed_precision", + distributed_megatron_full_precision_lamb_path, +) +_distributed_run( + load_checkpoint_file, + "test_load_from_distributed_zero_megatron_full_precision_into_distributed_zero_full_precision", + distributed_zero_megatron_full_precision_lamb_path, +) +_distributed_run( + load_checkpoint_file, + "test_load_from_distributed_zero_megatron_mixed_precision_into_distributed_zero_full_precision", + distributed_zero_megatron_mixed_precision_lamb_path, +) +_distributed_run( + load_checkpoint_file, + "test_load_from_distributed_zero_megatron_mixed_precision_into_distributed_zero_mixed_precision", + distributed_zero_megatron_mixed_precision_lamb_path, +) +_distributed_run( + load_checkpoint_file, + "test_load_from_distributed_zero_megatron_full_precision_into_distributed_zero_mixed_precision", + distributed_zero_megatron_full_precision_lamb_path, +) # going to distributed zero+megatron trainer -_distributed_run(load_checkpoint_file, 'test_load_from_single_node_full_precision_into_distributed_megatron_full_precision', single_node_full_precision_bart_path) -_distributed_run(load_checkpoint_file, 'test_load_from_single_node_mixed_precision_into_distributed_megatron_full_precision', single_node_mixed_precision_bart_path) -_distributed_run(load_checkpoint_file, 'test_load_from_single_node_mixed_precision_into_distributed_megatron_mixed_precision', single_node_mixed_precision_bart_path) -_distributed_run(load_checkpoint_file, 'test_load_from_single_node_full_precision_into_distributed_megatron_mixed_precision', single_node_full_precision_bart_path) -_distributed_run(load_checkpoint_file, 'test_load_from_distributed_zero_full_precision_into_distributed_megatron_full_precision', distributed_zero_full_precision_lamb_bart_path) -_distributed_run(load_checkpoint_file, 'test_load_from_distributed_zero_mixed_precision_into_distributed_megatron_full_precision', distributed_zero_mixed_precision_lamb_bart_path) -_distributed_run(load_checkpoint_file, 'test_load_from_distributed_zero_mixed_precision_into_distributed_megatron_mixed_precision', distributed_zero_mixed_precision_lamb_bart_path) -_distributed_run(load_checkpoint_file, 'test_load_from_distributed_zero_full_precision_into_distributed_megatron_mixed_precision', distributed_zero_full_precision_lamb_bart_path) -_distributed_run(load_checkpoint_file, 'test_load_from_distributed_megatron_full_precision_into_distributed_megatron_full_precision', distributed_megatron_full_precision_lamb_path) -_distributed_run(load_checkpoint_file, 'test_load_from_distributed_megatron_mixed_precision_into_distributed_megatron_full_precision', distributed_megatron_mixed_precision_lamb_path) -_distributed_run(load_checkpoint_file, 'test_load_from_distributed_megatron_mixed_precision_into_distributed_megatron_mixed_precision', distributed_megatron_mixed_precision_lamb_path) -_distributed_run(load_checkpoint_file, 'test_load_from_distributed_megatron_full_precision_into_distributed_megatron_mixed_precision', distributed_megatron_full_precision_lamb_path) -_distributed_run(load_checkpoint_file, 'test_load_from_distributed_zero_megatron_full_precision_into_distributed_megatron_full_precision', distributed_zero_megatron_full_precision_lamb_path) -_distributed_run(load_checkpoint_file, 'test_load_from_distributed_zero_megatron_mixed_precision_into_distributed_megatron_full_precision', distributed_zero_megatron_mixed_precision_lamb_path) -_distributed_run(load_checkpoint_file, 'test_load_from_distributed_zero_megatron_mixed_precision_into_distributed_megatron_mixed_precision', distributed_zero_megatron_mixed_precision_lamb_path) -_distributed_run(load_checkpoint_file, 'test_load_from_distributed_zero_megatron_full_precision_into_distributed_megatron_mixed_precision', distributed_zero_megatron_full_precision_lamb_path) +_distributed_run( + load_checkpoint_file, + "test_load_from_single_node_full_precision_into_distributed_megatron_full_precision", + single_node_full_precision_bart_path, +) +_distributed_run( + load_checkpoint_file, + "test_load_from_single_node_mixed_precision_into_distributed_megatron_full_precision", + single_node_mixed_precision_bart_path, +) +_distributed_run( + load_checkpoint_file, + "test_load_from_single_node_mixed_precision_into_distributed_megatron_mixed_precision", + single_node_mixed_precision_bart_path, +) +_distributed_run( + load_checkpoint_file, + "test_load_from_single_node_full_precision_into_distributed_megatron_mixed_precision", + single_node_full_precision_bart_path, +) +_distributed_run( + load_checkpoint_file, + "test_load_from_distributed_zero_full_precision_into_distributed_megatron_full_precision", + distributed_zero_full_precision_lamb_bart_path, +) +_distributed_run( + load_checkpoint_file, + "test_load_from_distributed_zero_mixed_precision_into_distributed_megatron_full_precision", + distributed_zero_mixed_precision_lamb_bart_path, +) +_distributed_run( + load_checkpoint_file, + "test_load_from_distributed_zero_mixed_precision_into_distributed_megatron_mixed_precision", + distributed_zero_mixed_precision_lamb_bart_path, +) +_distributed_run( + load_checkpoint_file, + "test_load_from_distributed_zero_full_precision_into_distributed_megatron_mixed_precision", + distributed_zero_full_precision_lamb_bart_path, +) +_distributed_run( + load_checkpoint_file, + "test_load_from_distributed_megatron_full_precision_into_distributed_megatron_full_precision", + distributed_megatron_full_precision_lamb_path, +) +_distributed_run( + load_checkpoint_file, + "test_load_from_distributed_megatron_mixed_precision_into_distributed_megatron_full_precision", + distributed_megatron_mixed_precision_lamb_path, +) +_distributed_run( + load_checkpoint_file, + "test_load_from_distributed_megatron_mixed_precision_into_distributed_megatron_mixed_precision", + distributed_megatron_mixed_precision_lamb_path, +) +_distributed_run( + load_checkpoint_file, + "test_load_from_distributed_megatron_full_precision_into_distributed_megatron_mixed_precision", + distributed_megatron_full_precision_lamb_path, +) +_distributed_run( + load_checkpoint_file, + "test_load_from_distributed_zero_megatron_full_precision_into_distributed_megatron_full_precision", + distributed_zero_megatron_full_precision_lamb_path, +) +_distributed_run( + load_checkpoint_file, + "test_load_from_distributed_zero_megatron_mixed_precision_into_distributed_megatron_full_precision", + distributed_zero_megatron_mixed_precision_lamb_path, +) +_distributed_run( + load_checkpoint_file, + "test_load_from_distributed_zero_megatron_mixed_precision_into_distributed_megatron_mixed_precision", + distributed_zero_megatron_mixed_precision_lamb_path, +) +_distributed_run( + load_checkpoint_file, + "test_load_from_distributed_zero_megatron_full_precision_into_distributed_megatron_mixed_precision", + distributed_zero_megatron_full_precision_lamb_path, +) # going to distributed zero+megatron trainer -_distributed_run(load_checkpoint_file, 'test_load_from_single_node_full_precision_into_distributed_zero_megatron_full_precision', single_node_full_precision_bart_path) -_distributed_run(load_checkpoint_file, 'test_load_from_single_node_mixed_precision_into_distributed_zero_megatron_full_precision', single_node_mixed_precision_bart_path) -_distributed_run(load_checkpoint_file, 'test_load_from_single_node_mixed_precision_into_distributed_zero_megatron_mixed_precision', single_node_mixed_precision_bart_path) -_distributed_run(load_checkpoint_file, 'test_load_from_single_node_full_precision_into_distributed_zero_megatron_mixed_precision', single_node_full_precision_bart_path) -_distributed_run(load_checkpoint_file, 'test_load_from_distributed_zero_full_precision_into_distributed_zero_megatron_full_precision', distributed_zero_full_precision_lamb_bart_path) -_distributed_run(load_checkpoint_file, 'test_load_from_distributed_zero_mixed_precision_into_distributed_zero_megatron_full_precision', distributed_zero_mixed_precision_lamb_bart_path) -_distributed_run(load_checkpoint_file, 'test_load_from_distributed_zero_mixed_precision_into_distributed_zero_megatron_mixed_precision', distributed_zero_mixed_precision_lamb_bart_path) -_distributed_run(load_checkpoint_file, 'test_load_from_distributed_zero_full_precision_into_distributed_zero_megatron_mixed_precision', distributed_zero_full_precision_lamb_bart_path) -_distributed_run(load_checkpoint_file, 'test_load_from_distributed_megatron_full_precision_into_distributed_zero_megatron_full_precision', distributed_megatron_full_precision_lamb_path) -_distributed_run(load_checkpoint_file, 'test_load_from_distributed_megatron_mixed_precision_into_distributed_zero_megatron_full_precision', distributed_megatron_mixed_precision_lamb_path) -_distributed_run(load_checkpoint_file, 'test_load_from_distributed_megatron_mixed_precision_into_distributed_zero_megatron_mixed_precision', distributed_megatron_mixed_precision_lamb_path) -_distributed_run(load_checkpoint_file, 'test_load_from_distributed_megatron_full_precision_into_distributed_zero_megatron_mixed_precision', distributed_megatron_full_precision_lamb_path) -_distributed_run(load_checkpoint_file, 'test_load_from_distributed_zero_megatron_full_precision_into_distributed_zero_megatron_full_precision', distributed_zero_megatron_full_precision_lamb_path) -_distributed_run(load_checkpoint_file, 'test_load_from_distributed_zero_megatron_mixed_precision_into_distributed_zero_megatron_full_precision', distributed_zero_megatron_mixed_precision_lamb_path) -_distributed_run(load_checkpoint_file, 'test_load_from_distributed_zero_megatron_mixed_precision_into_distributed_zero_megatron_mixed_precision', distributed_zero_megatron_mixed_precision_lamb_path) -_distributed_run(load_checkpoint_file, 'test_load_from_distributed_zero_megatron_full_precision_into_distributed_zero_megatron_mixed_precision', distributed_zero_megatron_full_precision_lamb_path) +_distributed_run( + load_checkpoint_file, + "test_load_from_single_node_full_precision_into_distributed_zero_megatron_full_precision", + single_node_full_precision_bart_path, +) +_distributed_run( + load_checkpoint_file, + "test_load_from_single_node_mixed_precision_into_distributed_zero_megatron_full_precision", + single_node_mixed_precision_bart_path, +) +_distributed_run( + load_checkpoint_file, + "test_load_from_single_node_mixed_precision_into_distributed_zero_megatron_mixed_precision", + single_node_mixed_precision_bart_path, +) +_distributed_run( + load_checkpoint_file, + "test_load_from_single_node_full_precision_into_distributed_zero_megatron_mixed_precision", + single_node_full_precision_bart_path, +) +_distributed_run( + load_checkpoint_file, + "test_load_from_distributed_zero_full_precision_into_distributed_zero_megatron_full_precision", + distributed_zero_full_precision_lamb_bart_path, +) +_distributed_run( + load_checkpoint_file, + "test_load_from_distributed_zero_mixed_precision_into_distributed_zero_megatron_full_precision", + distributed_zero_mixed_precision_lamb_bart_path, +) +_distributed_run( + load_checkpoint_file, + "test_load_from_distributed_zero_mixed_precision_into_distributed_zero_megatron_mixed_precision", + distributed_zero_mixed_precision_lamb_bart_path, +) +_distributed_run( + load_checkpoint_file, + "test_load_from_distributed_zero_full_precision_into_distributed_zero_megatron_mixed_precision", + distributed_zero_full_precision_lamb_bart_path, +) +_distributed_run( + load_checkpoint_file, + "test_load_from_distributed_megatron_full_precision_into_distributed_zero_megatron_full_precision", + distributed_megatron_full_precision_lamb_path, +) +_distributed_run( + load_checkpoint_file, + "test_load_from_distributed_megatron_mixed_precision_into_distributed_zero_megatron_full_precision", + distributed_megatron_mixed_precision_lamb_path, +) +_distributed_run( + load_checkpoint_file, + "test_load_from_distributed_megatron_mixed_precision_into_distributed_zero_megatron_mixed_precision", + distributed_megatron_mixed_precision_lamb_path, +) +_distributed_run( + load_checkpoint_file, + "test_load_from_distributed_megatron_full_precision_into_distributed_zero_megatron_mixed_precision", + distributed_megatron_full_precision_lamb_path, +) +_distributed_run( + load_checkpoint_file, + "test_load_from_distributed_zero_megatron_full_precision_into_distributed_zero_megatron_full_precision", + distributed_zero_megatron_full_precision_lamb_path, +) +_distributed_run( + load_checkpoint_file, + "test_load_from_distributed_zero_megatron_mixed_precision_into_distributed_zero_megatron_full_precision", + distributed_zero_megatron_mixed_precision_lamb_path, +) +_distributed_run( + load_checkpoint_file, + "test_load_from_distributed_zero_megatron_mixed_precision_into_distributed_zero_megatron_mixed_precision", + distributed_zero_megatron_mixed_precision_lamb_path, +) +_distributed_run( + load_checkpoint_file, + "test_load_from_distributed_zero_megatron_full_precision_into_distributed_zero_megatron_mixed_precision", + distributed_zero_megatron_full_precision_lamb_path, +) shutil.rmtree(checkpoint_dir) diff --git a/orttraining/orttraining/test/python/orttraining_test_checkpoint_storage.py b/orttraining/orttraining/test/python/orttraining_test_checkpoint_storage.py index 6beceab19b580..2ef8322bd9cfd 100644 --- a/orttraining/orttraining/test/python/orttraining_test_checkpoint_storage.py +++ b/orttraining/orttraining/test/python/orttraining_test_checkpoint_storage.py @@ -13,6 +13,7 @@ # Helper functions + def _equals(a, b): """Checks recursively if two dictionaries are equal""" if isinstance(a, dict): @@ -30,6 +31,7 @@ def _equals(a, b): return False + def _numpy_types(obj_value): """Return a bool indicating whether or not the input obj_value is a numpy type object @@ -47,6 +49,7 @@ def _numpy_types(obj_value): return False return True + def _get_dict(separated_key): """Create dummy dictionary with different datatypes @@ -77,26 +80,23 @@ def _get_dict(separated_key): (original_dict, {'key': 'dict1/str1', "onnxruntime") """ test_dict = { - 'int1':1, - 'int2': 2, - 'int_list': [1,2,3,5,6], - 'dict1': { - 'np_array': np.arange(100), - 'dict2': {'int3': 3, 'int4': 4}, - 'str1': "onnxruntime" - }, - 'bool1': bool(True), - 'int5': 5, - 'float1': 2.345, - 'np_array_float': np.array([1.234, 2.345, 3.456]), - 'np_array_float_3_dim': np.array([[[1,2],[3,4]], [[5,6],[7,8]]]) + "int1": 1, + "int2": 2, + "int_list": [1, 2, 3, 5, 6], + "dict1": {"np_array": np.arange(100), "dict2": {"int3": 3, "int4": 4}, "str1": "onnxruntime"}, + "bool1": bool(True), + "int5": 5, + "float1": 2.345, + "np_array_float": np.array([1.234, 2.345, 3.456]), + "np_array_float_3_dim": np.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]), } - key = '' + key = "" expected_val = test_dict for single_key in separated_key: - key += single_key + '/' + key += single_key + "/" expected_val = expected_val[single_key] - return test_dict, {'key': key} if len(separated_key) > 0 else dict(), expected_val + return test_dict, {"key": key} if len(separated_key) > 0 else dict(), expected_val + class _CustomClass(object): """Custom object that encpsulates dummy values for loss, epoch and train_step""" @@ -110,33 +110,43 @@ def __eq__(self, other): if isinstance(other, _CustomClass): return self._loss == other._loss and self._epoch == other._epoch and self._train_step == other._train_step + # Test fixtures + @pytest.yield_fixture(scope="function") def checkpoint_storage_test_setup(): - checkpoint_dir = os.path.abspath('checkpoint_dir/') + checkpoint_dir = os.path.abspath("checkpoint_dir/") if not os.path.exists(checkpoint_dir): - os.makedirs(checkpoint_dir, exist_ok = True) - pytest.checkpoint_path = os.path.join(checkpoint_dir, 'checkpoint.ortcp') - yield 'checkpoint_storage_test_setup' + os.makedirs(checkpoint_dir, exist_ok=True) + pytest.checkpoint_path = os.path.join(checkpoint_dir, "checkpoint.ortcp") + yield "checkpoint_storage_test_setup" shutil.rmtree(checkpoint_dir) + @pytest.yield_fixture(scope="function") def checkpoint_storage_test_parameterized_setup(request, checkpoint_storage_test_setup): yield request.param + # Tests -@pytest.mark.parametrize("checkpoint_storage_test_parameterized_setup", [ - _get_dict([]), - _get_dict(['int1']), - _get_dict(['dict1']), - _get_dict(['dict1', 'dict2']), - _get_dict(['dict1', 'dict2', 'int4']), - _get_dict(['dict1', 'str1']), - _get_dict(['bool1']), - _get_dict(['float1']), - _get_dict(['np_array_float'])], indirect=True) + +@pytest.mark.parametrize( + "checkpoint_storage_test_parameterized_setup", + [ + _get_dict([]), + _get_dict(["int1"]), + _get_dict(["dict1"]), + _get_dict(["dict1", "dict2"]), + _get_dict(["dict1", "dict2", "int4"]), + _get_dict(["dict1", "str1"]), + _get_dict(["bool1"]), + _get_dict(["float1"]), + _get_dict(["np_array_float"]), + ], + indirect=True, +) def test_checkpoint_storage_saved_dict_matches_loaded(checkpoint_storage_test_parameterized_setup): to_save = checkpoint_storage_test_parameterized_setup[0] key_arg = checkpoint_storage_test_parameterized_setup[1] @@ -146,24 +156,41 @@ def test_checkpoint_storage_saved_dict_matches_loaded(checkpoint_storage_test_pa assert _equals(loaded, expected) assert _numpy_types(loaded) -@pytest.mark.parametrize("checkpoint_storage_test_parameterized_setup", [ - {'int_set': {1, 2, 3, 4, 5}}, - {'str_set': {'one', 'two'}}, - [1, 2, 3], - 2.352], indirect=True) + +@pytest.mark.parametrize( + "checkpoint_storage_test_parameterized_setup", + [{"int_set": {1, 2, 3, 4, 5}}, {"str_set": {"one", "two"}}, [1, 2, 3], 2.352], + indirect=True, +) def test_checkpoint_storage_saving_non_supported_types_fails(checkpoint_storage_test_parameterized_setup): to_save = checkpoint_storage_test_parameterized_setup with pytest.raises(Exception): _checkpoint_storage.save(to_save, pytest.checkpoint_path) -@pytest.mark.parametrize("checkpoint_storage_test_parameterized_setup", [ - ({'int64_tensor': torch.tensor(np.arange(100))}, 'int64_tensor', torch.int64, np.int64), - ({'int32_tensor': torch.tensor(np.arange(100), dtype=torch.int32)}, 'int32_tensor', torch.int32, np.int32), - ({'int16_tensor': torch.tensor(np.arange(100), dtype=torch.int16)}, 'int16_tensor', torch.int16, np.int16), - ({'int8_tensor': torch.tensor(np.arange(100), dtype=torch.int8)}, 'int8_tensor', torch.int8, np.int8), - ({'float64_tensor': torch.tensor(np.array([1.0, 2.0]))}, 'float64_tensor', torch.float64, np.float64), - ({'float32_tensor': torch.tensor(np.array([1.0, 2.0]), dtype=torch.float32)}, 'float32_tensor', torch.float32, np.float32), - ({'float16_tensor': torch.tensor(np.array([1.0, 2.0]), dtype=torch.float16)}, 'float16_tensor', torch.float16, np.float16)], indirect=True) + +@pytest.mark.parametrize( + "checkpoint_storage_test_parameterized_setup", + [ + ({"int64_tensor": torch.tensor(np.arange(100))}, "int64_tensor", torch.int64, np.int64), + ({"int32_tensor": torch.tensor(np.arange(100), dtype=torch.int32)}, "int32_tensor", torch.int32, np.int32), + ({"int16_tensor": torch.tensor(np.arange(100), dtype=torch.int16)}, "int16_tensor", torch.int16, np.int16), + ({"int8_tensor": torch.tensor(np.arange(100), dtype=torch.int8)}, "int8_tensor", torch.int8, np.int8), + ({"float64_tensor": torch.tensor(np.array([1.0, 2.0]))}, "float64_tensor", torch.float64, np.float64), + ( + {"float32_tensor": torch.tensor(np.array([1.0, 2.0]), dtype=torch.float32)}, + "float32_tensor", + torch.float32, + np.float32, + ), + ( + {"float16_tensor": torch.tensor(np.array([1.0, 2.0]), dtype=torch.float16)}, + "float16_tensor", + torch.float16, + np.float16, + ), + ], + indirect=True, +) def test_checkpoint_storage_saving_tensor_datatype(checkpoint_storage_test_parameterized_setup): tensor_dict = checkpoint_storage_test_parameterized_setup[0] tensor_name = checkpoint_storage_test_parameterized_setup[1] @@ -178,10 +205,16 @@ def test_checkpoint_storage_saving_tensor_datatype(checkpoint_storage_test_param assert loaded[tensor_name].dtype == np_dtype assert (tensor_dict[tensor_name].numpy() == loaded[tensor_name]).all() -@pytest.mark.parametrize("checkpoint_storage_test_parameterized_setup", [ - ({'two_dim': torch.ones([2, 4], dtype=torch.float64)}, 'two_dim'), - ({'three_dim': torch.ones([2, 4, 6], dtype=torch.float64)}, 'three_dim'), - ({'four_dim': torch.ones([2, 4, 6, 8], dtype=torch.float64)}, 'four_dim')], indirect=True) + +@pytest.mark.parametrize( + "checkpoint_storage_test_parameterized_setup", + [ + ({"two_dim": torch.ones([2, 4], dtype=torch.float64)}, "two_dim"), + ({"three_dim": torch.ones([2, 4, 6], dtype=torch.float64)}, "three_dim"), + ({"four_dim": torch.ones([2, 4, 6, 8], dtype=torch.float64)}, "four_dim"), + ], + indirect=True, +) def test_checkpoint_storage_saving_multiple_dimension_tensors(checkpoint_storage_test_parameterized_setup): tensor_dict = checkpoint_storage_test_parameterized_setup[0] tensor_name = checkpoint_storage_test_parameterized_setup[1] @@ -192,10 +225,10 @@ def test_checkpoint_storage_saving_multiple_dimension_tensors(checkpoint_storage assert isinstance(loaded[tensor_name], np.ndarray) assert (tensor_dict[tensor_name].numpy() == loaded[tensor_name]).all() -@pytest.mark.parametrize("checkpoint_storage_test_parameterized_setup", [ - {}, - {'a': {}}, - {'a': {'b': {}}}], indirect=True) + +@pytest.mark.parametrize( + "checkpoint_storage_test_parameterized_setup", [{}, {"a": {}}, {"a": {"b": {}}}], indirect=True +) def test_checkpoint_storage_saving_and_loading_empty_dictionaries_succeeds(checkpoint_storage_test_parameterized_setup): saved = checkpoint_storage_test_parameterized_setup _checkpoint_storage.save(saved, pytest.checkpoint_path) @@ -203,31 +236,27 @@ def test_checkpoint_storage_saving_and_loading_empty_dictionaries_succeeds(check loaded = _checkpoint_storage.load(pytest.checkpoint_path) assert _equals(saved, loaded) + def test_checkpoint_storage_load_file_that_does_not_exist_fails(checkpoint_storage_test_setup): with pytest.raises(Exception): _checkpoint_storage.load(pytest.checkpoint_path) + def test_checkpoint_storage_for_custom_user_dict_succeeds(checkpoint_storage_test_setup): custom_class = _CustomClass() - user_dict = { - 'tensor1': torch.tensor(np.arange(100), dtype=torch.float32), - 'custom_class': custom_class - } + user_dict = {"tensor1": torch.tensor(np.arange(100), dtype=torch.float32), "custom_class": custom_class} pickled_bytes = pickle.dumps(user_dict).hex() - to_save = { - 'a': torch.tensor(np.array([1.0, 2.0]), dtype=torch.float32), - 'user_dict': pickled_bytes - } + to_save = {"a": torch.tensor(np.array([1.0, 2.0]), dtype=torch.float32), "user_dict": pickled_bytes} _checkpoint_storage.save(to_save, pytest.checkpoint_path) loaded_dict = _checkpoint_storage.load(pytest.checkpoint_path) - assert (loaded_dict['a'] == to_save['a'].numpy()).all() + assert (loaded_dict["a"] == to_save["a"].numpy()).all() try: - loaded_dict['user_dict'] = loaded_dict['user_dict'].decode() + loaded_dict["user_dict"] = loaded_dict["user_dict"].decode() except AttributeError: pass - loaded_obj = pickle.loads(bytes.fromhex(loaded_dict['user_dict'])) + loaded_obj = pickle.loads(bytes.fromhex(loaded_dict["user_dict"])) - assert torch.all(loaded_obj['tensor1'].eq(user_dict['tensor1'])) - assert loaded_obj['custom_class'] == custom_class + assert torch.all(loaded_obj["tensor1"].eq(user_dict["tensor1"])) + assert loaded_obj["custom_class"] == custom_class diff --git a/orttraining/orttraining/test/python/orttraining_test_data_loader.py b/orttraining/orttraining/test/python/orttraining_test_data_loader.py index 54a12ecab08eb..2df5a3964bc94 100644 --- a/orttraining/orttraining/test/python/orttraining_test_data_loader.py +++ b/orttraining/orttraining/test/python/orttraining_test_data_loader.py @@ -6,6 +6,7 @@ global_rng = random.Random() + def ids_tensor(shape, vocab_size, rng=None, name=None): """Creates a random int32 tensor of the shape within the vocab size.""" if rng is None: @@ -41,13 +42,16 @@ def floats_tensor(shape, scale=1.0, rng=None, name=None): class OrtTestDataset(Dataset): def __init__(self, input_desc, seq_len, dataset_len, device): import copy + self.input_desc_ = copy.deepcopy(input_desc) for input_desc in self.input_desc_: shape_ = [] for i, axis in enumerate(input_desc.shape_): - if axis == 'max_seq_len_in_batch': - shape_ = shape_ + [seq_len, ] - elif axis != 'batch': + if axis == "max_seq_len_in_batch": + shape_ = shape_ + [ + seq_len, + ] + elif axis != "batch": shape_ = input_desc.shape_[i] input_desc.shape_ = shape_ self.dataset_len_ = dataset_len @@ -63,20 +67,23 @@ def __getitem__(self, item): input_batch.append(input_sample) return input_batch + def create_ort_test_dataloader(input_desc, batch_size, seq_len, dataset_len, device): dataset = OrtTestDataset(input_desc, seq_len, dataset_len, device) return DataLoader(dataset, batch_size=batch_size) + class BatchArgsOption(Enum): List = 1 Dict = 2 ListAndDict = 3 + def split_batch(batch, input_desc, args_count): total_argument_count = len(input_desc) - # batch=[input_ids[batch, seglen], attention_mask[batch, seglen], token_type_ids[batch,seglen], token_type_ids[batch, seglen]] - args = [] # (input_ids[batch, seglen], attention_mask[batch, seglen]) - kwargs = {} # {'token_type_ids': token_type_ids[batch,seglen], 'position_ids': token_type_ids[batch, seglen]} + # batch=[input_ids[batch, seglen], attention_mask[batch, seglen], token_type_ids[batch,seglen], token_type_ids[batch, seglen]] + args = [] # (input_ids[batch, seglen], attention_mask[batch, seglen]) + kwargs = {} # {'token_type_ids': token_type_ids[batch,seglen], 'position_ids': token_type_ids[batch, seglen]} for i in range(args_count): args = args + [batch[i]] diff --git a/orttraining/orttraining/test/python/orttraining_test_debuggability.py b/orttraining/orttraining/test/python/orttraining_test_debuggability.py index 25db0c60a3171..d3d6987f47c2a 100644 --- a/orttraining/orttraining/test/python/orttraining_test_debuggability.py +++ b/orttraining/orttraining/test/python/orttraining_test_debuggability.py @@ -8,13 +8,21 @@ from numpy.testing import assert_allclose from onnxruntime import set_seed -from onnxruntime.capi.ort_trainer import IODescription as Legacy_IODescription,\ - ModelDescription as Legacy_ModelDescription,\ - LossScaler as Legacy_LossScaler,\ - ORTTrainer as Legacy_ORTTrainer -from onnxruntime.training import _utils, amp, optim, orttrainer, TrainStepInfo,\ - model_desc_validation as md_val,\ - orttrainer_options as orttrainer_options +from onnxruntime.capi.ort_trainer import ( + IODescription as Legacy_IODescription, + ModelDescription as Legacy_ModelDescription, + LossScaler as Legacy_LossScaler, + ORTTrainer as Legacy_ORTTrainer, +) +from onnxruntime.training import ( + _utils, + amp, + optim, + orttrainer, + TrainStepInfo, + model_desc_validation as md_val, + orttrainer_options as orttrainer_options, +) from _test_commons import _load_pytorch_transformer_model @@ -26,20 +34,25 @@ ############################################################################### -@pytest.mark.parametrize("seed, device", [ - (24, 'cuda'), -]) +@pytest.mark.parametrize( + "seed, device", + [ + (24, "cuda"), + ], +) def testORTTransformerModelExport(seed, device): # Common setup optim_config = optim.LambConfig() - opts = orttrainer.ORTTrainerOptions({ - 'debug' : { - 'check_model_export': True, - }, - 'device' : { - 'id' : device, + opts = orttrainer.ORTTrainerOptions( + { + "debug": { + "check_model_export": True, + }, + "device": { + "id": device, + }, } - }) + ) # Setup for the first ORTTRainer run torch.manual_seed(seed) @@ -49,4 +62,3 @@ def testORTTransformerModelExport(seed, device): data, targets = batcher_fn(train_data, 0) _ = first_trainer.train_step(data, targets) assert first_trainer._onnx_model is not None - diff --git a/orttraining/orttraining/test/python/orttraining_test_dhp_parallel_tests.py b/orttraining/orttraining/test/python/orttraining_test_dhp_parallel_tests.py index b0fa605e26bde..807cff93ebf9e 100644 --- a/orttraining/orttraining/test/python/orttraining_test_dhp_parallel_tests.py +++ b/orttraining/orttraining/test/python/orttraining_test_dhp_parallel_tests.py @@ -8,10 +8,8 @@ import torch from _test_commons import run_subprocess -logging.basicConfig( - format="%(asctime)s %(name)s [%(levelname)s] - %(message)s", - level=logging.DEBUG) -log = logging.getLogger('DistributedTests') +logging.basicConfig(format="%(asctime)s %(name)s [%(levelname)s] - %(message)s", level=logging.DEBUG) +log = logging.getLogger("DistributedTests") # This function should be used to call all DxHxP parallel test scripts. def main(): @@ -19,23 +17,21 @@ def main(): # Declare test scripts for parallel tests. # New test scripts should be added to "dhp_parallel" folder. - distributed_test_files = [os.path.join('dhp_parallel', - 'orttraining_test_parallel_train_simple_model.py'), - os.path.join('dhp_parallel', - 'orttraining_test_parallel_train_simple_model_fp16.py')] + distributed_test_files = [ + os.path.join("dhp_parallel", "orttraining_test_parallel_train_simple_model.py"), + os.path.join("dhp_parallel", "orttraining_test_parallel_train_simple_model_fp16.py"), + ] # parallel_test_process_number[i] is the number of processes needed to run distributed_test_files[i]. distributed_test_process_counts = [4, 4] - log.info('Running parallel training tests.') + log.info("Running parallel training tests.") for test_file, process_count in zip(distributed_test_files, distributed_test_process_counts): if ngpus < process_count: - log.error( - 'Machine Configuration Error. More GPUs are needed to run ' + test_file) + log.error("Machine Configuration Error. More GPUs are needed to run " + test_file) return 1 - log.debug('RUN: ' + test_file) + log.debug("RUN: " + test_file) - command = ['mpirun', '-n', - str(process_count), sys.executable, test_file] + command = ["mpirun", "-n", str(process_count), sys.executable, test_file] # The current working directory is set in # onnxruntime/orttraining/orttraining/test/python/orttraining_distributed_tests.py @@ -44,5 +40,5 @@ def main(): return 0 -if __name__ == '__main__': +if __name__ == "__main__": sys.exit(main()) diff --git a/orttraining/orttraining/test/python/orttraining_test_experimental_gradient_graph.py b/orttraining/orttraining/test/python/orttraining_test_experimental_gradient_graph.py index e4bbf76c9be0a..c67de052753ad 100644 --- a/orttraining/orttraining/test/python/orttraining_test_experimental_gradient_graph.py +++ b/orttraining/orttraining/test/python/orttraining_test_experimental_gradient_graph.py @@ -14,15 +14,10 @@ class NeuralNet(torch.nn.Module): Simple example model. """ - def __init__(self, - input_size: int, - embedding_size: int, - hidden_size: int, - num_classes: int): + def __init__(self, input_size: int, embedding_size: int, hidden_size: int, num_classes: int): super(NeuralNet, self).__init__() - self.frozen_layer = torch.nn.Linear( - input_size, embedding_size, bias=False) + self.frozen_layer = torch.nn.Linear(input_size, embedding_size, bias=False) # Freeze a layer (mainly to test that gradients don't get output for it). self.frozen_layer.requires_grad_(False) @@ -43,8 +38,7 @@ def to_numpy(tensor): def binary_cross_entropy_loss(inp, target): - loss = -torch.sum(target * torch.log2(inp[:, 0]) + - (1-target) * torch.log2(inp[:, 1])) + loss = -torch.sum(target * torch.log2(inp[:, 0]) + (1 - target) * torch.log2(inp[:, 1])) return loss @@ -54,30 +48,27 @@ def test_save(self): # You can still make the gradient graph with torch.nn.CrossEntropyLoss() and this test will pass. loss_fn = binary_cross_entropy_loss input_size = 10 - model = NeuralNet(input_size=input_size, embedding_size=20, hidden_size=5, - num_classes=2) + model = NeuralNet(input_size=input_size, embedding_size=20, hidden_size=5, num_classes=2) directory_path = Path(os.path.dirname(__file__)).resolve() - gradient_graph_path = directory_path/'gradient_graph_model.onnx' + gradient_graph_path = directory_path / "gradient_graph_model.onnx" batch_size = 1 - example_input = torch.randn( - batch_size, input_size, requires_grad=True) + example_input = torch.randn(batch_size, input_size, requires_grad=True) example_labels = torch.tensor([1]) - export_gradient_graph( - model, loss_fn, example_input, example_labels, gradient_graph_path) + export_gradient_graph(model, loss_fn, example_input, example_labels, gradient_graph_path) onnx_model = onnx.load(str(gradient_graph_path)) onnx.checker.check_model(onnx_model) # Expected inputs: input, labels, models parameters. - self.assertEqual( - 1 + 1 + sum(1 for _ in model.parameters()), len(onnx_model.graph.input)) - + self.assertEqual(1 + 1 + sum(1 for _ in model.parameters()), len(onnx_model.graph.input)) + # Expected outputs: prediction, loss, and parameters with gradients. self.assertEqual( - 1 + 1 + sum(1 if p.requires_grad else 0 for p in model.parameters()), len(onnx_model.graph.output)) + 1 + 1 + sum(1 if p.requires_grad else 0 for p in model.parameters()), len(onnx_model.graph.output) + ) torch_out = model(example_input) @@ -86,6 +77,7 @@ def test_save(self): except ValueError: # Sometimes it is required to pass the available providers. from onnxruntime.capi import _pybind_state as C + available_providers = C.get_available_providers() ort_session = onnxruntime.InferenceSession(str(gradient_graph_path), providers=available_providers) @@ -101,25 +93,22 @@ def test_save(self): onnx_output_names = [node.name for node in onnx_model.graph.output] onnx_name_to_output = dict(zip(onnx_output_names, ort_outs)) - ort_output = onnx_name_to_output['output'] - np.testing.assert_allclose( - to_numpy(torch_out), ort_output, rtol=1e-03, atol=1e-05) + ort_output = onnx_name_to_output["output"] + np.testing.assert_allclose(to_numpy(torch_out), ort_output, rtol=1e-03, atol=1e-05) torch_loss = loss_fn(torch_out, example_labels) - ort_loss = onnx_name_to_output['loss'] - np.testing.assert_allclose( - to_numpy(torch_loss), ort_loss, rtol=1e-03, atol=1e-05) + ort_loss = onnx_name_to_output["loss"] + np.testing.assert_allclose(to_numpy(torch_loss), ort_loss, rtol=1e-03, atol=1e-05) # Make sure the gradients have the right shape. - model_param_names = tuple( - name for name, param in model.named_parameters() if param.requires_grad) + model_param_names = tuple(name for name, param in model.named_parameters() if param.requires_grad) self.assertEqual(4, len(model_param_names)) for name, param in model.named_parameters(): if param.requires_grad: - grad = onnx_name_to_output[name + '_grad'] + grad = onnx_name_to_output[name + "_grad"] self.assertEqual(param.size(), grad.shape) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/orttraining/orttraining/test/python/orttraining_test_hierarchical_ortmodule.py b/orttraining/orttraining/test/python/orttraining_test_hierarchical_ortmodule.py index a387fc60d02f4..7eba32402cc4a 100644 --- a/orttraining/orttraining/test/python/orttraining_test_hierarchical_ortmodule.py +++ b/orttraining/orttraining/test/python/orttraining_test_hierarchical_ortmodule.py @@ -30,11 +30,12 @@ def __init__(self): self.a = A() def forward(self, x): - def custom(): def custom_forward(x_): return self.a(x_) + return custom_forward + z = self.l1(checkpoint(custom(), x)) return z @@ -88,7 +89,7 @@ def __init__(self): self.d = D() def forward(self, x, case): - if case == 'reverse': + if case == "reverse": z = self.alpha * self.a(self.b(self.c(self.d(x)))) else: z = self.alpha * self.d(self.c(self.b(self.a(x)))) @@ -192,13 +193,10 @@ def trial(module_to_wrap, args, expected_num_ortmodule): for _ in range(num_trials): trial(Main(), [torch.rand(2).requires_grad_()], 6) trial(MainWithModuleList(), [torch.rand(2).requires_grad_()], 12) - trial(MainWithMultiModuleOutputs(), [ - torch.rand(2).requires_grad_()], 10) - trial(MainWithNonTensorInput(), [ - torch.rand(2).requires_grad_(), 'reverse'], 6) - trial(MainWithNonTensorInput(), [ - torch.rand(2).requires_grad_(), 'normal'], 6) + trial(MainWithMultiModuleOutputs(), [torch.rand(2).requires_grad_()], 10) + trial(MainWithNonTensorInput(), [torch.rand(2).requires_grad_(), "reverse"], 6) + trial(MainWithNonTensorInput(), [torch.rand(2).requires_grad_(), "normal"], 6) -if __name__ == '__main__': +if __name__ == "__main__": test_hierarchical_ortmodule() diff --git a/orttraining/orttraining/test/python/orttraining_test_layer_norm_transform.py b/orttraining/orttraining/test/python/orttraining_test_layer_norm_transform.py index 883d7386e768a..241a963e28498 100644 --- a/orttraining/orttraining/test/python/orttraining_test_layer_norm_transform.py +++ b/orttraining/orttraining/test/python/orttraining_test_layer_norm_transform.py @@ -1,16 +1,18 @@ import onnx + def find_node(graph_proto, op_type): nodes = [] map_input_node = {} for node in graph_proto.node: if node.op_type == op_type: map_input_node[node.input[0]] = node - if op_type == 'Div' or op_type == 'Mul': + if op_type == "Div" or op_type == "Mul": map_input_node[node.input[1]] = node nodes.append(node) return nodes, map_input_node + def gen_attribute(key, value): attr = AttributeProto() attr.name = key @@ -18,6 +20,7 @@ def gen_attribute(key, value): attr.type = AttributeProto.INTS return attr + def layer_norm_transform(model_proto): # a layer norm subgraph # input @@ -45,23 +48,23 @@ def layer_norm_transform(model_proto): graph_proto = model_proto.graph - _, map_input_Div = find_node(graph_proto, 'Div') + _, map_input_Div = find_node(graph_proto, "Div") - _, map_input_Sqrt = find_node(graph_proto, 'Sqrt') + _, map_input_Sqrt = find_node(graph_proto, "Sqrt") - _, map_input_Add = find_node(graph_proto, 'Add') + _, map_input_Add = find_node(graph_proto, "Add") - nodes_ReduceMean, map_input_ReduceMean = find_node(graph_proto, 'ReduceMean') + nodes_ReduceMean, map_input_ReduceMean = find_node(graph_proto, "ReduceMean") - _, map_input_Pow = find_node(graph_proto, 'Pow') + _, map_input_Pow = find_node(graph_proto, "Pow") - _, map_input_Mul = find_node(graph_proto, 'Mul') + _, map_input_Mul = find_node(graph_proto, "Mul") # find right side Sub (see the layer norm subgrapg) nodes_Sub = [] map_input_Sub = {} for node in graph_proto.node: - if node.op_type == 'Sub': + if node.op_type == "Sub": if node.output[0] in map_input_Pow: nodes_Sub.append(node) map_input_Sub[node.input[1]] = node @@ -78,7 +81,7 @@ def layer_norm_transform(model_proto): nodes_Constant = [] map_output_Constant = {} for node in graph_proto.node: - if node.op_type == 'Constant': + if node.op_type == "Constant": nodes_Constant.append(node) map_output_Constant[node.output[0]] = node @@ -115,7 +118,7 @@ def layer_norm_transform(model_proto): node_Sqrt = map_input_Sqrt[node_Add.output[0]] if node_Sqrt.output[0] not in map_input_Div: continue - + node_Div = map_input_Div[node_Sqrt.output[0]] if node_Div.output[0] not in map_input_Mul: continue @@ -147,21 +150,23 @@ def layer_norm_transform(model_proto): removed_nodes.append(map_output_Constant[node_Add.input[1]]) layer_norm_output.append(node_Add1.output[0]) id = id + 1 - layer_norm_output.append('saved_mean_' + str(id)) + layer_norm_output.append("saved_mean_" + str(id)) id = id + 1 - layer_norm_output.append('saved_inv_std_var_' + str(id)) - layer_norm = onnx.helper.make_node("LayerNormalization", - layer_norm_input, - layer_norm_output, - "LayerNormalization_" + str(id), - None, - axis = node_reduce.attribute[0].ints[0], - epsilon = 9.999999960041972e-13) + layer_norm_output.append("saved_inv_std_var_" + str(id)) + layer_norm = onnx.helper.make_node( + "LayerNormalization", + layer_norm_input, + layer_norm_output, + "LayerNormalization_" + str(id), + None, + axis=node_reduce.attribute[0].ints[0], + epsilon=9.999999960041972e-13, + ) layer_norm_nodes.append(layer_norm) # remove left side Subs for node in graph_proto.node: - if node.op_type == 'Sub': + if node.op_type == "Sub": if node.input[1] in first_ReduceMean_outputs: removed_nodes.append(node) diff --git a/orttraining/orttraining/test/python/orttraining_test_model_transform.py b/orttraining/orttraining/test/python/orttraining_test_model_transform.py index 9ef92aabcfac9..d6984dcf08425 100644 --- a/orttraining/orttraining/test/python/orttraining_test_model_transform.py +++ b/orttraining/orttraining/test/python/orttraining_test_model_transform.py @@ -1,10 +1,12 @@ from onnx import numpy_helper + def add_name(model): i = 0 for node in model.graph.node: - node.name = '%s_%d' %(node.op_type, i) - i += 1 + node.name = "%s_%d" % (node.op_type, i) + i += 1 + def find_single_output_node(model, arg): result = [] @@ -14,24 +16,28 @@ def find_single_output_node(model, arg): result.append(node) return result[0] if len(result) == 1 else None + def find_input_as_initializer(model, arg): for initializer in model.graph.initializer: if initializer.name == arg: return initializer return None + def get_node_index(model, node): for i, n in enumerate(model.graph.node): if n == node: return i return None + def replace_input_arg(model, arg, new_arg): for node in model.graph.node: for i in range(len(node.input)): if node.input[i] == arg: node.input[i] = new_arg + def find_weight_index(model, name): for index, w in enumerate(model.graph.initializer): if w.name == name: @@ -39,17 +45,18 @@ def find_weight_index(model, name): index += 1 return None + def fix_transpose(model): """ remove transpose node if its input is a 2d weight which only feeds to the node. """ - # Find transpose nodes with initializer weight as input. + # Find transpose nodes with initializer weight as input. # The input weight needs to be only feeded into the transpose node. # Collect these nodes and weights. transpose = [] for node in model.graph.node: - if node.op_type == 'Transpose': + if node.op_type == "Transpose": weight = find_input_as_initializer(model, node.input[0]) if weight is not None: result = [] @@ -60,12 +67,12 @@ def fix_transpose(model): if len(result) > 1: continue perm = node.attribute[0] - assert perm.name == 'perm' + assert perm.name == "perm" perm = perm.ints assert len(perm) == 2 and perm[0] == 1 and perm[1] == 0 transpose.append((get_node_index(model, node), weight)) - # Transpose collected weights and add it to the model initializers. + # Transpose collected weights and add it to the model initializers. # The transposed weight initializers become inputs to the transpose nodes' recipient nodes. for t in transpose: node = model.graph.node[t[0]] @@ -81,7 +88,7 @@ def fix_transpose(model): for t in transpose: del model.graph.node[t[0]] - # the original weight initializer can be removed. + # the original weight initializer can be removed. # (remember that a wight needs only to be feeded into the transpose node when collecting wights) old_ws = [] for t in transpose: @@ -91,16 +98,17 @@ def fix_transpose(model): for w_i in old_ws: del model.graph.initializer[w_i] + def add_expand_shape(model): """ this method is very specific to the Bert model where there is a solo Expand op. training backend requires the op's output shape. it is the same as the shape of the model (single) input. """ - expand_node = [n for n in model.graph.node if n.op_type == 'Expand'] + expand_node = [n for n in model.graph.node if n.op_type == "Expand"] if len(expand_node) != 1: raise "cannot find the single expand node in the BERT model." return expand_out = model.graph.value_info.add() - expand_out.name = expand_node[0].output[0] # base: '421' # tiny: '85' - expand_out.type.CopyFrom(model.graph.input[0].type) \ No newline at end of file + expand_out.name = expand_node[0].output[0] # base: '421' # tiny: '85' + expand_out.type.CopyFrom(model.graph.input[0].type) diff --git a/orttraining/orttraining/test/python/orttraining_test_onnx_ops_ortmodule.py b/orttraining/orttraining/test/python/orttraining_test_onnx_ops_ortmodule.py index 12c2bbe155f56..1411d43be70eb 100644 --- a/orttraining/orttraining/test/python/orttraining_test_onnx_ops_ortmodule.py +++ b/orttraining/orttraining/test/python/orttraining_test_onnx_ops_ortmodule.py @@ -9,7 +9,6 @@ class TestOnnxOpsOrtModule(unittest.TestCase): - def assert_values_are_close(self, tensor, other, rtol=1e-05, atol=1e-06): are_close = torch.allclose(tensor, other, rtol=rtol, atol=atol) if not are_close: @@ -17,13 +16,11 @@ def assert_values_are_close(self, tensor, other, rtol=1e-05, atol=1e-06): abs_other = torch.abs(other) max_atol = torch.max((abs_diff - rtol * abs_other)) max_rtol = torch.max((abs_diff - atol) / abs_other) - raise AssertionError( - "The maximum atol is %r, maximum rtol is %r." % ( - max_atol, max_rtol)) + raise AssertionError("The maximum atol is %r, maximum rtol is %r." % (max_atol, max_rtol)) def assert_gradients_match_and_reset_gradient( - self, ort_model, pt_model, none_pt_params=None, - reset_gradient=True, rtol=1e-05, atol=1e-06): + self, ort_model, pt_model, none_pt_params=None, reset_gradient=True, rtol=1e-05, atol=1e-06 + ): if none_pt_params is None: none_pt_params = [] ort_named_params = list(ort_model.named_parameters()) @@ -38,11 +35,9 @@ def assert_gradients_match_and_reset_gradient( if pt_name in none_pt_params: self.assertNotEmpty(pt_param.grad) if ort_param is not None: - self.assertFalse(torch.is_nonzero( - torch.count_nonzero(ort_param.grad))) + self.assertFalse(torch.is_nonzero(torch.count_nonzero(ort_param.grad))) else: - self.assert_values_are_close( - ort_param.grad, pt_param.grad, rtol=rtol, atol=atol) + self.assert_values_are_close(ort_param.grad, pt_param.grad, rtol=rtol, atol=atol) if reset_gradient: ort_param.grad = None @@ -68,10 +63,8 @@ def run_step(model, x): pt_prediction = run_step(pt_model, x) ort_prediction = run_step(ort_model, x) - self.assert_values_are_close( - ort_prediction, pt_prediction, **kwargs) - self.assert_gradients_match_and_reset_gradient( - ort_model, pt_model, **kwargs) + self.assert_values_are_close(ort_prediction, pt_prediction, **kwargs) + self.assert_gradients_match_and_reset_gradient(ort_model, pt_model, **kwargs) onnx_graph_inf = ort_model._torch_module._execution_manager._training_manager._onnx_models.exported_model onnx_graph_train = ort_model._torch_module._execution_manager._training_manager._onnx_models.optimized_model @@ -83,20 +76,17 @@ def run_step(model, x): self.assertIn('op_type: "%s"' % name, str(onnx_graph_inf)) for onnx_model in [onnx_graph_inf, onnx_graph_train]: for oimp in onnx_model.opset_import: - if oimp.domain == '': + if oimp.domain == "": self.assertEqual(oimp.version, 14) if op_grad_type is not None: if isinstance(op_grad_type, tuple): text = str(onnx_graph_train) if all(map(lambda op: ('op_type: "%s"' % op) not in text, op_grad_type)): - raise AssertionError( - "Operator %s not found in %s." % (" or ".join(op_grad_type), text)) + raise AssertionError("Operator %s not found in %s." % (" or ".join(op_grad_type), text)) else: - self.assertIn('op_type: "%s"' % - op_grad_type, str(onnx_graph_train)) + self.assertIn('op_type: "%s"' % op_grad_type, str(onnx_graph_train)) def get_torch_model_name(self, name, device): - def from_numpy(v, device=None, requires_grad=False): v = torch.from_numpy(v) if device is not None: @@ -104,7 +94,7 @@ def from_numpy(v, device=None, requires_grad=False): v.requires_grad_(requires_grad) return v - if name == 'Softmax': + if name == "Softmax": class TestSoftmax(torch.nn.Module): def __init__(self, input_size=128, hidden_size=500, num_classes=100): @@ -119,18 +109,16 @@ def forward(self, input1): out = self.fc2(out) return out - return TestSoftmax, ('SoftmaxGrad', 'SoftmaxGrad_13'), None + return TestSoftmax, ("SoftmaxGrad", "SoftmaxGrad_13"), None - if name == 'GatherElements': + if name == "GatherElements": class TestGatherElement(torch.nn.Module): def __init__(self, input_size=32, hidden_size=500, num_classes=100): torch.nn.Module.__init__(self) self.fc1 = torch.nn.Linear(input_size, hidden_size) - rev_idx = np.array(list(np.arange(hidden_size)[::-1]), - dtype=np.int64) - idx = np.empty( - (input_size, hidden_size), dtype=np.int64) + rev_idx = np.array(list(np.arange(hidden_size)[::-1]), dtype=np.int64) + idx = np.empty((input_size, hidden_size), dtype=np.int64) for i in range(idx.shape[0]): idx[i, :] = rev_idx self.indices = from_numpy(idx, device=device) @@ -142,14 +130,14 @@ def forward(self, input1): out = self.fc2(out) return out - return TestGatherElement, 'GatherElementsGrad', dict(rtol=1e-04, atol=1e-05) + return TestGatherElement, "GatherElementsGrad", dict(rtol=1e-04, atol=1e-05) raise AssertionError("Unexpected name=%r." % name) def test_onnx_ops(self): - for name in ['GatherElements', 'Softmax']: - for device_name in ['cuda:0', 'cpu']: - if device_name == 'cuda:0' and not torch.cuda.is_available(): + for name in ["GatherElements", "Softmax"]: + for device_name in ["cuda:0", "cpu"]: + if device_name == "cuda:0" and not torch.cuda.is_available(): continue with self.subTest(name=name, device=device_name): device = torch.device(device_name) diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py index 5d3b19cdb25fa..880de49018c83 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py @@ -21,13 +21,15 @@ import pickle from distutils.version import LooseVersion from onnxruntime.training.ortmodule._custom_gradient_registry import register_gradient -from onnxruntime.training.ortmodule import (ORTModule, - _utils, - _io, - DebugOptions, - LogLevel, - _fallback, - _graph_execution_manager) +from onnxruntime.training.ortmodule import ( + ORTModule, + _utils, + _io, + DebugOptions, + LogLevel, + _fallback, + _graph_execution_manager, +) import onnxruntime.training.ortmodule as ortmodule_module from onnxruntime.training.optim import FusedAdam, AdamWMode @@ -44,6 +46,7 @@ # PyTorch model definitions for tests + class NeuralNetSinglePositionalArgument(torch.nn.Module): def __init__(self, input_size, hidden_size, num_classes): super(NeuralNetSinglePositionalArgument, self).__init__() @@ -58,6 +61,7 @@ def forward(self, input1): out = self.fc2(out) return out + class NeuralNetMultiplePositionalArgumentsMultiOutputsWithoutDependency(torch.nn.Module): def __init__(self, input_size, hidden_size, num_classes): super(NeuralNetMultiplePositionalArgumentsMultiOutputsWithoutDependency, self).__init__() @@ -79,6 +83,7 @@ def forward(self, input1, input2): out2 = self.relu2(out2) return out1, out2 + class NeuralNetMultiplePositionalArgumentsMultiOutputsWithDependency(torch.nn.Module): def __init__(self, input_size, hidden_size, num_classes): super(NeuralNetMultiplePositionalArgumentsMultiOutputsWithDependency, self).__init__() @@ -97,6 +102,7 @@ def forward(self, input1, input2): out2 = self.fc2(out1) return out1, out2 + class NeuralNetMultiplePositionalArguments(torch.nn.Module): def __init__(self, input_size, hidden_size, num_classes): super(NeuralNetMultiplePositionalArguments, self).__init__() @@ -112,6 +118,7 @@ def forward(self, input1, input2): out = self.fc2(out) return out + class NeuralNetMultiplePositionalArgumentsVarKeyword(torch.nn.Module): def __init__(self, input_size, hidden_size, num_classes): super(NeuralNetMultiplePositionalArgumentsVarKeyword, self).__init__() @@ -127,6 +134,7 @@ def forward(self, input1, input2, **kwargs): out = self.fc2(out) return out + class NeuralNetPositionalArguments(torch.nn.Module): def __init__(self, input_size, hidden_size, num_classes): super(NeuralNetPositionalArguments, self).__init__() @@ -142,6 +150,7 @@ def forward(self, *model_inputs): out = self.fc2(out) return out + class NeuralNetKeywordArguments(torch.nn.Module): def __init__(self, input_size, hidden_size, num_classes): super(NeuralNetKeywordArguments, self).__init__() @@ -157,6 +166,7 @@ def forward(self, x=None, y=None, z=None): out = self.fc2(out) return out + class NeuralNetPositionalAndKeywordArguments(torch.nn.Module): def __init__(self, input_size, hidden_size, num_classes): super(NeuralNetPositionalAndKeywordArguments, self).__init__() @@ -172,10 +182,12 @@ def forward(self, model_input, x=None, y=None, z=None): out = self.fc2(out) return out + class NeuralNetSimplePositionalAndKeywordArguments(torch.nn.Module): def __init__(self): super(NeuralNetSimplePositionalAndKeywordArguments, self).__init__() - self.a = torch.nn.Parameter(torch.FloatTensor([-1., 1.])) + self.a = torch.nn.Parameter(torch.FloatTensor([-1.0, 1.0])) + def forward(self, x, y=None, z=None): if z is not None: return torch.mean(self.a) + x + 4 * z @@ -183,6 +195,7 @@ def forward(self, x, y=None, z=None): return torch.mean(self.a) + 3 * y return torch.mean(self.a) + x + class NeuralNetNonDifferentiableOutput(torch.nn.Module): def __init__(self, input_size, hidden_size, num_classes): super(NeuralNetNonDifferentiableOutput, self).__init__() @@ -195,13 +208,14 @@ def forward(self, input1): out1 = self.relu(out) out2 = self.fc2(out1) mask1 = torch.gt(out1, 0.01) - mask1 = mask1.long() # TODO: Casting from bool to float or int will cause the UT failure - # True is casted to 1065353216 for Cast(from=bool, to=int), whereas pytorch would give 1 - # True is casted to -1 for Cast(from=bool, to=float), where as pytorch would give 1.0f + mask1 = mask1.long() # TODO: Casting from bool to float or int will cause the UT failure + # True is casted to 1065353216 for Cast(from=bool, to=int), whereas pytorch would give 1 + # True is casted to -1 for Cast(from=bool, to=float), where as pytorch would give 1.0f mask2 = torch.lt(out2, 0.02) mask2 = mask2.long() - return out1, mask1, out2, mask2 # intentionally place the non-differentiable output in the middle + return out1, mask1, out2, mask2 # intentionally place the non-differentiable output in the middle + class NeuralNetChainedLayersWithNonDifferentiableOutput(torch.nn.Module): def __init__(self, input_size, hidden_size, num_classes): @@ -219,6 +233,7 @@ def forward(self, input1, mask1): return out2, mask + class NeuralNetPartialNoGradModel(torch.nn.Module): def __init__(self, input_size, hidden_size, num_classes): super(NeuralNetPartialNoGradModel, self).__init__() @@ -232,6 +247,7 @@ def forward(self, model_input): out = self.fc2(out) return out + class UnusedEndParameterNet(torch.nn.Module): def __init__(self, input_size, hidden_size1, hidden_size2, num_classes): super(UnusedEndParameterNet, self).__init__() @@ -249,6 +265,7 @@ def forward(self, input1): out = out + self.buffer return out + class UnusedBeginParameterNet(torch.nn.Module): def __init__(self, input_size, hidden_size1, hidden_size2, num_classes): super(UnusedBeginParameterNet, self).__init__() @@ -266,6 +283,7 @@ def forward(self, input1): out = out + self.buffer return out + class UnusedMiddleParameterNet(torch.nn.Module): def __init__(self, input_size, hidden_size1, hidden_size2, num_classes): super(UnusedMiddleParameterNet, self).__init__() @@ -285,6 +303,7 @@ def forward(self, input1): out = out + self.buffer return out + class StatelessModel(torch.nn.Module): def __init__(self): super(StatelessModel, self).__init__() @@ -292,6 +311,7 @@ def __init__(self): def forward(self, x): return x + class NeuralNetCustomClassOutput(torch.nn.Module): class CustomClass(object): def __init__(self, out1, out2, out3): @@ -320,12 +340,14 @@ def forward(self, input1, input2, input3): out3 = self.fc3_2(self.relu3(self.fc3_1(input3))) return NeuralNetCustomClassOutput.CustomClass(out1, out2, out3) + class MyStrNet(torch.nn.Module): def forward(self, x, my_str): - if my_str.lower() == 'hello': - return x+1 + if my_str.lower() == "hello": + return x + 1 return x + class SerializationNet(torch.nn.Module): def __init__(self, input_size, hidden_size, num_classes): super(SerializationNet, self).__init__() @@ -347,7 +369,8 @@ def train_step(self, input): return out -@pytest.fixture(scope='session', autouse=True) + +@pytest.fixture(scope="session", autouse=True) def run_before_test_session(request): def insert_disable_fallback_in_env(): os.environ["ORTMODULE_FALLBACK_POLICY"] = "FALLBACK_DISABLE" @@ -358,6 +381,7 @@ def remove_disable_fallback_from_env(): insert_disable_fallback_in_env() request.addfinalizer(remove_disable_fallback_from_env) + # TODO: This is a workaround for the problem that pytest is still cleaning up the previous test # while the next task already start. @pytest.fixture(autouse=True) @@ -365,18 +389,25 @@ def run_before_tests(): # wait for 50ms before starting the next test sleep(0.05) -def _get_bert_for_sequence_classification_model(device, output_attentions = False, \ - output_hidden_states = False, return_dict = True, hidden_dropout_prob = 0.1, attention_probs_dropout_prob = 0.1): + +def _get_bert_for_sequence_classification_model( + device, + output_attentions=False, + output_hidden_states=False, + return_dict=True, + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, +): """Returns the BertForSequenceClassification pretrained model""" config = AutoConfig.from_pretrained( - "bert-base-uncased", - num_labels=2, - num_hidden_layers=1, - output_attentions = output_attentions, - output_hidden_states = output_hidden_states, - hidden_dropout_prob = hidden_dropout_prob, - attention_probs_dropout_prob = attention_probs_dropout_prob, + "bert-base-uncased", + num_labels=2, + num_hidden_layers=1, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + hidden_dropout_prob=hidden_dropout_prob, + attention_probs_dropout_prob=attention_probs_dropout_prob, ) config.return_dict = return_dict @@ -387,6 +418,7 @@ def _get_bert_for_sequence_classification_model(device, output_attentions = Fals return model + def _get_bert_for_sequence_classification_sample_data(device): """Returns sample data to be used with BertForSequenceClassification model""" @@ -396,21 +428,24 @@ def _get_bert_for_sequence_classification_sample_data(device): return input_ids, input_mask, labels + def _get_bert_for_sequence_classification_sample_data_with_random_shapes(device): """Returns sample data with random shape to be used with BertForSequenceClassification model""" - x = random.randint(1,100) - y = random.randint(1,100) + x = random.randint(1, 100) + y = random.randint(1, 100) input_ids = torch.randint(0, 100, (x, y), dtype=torch.long, device=device) input_mask = torch.randint(0, 100, (x, y), dtype=torch.long, device=device) labels = torch.randint(0, 1, (x,), dtype=torch.long, device=device) return input_ids, input_mask, labels + # ORTModule-API tests + def test_forward_call_single_positional_argument(): - device = 'cuda' + device = "cuda" N, D_in, H, D_out = 64, 784, 500, 10 model = NeuralNetSinglePositionalArgument(D_in, H, D_out).to(device) @@ -424,8 +459,9 @@ def test_forward_call_single_positional_argument(): prediction = prediction.sum() prediction.backward() + def test_forward_call_multiple_positional_arguments(): - device = 'cuda' + device = "cuda" N, D_in, H, D_out = 64, 784, 500, 10 model = NeuralNetMultiplePositionalArguments(input_size=D_in, hidden_size=H, num_classes=D_out).to(device) @@ -441,13 +477,18 @@ def test_forward_call_multiple_positional_arguments(): prediction = prediction.sum() prediction.backward() + def test_forward_call_positional_arguments(): - device = 'cuda' + device = "cuda" N, D_in, H, D_out = 64, 784, 500, 10 model = NeuralNetPositionalArguments(input_size=D_in, hidden_size=H, num_classes=D_out).to(device) model = ORTModule(model) - args = [torch.randn(N, D_in, device=device), torch.randn(N, D_in, device=device), torch.randn(N, D_in, device=device)] + args = [ + torch.randn(N, D_in, device=device), + torch.randn(N, D_in, device=device), + torch.randn(N, D_in, device=device), + ] # Make sure model runs without any exception prediction = model(*args) @@ -455,8 +496,9 @@ def test_forward_call_positional_arguments(): prediction = prediction.sum() prediction.backward() + def test_forward_call_keyword_arguments(): - device = 'cuda' + device = "cuda" N, D_in, H, D_out = 64, 784, 500, 10 model = NeuralNetKeywordArguments(D_in, H, D_out).to(device) @@ -471,8 +513,9 @@ def test_forward_call_keyword_arguments(): prediction = prediction.sum() prediction.backward() + def test_forward_call_positional_and_keyword_arguments(): - device = 'cuda' + device = "cuda" N, D_in, H, D_out = 64, 784, 500, 10 model = NeuralNetPositionalAndKeywordArguments(D_in, H, D_out).to(device) @@ -488,24 +531,28 @@ def test_forward_call_positional_and_keyword_arguments(): prediction = prediction.sum() prediction.backward() -@pytest.mark.parametrize("forward_statement", [ - "model(one)", - "model(x=one)", - "model(one, None, None)", - "model(one, None, z=None)", - "model(one, None)", - "model(x=one, y=one)", - "model(y=one, x=one)", - "model(y=one, z=None, x=one)", - "model(one, None, z=one)", - "model(x=one, z=one)", - "model(one, z=one)", - "model(one, z=one, y=one)", - "model(one, one, one)", - "model(one, None, one)", - "model(z=one, x=one, y=one)", - "model(z=one, x=one, y=None)" -]) + +@pytest.mark.parametrize( + "forward_statement", + [ + "model(one)", + "model(x=one)", + "model(one, None, None)", + "model(one, None, z=None)", + "model(one, None)", + "model(x=one, y=one)", + "model(y=one, x=one)", + "model(y=one, z=None, x=one)", + "model(one, None, z=one)", + "model(x=one, z=one)", + "model(one, z=one)", + "model(one, z=one, y=one)", + "model(one, one, one)", + "model(one, None, one)", + "model(z=one, x=one, y=one)", + "model(z=one, x=one, y=None)", + ], +) def test_compare_pytorch_forward_call_positional_and_keyword_arguments(forward_statement): one = torch.FloatTensor([1]) @@ -522,9 +569,10 @@ def test_compare_pytorch_forward_call_positional_and_keyword_arguments(forward_s prediction = eval(forward_statement).sum() prediction.backward() + def test_torch_nn_module_cuda_method(): - original_device = 'cpu' - to_device = 'cuda' + original_device = "cpu" + to_device = "cuda" N, D_in, H, D_out = 64, 784, 500, 10 model = NeuralNetSinglePositionalArgument(D_in, H, D_out) @@ -539,10 +587,11 @@ def test_torch_nn_module_cuda_method(): for _, parameter_value in model.named_parameters(): assert parameter_value.device.type == to_device + @pytest.mark.parametrize("set_gpu_on_original_module", [True, False]) def test_torch_nn_module_cpu_method(set_gpu_on_original_module): - original_device = 'cuda' - to_device = 'cpu' + original_device = "cuda" + to_device = "cpu" N, D_in, H, D_out = 64, 784, 500, 10 if set_gpu_on_original_module: @@ -560,8 +609,9 @@ def test_torch_nn_module_cpu_method(set_gpu_on_original_module): for _, parameter_value in model.named_parameters(): assert parameter_value.device.type == to_device -@pytest.mark.parametrize("original_device", ['cpu', 'cuda']) -@pytest.mark.parametrize("to_argument", ['cpu', 'cuda', 'cuda:0', torch.device('cpu'), torch.device('cuda')]) + +@pytest.mark.parametrize("original_device", ["cpu", "cuda"]) +@pytest.mark.parametrize("to_argument", ["cpu", "cuda", "cuda:0", torch.device("cpu"), torch.device("cuda")]) def test_torch_nn_module_to_api(original_device, to_argument): N, D_in, H, D_out = 64, 784, 500, 10 model = NeuralNetSinglePositionalArgument(D_in, H, D_out).to(original_device) @@ -573,8 +623,10 @@ def test_torch_nn_module_to_api(original_device, to_argument): model = model.to(to_argument) x = x.to(to_argument) model(x) - assert _utils.get_device_str(model._torch_module._execution_manager(model._is_training())._device) == \ - _utils.get_device_str(torch.device(to_argument)) + assert _utils.get_device_str( + model._torch_module._execution_manager(model._is_training())._device + ) == _utils.get_device_str(torch.device(to_argument)) + def test_model_without_device(): # Model doesn't have device (CPU is assumed) @@ -583,15 +635,20 @@ def test_model_without_device(): model = ORTModule(model) # User input is on GPU - input_device='cuda' + input_device = "cuda" x = torch.randn(N, D_in).to(input_device) # ORTModule and PyTorch does not move model to where user input is hosted with pytest.raises(RuntimeError) as type_error: model(x) - assert \ - ("Tensor for argument #1 'self' is on CPU, but expected them to be on GPU (while checking arguments for addmm)" in str(type_error.value)) \ - or ("Expected all tensors to be on the same device, but found at least two devices, cpu and cuda:0!" in str(type_error.value)) + assert ( + "Tensor for argument #1 'self' is on CPU, but expected them to be on GPU (while checking arguments for addmm)" + in str(type_error.value) + ) or ( + "Expected all tensors to be on the same device, but found at least two devices, cpu and cuda:0!" + in str(type_error.value) + ) + def test_model_and_input_without_device(): N, D_in, H, D_out = 64, 784, 500, 10 @@ -603,8 +660,9 @@ def test_model_and_input_without_device(): out = model(x) out is not None + def test_model_with_different_devices_same_session(): - os.environ['ORTMODULE_SKIPCHECK_POLICY'] = 'SKIP_CHECK_DISABLED' + os.environ["ORTMODULE_SKIPCHECK_POLICY"] = "SKIP_CHECK_DISABLED" N, D_in, H, D_out = 64, 784, 500, 10 model = NeuralNetSinglePositionalArgument(D_in, H, D_out) @@ -612,26 +670,28 @@ def test_model_with_different_devices_same_session(): for i in range(5): if i % 2 == 0: - device = 'cpu' + device = "cpu" else: - device = 'cuda' + device = "cuda" model.to(device) x = torch.randn(N, D_in, device=device) y = model(x) - del os.environ['ORTMODULE_SKIPCHECK_POLICY'] + del os.environ["ORTMODULE_SKIPCHECK_POLICY"] -@pytest.mark.parametrize("device", ['cuda', 'cpu']) + +@pytest.mark.parametrize("device", ["cuda", "cpu"]) def test_input_requires_grad_saved(device): N, D_in, H, D_out = 32, 784, 500, 10 model = NeuralNetSinglePositionalArgument(D_in, H, D_out).to(device) model = ORTModule(model) x = torch.randn(N, D_in, device=device, requires_grad=True) + 1 model(x) - assert model._torch_module._execution_manager(model._is_training())._input_info.require_grad_names == ['input1'] + assert model._torch_module._execution_manager(model._is_training())._input_info.require_grad_names == ["input1"] + -@pytest.mark.parametrize("device", ['cuda', 'cpu']) +@pytest.mark.parametrize("device", ["cuda", "cpu"]) def test_input_requires_grad_backward_creates_input_grad(device): N, D_in, H, D_out = 32, 784, 500, 10 model = NeuralNetSinglePositionalArgument(D_in, H, D_out).to(device) @@ -643,8 +703,9 @@ def test_input_requires_grad_backward_creates_input_grad(device): s.backward() assert x.grad is not None + def test_gradient_correctness(): - device = 'cuda' + device = "cuda" N, D_in, H, D_out = 32, 128, 500, 10 pt_model = NeuralNetSinglePositionalArgument(D_in, H, D_out).to(device) ort_model = ORTModule(copy.deepcopy(pt_model)) @@ -663,9 +724,9 @@ def run_step(model, x): _test_helpers.assert_values_are_close(ort_prediction, pt_prediction) _test_helpers.assert_gradients_match_and_reset_gradient(ort_model, pt_model) -@pytest.mark.parametrize("device", ['cpu', 'cuda']) -@pytest.mark.parametrize("indices", ([[ 2, 3, -1, -1],[0, 1, -1, -1]], - [[ 2, 3, 4, 4],[ 0, 1, 4, 4]])) + +@pytest.mark.parametrize("device", ["cpu", "cuda"]) +@pytest.mark.parametrize("indices", ([[2, 3, -1, -1], [0, 1, -1, -1]], [[2, 3, 4, 4], [0, 1, 4, 4]])) def test_scatternd_correctness(device, indices): class NeuralNetScatterND(torch.nn.Module): def __init__(self): @@ -682,9 +743,11 @@ def run_step(model, rerouted_output, dispatch_mask, expert_output): prediction = model(rerouted_output, dispatch_mask, expert_output) return prediction - rerouted_output = torch.tensor([[0.],[0.],[0.],[0.],[0.]], device=device) + rerouted_output = torch.tensor([[0.0], [0.0], [0.0], [0.0], [0.0]], device=device) dispatch_mask = torch.tensor(indices, device=device) - expert_output = torch.tensor([[[0.3817],[0.9625],[0.9625],[0.9625]],[[0.3817],[0.9625],[0.9625],[0.9625]]], device=device) + expert_output = torch.tensor( + [[[0.3817], [0.9625], [0.9625], [0.9625]], [[0.3817], [0.9625], [0.9625], [0.9625]]], device=device + ) pt_prediction = run_step(pt_model, rerouted_output, dispatch_mask, expert_output) ort_prediction = run_step(ort_model, rerouted_output, dispatch_mask, expert_output) @@ -709,7 +772,7 @@ def forward(self, input): if torch.cuda.get_device_capability()[0] < 7: return - device = 'cuda' + device = "cuda" N, seq_len, C_in, C_out, kernel_size = 32, 128, 1536, 1536, 3 pt_model = NeuralNetConv1D(C_in, C_out, kernel_size, padding=1).to(device) ort_model = ORTModule(copy.deepcopy(pt_model)) @@ -736,6 +799,7 @@ def run_step(model, x): _test_helpers.assert_values_are_close(ort_prediction, pt_prediction, atol=1e-5) _test_helpers.assert_gradients_match_and_reset_gradient(ort_model, pt_model, rtol=5e-2, atol=4e-2) + def _run_gradient_correctness_transpose(perm, shape): class NeuralNetTranspose(torch.nn.Module): def __init__(self, perm): @@ -746,7 +810,7 @@ def forward(self, input): out = torch.sin(input.permute(*self.perm)) return out - device = 'cuda' + device = "cuda" pt_model = NeuralNetTranspose(perm).to(device) ort_model = ORTModule(copy.deepcopy(pt_model)) @@ -763,115 +827,130 @@ def run_step(model, x): _test_helpers.assert_values_are_close(ort_prediction, pt_prediction) _test_helpers.assert_gradients_match_and_reset_gradient(ort_model, pt_model) -@pytest.mark.parametrize("perm", [ - [0,1,2], # no-op - [0,2,1], # special handle by Transpose021 - [1,0,2], # handled as [0,2,1,3] - [1,2,0], # coalesced to [1,0] - [2,0,1], # coalesced to [1,0] - [2,1,0], # handled as [0,3,2,1] -]) -@pytest.mark.parametrize("shape", [ - [245,1024,32], - [255,2272,32], - [246,2080,32], - [254,128,256], - [260,245,256], - [284,254,256], - [245,260,256], - [1024,1024,256], - [254,284,256], - [4,5,2944], - [4,28,3136], - [4,312,768], - [3,224,224], - [17,5,4], - [8,2080,32], - [8,2272,32], - [2,2,2], - [1024,245,32], - [2080,246,32], - [1024,254,32], - [2272,255,32], - [4,5,736], - [4,111,160], - [8,246,32], - [8,255,32], - [4,1,2], - [1,2,2], - [2,1,2], - [2,2,1], - [2,1,4], - [4,2,1], -]) + +@pytest.mark.parametrize( + "perm", + [ + [0, 1, 2], # no-op + [0, 2, 1], # special handle by Transpose021 + [1, 0, 2], # handled as [0,2,1,3] + [1, 2, 0], # coalesced to [1,0] + [2, 0, 1], # coalesced to [1,0] + [2, 1, 0], # handled as [0,3,2,1] + ], +) +@pytest.mark.parametrize( + "shape", + [ + [245, 1024, 32], + [255, 2272, 32], + [246, 2080, 32], + [254, 128, 256], + [260, 245, 256], + [284, 254, 256], + [245, 260, 256], + [1024, 1024, 256], + [254, 284, 256], + [4, 5, 2944], + [4, 28, 3136], + [4, 312, 768], + [3, 224, 224], + [17, 5, 4], + [8, 2080, 32], + [8, 2272, 32], + [2, 2, 2], + [1024, 245, 32], + [2080, 246, 32], + [1024, 254, 32], + [2272, 255, 32], + [4, 5, 736], + [4, 111, 160], + [8, 246, 32], + [8, 255, 32], + [4, 1, 2], + [1, 2, 2], + [2, 1, 2], + [2, 2, 1], + [2, 1, 4], + [4, 2, 1], + ], +) def test_gradient_correctness_transpose3d(perm, shape): _run_gradient_correctness_transpose(perm, shape) -@pytest.mark.parametrize("perm", [ - [0,1,2,3], - [0,1,3,2], - [0,2,1,3], - [0,2,3,1], - [0,3,1,2], - [0,3,2,1], - [1,0,2,3], - [1,0,3,2], - [1,2,0,3], - [1,2,3,0], - [1,3,0,2], - [1,3,2,0], - [2,0,1,3], - [2,0,3,1], - [2,1,0,3], - [2,1,3,0], - [2,3,0,1], - [2,3,1,0], - [3,0,1,2], - [3,0,2,1], - [3,1,0,2], - [3,1,2,0], - [3,2,0,1], - [3,2,1,0], -]) -@pytest.mark.parametrize("shape", [ - [1,245,1024,32], - [1,255,2272,32], - [1,246,2080,32], - [1,254,128,256], - [1,260,245,256], - [1,284,254,256], - [1,245,260,256], - [1,1024,1024,256], - [1,254,284,256], - [1,4,5,2944], - [1,4,28,3136], - [1,4,312,768], - [1,3,224,224], - [1,17,5,4], - [260,8,2080,32], - [284,8,2272,32], - [1,2,2,2], - [1,1024,245,32], - [1,2080,246,32], - [1,1024,254,32], - [1,2272,255,32], - [1,4,5,736], - [1,4,111,160], - [260,8,246,32], - [284,8,255,32], - [4,1,2,1], - [1,1,2,2], - [1,2,1,2], - [1,2,2,1], - [2,1,4,1], - [2,2,2,1], - [2,1,2,1], - [1,4,2,1], -]) + +@pytest.mark.parametrize( + "perm", + [ + [0, 1, 2, 3], + [0, 1, 3, 2], + [0, 2, 1, 3], + [0, 2, 3, 1], + [0, 3, 1, 2], + [0, 3, 2, 1], + [1, 0, 2, 3], + [1, 0, 3, 2], + [1, 2, 0, 3], + [1, 2, 3, 0], + [1, 3, 0, 2], + [1, 3, 2, 0], + [2, 0, 1, 3], + [2, 0, 3, 1], + [2, 1, 0, 3], + [2, 1, 3, 0], + [2, 3, 0, 1], + [2, 3, 1, 0], + [3, 0, 1, 2], + [3, 0, 2, 1], + [3, 1, 0, 2], + [3, 1, 2, 0], + [3, 2, 0, 1], + [3, 2, 1, 0], + ], +) +@pytest.mark.parametrize( + "shape", + [ + [1, 245, 1024, 32], + [1, 255, 2272, 32], + [1, 246, 2080, 32], + [1, 254, 128, 256], + [1, 260, 245, 256], + [1, 284, 254, 256], + [1, 245, 260, 256], + [1, 1024, 1024, 256], + [1, 254, 284, 256], + [1, 4, 5, 2944], + [1, 4, 28, 3136], + [1, 4, 312, 768], + [1, 3, 224, 224], + [1, 17, 5, 4], + [260, 8, 2080, 32], + [284, 8, 2272, 32], + [1, 2, 2, 2], + [1, 1024, 245, 32], + [1, 2080, 246, 32], + [1, 1024, 254, 32], + [1, 2272, 255, 32], + [1, 4, 5, 736], + [1, 4, 111, 160], + [260, 8, 246, 32], + [284, 8, 255, 32], + [4, 1, 2, 1], + [1, 1, 2, 2], + [1, 2, 1, 2], + [1, 2, 2, 1], + [2, 1, 4, 1], + [2, 2, 2, 1], + [2, 1, 2, 1], + [1, 4, 2, 1], + ], +) def test_gradient_correctness_transpose4d(perm, shape): _run_gradient_correctness_transpose(perm, shape) -@pytest.mark.parametrize("device", ['cuda', 'cpu']) + +@pytest.mark.parametrize("device", ["cuda", "cpu"]) @pytest.mark.parametrize("padding_idx", [None, 1]) def test_gradient_correctness_embedding(device, padding_idx): class NeuralNetEmbedding(torch.nn.Module): @@ -901,6 +980,7 @@ def run_step(model, input): _test_helpers.assert_values_are_close(ort_prediction, pt_prediction) _test_helpers.assert_gradients_match_and_reset_gradient(ort_model, pt_model, atol=1e-5) + @pytest.mark.parametrize("use_fp16", [False, True]) def test_gradient_correctness_cross_entropy_loss(use_fp16): class NeuralNetCrossEntropyLoss(torch.nn.Module): @@ -914,7 +994,7 @@ def forward(self, input, positions): loss_fct = torch.nn.CrossEntropyLoss(ignore_index=ignored_index) return loss_fct(output, positions) - device = 'cuda' + device = "cuda" num_embeddings, embedding_dim = 32, 128 pt_model = NeuralNetCrossEntropyLoss(num_embeddings, embedding_dim).to(device) ort_model = ORTModule(copy.deepcopy(pt_model)) @@ -935,15 +1015,16 @@ def run_step(model, input, positions): _test_helpers.assert_values_are_close(ort_prediction, pt_prediction) _test_helpers.assert_gradients_match_and_reset_gradient(ort_model, pt_model, atol=1e-5) -@pytest.mark.parametrize("pool_type", ['MaxPool', 'AvgPool', 'AdaptiveAvgPool']) + +@pytest.mark.parametrize("pool_type", ["MaxPool", "AvgPool", "AdaptiveAvgPool"]) def test_gradient_correctness_pool2d(pool_type): class NeuralNetPool2d(torch.nn.Module): def __init__(self): super(NeuralNetPool2d, self).__init__() self.conv = torch.nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) - if pool_type == 'MaxPool': + if pool_type == "MaxPool": self.pool = torch.nn.MaxPool2d(kernel_size=3, stride=2, padding=1) - elif pool_type == 'AvgPool': + elif pool_type == "AvgPool": self.pool = torch.nn.AvgPool2d(kernel_size=3, stride=2, padding=1) else: self.pool = torch.nn.AdaptiveAvgPool2d((5, 7)) @@ -952,7 +1033,7 @@ def forward(self, input): return self.pool(self.conv(input)) N, C, H, W = 8, 3, 224, 224 - device = 'cuda' + device = "cuda" pt_model = NeuralNetPool2d().to(device) ort_model = ORTModule(copy.deepcopy(pt_model)) @@ -970,7 +1051,8 @@ def run_step(model, input): _test_helpers.assert_values_are_close(ort_prediction, pt_prediction) _test_helpers.assert_gradients_match_and_reset_gradient(ort_model, pt_model, rtol=5e-3, atol=4e-3) -@pytest.mark.parametrize("pool_type", ['MaxPool', 'AvgPool']) + +@pytest.mark.parametrize("pool_type", ["MaxPool", "AvgPool"]) @pytest.mark.parametrize("stride", [None, 2]) def test_export_correctness_pool2d(pool_type, stride): class NeuralNetPool2d(torch.nn.Module): @@ -981,14 +1063,14 @@ def __init__(self): def forward(self, input): x = self.conv(input) - if pool_type == 'MaxPool': + if pool_type == "MaxPool": output = torch.nn.functional.max_pool2d(x, kernel_size=3, stride=stride) - elif pool_type == 'AvgPool': + elif pool_type == "AvgPool": output = torch.nn.functional.avg_pool2d(x, kernel_size=3, stride=stride) return output N, C, H, W = 8, 3, 224, 224 - device = 'cuda' + device = "cuda" pt_model = NeuralNetPool2d().to(device) ort_model = ORTModule(copy.deepcopy(pt_model)) @@ -1003,21 +1085,23 @@ def run_step(model, input): _test_helpers.assert_values_are_close(ort_prediction, pt_prediction) + def test_gradient_correctness_argmax_unfold(): class NeuralNetUnfold(torch.nn.Module): def __init__(self, input_size, hidden_size, unfold_dim, unfold_size, unfold_step): super(NeuralNetUnfold, self).__init__() - self.linear= torch.nn.Linear(input_size, hidden_size) + self.linear = torch.nn.Linear(input_size, hidden_size) self.unfold_dim = unfold_dim self.unfold_size = unfold_size self.unfold_step = unfold_step def forward(self, input): return self.linear(input.argmax(-1).to(torch.float) * input.argmax().to(torch.float)).unfold( - dimension=self.unfold_dim, size=self.unfold_size, step=self.unfold_step) + dimension=self.unfold_dim, size=self.unfold_size, step=self.unfold_step + ) N, D, H = 16, 256, 128 - device = 'cuda' + device = "cuda" pt_model = NeuralNetUnfold(D, H, 1, 50, 30).to(device) ort_model = ORTModule(copy.deepcopy(pt_model)) @@ -1035,10 +1119,11 @@ def run_step(model, input): _test_helpers.assert_values_are_close(ort_prediction, pt_prediction) _test_helpers.assert_gradients_match_and_reset_gradient(ort_model, pt_model) + @pytest.mark.parametrize("high", [1, 2, 10]) def test_correctness_argmax_bitwise_or(high): N, D, H, M = 16, 256, 128, 4 - device = 'cuda' + device = "cuda" class NeuralNetBitwiseOr(torch.nn.Module): def __init__(self, high): @@ -1064,6 +1149,7 @@ def run_step(model, input): _test_helpers.assert_values_are_close(ort_prediction, pt_prediction) + @pytest.mark.parametrize("offset", [-1, 0, 1]) @pytest.mark.parametrize("dim1, dim2", ([0, 1], [0, 2], [1, 2], [2, 0])) def test_gradient_correctness_argmax_diagonal(offset, dim1, dim2): @@ -1078,7 +1164,7 @@ def forward(self, input): return torch.diagonal(input, self.offset, self.dim1, self.dim2) N, D, H = 16, 256, 128 - device = 'cuda' + device = "cuda" pt_model = NeuralNetDiagonal(offset, dim1, dim2).to(device) ort_model = ORTModule(copy.deepcopy(pt_model)) @@ -1097,6 +1183,7 @@ def run_step(model, input): _test_helpers.assert_values_are_close(ort_prediction, pt_prediction) _test_helpers.assert_values_are_close(ort_input.grad, pt_input.grad) + @pytest.mark.parametrize("dim", [None, 0, 1, (0, 1), (-1, 0), (0, 1, 2)]) @pytest.mark.parametrize("keepdim", [True, False]) def test_gradient_correctness_reducesum(dim, keepdim): @@ -1115,7 +1202,7 @@ def forward(self, input): return torch.sum(t, self.dim, keepdim=self.keepdim) N, D, H, W = 16, 256, 128, 64 - device = 'cuda' + device = "cuda" pt_model = NeuralNetReduceSum(H, W, dim, keepdim).to(device) ort_model = ORTModule(copy.deepcopy(pt_model)) @@ -1134,36 +1221,49 @@ def run_step(model, input): _test_helpers.assert_values_are_close(ort_prediction, pt_prediction) _test_helpers.assert_values_are_close(ort_input.grad, pt_input.grad) + # In PyTorch 1.11.0, there is issue during reduce node shape handling for exporter, so any sub-graph that # contains ReduceProd will fail to run, for example, "sec,sm->ecm", "sec,ecm->sm". # Currently skip these cases and test_gradient_correctness_einsum_2, # will enable these tests again once the issue in PyTorch is fixed. -skip_torch_1_11 = pytest.mark.skipif(LooseVersion(torch.__version__) >= LooseVersion('1.11.0'), reason="PyTorch 1.11 incompatible") -@pytest.mark.parametrize("equation", [ - "s,se->se", "se,sc->sec", "se,se->s", "ks,ksm->sm", "kes,ems->mek", "kes,ksm->ms", - pytest.param("sec,sm->ecm", marks=[skip_torch_1_11]), - pytest.param("sec,ecm->sm", marks=[skip_torch_1_11]) -]) +skip_torch_1_11 = pytest.mark.skipif( + LooseVersion(torch.__version__) >= LooseVersion("1.11.0"), reason="PyTorch 1.11 incompatible" +) + + +@pytest.mark.parametrize( + "equation", + [ + "s,se->se", + "se,sc->sec", + "se,se->s", + "ks,ksm->sm", + "kes,ems->mek", + "kes,ksm->ms", + pytest.param("sec,sm->ecm", marks=[skip_torch_1_11]), + pytest.param("sec,ecm->sm", marks=[skip_torch_1_11]), + ], +) def test_gradient_correctness_einsum(equation): class NeuralNetEinsum(torch.nn.Module): def __init__(self, bias_size): super(NeuralNetEinsum, self).__init__() - self.register_parameter(name='bias', param=torch.nn.Parameter(torch.randn(bias_size))) + self.register_parameter(name="bias", param=torch.nn.Parameter(torch.randn(bias_size))) def forward(self, left, right): left = left + self.bias return torch.einsum(equation, left, right) - device = 'cuda' + device = "cuda" K, S, M, E = 16, 1024, 768, 64 - C = int(S/E*2) + C = int(S / E * 2) - SIZE_MAP = { 'K': K, 'S': S, 'E': E, 'C': C, 'M': M } + SIZE_MAP = {"K": K, "S": S, "E": E, "C": C, "M": M} - pos1 = equation.find(',') - pos2 = equation.find('->') + pos1 = equation.find(",") + pos2 = equation.find("->") lhs_op = equation[0:pos1] - rhs_op = equation[pos1 + 1:pos2] + rhs_op = equation[pos1 + 1 : pos2] lhs_shape = [] for c in lhs_op: lhs_shape.append(SIZE_MAP[c.upper()]) @@ -1191,31 +1291,36 @@ def run_step(model, input_left, input_right): _test_helpers.assert_values_are_close(ort_prediction, pt_prediction, atol=1e-3, rtol=1e-3) _test_helpers.assert_gradients_match_and_reset_gradient(ort_model, pt_model, atol=1e-3, rtol=1e-3) + @skip_torch_1_11 def test_gradient_correctness_einsum_2(): class NeuralNetEinsum(torch.nn.Module): def __init__(self, bias_size): super(NeuralNetEinsum, self).__init__() - self.register_parameter(name='bias', param=torch.nn.Parameter(torch.randn(bias_size))) + self.register_parameter(name="bias", param=torch.nn.Parameter(torch.randn(bias_size))) def forward(self, left, right): left = left + self.bias return torch.einsum(equation, left, right) - device = 'cuda' + device = "cuda" A, B, C, D = 16, 32, 8, 64 - SIZE_MAP = { 'A': A, 'B': B, 'C': C, 'D': D } + SIZE_MAP = {"A": A, "B": B, "C": C, "D": D} def to_string(perm): - result = '' + result = "" for v in perm: - result += chr(ord('a') + v) + result += chr(ord("a") + v) return result - lhs_candidates = [[0], [0,1], [0,1,2]] - perm = [0,1,2,3] - combs = list(itertools.combinations(perm, 1)) + list(itertools.combinations(perm, 2)) + list(itertools.combinations(perm, 3)) + lhs_candidates = [[0], [0, 1], [0, 1, 2]] + perm = [0, 1, 2, 3] + combs = ( + list(itertools.combinations(perm, 1)) + + list(itertools.combinations(perm, 2)) + + list(itertools.combinations(perm, 3)) + ) rhs_candidates = [] for comb in combs: rhs_candidates += list(itertools.permutations(comb)) @@ -1244,11 +1349,11 @@ def to_string(perm): all_cases.append((lhs_candidate, rhs_candidate, output_candidate)) for case in all_cases: - equation = to_string(case[0]) + ',' + to_string(case[1]) + '->' + to_string(case[2]) - pos1 = equation.find(',') - pos2 = equation.find('->') + equation = to_string(case[0]) + "," + to_string(case[1]) + "->" + to_string(case[2]) + pos1 = equation.find(",") + pos2 = equation.find("->") lhs_op = equation[0:pos1] - rhs_op = equation[pos1 + 1:pos2] + rhs_op = equation[pos1 + 1 : pos2] lhs_shape = [] for c in lhs_op: lhs_shape.append(SIZE_MAP[c.upper()]) @@ -1276,11 +1381,12 @@ def run_step(model, input_left, input_right): _test_helpers.assert_values_are_close(ort_prediction, pt_prediction, atol=1e-3, rtol=1e-3) _test_helpers.assert_gradients_match_and_reset_gradient(ort_model, pt_model, atol=1e-3, rtol=1e-3) + # Since multinomial is a generator function, we do not have to test for gradient # Two consecutive calls on the torch.multinomail on a probability distribution with more # than one index with non-zero probability(eg, [0, 10, 3, 0]) will not result in # the same output. Thus we reset the seed before each call to the op torch.multinomial. -@pytest.mark.parametrize("input_shape", ([5], [2,5])) +@pytest.mark.parametrize("input_shape", ([5], [2, 5])) @pytest.mark.parametrize("num_samples, replacement", ((1, False), (2, True))) def test_aten_multinomial(input_shape, num_samples, replacement): class NeuralNetDiagonal(torch.nn.Module): @@ -1293,7 +1399,7 @@ def forward(self, input): return torch.multinomial(input, self.num_samples, self.replacement) torch.backends.cudnn.deterministic = True - device = 'cuda' + device = "cuda" pt_model = NeuralNetDiagonal(num_samples, replacement).to(device) ort_model = ORTModule(copy.deepcopy(pt_model)) @@ -1314,18 +1420,20 @@ def run_step(model, input): _test_helpers.assert_values_are_close(ort_prediction, pt_prediction) -@pytest.mark.parametrize("input_shape", ([4,2],)) + +@pytest.mark.parametrize("input_shape", ([4, 2],)) def test_aten_argmax(input_shape): import torch.nn.functional as F + class TopKGate(torch.nn.Module): def forward(self, input: torch.Tensor): - indices = torch.argmax(input, dim = 1) - device = 'cpu' if indices.get_device() < 0 else indices.get_device() + indices = torch.argmax(input, dim=1) + device = "cpu" if indices.get_device() < 0 else indices.get_device() ret = torch.zeros(indices.shape[0], 2, dtype=torch.int64, device=device) ret = ret.scatter(-1, indices.unsqueeze(-1), 1) + input return ret - device = 'cuda' + device = "cuda" pt_model = TopKGate() ort_model = ORTModule(copy.deepcopy(pt_model)) pt_input = torch.rand(input_shape, dtype=torch.float, device=device, requires_grad=True) @@ -1337,7 +1445,8 @@ def forward(self, input: torch.Tensor): _test_helpers.assert_values_are_close(ort_output, pt_output) -@pytest.mark.parametrize("input_shape", ([], [5], [2,5], [3,2,5])) + +@pytest.mark.parametrize("input_shape", ([], [5], [2, 5], [3, 2, 5])) def test_numpy_T(input_shape): class NeuralNet(torch.nn.Module): def __init__(self): @@ -1347,7 +1456,7 @@ def forward(self, input): return input.T torch.backends.cudnn.deterministic = True - device = 'cuda' + device = "cuda" pt_model = NeuralNet().to(device) ort_model = ORTModule(copy.deepcopy(pt_model), DebugOptions(log_level=LogLevel.VERBOSE)) @@ -1364,6 +1473,7 @@ def run_step(model, input): _test_helpers.assert_values_are_close(ort_prediction, pt_prediction) + def test_gradient_correctness_bce_with_logits(): class NeuralNetBCEWithLogitsLoss(torch.nn.Module): def __init__(self, input_size, hidden_size): @@ -1375,7 +1485,7 @@ def forward(self, input, target): return loss_fct(self.linear(input), target) N, D, H = 16, 256, 128 - device = 'cuda' + device = "cuda" pt_model = NeuralNetBCEWithLogitsLoss(D, H).to(device) ort_model = ORTModule(copy.deepcopy(pt_model)) @@ -1396,18 +1506,20 @@ def run_step(model, input, target): _test_helpers.assert_values_are_close(ort_prediction, pt_prediction) _test_helpers.assert_values_are_close(ort_input.grad, pt_input.grad) + def test_gradient_correctness_cast_chain(): class NeuralNetCast(torch.nn.Module): def __init__(self, D): super(NeuralNetCast, self).__init__() self.a = torch.nn.parameter.Parameter(torch.rand(D)) + def forward(self, b): mask = self.a.bool().float() output = self.a + b + mask return output - D=16 - device = 'cuda' + D = 16 + device = "cuda" pt_model = NeuralNetCast(D).to(device) ort_model = ORTModule(copy.deepcopy(pt_model)) @@ -1427,8 +1539,9 @@ def run_step(model, input): _test_helpers.assert_values_are_close(ort_input.grad, pt_input.grad) _test_helpers.assert_values_are_close(ort_model.a.grad, pt_model.a.grad) + def test_module_with_non_differential_output(): - device = 'cuda' + device = "cuda" N, D_in, H, D_out = 32, 128, 64, 10 pt_model = NeuralNetNonDifferentiableOutput(D_in, H, D_out).to(device) ort_model = ORTModule(copy.deepcopy(pt_model)) @@ -1445,15 +1558,16 @@ def run_step(model, x): ort_prediction1, ort_mask1, ort_prediction2, ort_mask2 = run_step(ort_model, x) # _test_helpers.assert_values_are_close(ort_prediction1, pt_prediction1) # TODO: this is failing, need to investigate! - # This will be no reproducible if we change the model forward to - # mask1 = torch.gt(out, 0.01) + # This will be no reproducible if we change the model forward to + # mask1 = torch.gt(out, 0.01) _test_helpers.assert_values_are_close(ort_prediction2, pt_prediction2) _test_helpers.assert_values_are_close(ort_mask1, pt_mask1) _test_helpers.assert_values_are_close(ort_mask2, pt_mask2) _test_helpers.assert_gradients_match_and_reset_gradient(ort_model, pt_model) + def test_multiple_chained_ortmodules_with_non_differential_output(): - device = 'cuda' + device = "cuda" N, D_in, H, D_out = 32, 128, 64, 10 pt_model = NeuralNetChainedLayersWithNonDifferentiableOutput(D_in, H, D_out).to(device) ort_model = ORTModule(copy.deepcopy(pt_model)) @@ -1480,6 +1594,7 @@ def run_step(layer1, layer2, x, mask1): _test_helpers.assert_values_are_close(ort_prediction, pt_prediction) _test_helpers.assert_gradients_match_and_reset_gradient(ort_model2, pt_model2) + @pytest.mark.parametrize("loss_with_duplicated_output", [False, True]) def test_duplicated_output(loss_with_duplicated_output): class NeuralNet(torch.nn.Module): @@ -1489,10 +1604,10 @@ def __init__(self): def forward(self, input): out = self.fc1(input) - return out, out # duplicated output + return out, out # duplicated output N, C, H = 8, 4, 128 - device = 'cuda' + device = "cuda" pt_model = NeuralNet().to(device) ort_model = ORTModule(copy.deepcopy(pt_model)) @@ -1513,8 +1628,9 @@ def run_step(model, input): _test_helpers.assert_values_are_close(ort_prediction2, pt_prediction2) _test_helpers.assert_gradients_match_and_reset_gradient(ort_model, pt_model, atol=1e-5) + def test_multiple_forward_only_calls(): - device = 'cuda' + device = "cuda" N, D_in, H, D_out = 32, 784, 500, 10 pt_model = NeuralNetSinglePositionalArgument(D_in, H, D_out).to(device) ort_model = ORTModule(copy.deepcopy(pt_model)) @@ -1526,8 +1642,9 @@ def test_multiple_forward_only_calls(): _test_helpers.assert_values_are_close(ort_prediction, pt_prediction) + def test_nesting_forward_backward_calls(): - device = 'cuda' + device = "cuda" N, D_in, H, D_out = 32, 784, 500, 10 pt_model = NeuralNetSinglePositionalArgument(D_in, H, D_out).to(device) ort_model = ORTModule(copy.deepcopy(pt_model)) @@ -1559,8 +1676,9 @@ def test_nesting_forward_backward_calls(): _test_helpers.assert_values_are_close(ort_x1.grad, pt_x1.grad) _test_helpers.assert_gradients_match_and_reset_gradient(ort_model, pt_model) + def test_multiple_overlapping_forward_backward_calls(): - device = 'cuda' + device = "cuda" N, D_in, H, D_out = 32, 784, 500, 10 pt_model = NeuralNetSinglePositionalArgument(D_in, H, D_out).to(device) ort_model = ORTModule(copy.deepcopy(pt_model)) @@ -1594,8 +1712,9 @@ def run_step(model, x1, x2): _test_helpers.assert_values_are_close(ort_x2.grad, pt_x2.grad) _test_helpers.assert_gradients_match_and_reset_gradient(ort_model, pt_model) + def test_multiple_ortmodules_training(): - device = 'cuda' + device = "cuda" N, D_in, H, D_out = 32, 784, 128, 10 pt_model1 = NeuralNetSinglePositionalArgument(D_in, H, D_out).to(device) pt_model2 = NeuralNetSinglePositionalArgument(D_in, H, D_out).to(device) @@ -1623,8 +1742,9 @@ def run_step(model1, model2, x1, x2): _test_helpers.assert_gradients_match_and_reset_gradient(ort_model1, pt_model1) _test_helpers.assert_gradients_match_and_reset_gradient(ort_model2, pt_model2) + def test_multiple_ortmodules_common_backbone_training(): - device = 'cuda' + device = "cuda" N, D_in, H, D_out = 32, 64, 128, 64 pt_model0 = NeuralNetSinglePositionalArgument(D_in, H, D_out).to(device) pt_model1 = NeuralNetSinglePositionalArgument(D_in, H, D_out).to(device) @@ -1659,8 +1779,9 @@ def run_step(backbone_layers, task_layers, x): _test_helpers.assert_gradients_match_and_reset_gradient(ort_model0, pt_model0, reset_gradient=True, atol=1e-5) _test_helpers.assert_gradients_match_and_reset_gradient(ort_model2, pt_model2) + def test_multiple_chained_ortmodules_training(): - device = 'cuda' + device = "cuda" N, D_in, H, D_out = 32, 128, 500, 128 pt_model1 = NeuralNetSinglePositionalArgument(D_in, H, D_out).to(device) pt_model2 = NeuralNetSinglePositionalArgument(D_in, H, D_out).to(device) @@ -1682,14 +1803,15 @@ def run_step(layers1, layers2, x): _test_helpers.assert_gradients_match_and_reset_gradient(ort_model1, pt_model1) _test_helpers.assert_gradients_match_and_reset_gradient(ort_model2, pt_model2) + def test_mixed_nnmodule_ortmodules_training(): - device = 'cuda' + device = "cuda" N, D_in, H, D_out = 32, 128, 500, 128 pt_model1 = NeuralNetSinglePositionalArgument(D_in, H, D_out).to(device) pt_model2 = NeuralNetSinglePositionalArgument(D_in, H, D_out).to(device) pt_model3 = NeuralNetMultiplePositionalArguments(D_in, H, D_out).to(device) ort_model1 = ORTModule(copy.deepcopy(pt_model1)) - ort_model2 = copy.deepcopy(pt_model2) # model2 is intentionally left as nn.module + ort_model2 = copy.deepcopy(pt_model2) # model2 is intentionally left as nn.module ort_model3 = ORTModule(copy.deepcopy(pt_model3)) def run_step(model1, model2, model3, x1, x2): @@ -1713,6 +1835,7 @@ def run_step(model1, model2, model3, x1, x2): _test_helpers.assert_gradients_match_and_reset_gradient(ort_model2, pt_model2) _test_helpers.assert_gradients_match_and_reset_gradient(ort_model3, pt_model3) + def test_identity_elimination(): class NeuralNetSimpleIdentity(torch.nn.Module): def __init__(self, input_size, num_classes): @@ -1727,7 +1850,7 @@ def forward(self, x): z = y return z - device = 'cuda' + device = "cuda" N, D_in, H, D_out = 64, 784, 500, 10 model = NeuralNetSimpleIdentity(D_in, D_out).to(device) model = ORTModule(model) @@ -1737,10 +1860,11 @@ def forward(self, x): # Make sure model runs OK assert output is not None + def test_ortmodule_inputs_with_dynamic_shape(): D_in, H, D_out = 784, 500, 10 - pt_model = NeuralNetSinglePositionalArgument(D_in, H, D_out).to('cuda') + pt_model = NeuralNetSinglePositionalArgument(D_in, H, D_out).to("cuda") ort_model = ORTModule(copy.deepcopy(pt_model)) def run_step(model, x): @@ -1750,23 +1874,23 @@ def run_step(model, x): return p for step in range(10): - N = random.randint(1,100) - x = torch.randn(N, D_in, device='cuda', requires_grad=True) + N = random.randint(1, 100) + x = torch.randn(N, D_in, device="cuda", requires_grad=True) assert x.grad is None pt_p = run_step(pt_model, x) ort_p = run_step(ort_model, x) - _test_helpers.assert_values_are_close(ort_p, pt_p, atol=1e-6) # relaxing tolerance, 1e-7 or less would fail + _test_helpers.assert_values_are_close(ort_p, pt_p, atol=1e-6) # relaxing tolerance, 1e-7 or less would fail _test_helpers.assert_gradients_match_and_reset_gradient(ort_model, pt_model) def test_bert_inputs_with_dynamic_shape(): # create pytorch model with dropout disabled - pt_model = _get_bert_for_sequence_classification_model('cuda', - hidden_dropout_prob=0.0, - attention_probs_dropout_prob=0.0) + pt_model = _get_bert_for_sequence_classification_model( + "cuda", hidden_dropout_prob=0.0, attention_probs_dropout_prob=0.0 + ) ort_model = ORTModule(copy.deepcopy(pt_model)) def run_step(model, x, y, z): @@ -1776,18 +1900,20 @@ def run_step(model, x, y, z): return outputs[0] for step in range(10): - x, y, z = _get_bert_for_sequence_classification_sample_data_with_random_shapes('cuda') + x, y, z = _get_bert_for_sequence_classification_sample_data_with_random_shapes("cuda") pt_p = run_step(pt_model, x, y, z) ort_p = run_step(ort_model, x, y, z) - _test_helpers.assert_values_are_close(ort_p, pt_p, atol=1e-02) # TODO: this assert is failing with smaller tolerance, need to investigate!! + _test_helpers.assert_values_are_close( + ort_p, pt_p, atol=1e-02 + ) # TODO: this assert is failing with smaller tolerance, need to investigate!! # _test_helpers.assert_gradients_match_and_reset_gradient(ort_model, pt_model) #TODO - enable this check after the investigation -@pytest.mark.parametrize("device", ['cuda', 'cpu']) +@pytest.mark.parametrize("device", ["cuda", "cpu"]) def test_changes_input_requires_grad_reinitializes_module_gradient_graph_builder(device): - os.environ['ORTMODULE_SKIPCHECK_POLICY'] = 'SKIP_CHECK_DISABLED' + os.environ["ORTMODULE_SKIPCHECK_POLICY"] = "SKIP_CHECK_DISABLED" N, D_in, H, D_out = 32, 784, 500, 10 model = NeuralNetSinglePositionalArgument(D_in, H, D_out).to(device) @@ -1798,17 +1924,19 @@ def test_changes_input_requires_grad_reinitializes_module_gradient_graph_builder output_x = torch.sum(model(x)) output_x.backward() assert x.grad is None - module_gradient_graph_builder_training = \ - model._torch_module._execution_manager(model._is_training())._graph_builder + module_gradient_graph_builder_training = model._torch_module._execution_manager(model._is_training())._graph_builder output_y = torch.sum(model(y)) output_y.backward() assert y.grad is not None - assert module_gradient_graph_builder_training != \ - model._torch_module._execution_manager(model._is_training())._graph_builder + assert ( + module_gradient_graph_builder_training + != model._torch_module._execution_manager(model._is_training())._graph_builder + ) - del os.environ['ORTMODULE_SKIPCHECK_POLICY'] + del os.environ["ORTMODULE_SKIPCHECK_POLICY"] -@pytest.mark.parametrize("device", ['cuda']) + +@pytest.mark.parametrize("device", ["cuda"]) def test_input_requires_grad_backward_creates_input_grad_as_required0(device): N, D_in, H, D_out = 32, 784, 500, 10 pt_model = NeuralNetMultiplePositionalArgumentsMultiOutputsWithoutDependency(D_in, H, D_out).to(device) @@ -1823,7 +1951,7 @@ def test_input_requires_grad_backward_creates_input_grad_as_required0(device): def run_step0(model, x1, x2): y1, _ = model(x1, x2) s1 = y1.sum() - s1.backward() # y2's gradient will be materialized to full shape. + s1.backward() # y2's gradient will be materialized to full shape. return y1 pt_y1 = run_step0(pt_model, pt_x1, pt_x2) @@ -1833,12 +1961,14 @@ def run_step0(model, x1, x2): _test_helpers.assert_values_are_close(ort_x1.grad, pt_x1.grad) _test_helpers.assert_values_are_close(ort_x2.grad, pt_x2.grad) # backward() is from y1, so grad of fc2.weight and fc2.bias will not be calculated. - _test_helpers.assert_gradients_match_and_reset_gradient(ort_model, pt_model, none_pt_params=['fc2.weight', 'fc2.bias'], reset_gradient=True) + _test_helpers.assert_gradients_match_and_reset_gradient( + ort_model, pt_model, none_pt_params=["fc2.weight", "fc2.bias"], reset_gradient=True + ) def run_step1(model, x1, x2): _, y2 = model(x1, x2) s2 = y2.sum() - s2.backward() # y1's gradient will be materialized to full shape. + s2.backward() # y1's gradient will be materialized to full shape. return y2 pt_y2 = run_step1(pt_model, pt_x1, pt_x2) @@ -1848,10 +1978,12 @@ def run_step1(model, x1, x2): _test_helpers.assert_values_are_close(ort_x1.grad, pt_x1.grad) _test_helpers.assert_values_are_close(ort_x2.grad, pt_x2.grad) # backward() is from y2, so grad of fc1.weight and fc1.bias will not be calculated. - _test_helpers.assert_gradients_match_and_reset_gradient(ort_model, pt_model, none_pt_params=['fc1.weight', 'fc1.bias']) + _test_helpers.assert_gradients_match_and_reset_gradient( + ort_model, pt_model, none_pt_params=["fc1.weight", "fc1.bias"] + ) -@pytest.mark.parametrize("device", ['cuda']) +@pytest.mark.parametrize("device", ["cuda"]) def test_model_output_with_inplace_update(device): class NeuralNetWithGradNeedOutput(torch.nn.Module): def __init__(self, input_size, hidden_size): @@ -1886,9 +2018,9 @@ def run_step(model, x1): ort_y1 = run_step(ort_model, ort_x1) assert "modified by an inplace operation" in str(ex_info.value) -@pytest.mark.parametrize("device", ['cuda']) -def test_loss_combines_two_outputs_with_dependency(device): +@pytest.mark.parametrize("device", ["cuda"]) +def test_loss_combines_two_outputs_with_dependency(device): def run_step(model, x1, x2): y1, y2 = model(x1, x2) loss = y1.sum() + y2.sum() @@ -1912,10 +2044,10 @@ def run_step(model, x1, x2): _test_helpers.assert_values_are_close(pt_y2, ort_y2, atol=1e-06) _test_helpers.assert_gradients_match_and_reset_gradient(ort_model, pt_model) + @pytest.mark.parametrize("x1_requires_grad", [True, False]) @pytest.mark.parametrize("x2_requires_grad", [True, False]) def test_input_requires_grad_backward_creates_input_grad_as_required1(x1_requires_grad, x2_requires_grad): - def run_step(model, x1, x2): y1, y2 = model(x1, x2) s = y2.sum() @@ -1923,7 +2055,7 @@ def run_step(model, x1, x2): return y1, y2 N, D_in, H, D_out = 32, 784, 500, 10 - device = 'cuda' + device = "cuda" pt_model = NeuralNetMultiplePositionalArgumentsMultiOutputsWithDependency(D_in, H, D_out).to(device) ort_model = ORTModule(copy.deepcopy(pt_model)) pt_x1 = torch.randn(N, D_in, device=device, requires_grad=x1_requires_grad) @@ -1946,7 +2078,7 @@ def run_step(model, x1, x2): _test_helpers.assert_gradients_match_and_reset_gradient(ort_model, pt_model) -@pytest.mark.parametrize("device", ['cuda']) +@pytest.mark.parametrize("device", ["cuda"]) def test_model_with_bypass_input(device): class NeuralNetWithBypassInput(torch.nn.Module): def __init__(self, input_size, hidden_size, num_classes): @@ -1985,8 +2117,9 @@ def run_step(model, x1, x2): _test_helpers.assert_values_are_close(pt_y2, ort_y2, atol=1e-06) _test_helpers.assert_gradients_match_and_reset_gradient(ort_model, pt_model) + def test_gpu_reserved_memory_with_torch_no_grad(): - device = 'cuda' + device = "cuda" # Create a model and get the memory_reserved when torch.no_grad has been enabled # before and after export @@ -2004,14 +2137,15 @@ def test_gpu_reserved_memory_with_torch_no_grad(): model_without_no_grad = ORTModule(model_without_no_grad) mem_reserved_after_export_without_torch_no_grad = 0 - with patch('torch.no_grad'): + with patch("torch.no_grad"): model_without_no_grad(x, attention_mask=y, labels=z) mem_reserved_after_export_without_torch_no_grad = torch.cuda.memory_reserved(device) assert mem_reserved_after_export_with_torch_no_grad <= mem_reserved_after_export_without_torch_no_grad + @pytest.mark.parametrize("return_type", [dict, OrderedDict, SequenceClassifierOutput]) -@pytest.mark.parametrize("device", ['cpu', 'cuda']) +@pytest.mark.parametrize("device", ["cpu", "cuda"]) def test_dict_return_value_module(return_type, device): class NeuralNetDictOutput(torch.nn.Module): def __init__(self, input_size, hidden_size, num_classes): @@ -2033,7 +2167,7 @@ def forward(self, input1, input2, input3): out1 = self.fc1_2(self.relu1(self.fc1_1(input1))) out2 = self.fc2_2(self.relu2(self.fc2_1(input2))) out3 = self.fc3_2(self.relu3(self.fc3_1(input3))) - return return_type([('loss', out1), ('logits', out2), ('hidden_states', out3)]) + return return_type([("loss", out1), ("logits", out2), ("hidden_states", out3)]) N, D_in, H, D_out = 64, 784, 500, 10 model = NeuralNetDictOutput(D_in, H, D_out).to(device) @@ -2044,9 +2178,10 @@ def forward(self, input1, input2, input3): output = model(x, y, z) assert isinstance(output, return_type) - assert 'loss' in output and 'logits' in output and 'hidden_states' in output + assert "loss" in output and "logits" in output and "hidden_states" in output -@pytest.mark.parametrize("device", ['cuda', 'cpu']) + +@pytest.mark.parametrize("device", ["cuda", "cpu"]) def test_dict_of_tuple_return_value_module(device): class NeuralNetDictOfTuplesOutput(torch.nn.Module): def __init__(self, input_size, hidden_size, num_classes): @@ -2068,7 +2203,7 @@ def forward(self, input1, input2, input3): out1 = self.fc1_2(self.relu1(self.fc1_1(input1))) out2 = self.fc2_2(self.relu2(self.fc2_1(input2))) out3 = self.fc3_2(self.relu3(self.fc3_1(input3))) - return {'loss': (out1, out2, out3)} + return {"loss": (out1, out2, out3)} N, D_in, H, D_out = 64, 784, 500, 10 model = NeuralNetDictOfTuplesOutput(D_in, H, D_out).to(device) @@ -2078,10 +2213,11 @@ def forward(self, input1, input2, input3): z = torch.randn(N, D_in, device=device) output = model(x, y, z) - assert 'loss' in output - assert len(output['loss']) == 3 + assert "loss" in output + assert len(output["loss"]) == 3 + -@pytest.mark.parametrize("device", ['cuda', 'cpu']) +@pytest.mark.parametrize("device", ["cuda", "cpu"]) def test_tuple_of_tuple_return_value_module(device): class NeuralNetTupleOfTuplesOutput(torch.nn.Module): def __init__(self, input_size, hidden_size, num_classes): @@ -2118,9 +2254,11 @@ def forward(self, input1, input2, input3): assert len(output[0]) == 2 assert isinstance(output[1], torch.Tensor) -@pytest.mark.parametrize("device", ['cpu', 'cuda']) + +@pytest.mark.parametrize("device", ["cpu", "cuda"]) def test_named_tuple_return_value_module(device): - ReturnValue = namedtuple('NamedTupleReturnValue', 'loss logits hidden_states') + ReturnValue = namedtuple("NamedTupleReturnValue", "loss logits hidden_states") + class NeuralNetNamedTupleOutput(torch.nn.Module): def __init__(self, input_size, hidden_size, num_classes): super(NeuralNetNamedTupleOutput, self).__init__() @@ -2155,9 +2293,10 @@ def forward(self, input1, input2, input3): assert isinstance(output, tuple) assert isinstance(output, ReturnValue) -@pytest.mark.parametrize("device", ['cpu', 'cuda']) + +@pytest.mark.parametrize("device", ["cpu", "cuda"]) def test_exception_raised_for_custom_class_return_value_module(device): - os.environ['ORTMODULE_SKIPCHECK_POLICY'] = 'SKIP_CHECK_DISABLED' + os.environ["ORTMODULE_SKIPCHECK_POLICY"] = "SKIP_CHECK_DISABLED" N, D_in, H, D_out = 64, 784, 500, 10 pt_model = NeuralNetCustomClassOutput(D_in, H, D_out).to(device) @@ -2167,6 +2306,7 @@ def test_exception_raised_for_custom_class_return_value_module(device): z = torch.randn(N, D_in, device=device) from onnxruntime.training.ortmodule._fallback import _FallbackPolicy + if _test_helpers.is_all_or_nothing_fallback_enabled(None, _FallbackPolicy.FALLBACK_UNSUPPORTED_DATA): # Fallback pt_out = pt_model(x, y, z) @@ -2179,12 +2319,13 @@ def test_exception_raised_for_custom_class_return_value_module(device): # ORT backend with pytest.raises(_fallback.ORTModuleIOError) as runtime_error: ort_model(x, y, z) - assert 'ORTModule does not support the following model output type' in str(runtime_error.value) + assert "ORTModule does not support the following model output type" in str(runtime_error.value) + + del os.environ["ORTMODULE_SKIPCHECK_POLICY"] - del os.environ['ORTMODULE_SKIPCHECK_POLICY'] def test_dynamic_axes_config(): - device = 'cuda' + device = "cuda" # Model 1 N, D_in, H, D_out = 64, 784, 500, 10 @@ -2204,6 +2345,7 @@ def test_dynamic_axes_config(): assert output is not None assert _test_helpers.is_dynamic_axes(model_with_no_grad) + def test_model_with_multiple_devices_cpu_cuda(): class MultipleDeviceModel(torch.nn.Module): def __init__(self): @@ -2220,24 +2362,28 @@ def forward(self, x): x = torch.randn(20, 10) from onnxruntime.training.ortmodule._fallback import _FallbackPolicy + if _test_helpers.is_all_or_nothing_fallback_enabled(None, _FallbackPolicy.FALLBACK_UNSUPPORTED_DEVICE): # Fallback ort_model = ORTModule(copy.deepcopy(pt_model)) with pytest.raises(RuntimeError) as runtime_error: ort_model(x) - assert f"Expected all tensors to be on the same device, but found at least two devices" in str(runtime_error.value) + assert f"Expected all tensors to be on the same device, but found at least two devices" in str( + runtime_error.value + ) else: # ORT backend with pytest.raises(_fallback.ORTModuleFallbackException) as e: ort_model = ORTModule(pt_model) - assert str(e.value) == 'ORTModule supports a single device per model' + assert str(e.value) == "ORTModule supports a single device per model" + def test_model_with_multiple_devices_to_to(): class MultipleDeviceModel(torch.nn.Module): def __init__(self): super().__init__() - self.fc1 = torch.nn.Linear(10, 10).to('cpu') - self.fc2 = torch.nn.Linear(10, 10).to('cuda') + self.fc1 = torch.nn.Linear(10, 10).to("cpu") + self.fc2 = torch.nn.Linear(10, 10).to("cuda") def forward(self, x): x = self.fc1(x) @@ -2247,23 +2393,27 @@ def forward(self, x): pt_model = MultipleDeviceModel() x = torch.randn(20, 10) from onnxruntime.training.ortmodule._fallback import _FallbackPolicy + if _test_helpers.is_all_or_nothing_fallback_enabled(None, _FallbackPolicy.FALLBACK_UNSUPPORTED_DEVICE): # Fallback with pytest.raises(RuntimeError) as runtime_error: ort_model = ORTModule(copy.deepcopy(pt_model)) ort_model(x) - assert f"Expected all tensors to be on the same device, but found at least two devices" in str(runtime_error.value) + assert f"Expected all tensors to be on the same device, but found at least two devices" in str( + runtime_error.value + ) else: # ORT backend with pytest.raises(_fallback.ORTModuleFallbackException) as e: ort_model = ORTModule(pt_model) - assert str(e.value) == 'ORTModule supports a single device per model' + assert str(e.value) == "ORTModule supports a single device per model" + def test_model_with_multiple_devices_to_cpu(): class MultipleDeviceModel(torch.nn.Module): def __init__(self): super().__init__() - self.fc1 = torch.nn.Linear(10, 10).to('cuda') + self.fc1 = torch.nn.Linear(10, 10).to("cuda") self.fc2 = torch.nn.Linear(10, 10).cpu() def forward(self, x): @@ -2274,23 +2424,27 @@ def forward(self, x): pt_model = MultipleDeviceModel() x = torch.randn(20, 10) from onnxruntime.training.ortmodule._fallback import _FallbackPolicy + if _test_helpers.is_all_or_nothing_fallback_enabled(None, _FallbackPolicy.FALLBACK_UNSUPPORTED_DEVICE): # Fallback ort_model = ORTModule(copy.deepcopy(pt_model)) with pytest.raises(RuntimeError) as runtime_error: ort_model(x) - assert f"Expected all tensors to be on the same device, but found at least two devices" in str(runtime_error.value) + assert f"Expected all tensors to be on the same device, but found at least two devices" in str( + runtime_error.value + ) else: # ORT backend with pytest.raises(_fallback.ORTModuleFallbackException) as e: ort_model = ORTModule(pt_model) - assert str(e.value) == 'ORTModule supports a single device per model' + assert str(e.value) == "ORTModule supports a single device per model" + def test_model_with_multiple_devices_to_cuda(): class MultipleDeviceModel(torch.nn.Module): def __init__(self): super().__init__() - self.fc1 = torch.nn.Linear(10, 10).to('cpu') + self.fc1 = torch.nn.Linear(10, 10).to("cpu") self.fc2 = torch.nn.Linear(10, 10).cuda() def forward(self, x): @@ -2301,25 +2455,29 @@ def forward(self, x): pt_model = MultipleDeviceModel() x = torch.randn(20, 10) from onnxruntime.training.ortmodule._fallback import _FallbackPolicy + if _test_helpers.is_all_or_nothing_fallback_enabled(None, _FallbackPolicy.FALLBACK_UNSUPPORTED_DEVICE): # Fallback ort_model = ORTModule(copy.deepcopy(pt_model)) with pytest.raises(RuntimeError) as runtime_error: ort_model(x) - assert f"Expected all tensors to be on the same device, but found at least two devices" in str(runtime_error.value) + assert f"Expected all tensors to be on the same device, but found at least two devices" in str( + runtime_error.value + ) else: # ORT backend with pytest.raises(_fallback.ORTModuleFallbackException) as e: ort_model = ORTModule(pt_model) - assert str(e.value) == 'ORTModule supports a single device per model' + assert str(e.value) == "ORTModule supports a single device per model" + -@pytest.mark.parametrize("device", ['cuda', 'cuda:0', 'cuda:1', 'cuda:2']) +@pytest.mark.parametrize("device", ["cuda", "cuda:0", "cuda:1", "cuda:2"]) def test_model_with_different_cuda_devices(device): # Trick to run this test in single GPU machines device_id = _utils.get_device_index(device) if device_id >= torch.cuda.device_count(): - warnings.warn('Skipping test_model_with_different_cuda_devices(cuda:{})'.format(device_id)) + warnings.warn("Skipping test_model_with_different_cuda_devices(cuda:{})".format(device_id)) return N, D_in, H, D_out = 64, 784, 500, 10 @@ -2329,6 +2487,7 @@ def test_model_with_different_cuda_devices(device): x = torch.randn(N, D_in, device=device) model(x) + def test_register_custom_ops_pytorch_exporter_tensor_triu(): class SimpleTensorTriuModel(torch.nn.Module): def __init__(self): @@ -2346,7 +2505,8 @@ def forward(self, x): user_input = torch.ones(1, 10, 10) output = model(user_input) - assert list(output.shape) == [1, 10, 10] + assert list(output.shape) == [1, 10, 10] + def test_register_custom_ops_pytorch_exporter_torch_triu(): class SimpleTorchTriuModel(torch.nn.Module): @@ -2365,17 +2525,15 @@ def forward(self, x): user_input = torch.ones(1, 10, 10) output = model(user_input) - assert list(output.shape) == [1, 10, 10] + assert list(output.shape) == [1, 10, 10] + def test_wrap_ortmodule_and_change_device(): # Basic Sequencial model wrapping ORTModule x = torch.linspace(-math.pi, math.pi, 2000) xx = x.unsqueeze(-1).pow(torch.tensor([1, 2, 3])) y = torch.sin(x) - model = torch.nn.Sequential( - ORTModule(torch.nn.Linear(3, 1)), - torch.nn.Flatten(0, 1) - ) + model = torch.nn.Sequential(ORTModule(torch.nn.Linear(3, 1)), torch.nn.Flatten(0, 1)) # Changing device for fun model = model.cpu() @@ -2386,7 +2544,7 @@ def test_wrap_ortmodule_and_change_device(): y = y.cuda() # Quick train - loss_fn = torch.nn.MSELoss(reduction='sum') + loss_fn = torch.nn.MSELoss(reduction="sum") learning_rate = 1e-6 for t in range(2000): y_pred = model(xx) @@ -2400,12 +2558,14 @@ def test_wrap_ortmodule_and_change_device(): # Checking training finished normally assert y_pred is not None and loss is not None + @pytest.mark.parametrize("return_dict", [True, False]) def test_hf_model_output_with_tuples(return_dict): - device = 'cuda' + device = "cuda" - model = _get_bert_for_sequence_classification_model(device, output_attentions=True, - output_hidden_states=True, return_dict=return_dict) + model = _get_bert_for_sequence_classification_model( + device, output_attentions=True, output_hidden_states=True, return_dict=return_dict + ) x, y, z = _get_bert_for_sequence_classification_sample_data(device) model = ORTModule(model) @@ -2413,12 +2573,11 @@ def test_hf_model_output_with_tuples(return_dict): if return_dict: assert isinstance(output, SequenceClassifierOutput) - assert 'loss' in output and 'logits' in output and \ - 'attentions' in output and 'hidden_states' in output - assert isinstance(output['loss'], torch.Tensor) - assert isinstance(output['logits'], torch.Tensor) - assert isinstance(output['attentions'], tuple) - assert isinstance(output['hidden_states'], tuple) + assert "loss" in output and "logits" in output and "attentions" in output and "hidden_states" in output + assert isinstance(output["loss"], torch.Tensor) + assert isinstance(output["logits"], torch.Tensor) + assert isinstance(output["attentions"], tuple) + assert isinstance(output["hidden_states"], tuple) else: assert isinstance(output, tuple) assert isinstance(output[0], torch.Tensor) @@ -2426,7 +2585,8 @@ def test_hf_model_output_with_tuples(return_dict): assert isinstance(output[2], tuple) assert isinstance(output[3], tuple) -@pytest.mark.parametrize("device", ['cuda', 'cpu']) + +@pytest.mark.parametrize("device", ["cuda", "cpu"]) def test_nested_return_value_module(device): class NeuralNetNestedOutput(torch.nn.Module): def __init__(self, input_size, hidden_size, num_classes): @@ -2449,14 +2609,7 @@ def forward(self, input1, input2, input3): out1 = self.fc1_2(self.relu1(self.fc1_1(input1))) out2 = self.fc2_2(self.relu2(self.fc2_1(input2))) out3 = self.fc3_2(self.relu(self.relu3(self.fc3_1(input3)))) - return { - 'a': { - 'b': { - 'c': out1 - }, - 'd': (out2, out3) - } - } + return {"a": {"b": {"c": out1}, "d": (out2, out3)}} N, D_in, H, D_out = 64, 784, 500, 10 model = NeuralNetNestedOutput(D_in, H, D_out).to(device) @@ -2467,20 +2620,18 @@ def forward(self, input1, input2, input3): z = torch.randn(N, D_in, device=device) output = model(x, y, z) - assert 'a' in output and 'b' in output['a'] and 'c' in output['a']['b'] - assert isinstance(output['a']['b']['c'], torch.Tensor) + assert "a" in output and "b" in output["a"] and "c" in output["a"]["b"] + assert isinstance(output["a"]["b"]["c"], torch.Tensor) - assert 'd' in output['a'] - assert isinstance(output['a']['d'], tuple) - assert len(output['a']['d']) == 2 + assert "d" in output["a"] + assert isinstance(output["a"]["d"], tuple) + assert len(output["a"]["d"]) == 2 -@pytest.mark.parametrize("data_device, model_device", ( - ['cuda', 'cpu'], - ['cpu', 'cuda']) -) + +@pytest.mark.parametrize("data_device, model_device", (["cuda", "cpu"], ["cpu", "cuda"])) def test_forward_data_and_model_on_different_devices(data_device, model_device): - os.environ['ORTMODULE_SKIPCHECK_POLICY'] = 'SKIP_CHECK_DISABLED' + os.environ["ORTMODULE_SKIPCHECK_POLICY"] = "SKIP_CHECK_DISABLED" N, D_in, H, D_out = 64, 784, 500, 10 model = NeuralNetSinglePositionalArgument(D_in, H, D_out).to(model_device) @@ -2492,18 +2643,25 @@ def test_forward_data_and_model_on_different_devices(data_device, model_device): # Now that the model has been exported, feed in data from device other than the model device x = torch.randn(N, D_in, device=data_device) from onnxruntime.training.ortmodule._fallback import _FallbackPolicy, ORTModuleDeviceException + if _test_helpers.is_all_or_nothing_fallback_enabled(None, _FallbackPolicy.FALLBACK_UNSUPPORTED_DEVICE): # Fallback with pytest.raises(RuntimeError) as runtime_error: ort_model(x) - assert f"Expected all tensors to be on the same device, but found at least two devices" in str(runtime_error.value) + assert f"Expected all tensors to be on the same device, but found at least two devices" in str( + runtime_error.value + ) else: # ORT backend with pytest.raises(ORTModuleDeviceException) as runtime_error: ort_model(x) - assert f"Input argument to forward found on device {torch.device(x.device)}, but expected it to be on module device {ort_model._torch_module._execution_manager(ort_model._is_training())._device}." in str(runtime_error.value) + assert ( + f"Input argument to forward found on device {torch.device(x.device)}, but expected it to be on module device {ort_model._torch_module._execution_manager(ort_model._is_training())._device}." + in str(runtime_error.value) + ) + + del os.environ["ORTMODULE_SKIPCHECK_POLICY"] - del os.environ['ORTMODULE_SKIPCHECK_POLICY'] def test_forward_returns_none_type_as_output(): class NeuralNetNoneTypeOutput(torch.nn.Module): @@ -2516,17 +2674,18 @@ def __init__(self, input_size, num_classes): def forward(self, input1): out1 = self.fc1(input1) out1 = self.relu1(out1) - return {'out': out1, 'none_output': None} + return {"out": out1, "none_output": None} - device = 'cuda' + device = "cuda" N, D_in, H, D_out = 64, 784, 500, 10 model = NeuralNetNoneTypeOutput(D_in, D_out).to(device) model = ORTModule(model) x = torch.randn(N, D_in, device=device) output = model(x) - assert output['out'] is not None - assert output['none_output'] is None + assert output["out"] is not None + assert output["none_output"] is None + def test_bool_input_and_output(): class NeuralNetBoolInputOutput(torch.nn.Module): @@ -2540,7 +2699,7 @@ def forward(self, condition, x1, x2): out2 = torch.tensor(out1).to(torch.bool) return out1, out2 - device = 'cuda' + device = "cuda" N, D_in, D_out = 64, 784, 10 model = NeuralNetBoolInputOutput(D_in, D_out).to(device) model = ORTModule(model) @@ -2552,6 +2711,7 @@ def forward(self, condition, x1, x2): assert y1 is not None assert y2 is not None and y2.dtype == torch.bool + def test_uint8_input_and_output(): class NeuralNetUInt8InputOutput(torch.nn.Module): def __init__(self, input_size, num_classes): @@ -2564,7 +2724,7 @@ def forward(self, mask, x1, x2): out2 = torch.tensor(out1).to(torch.uint8) return out1, out2 - device = 'cuda' + device = "cuda" N, D_in, D_out = 64, 784, 10 model = NeuralNetUInt8InputOutput(D_in, D_out).to(device) model = ORTModule(model) @@ -2576,8 +2736,9 @@ def forward(self, mask, x1, x2): assert y1 is not None assert y2 is not None and y2.dtype == torch.uint8 + def test_model_partially_requires_grad(): - device = 'cuda' + device = "cuda" N, D_in, H, D_out = 64, 784, 500, 10 model = NeuralNetPartialNoGradModel(D_in, H, D_out).to(device) model = ORTModule(model) @@ -2589,8 +2750,9 @@ def test_model_partially_requires_grad(): loss = torch.sum(output) loss.backward() + def test_model_wrapped_inside_torch_no_grad(): - device = 'cuda' + device = "cuda" N, D_in, H, D_out = 64, 784, 500, 10 model = NeuralNetSinglePositionalArgument(D_in, H, D_out).to(device) model = ORTModule(model) @@ -2600,11 +2762,12 @@ def test_model_wrapped_inside_torch_no_grad(): with torch.no_grad(): output = model(x) + def test_model_initializer_requires_grad_changes_from_one_forward_to_next(): - os.environ['ORTMODULE_SKIPCHECK_POLICY'] = 'SKIP_CHECK_DISABLED' + os.environ["ORTMODULE_SKIPCHECK_POLICY"] = "SKIP_CHECK_DISABLED" - device = 'cuda' + device = "cuda" N, D_in, H, D_out = 64, 784, 500, 10 model = NeuralNetPartialNoGradModel(D_in, H, D_out).to(device) model.fc1.requires_grad_(True) @@ -2635,7 +2798,8 @@ def test_model_initializer_requires_grad_changes_from_one_forward_to_next(): assert torch.equal(weight_grad_2, weight_grad_3) assert torch.equal(bias_grad_2, bias_grad_3) - del os.environ['ORTMODULE_SKIPCHECK_POLICY'] + del os.environ["ORTMODULE_SKIPCHECK_POLICY"] + def test_model_with_registered_buffers(): class NeuralNetWithRegisteredBuffer(torch.nn.Module): @@ -2646,7 +2810,7 @@ def __init__(self, input_size, hidden_size, num_classes): self.relu = torch.nn.ReLU() self.fc2 = torch.nn.Linear(hidden_size, num_classes) self.register_buffer("buffer1s", torch.ones(num_classes)) - self.register_buffer("buffer2s", 1+torch.ones(num_classes)) + self.register_buffer("buffer2s", 1 + torch.ones(num_classes)) def forward(self, input1): out = self.fc1(input1) @@ -2655,7 +2819,8 @@ def forward(self, input1): out += self.buffer1s out += self.buffer2s return out - device = 'cuda' + + device = "cuda" N, D_in, H, D_out = 64, 784, 500, 10 model = NeuralNetWithRegisteredBuffer(D_in, H, D_out).to(device) @@ -2677,8 +2842,8 @@ def __init__(self, input_size, hidden_size, num_classes): self.relu = torch.nn.ReLU() self.fc2 = torch.nn.Linear(hidden_size, num_classes) self.register_buffer("buffer1s", torch.ones(num_classes)) - self.register_buffer("buffer2s", 1+torch.ones(num_classes)) - self.register_buffer("buffer3s", 2+torch.ones(num_classes)) + self.register_buffer("buffer2s", 1 + torch.ones(num_classes)) + self.register_buffer("buffer3s", 2 + torch.ones(num_classes)) def forward(self, input1): out = self.fc1(input1) @@ -2686,7 +2851,8 @@ def forward(self, input1): out = self.fc2(out) out += self.buffer3s return out - device = 'cuda' + + device = "cuda" N, D_in, H, D_out = 64, 784, 500, 10 model = UnusedBufferNet(D_in, H, D_out).to(device) @@ -2708,7 +2874,7 @@ def __init__(self, input_size, hidden_size, num_classes): self.relu = torch.nn.ReLU() self.fc2 = torch.nn.Linear(hidden_size, num_classes) self.register_parameter("param1", torch.nn.Parameter(torch.ones(num_classes))) - self.register_parameter("param2", torch.nn.Parameter(1+torch.ones(num_classes))) + self.register_parameter("param2", torch.nn.Parameter(1 + torch.ones(num_classes))) def forward(self, input1): out = self.fc1(input1) @@ -2716,9 +2882,10 @@ def forward(self, input1): out = self.fc2(out) out += self.param1 out += self.param2 - out += torch.tensor([3.], device=next(self.parameters()).device) + out += torch.tensor([3.0], device=next(self.parameters()).device) return out - device = 'cuda' + + device = "cuda" N, D_in, H, D_out = 64, 784, 500, 10 model = NeuralNetWithRegisteredParamsWithConstant(D_in, H, D_out).to(device) @@ -2730,8 +2897,9 @@ def forward(self, input1): output = ort_model(x) assert output is not None + def test_state_dict(): - device = 'cuda' + device = "cuda" N, D_in, H, D_out = 64, 784, 500, 10 pt_model = NeuralNetSinglePositionalArgument(D_in, H, D_out).to(device) ort_model = ORTModule(copy.deepcopy(pt_model)) @@ -2758,8 +2926,9 @@ def test_state_dict(): assert param_name in state_dict_ort assert torch.equal(param_value, state_dict_ort[param_name]) + def test_load_state_dict(): - device = 'cuda' + device = "cuda" N, D_in, H, D_out = 64, 784, 500, 10 pt_model = NeuralNetSinglePositionalArgument(D_in, H, D_out).to(device) ort_model = ORTModule(copy.deepcopy(pt_model)) @@ -2791,8 +2960,9 @@ def test_load_state_dict(): assert param_name in state_dict_ort assert torch.equal(param_value, state_dict_ort[param_name]) + def test_named_parameters(): - device = 'cuda' + device = "cuda" N, D_in, H, D_out = 64, 784, 500, 10 pt_model = NeuralNetSinglePositionalArgument(D_in, H, D_out).to(device) ort_model = ORTModule(copy.deepcopy(pt_model)) @@ -2802,8 +2972,9 @@ def test_named_parameters(): assert len(named_parameters_pt) > 0 assert named_parameters_pt == named_parameters_ort + def test_parameters(): - device = 'cuda' + device = "cuda" N, D_in, H, D_out = 64, 784, 500, 10 pt_model = NeuralNetSinglePositionalArgument(D_in, H, D_out).to(device) ort_model = ORTModule(copy.deepcopy(pt_model)) @@ -2814,11 +2985,12 @@ def test_parameters(): assert len(parameters_pt) == len(parameters_ort) assert all(torch.equal(parameters_pt[i], parameters_ort[i]) for i in range(len(parameters_pt))) + def test_named_buffers(): - device = 'cuda' + device = "cuda" N, D_in, H, D_out = 64, 784, 500, 10 pt_model = NeuralNetSinglePositionalArgument(D_in, H, D_out).to(device) - pt_model.register_buffer('sample_buffer_pt', torch.tensor(torch.randn(N, D_in, device=device))) + pt_model.register_buffer("sample_buffer_pt", torch.tensor(torch.randn(N, D_in, device=device))) ort_model = ORTModule(copy.deepcopy(pt_model)) named_buffers_pt = [name for name, _ in pt_model.named_buffers()] named_buffers_ort = [name for name, _ in ort_model.named_buffers()] @@ -2826,15 +2998,16 @@ def test_named_buffers(): assert len(named_buffers_pt) > 0 assert named_buffers_pt == named_buffers_ort - ort_model.register_buffer('sample_buffer_ort', torch.tensor(torch.randn(N, D_in, device=device))) + ort_model.register_buffer("sample_buffer_ort", torch.tensor(torch.randn(N, D_in, device=device))) named_buffers_ort = [name for name, _ in ort_model.named_buffers()] - assert named_buffers_ort == ['sample_buffer_pt', 'sample_buffer_ort'] + assert named_buffers_ort == ["sample_buffer_pt", "sample_buffer_ort"] + def test_buffers(): - device = 'cuda' + device = "cuda" N, D_in, H, D_out = 64, 784, 500, 10 pt_model = NeuralNetSinglePositionalArgument(D_in, H, D_out).to(device) - pt_model.register_buffer('sample_buffer_pt', torch.tensor(torch.randn(N, D_in, device=device))) + pt_model.register_buffer("sample_buffer_pt", torch.tensor(torch.randn(N, D_in, device=device))) ort_model = ORTModule(copy.deepcopy(pt_model)) buffers_pt = [buffer for buffer in pt_model.buffers()] buffers_ort = [buffer for buffer in ort_model.buffers()] @@ -2844,11 +3017,12 @@ def test_buffers(): assert all(torch.equal(buffers_pt[i], buffers_ort[i]) for i in range(len(buffers_pt))) x = torch.tensor(torch.randn(N, D_in, device=device)) - ort_model.register_buffer('sample_buffer_ort', x) + ort_model.register_buffer("sample_buffer_ort", x) buffers_ort = [buffer for buffer in ort_model.buffers()] assert len(buffers_ort) == 2 assert torch.equal(buffers_ort[1], x) + def test_eval_with_dropout(): class NeuralNetDropout(torch.nn.Module): def __init__(self, input_size, hidden_size, num_classes): @@ -2866,7 +3040,7 @@ def forward(self, input1): out = self.dropout(out) return out - device = 'cuda' + device = "cuda" N, D_in, H, D_out = 64, 784, 500, 10 model = NeuralNetDropout(D_in, H, D_out).to(device) @@ -2886,8 +3060,9 @@ def forward(self, input1): # Assert that the output from torch is the same as the one from ORTModule _test_helpers.assert_values_are_close(output, output_pt) + def test_with_torch_no_grad_context(): - device = 'cuda' + device = "cuda" N, D_in, H, D_out = 64, 784, 500, 10 model = NeuralNetSinglePositionalArgument(D_in, H, D_out).to(device) @@ -2909,6 +3084,7 @@ def test_with_torch_no_grad_context(): _test_helpers.assert_values_are_close(output, output_pt) assert output.grad is None and output_pt.grad is None + def test_unused_layer(): class Net(torch.nn.Module): def __init__(self, input_size, hidden_size, num_classes): @@ -2923,7 +3099,7 @@ def forward(self, input1): out = self.relu(out) return out - device = torch.device('cuda') + device = torch.device("cuda") N, D_in, H, D_out = 64, 784, 500, 10 pt_model = Net(D_in, H, D_out).to(device) ort_model = ORTModule(copy.deepcopy(pt_model)) @@ -2933,6 +3109,7 @@ def forward(self, input1): ort_output = ort_model(x) _test_helpers.assert_values_are_close(pt_output, ort_output) + def test_train_eval_with_various_outputs(): class Net(torch.nn.Module): def __init__(self, input_size, hidden_size, num_classes): @@ -2955,7 +3132,7 @@ def train_step(model, x): loss.backward() return out1, out2 - device = torch.device('cuda') + device = torch.device("cuda") N, D_in, H, D_out = 64, 784, 500, 10 pt_model = Net(D_in, H, D_out).to(device) ort_model = ORTModule(copy.deepcopy(pt_model)) @@ -2978,18 +3155,19 @@ def train_step(model, x): ort_out = ort_model(x) _test_helpers.assert_values_are_close(pt_out, ort_out) + def test_forward_dynamic_args(): - os.environ['ORTMODULE_SKIPCHECK_POLICY'] = 'SKIP_CHECK_DISABLED' + os.environ["ORTMODULE_SKIPCHECK_POLICY"] = "SKIP_CHECK_DISABLED" - device = 'cuda' + device = "cuda" N, D_in, H, D_out = 64, 784, 500, 10 model = NeuralNetPositionalArguments(input_size=D_in, hidden_size=H, num_classes=D_out).to(device) model = ORTModule(model) - args_size1 = [torch.randn(N, D_in, device=device)]*4 - args_size2 = [torch.randn(N, D_in, device=device)]*3 - args_size3 = [torch.randn(N, D_in, device=device)]*5 + args_size1 = [torch.randn(N, D_in, device=device)] * 4 + args_size2 = [torch.randn(N, D_in, device=device)] * 3 + args_size3 = [torch.randn(N, D_in, device=device)] * 5 # Make sure model runs without any exception for i in range(2): @@ -3021,12 +3199,12 @@ def test_forward_dynamic_args(): hash_args_size3 = hash(repr(model._torch_module._execution_manager(model._is_training())._input_info.schema)) assert hash_args_size3 != hash_args_size2 - del os.environ['ORTMODULE_SKIPCHECK_POLICY'] + del os.environ["ORTMODULE_SKIPCHECK_POLICY"] def test_forward_dynamic_kwargs(): - os.environ['ORTMODULE_SKIPCHECK_POLICY'] = 'SKIP_CHECK_DISABLED' + os.environ["ORTMODULE_SKIPCHECK_POLICY"] = "SKIP_CHECK_DISABLED" one = torch.FloatTensor([1]) model = NeuralNetSimplePositionalAndKeywordArguments() @@ -3050,21 +3228,21 @@ def test_forward_dynamic_kwargs(): # Train with x and y as inputs for _ in range(10): - output = model(one,y=one) + output = model(one, y=one) assert output is not None hash_x_y = hash(repr(model._torch_module._execution_manager(model._is_training())._input_info.schema)) assert hash_x_y != hash_x # Train with x and z as inputs for _ in range(10): - output = model(one,z=one) + output = model(one, z=one) assert output is not None hash_x_z = hash(repr(model._torch_module._execution_manager(model._is_training())._input_info.schema)) assert hash_x_z != hash_x_y # Train with x, y and z as inputs for _ in range(10): - output = model(one,y=one, z=one) + output = model(one, y=one, z=one) assert output is not None hash_x_y_z = hash(repr(model._torch_module._execution_manager(model._is_training())._input_info.schema)) assert hash_x_y_z != hash_x_z @@ -3077,47 +3255,49 @@ def test_forward_dynamic_kwargs(): assert hash_x2 != hash_x_y_z assert hash_x2 == hash_x - del os.environ['ORTMODULE_SKIPCHECK_POLICY'] - - -@pytest.mark.parametrize("forward_statement", - [# Only pos_X, pos_X as positionals - "model(pos_0, pos_1)", - # Only pos_X, pos_X as keywords - "model(pos_0=pos_0, pos_1=pos_1)", - # pos_X + *args, pos_X as positionals - "model(pos_0, pos_1, *args)", - # pos_X + kw_X, pos_X as positionals - "model(pos_0, pos_1, kw_0=kw_0, kw_1=kw_1)", - # pos_X + kw_X, pos_X as keywords - "model(pos_0=pos_0, pos_1=pos_1, kw_0=kw_0, kw_1=kw_1)", - # pos_X + kw_X, pos_X as positionals (missing kw_1) - "model(pos_0, pos_1, kw_0=kw_0)", - # pos_X + kw_X, pos_X as keywords (missing kw_1) - "model(pos_0=pos_0, pos_1=pos_1, kw_0=kw_0)", - # pos_X + kw_X, pos_X as positionals (missing kw_0) - "model(pos_0, pos_1, kw_1=kw_1)", - # pos_X + kw_X, pos_X as keywords (missing kw_0) - "model(pos_0=pos_0, pos_1=pos_1, kw_1=kw_1)", - # pos_X + kwargs, pos_X as positionals - "model(pos_0, pos_1, **kwargs)", - # pos_X + kwargs, pos_X as keywords - "model(pos_0=pos_0, pos_1=pos_1, **kwargs)", - # pos_X + *args + kw_X, pos_X as positionals - "model(pos_0, pos_1, *args, kw_0=kw_0, kw_1=kw_1)", - # pos_X + *args + kw_X, pos_X as positionals (missing kw_0) - "model(pos_0, pos_1, *args, kw_1=kw_1)", - # pos_X + *args + kw_X, pos_X as positionals (missing kw_1) - "model(pos_0, pos_1, *args, kw_0=kw_0)", - # pos_X + *args + kwargs, pos_X as positionals - "model(pos_0, pos_1, *args, **kwargs)", - # pos_X + *args + kw_X + kwargs, pos_X as positionals - "model(pos_0, pos_1, *args, kw_0=kw_0, kw_1=kw_1, **kwargs)", - # pos_X + *args + kw_X + kwargs, pos_X as positionals (missing kw_0) - "model(pos_0, pos_1, *args, kw_1=kw_1, **kwargs)", - # pos_X + *args + kw_X + kwargs, pos_X as positionals (missing kw_1) - "model(pos_0, pos_1, *args, kw_0=kw_0, **kwargs)", - ]) + del os.environ["ORTMODULE_SKIPCHECK_POLICY"] + + +@pytest.mark.parametrize( + "forward_statement", + [ # Only pos_X, pos_X as positionals + "model(pos_0, pos_1)", + # Only pos_X, pos_X as keywords + "model(pos_0=pos_0, pos_1=pos_1)", + # pos_X + *args, pos_X as positionals + "model(pos_0, pos_1, *args)", + # pos_X + kw_X, pos_X as positionals + "model(pos_0, pos_1, kw_0=kw_0, kw_1=kw_1)", + # pos_X + kw_X, pos_X as keywords + "model(pos_0=pos_0, pos_1=pos_1, kw_0=kw_0, kw_1=kw_1)", + # pos_X + kw_X, pos_X as positionals (missing kw_1) + "model(pos_0, pos_1, kw_0=kw_0)", + # pos_X + kw_X, pos_X as keywords (missing kw_1) + "model(pos_0=pos_0, pos_1=pos_1, kw_0=kw_0)", + # pos_X + kw_X, pos_X as positionals (missing kw_0) + "model(pos_0, pos_1, kw_1=kw_1)", + # pos_X + kw_X, pos_X as keywords (missing kw_0) + "model(pos_0=pos_0, pos_1=pos_1, kw_1=kw_1)", + # pos_X + kwargs, pos_X as positionals + "model(pos_0, pos_1, **kwargs)", + # pos_X + kwargs, pos_X as keywords + "model(pos_0=pos_0, pos_1=pos_1, **kwargs)", + # pos_X + *args + kw_X, pos_X as positionals + "model(pos_0, pos_1, *args, kw_0=kw_0, kw_1=kw_1)", + # pos_X + *args + kw_X, pos_X as positionals (missing kw_0) + "model(pos_0, pos_1, *args, kw_1=kw_1)", + # pos_X + *args + kw_X, pos_X as positionals (missing kw_1) + "model(pos_0, pos_1, *args, kw_0=kw_0)", + # pos_X + *args + kwargs, pos_X as positionals + "model(pos_0, pos_1, *args, **kwargs)", + # pos_X + *args + kw_X + kwargs, pos_X as positionals + "model(pos_0, pos_1, *args, kw_0=kw_0, kw_1=kw_1, **kwargs)", + # pos_X + *args + kw_X + kwargs, pos_X as positionals (missing kw_0) + "model(pos_0, pos_1, *args, kw_1=kw_1, **kwargs)", + # pos_X + *args + kw_X + kwargs, pos_X as positionals (missing kw_1) + "model(pos_0, pos_1, *args, kw_0=kw_0, **kwargs)", + ], +) def test_forward_call_kwargs_input(forward_statement): class KwargsNet(torch.nn.Module): def __init__(self, input_size, hidden_size, num_classes): @@ -3136,10 +3316,10 @@ def forward(self, pos_0, pos_1, *args, kw_0=None, kw_1=None, **kwargs): if kw_1 is not None: model_input += kw_1 if kwargs: - if 'kwargs_0' in kwargs: - model_input += kwargs['kwargs_0'] - if 'kwargs_1' in kwargs: - model_input += torch.matmul(kwargs['kwargs_0'], kwargs['kwargs_1']) + if "kwargs_0" in kwargs: + model_input += kwargs["kwargs_0"] + if "kwargs_1" in kwargs: + model_input += torch.matmul(kwargs["kwargs_0"], kwargs["kwargs_1"]) out = self.fc1(model_input) out = self.relu(out) @@ -3147,7 +3327,7 @@ def forward(self, pos_0, pos_1, *args, kw_0=None, kw_1=None, **kwargs): return out # Modeling - device = 'cuda' + device = "cuda" N, D_in, H, D_out = 64, 784, 500, 10 model = KwargsNet(input_size=D_in, hidden_size=H, num_classes=D_out).to(device) model = ORTModule(model) @@ -3157,9 +3337,8 @@ def forward(self, pos_0, pos_1, *args, kw_0=None, kw_1=None, **kwargs): pos_1 = torch.randn(N, D_in, device=device) kw_0 = torch.randn(N, D_in, device=device) kw_1 = torch.randn(N, D_in, device=device) - args = [torch.randn(N, D_in, device=device)]*2 - kwargs = {'kwargs_0' : torch.randn(N, D_in, device=device), - 'kwargs_1' : torch.randn(D_in, D_in, device=device)} + args = [torch.randn(N, D_in, device=device)] * 2 + kwargs = {"kwargs_0": torch.randn(N, D_in, device=device), "kwargs_1": torch.randn(D_in, D_in, device=device)} # Training step prediction = eval(forward_statement) @@ -3172,7 +3351,8 @@ def test_repro_iscontiguous(): class SimpleNet(torch.nn.Module): def __init__(self): super(SimpleNet, self).__init__() - self.a = torch.nn.Parameter(torch.FloatTensor([-1., 1.])) + self.a = torch.nn.Parameter(torch.FloatTensor([-1.0, 1.0])) + def forward(self, x): result = torch.mean(self.a) + x return result @@ -3187,12 +3367,12 @@ def forward(self, x): def test_forward_call_default_input(): - os.environ['ORTMODULE_SKIPCHECK_POLICY'] = 'SKIP_CHECK_DISABLED' + os.environ["ORTMODULE_SKIPCHECK_POLICY"] = "SKIP_CHECK_DISABLED" class UnusedNet(torch.nn.Module): def __init__(self): super().__init__() - self.zeros = torch.nn.Parameter(torch.zeros(1,1)) + self.zeros = torch.nn.Parameter(torch.zeros(1, 1)) def forward(self, a, b, c, d, *args, kw_0=None, **kwargs): result = a + d + self.zeros.sum() @@ -3201,23 +3381,23 @@ def forward(self, a, b, c, d, *args, kw_0=None, **kwargs): if kw_0: result += kw_0 if kwargs: - assert 'kwargs_1' in kwargs - result += kwargs['kwargs_1'] + assert "kwargs_1" in kwargs + result += kwargs["kwargs_1"] return result # Modeling - device = 'cuda' + device = "cuda" model = UnusedNet().to(device) model = ORTModule(model) # Dummy data one = torch.FloatTensor([1]).to(device) - two = 2*one - three = 3*one - four = 4*one - args = [two]*5 - kw_0 = 6*one - kwargs = {'kwargs_0': 7*one, 'kwargs_1': 8*one} + two = 2 * one + three = 3 * one + four = 4 * one + args = [two] * 5 + kw_0 = 6 * one + kwargs = {"kwargs_0": 7 * one, "kwargs_1": 8 * one} # Make sure model runs without any exception for i in range(2): @@ -3256,7 +3436,8 @@ def forward(self, a, b, c, d, *args, kw_0=None, **kwargs): if model._is_training(): out.sum().backward() - del os.environ['ORTMODULE_SKIPCHECK_POLICY'] + del os.environ["ORTMODULE_SKIPCHECK_POLICY"] + def test_forward_call_kwargs_input_unexpected_order(): class OrderlyNet(torch.nn.Module): @@ -3277,7 +3458,7 @@ def forward(self, input1=None, input2=None): out2 = self.fc2(out1) return out1, out2 - device = 'cuda' + device = "cuda" N, D_in, H, D_out = 32, 784, 500, 10 model = OrderlyNet(D_in, H, D_out).to(device) model = ORTModule(model) @@ -3294,7 +3475,7 @@ def forward(self, input1=None, input2=None): model.eval() # Must work because forward() and dict order match - y1, y2 = model(**{'input1': input1, 'input2': input2}) + y1, y2 = model(**{"input1": input1, "input2": input2}) assert y1 is not None assert y2 is not None if model._is_training(): @@ -3302,7 +3483,7 @@ def forward(self, input1=None, input2=None): loss.backward() # Must work even when forward() and dict order mismatch - y1, y2 = model(**{'input2': input2, 'input1': input1}) + y1, y2 = model(**{"input2": input2, "input1": input1}) assert y1 is not None assert y2 is not None if model._is_training(): @@ -3312,12 +3493,12 @@ def forward(self, input1=None, input2=None): def test_forward_call_lots_None(): - os.environ['ORTMODULE_SKIPCHECK_POLICY'] = 'SKIP_CHECK_DISABLED' + os.environ["ORTMODULE_SKIPCHECK_POLICY"] = "SKIP_CHECK_DISABLED" class NoneNet(torch.nn.Module): def __init__(self): super().__init__() - self.zeros = torch.nn.Parameter(torch.zeros(1,1)) + self.zeros = torch.nn.Parameter(torch.zeros(1, 1)) def forward(self, a, b, c, d, e, f, y=None, z=None): assert a is not None @@ -3346,25 +3527,25 @@ def run_step(expected, a, b, c, d, e, f, y, z): # the model when `forward(a,b)` is used after `forward(**{'a': a, 'b': b})` # or vice-versa model._torch_module._execution_manager(model._is_training())._onnx_model = None - out = model(a,b,c,d,e,f,y,z) + out = model(a, b, c, d, e, f, y, z) assert out is not None assert out.item() == expected if model._is_training(): loss = out.sum() loss.backward() - device = 'cuda' + device = "cuda" model = NoneNet().to(device) model = ORTModule(model) - a = torch.FloatTensor([1]).to(device)*1 - b = torch.FloatTensor([1]).to(device)*10 - c = torch.FloatTensor([1]).to(device)*100 - d = torch.FloatTensor([1]).to(device)*1000 - e = torch.FloatTensor([1]).to(device)*10000 - f = torch.FloatTensor([1]).to(device)*100000 - y = torch.FloatTensor([1]).to(device)*1000000 - z = torch.FloatTensor([1]).to(device)*10000000 + a = torch.FloatTensor([1]).to(device) * 1 + b = torch.FloatTensor([1]).to(device) * 10 + c = torch.FloatTensor([1]).to(device) * 100 + d = torch.FloatTensor([1]).to(device) * 1000 + e = torch.FloatTensor([1]).to(device) * 10000 + f = torch.FloatTensor([1]).to(device) * 100000 + y = torch.FloatTensor([1]).to(device) * 1000000 + z = torch.FloatTensor([1]).to(device) * 10000000 # Make sure model runs without any exception for i in range(2): @@ -3374,28 +3555,53 @@ def run_step(expected, a, b, c, d, e, f, y, z): else: model.eval() - run_step(a.item() + f.item(), - a, None, None, None, None, f, None, None, ) - run_step(a.item() + f.item(), - **{'a': a, 'b': None, 'c': None, 'd': None, 'e': None, 'f': f, 'y': None, 'z': None}) - run_step(a.item() + z.item(), - a, None, None, None, None, None, None, z) - run_step(a.item() + z.item(), - **{'a': a, 'b': None, 'c': None, 'd': None, 'e': None, 'f': None, 'y': None, 'z': z}) - run_step(a.item() + c.item() + y.item(), - a, None, c, None, None, None, y, None) - run_step(a.item() + c.item() + y.item(), - **{'a': a, 'b': None, 'c': c, 'd': None, 'e': None, 'f': None, 'y': y, 'z': None}) - run_step(a.item() + b.item() + c.item() + d.item() + e.item() + f.item() + y.item() + z.item(), - a, b, c, d, e, f, y, z) - run_step(a.item() + b.item() + c.item() + d.item() + e.item() + f.item() + y.item() + z.item(), - **{'a': a, 'b': b, 'c': c, 'd': d, 'e': e, 'f': f, 'y': y, 'z': z}) - - del os.environ['ORTMODULE_SKIPCHECK_POLICY'] + run_step( + a.item() + f.item(), + a, + None, + None, + None, + None, + f, + None, + None, + ) + run_step( + a.item() + f.item(), **{"a": a, "b": None, "c": None, "d": None, "e": None, "f": f, "y": None, "z": None} + ) + run_step(a.item() + z.item(), a, None, None, None, None, None, None, z) + run_step( + a.item() + z.item(), **{"a": a, "b": None, "c": None, "d": None, "e": None, "f": None, "y": None, "z": z} + ) + run_step(a.item() + c.item() + y.item(), a, None, c, None, None, None, y, None) + run_step( + a.item() + c.item() + y.item(), + **{"a": a, "b": None, "c": c, "d": None, "e": None, "f": None, "y": y, "z": None}, + ) + run_step( + a.item() + b.item() + c.item() + d.item() + e.item() + f.item() + y.item() + z.item(), + a, + b, + c, + d, + e, + f, + y, + z, + ) + run_step( + a.item() + b.item() + c.item() + d.item() + e.item() + f.item() + y.item() + z.item(), + **{"a": a, "b": b, "c": c, "d": d, "e": e, "f": f, "y": y, "z": z}, + ) + + del os.environ["ORTMODULE_SKIPCHECK_POLICY"] + @pytest.mark.parametrize("bool_argument", [True, False]) @pytest.mark.parametrize("int_argument", [100, 100000, 100000000, -100, -100000, -100000000]) -@pytest.mark.parametrize("float_argument", [1.23, 11209123.12452, 12093702935.1249863, -1.23, -11209123.12452, -12093702935.1249863]) +@pytest.mark.parametrize( + "float_argument", [1.23, 11209123.12452, 12093702935.1249863, -1.23, -11209123.12452, -12093702935.1249863] +) def test_primitive_inputs(bool_argument, int_argument, float_argument): class PrimitiveTypesInputNet(torch.nn.Module): def __init__(self, input_size, hidden_size, num_classes): @@ -3421,7 +3627,7 @@ def forward(self, input1, bool_argument, int_argument, float_argument): assert type(int_argument) is int assert type(float_argument) is float - device = 'cuda' + device = "cuda" N, D_in, H, D_out = 32, 784, 500, 10 pt_model = PrimitiveTypesInputNet(D_in, H, D_out).to(device) ort_model = ORTModule(copy.deepcopy(pt_model)) @@ -3431,10 +3637,11 @@ def forward(self, input1, bool_argument, int_argument, float_argument): ort_out = ort_model(input1, bool_argument, int_argument, float_argument) _test_helpers.assert_values_are_close(pt_out, ort_out) + @pytest.mark.parametrize("bool_arguments", [(True, False), (False, True)]) def test_changing_bool_input_re_exports_model(bool_arguments): - os.environ['ORTMODULE_SKIPCHECK_POLICY'] = 'SKIP_CHECK_DISABLED' + os.environ["ORTMODULE_SKIPCHECK_POLICY"] = "SKIP_CHECK_DISABLED" class PrimitiveTypesInputNet(torch.nn.Module): def __init__(self, input_size, hidden_size, num_classes): @@ -3458,7 +3665,7 @@ def forward(self, input1, bool_argument): assert type(bool_arguments[0]) is bool assert type(bool_arguments[1]) is bool - device = 'cuda' + device = "cuda" N, D_in, H, D_out = 32, 784, 500, 10 pt_model = PrimitiveTypesInputNet(D_in, H, D_out).to(device) ort_model = ORTModule(pt_model) @@ -3472,7 +3679,8 @@ def forward(self, input1, bool_argument): assert exported_model1 != exported_model2 - del os.environ['ORTMODULE_SKIPCHECK_POLICY'] + del os.environ["ORTMODULE_SKIPCHECK_POLICY"] + def test_model_with_registered_buffer_and_dropped_parameters(): class ModelWithBufferAndDroppedParameter(torch.nn.Module): @@ -3497,7 +3705,7 @@ def forward(self, bool_argument, input1): out = out + self.buffer return out - device = 'cuda' + device = "cuda" N, D_in, H, D_out = 64, 784, 500, 10 model = ModelWithBufferAndDroppedParameter(D_in, H, D_out).to(device) model = ORTModule(model) @@ -3508,12 +3716,17 @@ def forward(self, bool_argument, input1): # Ensure that no exceptions are raised out = model(bool_argument, x) -@pytest.mark.parametrize("model, none_pt_params", - [(UnusedBeginParameterNet(784, 500, 400, 10), ['fc1.weight', 'fc1.bias']), - (UnusedMiddleParameterNet(784, 500, 400, 10), ['fc2.weight', 'fc2.bias']), - (UnusedEndParameterNet(784, 500, 400, 10), ['fc2.weight', 'fc2.bias'])]) + +@pytest.mark.parametrize( + "model, none_pt_params", + [ + (UnusedBeginParameterNet(784, 500, 400, 10), ["fc1.weight", "fc1.bias"]), + (UnusedMiddleParameterNet(784, 500, 400, 10), ["fc2.weight", "fc2.bias"]), + (UnusedEndParameterNet(784, 500, 400, 10), ["fc2.weight", "fc2.bias"]), + ], +) def test_unused_parameters(model, none_pt_params): - device = 'cuda' + device = "cuda" N, D_in, H1, H2, D_out = 64, 784, 500, 400, 10 model = model.to(device) @@ -3531,8 +3744,7 @@ def test_unused_parameters(model, none_pt_params): loss_ort = out_ort.sum() loss_ort.backward() _test_helpers.assert_values_are_close(out_ort, out_pt) - _test_helpers.assert_gradients_match_and_reset_gradient(ort_model, model, - none_pt_params=none_pt_params) + _test_helpers.assert_gradients_match_and_reset_gradient(ort_model, model, none_pt_params=none_pt_params) # Also try in eval mode model.eval() @@ -3546,6 +3758,7 @@ def test_unused_parameters(model, none_pt_params): out_ort = ort_model(y) _test_helpers.assert_values_are_close(out_ort, out_pt) + def test_output_order(): class OutputOrderNet(torch.nn.Module): def __init__(self, input_size, hidden_size, num_classes): @@ -3564,13 +3777,25 @@ def __init__(self, input_size, hidden_size, num_classes): self.fc11 = torch.nn.Linear(input_size, hidden_size) self.fc12 = torch.nn.Linear(input_size, hidden_size) - def forward(self, input1, input2, input3, input4, input5, input6, input7, input8, input9, input10, input11, input12): - return self.fc1(input1), self.fc2(input2), self.fc3(input3), \ - self.fc4(input4), self.fc5(input5), self.fc6(input6), \ - self.fc7(input7), self.fc8(input8), self.fc9(input9), \ - self.fc10(input10), self.fc11(input11), self.fc12(input12) - - device = 'cuda' + def forward( + self, input1, input2, input3, input4, input5, input6, input7, input8, input9, input10, input11, input12 + ): + return ( + self.fc1(input1), + self.fc2(input2), + self.fc3(input3), + self.fc4(input4), + self.fc5(input5), + self.fc6(input6), + self.fc7(input7), + self.fc8(input8), + self.fc9(input9), + self.fc10(input10), + self.fc11(input11), + self.fc12(input12), + ) + + device = "cuda" N, D_in, H, D_out = 64, 784, 500, 10 model = OutputOrderNet(D_in, H, D_out).to(device) ort_model = ORTModule(copy.deepcopy(model)) @@ -3585,7 +3810,8 @@ def forward(self, input1, input2, input3, input4, input5, input6, input7, input8 for x, y in zip(out_pt, out_ort): _test_helpers.assert_values_are_close(x, y) -@pytest.mark.parametrize("device", ['cuda', 'cpu', None]) + +@pytest.mark.parametrize("device", ["cuda", "cpu", None]) def test_stateless_model_specified_device(device): N, D_in, H, D_out = 32, 784, 500, 10 @@ -3600,6 +3826,7 @@ def test_stateless_model_specified_device(device): _test_helpers.assert_values_are_close(pt_y, ort_y) + def test_stateless_model_unspecified_device(): N, D_in, H, D_out = 32, 784, 500, 10 @@ -3614,12 +3841,17 @@ def test_stateless_model_unspecified_device(): _test_helpers.assert_values_are_close(pt_y, ort_y) -@pytest.mark.parametrize("model", - [(UnusedBeginParameterNet(784, 500, 400, 10)), - (UnusedMiddleParameterNet(784, 500, 400, 10)), - (UnusedEndParameterNet(784, 500, 400, 10))]) + +@pytest.mark.parametrize( + "model", + [ + (UnusedBeginParameterNet(784, 500, 400, 10)), + (UnusedMiddleParameterNet(784, 500, 400, 10)), + (UnusedEndParameterNet(784, 500, 400, 10)), + ], +) def test_unused_parameters_does_not_unnecessarily_reinitialize(model): - device = 'cuda' + device = "cuda" N, D_in, H1, H2, D_out = 64, 784, 500, 400, 10 model = model.to(device) @@ -3629,14 +3861,17 @@ def test_unused_parameters_does_not_unnecessarily_reinitialize(model): x = torch.randn(N, D_in, device=device) _ = ort_model(x) - input_info = _io.parse_inputs_for_onnx_export(training_manager._module_parameters, - training_manager._onnx_models.exported_model, - training_manager._input_info.schema, - x, - {}) + input_info = _io.parse_inputs_for_onnx_export( + training_manager._module_parameters, + training_manager._onnx_models.exported_model, + training_manager._input_info.schema, + x, + {}, + ) assert not training_manager._reinitialize_graph_builder(input_info) + def test_load_state_dict_for_wrapped_ortmodule(): class WrapperModule(torch.nn.Module): def __init__(self, ortmodule): @@ -3646,7 +3881,7 @@ def __init__(self, ortmodule): def forward(self, x): return self._ortmodule(x) - device = 'cuda' + device = "cuda" N, D_in, H, D_out = 64, 784, 500, 10 model = NeuralNetSinglePositionalArgument(D_in, H, D_out).to(device) model = ORTModule(copy.deepcopy(model)) @@ -3665,8 +3900,9 @@ def forward(self, x): assert param_name in state_dict2 assert torch.equal(param_value, state_dict2[param_name]) + def test_hf_save_pretrained(): - device = 'cuda' + device = "cuda" model1 = _get_bert_for_sequence_classification_model(device) model1 = ORTModule(model1) @@ -3688,13 +3924,15 @@ def test_hf_save_pretrained(): # to check if from_pretrained worked. config = AutoConfig.from_pretrained(temporary_dir) model2 = BertForSequenceClassification.from_pretrained( - temporary_dir, config=config, + temporary_dir, + config=config, ).to(device) model2 = ORTModule(model2) for p1, p2 in zip(model1.parameters(), model2.parameters()): assert p1.data.ne(p2.data).sum() == 0 + def test_ortmodule_string_inputs_are_ignored(): pt_model = MyStrNet() @@ -3702,11 +3940,15 @@ def test_ortmodule_string_inputs_are_ignored(): x = torch.randn(1, 2) with pytest.warns(UserWarning) as warning_record: - out = ort_model(x, 'hello') + out = ort_model(x, "hello") assert len(warning_record) == 2 - assert "Received input of type which may be treated as a constant by ORT by default." in warning_record[1].message.args[0] - _test_helpers.assert_values_are_close(out, x+1) + assert ( + "Received input of type which may be treated as a constant by ORT by default." + in warning_record[1].message.args[0] + ) + _test_helpers.assert_values_are_close(out, x + 1) + def test_ortmodule_list_input(): class ListNet(torch.nn.Module): @@ -3719,7 +3961,7 @@ def forward(self, batch): b = batch[1] return self.dummy + a + b - device = 'cuda' + device = "cuda" N, D_in, H, D_out = 64, 784, 500, 10 pt_model = ListNet().to(device) ort_model = ORTModule(copy.deepcopy(pt_model)) @@ -3728,6 +3970,7 @@ def forward(self, batch): _test_helpers.assert_values_are_close(pt_model(x), ort_model(y)) + def test_ortmodule_list_input_with_unused_values(): class ListNet(torch.nn.Module): def __init__(self): @@ -3739,7 +3982,7 @@ def forward(self, batch): b = batch[1] return self.dummy + b - device = 'cuda' + device = "cuda" N, D_in, H, D_out = 64, 784, 500, 10 pt_model = ListNet().to(device) ort_model = ORTModule(copy.deepcopy(pt_model)) @@ -3748,6 +3991,7 @@ def forward(self, batch): _test_helpers.assert_values_are_close(pt_model(x), ort_model(y)) + def test_ortmodule_list_input_with_none_values(): class ListNet(torch.nn.Module): def __init__(self): @@ -3759,7 +4003,7 @@ def forward(self, batch): b = batch[1] return self.dummy + a + b - device = 'cuda' + device = "cuda" N, D_in, H, D_out = 64, 784, 500, 10 pt_model = ListNet().to(device) ort_model = ORTModule(copy.deepcopy(pt_model)) @@ -3768,6 +4012,7 @@ def forward(self, batch): _test_helpers.assert_values_are_close(pt_model(x), ort_model(y)) + def test_ortmodule_nested_list_input(): class ListNet(torch.nn.Module): def __init__(self): @@ -3782,28 +4027,31 @@ def forward(self, batch): e = batch[2][1][0] return self.dummy + a + b + c + d + e - device = 'cuda' + device = "cuda" N, D_in, H, D_out = 64, 784, 500, 10 pt_model = ListNet().to(device) ort_model = ORTModule(copy.deepcopy(pt_model)) - x = [torch.randn(N, D_in, device=device), + x = [ + torch.randn(N, D_in, device=device), [torch.randn(N, D_in, device=device), torch.randn(N, D_in, device=device)], - [torch.randn(N, D_in, device=device), [torch.randn(N, D_in, device=device)]]] + [torch.randn(N, D_in, device=device), [torch.randn(N, D_in, device=device)]], + ] y = copy.deepcopy(x) _test_helpers.assert_values_are_close(pt_model(x), ort_model(y)) -@pytest.mark.parametrize("mode", ['training', 'inference']) + +@pytest.mark.parametrize("mode", ["training", "inference"]) def test_debug_options_save_onnx_models_os_environment(mode): - device = 'cuda' + device = "cuda" N, D_in, H, D_out = 64, 784, 500, 10 # Create a temporary directory for the onnx_models with tempfile.TemporaryDirectory() as temporary_dir: - os.environ['ORTMODULE_SAVE_ONNX_PATH'] = temporary_dir + os.environ["ORTMODULE_SAVE_ONNX_PATH"] = temporary_dir model = NeuralNetSinglePositionalArgument(D_in, H, D_out).to(device) - ort_model = ORTModule(model, DebugOptions(save_onnx=True, onnx_prefix='my_model')) - if mode == 'inference': + ort_model = ORTModule(model, DebugOptions(save_onnx=True, onnx_prefix="my_model")) + if mode == "inference": ort_model.eval() x = torch.randn(N, D_in, device=device) _ = ort_model(x) @@ -3813,16 +4061,17 @@ def test_debug_options_save_onnx_models_os_environment(mode): assert os.path.exists(os.path.join(temporary_dir, f"my_model_optimized_{mode}.onnx")) assert os.path.exists(os.path.join(temporary_dir, f"my_model_optimized_pre_grad_{mode}.onnx")) assert os.path.exists(os.path.join(temporary_dir, f"my_model_execution_model_{mode}.onnx")) - del os.environ['ORTMODULE_SAVE_ONNX_PATH'] + del os.environ["ORTMODULE_SAVE_ONNX_PATH"] + -@pytest.mark.parametrize("mode", ['training', 'inference']) +@pytest.mark.parametrize("mode", ["training", "inference"]) def test_debug_options_save_onnx_models_cwd(mode): - device = 'cuda' + device = "cuda" N, D_in, H, D_out = 64, 784, 500, 10 model = NeuralNetSinglePositionalArgument(D_in, H, D_out).to(device) - ort_model = ORTModule(model, DebugOptions(save_onnx=True, onnx_prefix='my_cwd_model')) - if mode == 'inference': + ort_model = ORTModule(model, DebugOptions(save_onnx=True, onnx_prefix="my_cwd_model")) + if mode == "inference": ort_model.eval() x = torch.randn(N, D_in, device=device) _ = ort_model(x) @@ -3838,13 +4087,15 @@ def test_debug_options_save_onnx_models_cwd(mode): os.remove(os.path.join(os.getcwd(), f"my_cwd_model_optimized_pre_grad_{mode}.onnx")) os.remove(os.path.join(os.getcwd(), f"my_cwd_model_execution_model_{mode}.onnx")) + def test_debug_options_save_onnx_models_validate_fail_on_non_writable_dir(): - os.environ['ORTMODULE_SAVE_ONNX_PATH'] = '/non/existent/directory' + os.environ["ORTMODULE_SAVE_ONNX_PATH"] = "/non/existent/directory" with pytest.raises(Exception) as ex_info: - _ = DebugOptions(save_onnx=True, onnx_prefix='my_model') + _ = DebugOptions(save_onnx=True, onnx_prefix="my_model") assert "Directory /non/existent/directory is not writable." in str(ex_info.value) - del os.environ['ORTMODULE_SAVE_ONNX_PATH'] + del os.environ["ORTMODULE_SAVE_ONNX_PATH"] + def test_debug_options_save_onnx_models_validate_fail_on_non_str_prefix(): prefix = 23 @@ -3852,15 +4103,17 @@ def test_debug_options_save_onnx_models_validate_fail_on_non_str_prefix(): _ = DebugOptions(save_onnx=True, onnx_prefix=prefix) assert f"Expected name prefix of type str, got {type(prefix)}." in str(ex_info.value) + def test_debug_options_save_onnx_models_validate_fail_on_no_prefix(): with pytest.raises(Exception) as ex_info: _ = DebugOptions(save_onnx=True) assert f"onnx_prefix must be provided when save_onnx is set." in str(ex_info.value) + def test_debug_options_log_level(): # NOTE: This test will output verbose logging - device = 'cuda' + device = "cuda" N, D_in, H, D_out = 64, 784, 500, 10 model = NeuralNetSinglePositionalArgument(D_in, H, D_out).to(device) ort_model = ORTModule(model, DebugOptions(log_level=LogLevel.VERBOSE)) @@ -3870,11 +4123,12 @@ def test_debug_options_log_level(): # assert that the logging is done in verbose mode assert ort_model._torch_module._execution_manager(True)._debug_options.logging.log_level == LogLevel.VERBOSE + def test_debug_options_log_level_os_environment(): # NOTE: This test will output info logging - os.environ['ORTMODULE_LOG_LEVEL'] = 'INFO' - device = 'cuda' + os.environ["ORTMODULE_LOG_LEVEL"] = "INFO" + device = "cuda" N, D_in, H, D_out = 64, 784, 500, 10 model = NeuralNetSinglePositionalArgument(D_in, H, D_out).to(device) ort_model = ORTModule(model) @@ -3883,14 +4137,16 @@ def test_debug_options_log_level_os_environment(): # assert that the logging is done in info mode assert ort_model._torch_module._execution_manager(True)._debug_options.logging.log_level == LogLevel.INFO - del os.environ['ORTMODULE_LOG_LEVEL'] + del os.environ["ORTMODULE_LOG_LEVEL"] + def test_debug_options_log_level_validation_fails_on_type_mismatch(): - log_level = 'some_string' + log_level = "some_string" with pytest.raises(Exception) as ex_info: _ = DebugOptions(log_level=log_level) assert f"Expected log_level of type LogLevel, got {type(log_level)}." in str(ex_info.value) + def test_ortmodule_gradient_accumulation_optimization_correctness(): class NeuralNetWithCast(torch.nn.Module): def __init__(self, input_size, hidden_size, num_classes): @@ -3905,7 +4161,8 @@ def forward(self, input1): out = self.relu(out) out = self.fc2(out) return out - device = 'cuda' + + device = "cuda" N, D_in, H, D_out = 64, 784, 500, 10 pt_model = NeuralNetWithCast(D_in, H, D_out).to(device) @@ -3944,6 +4201,7 @@ def run_optim_step(optimizer): run_optim_step(tgt_optimizer) run_optim_step(opt_optimizer) + def test_ortmodule_dict_input(): class DictNet(torch.nn.Module): def __init__(self): @@ -3951,19 +4209,20 @@ def __init__(self): self.dummy = torch.nn.Parameter(torch.FloatTensor([0])) def forward(self, batch): - b = batch['one_value'] - a = batch['two_value'] + b = batch["one_value"] + a = batch["two_value"] return self.dummy + a + b - device = 'cuda' + device = "cuda" N, D_in, H, D_out = 64, 784, 500, 10 pt_model = DictNet().to(device) ort_model = ORTModule(copy.deepcopy(pt_model)) - x = {'one_value': torch.randn(N, D_in, device=device), 'two_value': torch.randn(N, D_in, device=device)} + x = {"one_value": torch.randn(N, D_in, device=device), "two_value": torch.randn(N, D_in, device=device)} x_copy = copy.deepcopy(x) _test_helpers.assert_values_are_close(pt_model(x), ort_model(x_copy)) + def test_ortmodule_dict_input_with_unused_values(): class DictNet(torch.nn.Module): def __init__(self): @@ -3971,19 +4230,20 @@ def __init__(self): self.dummy = torch.nn.Parameter(torch.FloatTensor([0])) def forward(self, batch): - b = batch['b'] - a = batch['a'] + b = batch["b"] + a = batch["a"] return self.dummy + a - device = 'cuda' + device = "cuda" N, D_in, H, D_out = 64, 784, 500, 10 pt_model = DictNet().to(device) ort_model = ORTModule(copy.deepcopy(pt_model)) - x = {'a': torch.randn(N, D_in, device=device), 'b': torch.randn(N, D_in, device=device)} + x = {"a": torch.randn(N, D_in, device=device), "b": torch.randn(N, D_in, device=device)} x_copy = copy.deepcopy(x) _test_helpers.assert_values_are_close(pt_model(x), ort_model(x_copy)) + def test_ortmodule_dict_input_with_none_values(): class DictNet(torch.nn.Module): def __init__(self): @@ -3991,19 +4251,20 @@ def __init__(self): self.dummy = torch.nn.Parameter(torch.FloatTensor([0])) def forward(self, batch): - b = batch['b'] - a = batch['a'] if batch['a'] else torch.FloatTensor([2.0]).cuda() + b = batch["b"] + a = batch["a"] if batch["a"] else torch.FloatTensor([2.0]).cuda() return self.dummy + a + b - device = 'cuda' + device = "cuda" N, D_in, H, D_out = 64, 784, 500, 10 pt_model = DictNet().to(device) ort_model = ORTModule(copy.deepcopy(pt_model)) - x = {'a': None, 'b': torch.randn(N, D_in, device=device)} + x = {"a": None, "b": torch.randn(N, D_in, device=device)} x_copy = copy.deepcopy(x) _test_helpers.assert_values_are_close(pt_model(x), ort_model(x_copy)) + def test_ortmodule_dict_input_with_nested_values(): class DictNet(torch.nn.Module): def __init__(self): @@ -4011,34 +4272,33 @@ def __init__(self): self.dummy = torch.nn.Parameter(torch.FloatTensor([0])) def forward(self, batch): - a = batch['one_value'] - b = batch['two_value']['three_value'] - c = batch['two_value']['four_value'] - d = batch['five_value']['six_value'] - e = batch['five_value']['seven_value']['eight_value'] + a = batch["one_value"] + b = batch["two_value"]["three_value"] + c = batch["two_value"]["four_value"] + d = batch["five_value"]["six_value"] + e = batch["five_value"]["seven_value"]["eight_value"] return self.dummy + a + b + c + d + e - device = 'cuda' + device = "cuda" N, D_in, H, D_out = 64, 784, 500, 10 pt_model = DictNet().to(device) ort_model = ORTModule(copy.deepcopy(pt_model)) x = { - 'one_value': torch.randn(N, D_in, device=device), - 'two_value': { - 'three_value': torch.randn(N, D_in, device=device), - 'four_value': torch.randn(N, D_in, device=device) - }, - 'five_value': { - 'six_value': torch.randn(N, D_in, device=device), - 'seven_value': { - 'eight_value': torch.randn(N, D_in, device=device) - } - } - } + "one_value": torch.randn(N, D_in, device=device), + "two_value": { + "three_value": torch.randn(N, D_in, device=device), + "four_value": torch.randn(N, D_in, device=device), + }, + "five_value": { + "six_value": torch.randn(N, D_in, device=device), + "seven_value": {"eight_value": torch.randn(N, D_in, device=device)}, + }, + } x_copy = copy.deepcopy(x) _test_helpers.assert_values_are_close(pt_model(x), ort_model(x_copy)) + def test_ortmodule_list_dict_input_with_nested_values(): class ListDictNet(torch.nn.Module): def __init__(self): @@ -4046,74 +4306,65 @@ def __init__(self): self.dummy = torch.nn.Parameter(torch.FloatTensor([3])) def forward(self, batch): - a = batch['one_value'][0] - b = batch['two_value'][0] - c = batch['two_value'][1] - d = batch['three_value'][0] - e = batch['three_value'][1]['four_value'] + a = batch["one_value"][0] + b = batch["two_value"][0] + c = batch["two_value"][1] + d = batch["three_value"][0] + e = batch["three_value"][1]["four_value"] return self.dummy + a + b + c + d + e - device = 'cuda' + device = "cuda" N, D_in, H, D_out = 64, 784, 500, 10 pt_model = ListDictNet().to(device) ort_model = ORTModule(copy.deepcopy(pt_model)) x = { - 'one_value': [torch.randn(N, D_in, device=device)], - 'two_value': [torch.randn(N, D_in, device=device), torch.randn(N, D_in, device=device)], - 'three_value': [ - torch.randn(N, D_in, device=device), - { - 'four_value': torch.randn(N, D_in, device=device) - } - ] - } + "one_value": [torch.randn(N, D_in, device=device)], + "two_value": [torch.randn(N, D_in, device=device), torch.randn(N, D_in, device=device)], + "three_value": [torch.randn(N, D_in, device=device), {"four_value": torch.randn(N, D_in, device=device)}], + } x_copy = copy.deepcopy(x) _test_helpers.assert_values_are_close(pt_model(x), ort_model(x_copy)) + def test_ortmodule_list_dict_input_with_kwargs_and_registered_buffer(): class ListDictKwargsNet(torch.nn.Module): def __init__(self, N, D_in): super(ListDictKwargsNet, self).__init__() - self.register_buffer("buffer", torch.ones(N, D_in, device='cuda')) + self.register_buffer("buffer", torch.ones(N, D_in, device="cuda")) self.dummy = torch.nn.Parameter(torch.FloatTensor([3])) def forward(self, batch, **kwargs): - a = batch['one_value'][0] - b = batch['two_value'][0] - c = batch['two_value'][1] - d = batch['three_value'][0] - e = batch['three_value'][1]['four_value'] + a = batch["one_value"][0] + b = batch["two_value"][0] + c = batch["two_value"][1] + d = batch["three_value"][0] + e = batch["three_value"][1]["four_value"] out = self.buffer + self.dummy + a + b + c + d + e if kwargs: - if 'kwargs_0' in kwargs: - out += kwargs['kwargs_0'] - if 'kwargs_1' in kwargs: - out += torch.matmul(kwargs['kwargs_0'], kwargs['kwargs_1']) + if "kwargs_0" in kwargs: + out += kwargs["kwargs_0"] + if "kwargs_1" in kwargs: + out += torch.matmul(kwargs["kwargs_0"], kwargs["kwargs_1"]) return out - device = 'cuda' + device = "cuda" N, D_in, H, D_out = 64, 784, 500, 10 pt_model = ListDictKwargsNet(N, D_in).to(device) - ort_model = ORTModule(copy.deepcopy(pt_model), DebugOptions(save_onnx=True, onnx_prefix='kwargsanddict')) + ort_model = ORTModule(copy.deepcopy(pt_model), DebugOptions(save_onnx=True, onnx_prefix="kwargsanddict")) x = { - 'one_value': [torch.randn(N, D_in, device=device)], - 'two_value': [torch.randn(N, D_in, device=device), torch.randn(N, D_in, device=device)], - 'three_value': [ - torch.randn(N, D_in, device=device), - { - 'four_value': torch.randn(N, D_in, device=device) - } - ] - } + "one_value": [torch.randn(N, D_in, device=device)], + "two_value": [torch.randn(N, D_in, device=device), torch.randn(N, D_in, device=device)], + "three_value": [torch.randn(N, D_in, device=device), {"four_value": torch.randn(N, D_in, device=device)}], + } x_copy = copy.deepcopy(x) - kwargs_input = {'kwargs_0' : torch.randn(N, D_in, device=device), - 'kwargs_1' : torch.randn(D_in, D_in, device=device)} + kwargs_input = {"kwargs_0": torch.randn(N, D_in, device=device), "kwargs_1": torch.randn(D_in, D_in, device=device)} kwargs_input_copy = copy.deepcopy(kwargs_input) _test_helpers.assert_values_are_close(pt_model(x, **kwargs_input), ort_model(x_copy, **kwargs_input_copy)) + def test_ortmodule_user_defined_method(): class UserDefinedMethodsNet(torch.nn.Module): def __init__(self): @@ -4126,7 +4377,7 @@ def forward(self, a): def custom_method_returns_input(self, user_input): return user_input - device = 'cuda' + device = "cuda" N, D_in, H, D_out = 64, 784, 500, 10 pt_model = UserDefinedMethodsNet().to(device) ort_model = ORTModule(copy.deepcopy(pt_model)) @@ -4140,6 +4391,7 @@ def custom_method_returns_input(self, user_input): ort_out = ort_model(y) _test_helpers.assert_values_are_close(pt_out, ort_out) + def test_ortmodule_user_getattr_gets_successfully(): class UserDefinedMethodsNet(torch.nn.Module): def __init__(self): @@ -4152,7 +4404,7 @@ def forward(self, a): def custom_method_returns_input(self, user_input): return user_input - device = 'cuda' + device = "cuda" N, D_in, H, D_out = 64, 784, 500, 10 pt_model = UserDefinedMethodsNet().to(device) ort_model = ORTModule(pt_model) @@ -4161,7 +4413,8 @@ def custom_method_returns_input(self, user_input): assert ort_model.custom_method_returns_input.__func__ == pt_model.custom_method_returns_input.__func__ assert ort_model.dummy is pt_model.dummy -@pytest.mark.parametrize("attribute", ['True', 'lambda x : x']) + +@pytest.mark.parametrize("attribute", ["True", "lambda x : x"]) def test_ortmodule_setattr_new_attribute(attribute): class UserNet(torch.nn.Module): def __init__(self): @@ -4171,14 +4424,15 @@ def __init__(self): def forward(self, a): return self.dummy + a - device = 'cuda' + device = "cuda" pt_model = UserNet().to(device) ort_model = ORTModule(pt_model) ort_model.a_new_attribute = attribute - assert hasattr(pt_model, 'a_new_attribute') + assert hasattr(pt_model, "a_new_attribute") assert pt_model.a_new_attribute == attribute - assert 'a_new_attribute' not in ort_model.__dict__ + assert "a_new_attribute" not in ort_model.__dict__ + def test_ortmodule_setattr_on_ortmodule_copied_user_model_attribute(): class UserNet(torch.nn.Module): @@ -4195,7 +4449,7 @@ def custom_method(self, a): def my_new_custom_method(self, a, b, c): return a + b + c - device = 'cuda' + device = "cuda" pt_model = UserNet().to(device) ort_model = ORTModule(pt_model) # custom_method is copied by ORTModule from the users model @@ -4204,14 +4458,15 @@ def my_new_custom_method(self, a, b, c): # dummy is defined on pt model ort_model.dummy = torch.nn.Parameter(torch.FloatTensor([12])) - assert hasattr(pt_model, 'dummy') + assert hasattr(pt_model, "dummy") assert torch.eq(pt_model.dummy, torch.nn.Parameter(torch.FloatTensor([12]))) - assert 'dummy' not in ort_model.__dict__ + assert "dummy" not in ort_model.__dict__ - assert hasattr(pt_model, 'custom_method') + assert hasattr(pt_model, "custom_method") assert pt_model.custom_method is not my_new_custom_method assert ort_model.custom_method is my_new_custom_method + def test_ortmodule_setattr_ortmodule_attribute(): class UserNet(torch.nn.Module): def __init__(self): @@ -4221,18 +4476,19 @@ def __init__(self): def forward(self, a): return self.dummy + a - device = 'cuda' + device = "cuda" pt_model = UserNet().to(device) ort_model = ORTModule(pt_model) ort_model._torch_module = True - assert not hasattr(pt_model, '_torch_module') - assert '_torch_module' in ort_model.__dict__ + assert not hasattr(pt_model, "_torch_module") + assert "_torch_module" in ort_model.__dict__ assert ort_model._torch_module == True + def test_ortmodule_setattr_signals_model_changed(): - os.environ['ORTMODULE_SKIPCHECK_POLICY'] = 'SKIP_CHECK_DISABLED' + os.environ["ORTMODULE_SKIPCHECK_POLICY"] = "SKIP_CHECK_DISABLED" class UserNet(torch.nn.Module): def __init__(self, input_flag): @@ -4246,7 +4502,7 @@ def forward(self, a): else: return a - device = 'cuda' + device = "cuda" N, D_in, H, D_out = 64, 784, 500, 10 pt_model = UserNet(True).to(device) ort_model = ORTModule(pt_model) @@ -4266,7 +4522,8 @@ def forward(self, a): assert exported_model1 != exported_model2 - del os.environ['ORTMODULE_SKIPCHECK_POLICY'] + del os.environ["ORTMODULE_SKIPCHECK_POLICY"] + def test_ortmodule_attribute_name_collision_warning(): class UserNet(torch.nn.Module): @@ -4281,7 +4538,7 @@ def forward(self, a): def load_state_dict(self): pass - device = 'cuda' + device = "cuda" pt_model = UserNet().to(device) with pytest.warns(UserWarning) as warning_record: ort_model = ORTModule(pt_model) @@ -4290,6 +4547,7 @@ def load_state_dict(self): assert "_torch_module collides with ORTModule's attribute name." in warning_record[0].message.args[0] assert "load_state_dict collides with ORTModule's attribute name." in warning_record[1].message.args[0] + def test_ortmodule_ortmodule_method_attribute_copy(): class UserNetWithSelfCallingForward(torch.nn.Module): def __init__(self, input_size, hidden_size, num_classes): @@ -4308,7 +4566,7 @@ def forward(self, input1): def run_forward(self, *args, **kwargs): return self(*args, **kwargs) - device = 'cuda' + device = "cuda" N, D_in, H, D_out = 64, 784, 500, 10 pt_model = UserNetWithSelfCallingForward(D_in, H, D_out).to(device) ort_model = ORTModule(copy.deepcopy(pt_model)) @@ -4326,32 +4584,39 @@ def run_forward(self, *args, **kwargs): assert torch.equal(out1, out2) _test_helpers.assert_values_are_close(out2, out3) - assert type(out1.grad_fn).__name__ == '_ORTModuleFunctionBackward' - assert type(out2.grad_fn).__name__ == '_ORTModuleFunctionBackward' - assert type(out3.grad_fn).__name__ == 'AddmmBackward0' if LooseVersion( - torch.__version__) >= LooseVersion('1.10.0') else 'AddmmBackward' - -@pytest.mark.parametrize("policy_str, policy",[ - ('SKIP_CHECK_DISABLED', _graph_execution_manager._SkipCheck.SKIP_CHECK_DISABLED), - ('SKIP_CHECK_DEVICE', _graph_execution_manager._SkipCheck.SKIP_CHECK_DEVICE), - ('SKIP_CHECK_BUILD_GRADIENT', _graph_execution_manager._SkipCheck.SKIP_CHECK_BUILD_GRADIENT), - ('SKIP_CHECK_EXECUTION_AGENT', _graph_execution_manager._SkipCheck.SKIP_CHECK_EXECUTION_AGENT), -]) + assert type(out1.grad_fn).__name__ == "_ORTModuleFunctionBackward" + assert type(out2.grad_fn).__name__ == "_ORTModuleFunctionBackward" + assert ( + type(out3.grad_fn).__name__ == "AddmmBackward0" + if LooseVersion(torch.__version__) >= LooseVersion("1.10.0") + else "AddmmBackward" + ) + + +@pytest.mark.parametrize( + "policy_str, policy", + [ + ("SKIP_CHECK_DISABLED", _graph_execution_manager._SkipCheck.SKIP_CHECK_DISABLED), + ("SKIP_CHECK_DEVICE", _graph_execution_manager._SkipCheck.SKIP_CHECK_DEVICE), + ("SKIP_CHECK_BUILD_GRADIENT", _graph_execution_manager._SkipCheck.SKIP_CHECK_BUILD_GRADIENT), + ("SKIP_CHECK_EXECUTION_AGENT", _graph_execution_manager._SkipCheck.SKIP_CHECK_EXECUTION_AGENT), + ], +) def test_ortmodule_skip_check_load_from_os_env(policy_str, policy): - device = 'cuda' + device = "cuda" N, D_in, H, D_out = 64, 784, 500, 10 - os.environ['ORTMODULE_SKIPCHECK_POLICY'] = policy_str + os.environ["ORTMODULE_SKIPCHECK_POLICY"] = policy_str model = NeuralNetSinglePositionalArgument(D_in, H, D_out).to(device) ort_model = ORTModule(model) for training_mode in [False, True]: assert ort_model._torch_module._execution_manager(training_mode)._skip_check == policy - del os.environ['ORTMODULE_SKIPCHECK_POLICY'] + del os.environ["ORTMODULE_SKIPCHECK_POLICY"] + -@pytest.mark.parametrize("is_training,deterministic", - list(itertools.product([True,False],repeat=2))) -def test_ortmodule_determinism_flag(is_training,deterministic): +@pytest.mark.parametrize("is_training,deterministic", list(itertools.product([True, False], repeat=2))) +def test_ortmodule_determinism_flag(is_training, deterministic): torch.use_deterministic_algorithms(deterministic) @@ -4365,6 +4630,7 @@ def test_ortmodule_determinism_flag(is_training,deterministic): _ = model(x) from onnxruntime.training.ortmodule import _are_deterministic_algorithms_enabled + assert _are_deterministic_algorithms_enabled() is torch.are_deterministic_algorithms_enabled() @@ -4372,16 +4638,19 @@ def test_ortmodule_gradient_builder(): class Model(torch.nn.Module): def __init__(self): super(Model, self).__init__() + def forward(self, x): return torch.cos(x) - device = 'cuda' + device = "cuda" - @register_gradient('', 'Cos') + @register_gradient("", "Cos") def Cos_gradient(): - return [('Sin', ['I(0)'], ['Sin_X']), - ('Mul', ['Sin_X', 'GO(0)'], ['Sin_X_Times_dY']), - ('Neg', ['Sin_X_Times_dY'], ['GI(0)'])] + return [ + ("Sin", ["I(0)"], ["Sin_X"]), + ("Mul", ["Sin_X", "GO(0)"], ["Sin_X_Times_dY"]), + ("Neg", ["Sin_X_Times_dY"], ["GI(0)"]), + ] pt_model = Model().to(device) ort_model = ORTModule(copy.deepcopy(pt_model)) @@ -4397,18 +4666,18 @@ def run_step(model, x): pt_prediction = run_step(pt_model, pt_x) ort_prediction = run_step(ort_model, ort_x) _test_helpers.assert_values_are_close(ort_prediction, pt_prediction) - _test_helpers.assert_values_are_close(ort_x.grad, pt_x.grad ) + _test_helpers.assert_values_are_close(ort_x.grad, pt_x.grad) def test_override_pytorch_exporter_kwargs(): - device = 'cuda' + device = "cuda" N, D_in, H, D_out = 64, 784, 500, 10 x = torch.randn(N, D_in, device=device) model = NeuralNetSinglePositionalArgument(D_in, H, D_out).to(device) ort_model = ORTModule(model) - ort_model._torch_module._execution_manager(True)._export_extra_kwargs = {'custom_opsets': None} + ort_model._torch_module._execution_manager(True)._export_extra_kwargs = {"custom_opsets": None} # Make sure model runs without any exception prediction = ort_model(x) @@ -4418,27 +4687,27 @@ def test_override_pytorch_exporter_kwargs(): def test_override_pytorch_exporter_kwargs__invalid(): - device = 'cuda' + device = "cuda" N, D_in, H, D_out = 64, 784, 500, 10 x = torch.randn(N, D_in, device=device) model = NeuralNetSinglePositionalArgument(D_in, H, D_out).to(device) ort_model = ORTModule(model) - ort_model._torch_module._execution_manager(True)._export_extra_kwargs = {'verbose': False} + ort_model._torch_module._execution_manager(True)._export_extra_kwargs = {"verbose": False} with pytest.raises(_fallback.ORTModuleONNXModelException) as type_error: _ = ort_model(x) assert "The following PyTorch exporter arguments cannot be specified: '{'verbose'}'." in str(type_error.value) def test_override_pytorch_exporter_kwargs_using_ortmodule_extension__invalid(): - device = 'cuda' + device = "cuda" class ORTModuleExtension(ORTModule): def __init__(self, module, debug_options=None): super().__init__(module, debug_options) for training_mode in [False, True]: - self._torch_module._execution_manager(training_mode)._export_extra_kwargs = {'verbose': None} + self._torch_module._execution_manager(training_mode)._export_extra_kwargs = {"verbose": None} N, D_in, H, D_out = 64, 784, 500, 10 x = torch.randn(N, D_in, device=device) @@ -4449,15 +4718,16 @@ def __init__(self, module, debug_options=None): _ = ort_model(x) assert "The following PyTorch exporter arguments cannot be specified: '{'verbose'}'." in str(type_error.value) + def test_override_pytorch_exporter_kwargs_using_ortmodule_extension(): - device = 'cuda' + device = "cuda" class ORTModuleExtension(ORTModule): def __init__(self, module, debug_options=None): super().__init__(module, debug_options) # modify GraphExecutionManager internally for training_mode in [False, True]: - self._torch_module._execution_manager(training_mode)._export_extra_kwargs = {'custom_opsets': None} + self._torch_module._execution_manager(training_mode)._export_extra_kwargs = {"custom_opsets": None} N, D_in, H, D_out = 64, 784, 500, 10 x = torch.randn(N, D_in, device=device) @@ -4470,11 +4740,12 @@ def __init__(self, module, debug_options=None): prediction = prediction.sum() prediction.backward() + def test_ortmodule_fused_adam_optimizer_correctness(): torch.manual_seed(8888) - device = 'cuda' + device = "cuda" N, D_in, H, D_out = 32, 128, 500, 10 pt_model = NeuralNetSinglePositionalArgument(D_in, H, D_out).to(device) @@ -4511,28 +4782,28 @@ def run_optim_step(optimizer): _test_helpers.assert_values_are_close(pt_loss, ort_loss) _test_helpers.assert_gradients_match_and_reset_gradient(ort_model, pt_model, reset_gradient=False) - if (step+1) % ga_steps == 0: + if (step + 1) % ga_steps == 0: run_optim_step(transformers_adamw_optimizer) run_optim_step(ort_fused_adam_optimizer) for pt_param, ort_param in zip(pt_model.parameters(), ort_model.parameters()): _test_helpers.assert_values_are_close(pt_param, ort_param, atol=1e-4, rtol=1e-5) + def test_ortmodule_fused_adam_optimizer_correctness_torch(): torch.manual_seed(8888) - device = 'cuda' + device = "cuda" N, D_in, H, D_out = 4, 4, 8, 4 pt_model = NeuralNetSinglePositionalArgument(D_in, H, D_out).to(device) adamw_optimizer = torch.optim.AdamW(pt_model.parameters(), lr=1e-3) ort_model = ORTModule(copy.deepcopy(pt_model)) - ort_fused_adam_optimizer = FusedAdam(ort_model.parameters(), lr=1e-3, - adam_w_mode=AdamWMode.ADAMW_TORCH, - weight_decay=0.01, - eps=1e-8) + ort_fused_adam_optimizer = FusedAdam( + ort_model.parameters(), lr=1e-3, adam_w_mode=AdamWMode.ADAMW_TORCH, weight_decay=0.01, eps=1e-8 + ) def run_step(model, x): prediction = model(x) @@ -4560,15 +4831,18 @@ def run_optim_step(optimizer): ort_param.grad = copy.deepcopy(pt_param.grad) _test_helpers.assert_values_are_close(pt_loss, ort_loss, atol=1e-4, rtol=1e-5) - _test_helpers.assert_gradients_match_and_reset_gradient(ort_model, pt_model, atol=1e-4, rtol=1e-5, reset_gradient=False) + _test_helpers.assert_gradients_match_and_reset_gradient( + ort_model, pt_model, atol=1e-4, rtol=1e-5, reset_gradient=False + ) - if (step+1) % ga_steps == 0: + if (step + 1) % ga_steps == 0: run_optim_step(adamw_optimizer) run_optim_step(ort_fused_adam_optimizer) for pt_param, ort_param in zip(pt_model.parameters(), ort_model.parameters()): _test_helpers.assert_values_are_close(pt_param, ort_param, atol=1e-4, rtol=1e-5) + def test_sigmoid_grad(): class NeuralNetSigmoid(torch.nn.Module): def __init__(self, input_size, hidden_size, num_classes): @@ -4587,7 +4861,8 @@ def run_step(model, x): loss = prediction.sum() loss.backward() return prediction, loss - device = 'cuda' + + device = "cuda" N, D_in, H, D_out = 120, 15360, 500, 15360 pt_model = NeuralNetSigmoid(D_in, H, D_out).to(device) @@ -4602,6 +4877,7 @@ def run_step(model, x): _test_helpers.assert_values_are_close(ort_x.grad, pt_x.grad) _test_helpers.assert_values_are_close(ort_loss, pt_loss) + def test_tanh_grad(): class NeuralNetTanh(torch.nn.Module): def __init__(self, input_size, hidden_size, num_classes): @@ -4620,7 +4896,8 @@ def run_step(model, x): loss = prediction.sum() loss.backward() return prediction, loss - device = 'cuda' + + device = "cuda" N, D_in, H, D_out = 120, 1536, 500, 1536 pt_model = NeuralNetTanh(D_in, H, D_out).to(device) @@ -4638,15 +4915,16 @@ def run_step(model, x): def test__defined_from_envvar(): from onnxruntime.training import ortmodule - os.environ['DUMMY_ORTMODULE'] = '15' - assert ortmodule._defined_from_envvar('DUMMY_ORTMODULE', 14) == 15 - os.environ['DUMMY_ORTMODULE'] = '15j' + + os.environ["DUMMY_ORTMODULE"] = "15" + assert ortmodule._defined_from_envvar("DUMMY_ORTMODULE", 14) == 15 + os.environ["DUMMY_ORTMODULE"] = "15j" with warnings.catch_warnings(record=True) as w: - assert ortmodule._defined_from_envvar('DUMMY_ORTMODULE', 14) == 14 + assert ortmodule._defined_from_envvar("DUMMY_ORTMODULE", 14) == 14 assert len(w) == 1 assert issubclass(w[-1].category, UserWarning) assert "Unable to overwrite constant" in str(w[-1].message) - del os.environ['DUMMY_ORTMODULE'] + del os.environ["DUMMY_ORTMODULE"] def test_sigmoid_grad_opset13(): @@ -4667,15 +4945,17 @@ def run_step(model, x): loss = prediction.sum() loss.backward() return prediction, loss - device = 'cuda' + + device = "cuda" N, D_in, H, D_out = 120, 15360, 500, 15360 pt_model = NeuralNetSigmoid(D_in, H, D_out).to(device) from onnxruntime.training import ortmodule + old_opst_cst = ortmodule.ONNX_OPSET_VERSION old_opset = os.getenv("ORTMODULE_ONNX_OPSET_VERSION", None) - os.environ["ORTMODULE_ONNX_OPSET_VERSION"] = '13' + os.environ["ORTMODULE_ONNX_OPSET_VERSION"] = "13" assert ortmodule.ONNX_OPSET_VERSION == 14 ort_model = ORTModule(copy.deepcopy(pt_model)) @@ -4687,11 +4967,11 @@ def run_step(model, x): pt_prediction, pt_loss = run_step(pt_model, pt_x) if step == 0: model_onx = ort_model._torch_module._execution_manager._training_manager._onnx_models - for name in ['exported_model', 'optimized_model', 'optimized_pre_grad_model']: + for name in ["exported_model", "optimized_model", "optimized_pre_grad_model"]: onx = getattr(model_onx, name) opv = None for op in onx.opset_import: - if op.domain == '': + if op.domain == "": opv = op.version assert opv == 13 _test_helpers.assert_values_are_close(ort_prediction, pt_prediction) @@ -4705,9 +4985,10 @@ def run_step(model, x): assert ortmodule.ONNX_OPSET_VERSION == 13 ortmodule.ONNX_OPSET_VERSION = old_opst_cst + @pytest.mark.parametrize("opset_version", [12, 13, 14]) def test_opset_version_change(opset_version): - device = 'cuda' + device = "cuda" N, D_in, H, D_out = 64, 784, 500, 10 x = torch.randn(N, D_in, device=device) @@ -4717,7 +4998,8 @@ def test_opset_version_change(opset_version): # Must import a namespace containing ONNX_OPSET_VERSION, not ONNX_OPSET_VERSION directly from onnxruntime.training import ortmodule - ortmodule.ONNX_OPSET_VERSION=opset_version + + ortmodule.ONNX_OPSET_VERSION = opset_version # Make sure model runs without any exception prediction = ort_model(x) @@ -4729,9 +5011,10 @@ def test_opset_version_change(opset_version): exported_model = ort_model._torch_module._execution_manager(ort_model._is_training())._onnx_models.exported_model assert exported_model.opset_import[0].version == opset_version + def test_serialize_ortmodule(): - device = 'cuda' + device = "cuda" N, D_in, H, D_out = 64, 784, 500, 10 pt_model = SerializationNet(D_in, H, D_out).to(device) ort_model = ORTModule(copy.deepcopy(pt_model)) @@ -4741,8 +5024,7 @@ def test_serialize_ortmodule(): pt_out = pt_model.train_step(x_1) ort_out = ort_model.train_step(x_2) _test_helpers.assert_values_are_close(pt_out, ort_out) - _test_helpers.assert_gradients_match_and_reset_gradient( - ort_model, pt_model) + _test_helpers.assert_gradients_match_and_reset_gradient(ort_model, pt_model) pt_out, ort_out = None, None # Serialize ortmodule @@ -4757,8 +5039,7 @@ def test_serialize_ortmodule(): ort_out = ort_model_2.train_step(x_2) assert ort_out is not None _test_helpers.assert_values_are_close(pt_out, ort_out) - _test_helpers.assert_gradients_match_and_reset_gradient( - ort_model_2, pt_model) + _test_helpers.assert_gradients_match_and_reset_gradient(ort_model_2, pt_model) @pytest.mark.parametrize("batch_size, M, N", [(1, 2, 3), (1, 4, 3), (1, 5, 5), (10, 3, 4), (10, 4, 3), (10, 4, 4)]) @@ -4784,7 +5065,7 @@ def run_step(model, x, k): loss.backward() return prediction, loss - device = 'cuda' + device = "cuda" pt_model = NeuralNetTrilu(has_upper, upper).to(device) ort_model = ORTModule(copy.deepcopy(pt_model)) @@ -4798,7 +5079,9 @@ def run_step(model, x, k): _test_helpers.assert_values_are_close(pt_x.grad, ort_x.grad) -@pytest.mark.parametrize("M, N", [(2400, 128), (2400, 256), (2400, 512), (2400, 1024), (2400, 2048), (2400, 4096), (2400, 12800)]) +@pytest.mark.parametrize( + "M, N", [(2400, 128), (2400, 256), (2400, 512), (2400, 1024), (2400, 2048), (2400, 4096), (2400, 12800)] +) def test_softmax(M, N): class NeuralNetSoftmax(torch.nn.Module): def __init__(self): @@ -4814,7 +5097,7 @@ def run_step(model, x): loss.backward() return prediction, loss - device = 'cuda' + device = "cuda" pt_model = NeuralNetSoftmax().to(device) ort_model = ORTModule(copy.deepcopy(pt_model)) @@ -4845,7 +5128,7 @@ def run_step(model, x): loss.backward() return prediction, loss - device = 'cuda' + device = "cuda" pt_model = NeuralNetSoftmax().to(device) ort_model = ORTModule(copy.deepcopy(pt_model)) @@ -4863,7 +5146,7 @@ def run_step(model, x): def test_random_states_unchanged_for_ortmodule(): import numpy - os.environ['ORTMODULE_FALLBACK_RETRY'] = 'False' + os.environ["ORTMODULE_FALLBACK_RETRY"] = "False" class NeuralNetSlice(torch.nn.Module): def __init__(self): @@ -4872,7 +5155,7 @@ def __init__(self): def forward(self, x): # This slice operation will call sympy.Min() when exporting, which will change Python's random state - return x[:self.dim, :] + return x[: self.dim, :] def random_state_equal(a, b): assert type(a) == type(b) @@ -4897,7 +5180,7 @@ def random_state_equal(a, b): assert random_state_equal(ori_random_states, new_random_states) - del os.environ['ORTMODULE_FALLBACK_RETRY'] + del os.environ["ORTMODULE_FALLBACK_RETRY"] def test_squeeze_custom_symbolic_registry(): @@ -4905,6 +5188,7 @@ class SqueezeModel(torch.nn.Module): def __init__(self): super(SqueezeModel, self).__init__() self.conv = torch.nn.Conv2d(in_channels=3, out_channels=32, kernel_size=14, stride=14, bias=False) + def forward(self, x): x = x.squeeze(1) return self.conv(x) @@ -4915,7 +5199,7 @@ def run_step(model, x): loss.backward() return prediction, loss - device = 'cuda' + device = "cuda" pt_model = SqueezeModel().to(device) ort_model = ORTModule(copy.deepcopy(pt_model)) diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_autograd.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_autograd.py index a3f118380c6af..388cf36e5312f 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_autograd.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_autograd.py @@ -30,9 +30,8 @@ def bias_gelu(bias, y): def bias_gelu_backward(g, bias, y): x = bias + y tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)) - ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + - 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out) - return ff*g + ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out) + return ff * g class GeLUFunction1(torch.autograd.Function): @staticmethod @@ -50,10 +49,7 @@ class GeLUModel(torch.nn.Module): def __init__(self, output_size): super(GeLUModel, self).__init__() self.relu = GeLUFunction1.apply - self.bias = Parameter(torch.empty( - output_size, - device=torch.cuda.current_device(), - dtype=torch.float)) + self.bias = Parameter(torch.empty(output_size, device=torch.cuda.current_device(), dtype=torch.float)) with torch.no_grad(): self.bias.uniform_() @@ -86,9 +82,8 @@ def bias_gelu(bias, y): def bias_gelu_backward(g, bias, y): x = bias + y tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)) - ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + - 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out) - return ff*g + ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out) + return ff * g class GeLUFunction2(torch.autograd.Function): @staticmethod @@ -106,10 +101,7 @@ class GeLUModel(torch.nn.Module): def __init__(self, output_size): super(GeLUModel, self).__init__() self.relu = GeLUFunction2.apply - self.bias = Parameter(torch.empty( - output_size, - device=torch.cuda.current_device(), - dtype=torch.float)) + self.bias = Parameter(torch.empty(output_size, device=torch.cuda.current_device(), dtype=torch.float)) with torch.no_grad(): self.bias.uniform_() @@ -148,9 +140,8 @@ def bias_gelu(bias, y): def bias_gelu_backward(g, bias, y): x = bias + y tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)) - ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + - 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out) - return ff*g + ff = 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out) + return ff * g class GeLUFunction3(torch.autograd.Function): @staticmethod @@ -168,10 +159,7 @@ class GeLUModel(torch.nn.Module): def __init__(self, output_size): super(GeLUModel, self).__init__() self.relu = GeLUFunction3.apply - self.bias = Parameter(torch.empty( - output_size, - device=torch.cuda.current_device(), - dtype=torch.float)) + self.bias = Parameter(torch.empty(output_size, device=torch.cuda.current_device(), dtype=torch.float)) with torch.no_grad(): self.bias.uniform_() @@ -193,6 +181,7 @@ def input_generator(): run_training_test_and_compare(model_builder, input_generator, label_input, run_forward_twice=True) + def test_MegatronF(): # MegatronGFunction is tested in distributed test files. class MegatronFFunction(torch.autograd.Function): @@ -209,10 +198,7 @@ class MegatronFModel(torch.nn.Module): def __init__(self, output_size): super(MegatronFModel, self).__init__() self.copy_ = MegatronFFunction.apply - self.bias = Parameter(torch.empty( - output_size, - device=torch.cuda.current_device(), - dtype=torch.float)) + self.bias = Parameter(torch.empty(output_size, device=torch.cuda.current_device(), dtype=torch.float)) with torch.no_grad(): self.bias.uniform_() @@ -248,7 +234,7 @@ def forward(ctx, input, alpha, beta, gamma): @staticmethod def backward(ctx, grad_output): - input, = ctx.saved_tensors + (input,) = ctx.saved_tensors alpha = ctx.alpha beta = ctx.beta gamma = ctx.gamma @@ -295,7 +281,7 @@ def forward(ctx, alpha, beta, input, gamma): @staticmethod def backward(ctx, grad_output): - input, = ctx.saved_tensors + (input,) = ctx.saved_tensors alpha = ctx.alpha beta = ctx.beta gamma = ctx.gamma @@ -330,7 +316,9 @@ def input_generator(): run_training_test_and_compare(model_builder, input_generator, label_input) -@pytest.mark.skip(reason="This test is not correct. All tensors modified by in-place operattions should be mark_dirty(...).") +@pytest.mark.skip( + reason="This test is not correct. All tensors modified by in-place operattions should be mark_dirty(...)." +) def test_InplaceUpdateInputAsOutputNotRequireGrad(): class InplaceUpdateInputAsOutputNotRequireGradFunction(torch.autograd.Function): @staticmethod @@ -351,10 +339,7 @@ class InplaceUpdateInputAsOutputNotRequireGradModel(torch.nn.Module): def __init__(self, output_size): super(InplaceUpdateInputAsOutputNotRequireGradModel, self).__init__() self.inplace_op = InplaceUpdateInputAsOutputNotRequireGradFunction.apply - self.bias = Parameter(torch.empty( - output_size, - device=torch.cuda.current_device(), - dtype=torch.float)) + self.bias = Parameter(torch.empty(output_size, device=torch.cuda.current_device(), dtype=torch.float)) with torch.no_grad(): self.bias.uniform_() @@ -378,11 +363,12 @@ def input_generator(): label_input = torch.ones([output_size]) # Test when input is in-place updated, but does not require gradient. - run_training_test_and_compare( - model_builder, input_generator, label_input, ignore_grad_compare=True) + run_training_test_and_compare(model_builder, input_generator, label_input, ignore_grad_compare=True) -@pytest.mark.skip(reason="This test is not correct. All tensors modified by in-place operattions should be mark_dirty(...).") +@pytest.mark.skip( + reason="This test is not correct. All tensors modified by in-place operattions should be mark_dirty(...)." +) def test_InplaceUpdateInputNotAsOutputNotRequireGrad(): class InplaceUpdateInputNotAsOutputNotRequireGradFunction(torch.autograd.Function): @staticmethod @@ -399,10 +385,7 @@ class InplaceUpdateInputNotAsOutputNotRequireGradModel(torch.nn.Module): def __init__(self, output_size): super(InplaceUpdateInputNotAsOutputNotRequireGradModel, self).__init__() self.inplace_op = InplaceUpdateInputNotAsOutputNotRequireGradFunction.apply - self.bias = Parameter(torch.empty( - output_size, - device=torch.cuda.current_device(), - dtype=torch.float)) + self.bias = Parameter(torch.empty(output_size, device=torch.cuda.current_device(), dtype=torch.float)) with torch.no_grad(): self.bias.uniform_() @@ -428,8 +411,7 @@ def input_generator(): # Without mark_ditry, the inner computation graph is extracted into another subgraph, which is a duplicated computation with the PythonOp. # So for the weights that are used twice BUT SHOULD only used once, the gradients are almost 2x than PyTorch's grad, this is the reason we # ignore the gradient compare here. - run_training_test_and_compare( - model_builder, input_generator, label_input, ignore_grad_compare=True) + run_training_test_and_compare(model_builder, input_generator, label_input, ignore_grad_compare=True) def test_InplaceUpdateInputAsOutputNotRequireGradWithMarkDirty(): @@ -450,13 +432,9 @@ def backward(ctx, grad_output): class InplaceUpdateInputAsOutputNotRequireGradWithMarkDirtyModel(torch.nn.Module): def __init__(self, output_size): - super(InplaceUpdateInputAsOutputNotRequireGradWithMarkDirtyModel, - self).__init__() + super(InplaceUpdateInputAsOutputNotRequireGradWithMarkDirtyModel, self).__init__() self.inplace_op = InplaceUpdateInputAsOutputNotRequireGradWithMarkDirtyFunction.apply - self.bias = Parameter(torch.empty( - output_size, - device=torch.cuda.current_device(), - dtype=torch.float)) + self.bias = Parameter(torch.empty(output_size, device=torch.cuda.current_device(), dtype=torch.float)) with torch.no_grad(): self.bias.uniform_() @@ -482,7 +460,9 @@ def input_generator(): run_training_test_and_compare(model_builder, input_generator, label_input) -@pytest.mark.skip(reason="This test is not correct. All tensors modified by in-place operattions should be mark_dirty(...).") +@pytest.mark.skip( + reason="This test is not correct. All tensors modified by in-place operattions should be mark_dirty(...)." +) def test_InplaceUpdateInputAsOutputRequireGrad(): class InplaceUpdateInputAsOutputRequireGradFunction(torch.autograd.Function): @staticmethod @@ -500,10 +480,7 @@ class InplaceUpdateInputAsOutputRequireGradModel(torch.nn.Module): def __init__(self, output_size): super(InplaceUpdateInputAsOutputRequireGradModel, self).__init__() self.inplace_op = InplaceUpdateInputAsOutputRequireGradFunction.apply - self.bias = Parameter(torch.empty( - output_size, - device=torch.cuda.current_device(), - dtype=torch.float)) + self.bias = Parameter(torch.empty(output_size, device=torch.cuda.current_device(), dtype=torch.float)) with torch.no_grad(): self.bias.uniform_() @@ -532,11 +509,12 @@ def input_generator(): # duplicated computation with the PythonOp. Thus, for the weights that are used twice BUT SHOULD # only used once, the gradients are almost 2x than PyTorch's grad, this is the reason we # ignore the gradient compare here. - run_training_test_and_compare( - model_builder, input_generator, label_input, ignore_grad_compare=True) + run_training_test_and_compare(model_builder, input_generator, label_input, ignore_grad_compare=True) -@pytest.mark.skip(reason="This test is not correct. All tensors modified by in-place operattions should be mark_dirty(...).") +@pytest.mark.skip( + reason="This test is not correct. All tensors modified by in-place operattions should be mark_dirty(...)." +) def test_InplaceUpdateInputNotAsOutputRequireGrad(): class InplaceUpdateInputNotAsOutputRequireGradFunction(torch.autograd.Function): # without mark_ditry, the inner computation graph is extracted into another subgraph, which is a duplicated computation with the PythonOp. @@ -556,10 +534,7 @@ class InplaceUpdateInputNotAsOutputRequireGradModel(torch.nn.Module): def __init__(self, output_size): super(InplaceUpdateInputNotAsOutputRequireGradModel, self).__init__() self.inplace_op = InplaceUpdateInputNotAsOutputRequireGradFunction.apply - self.bias = Parameter(torch.empty( - output_size, - device=torch.cuda.current_device(), - dtype=torch.float)) + self.bias = Parameter(torch.empty(output_size, device=torch.cuda.current_device(), dtype=torch.float)) with torch.no_grad(): self.bias.uniform_() @@ -586,8 +561,8 @@ def input_generator(): # should reuse the input torch tensor @140214095996104, 140212816617984 but actually not." It seems # if we don't have mark_dirty() in auto grad forward, the result is not using the input_, # (maybe a view of it, because data address is same) - run_training_test_and_compare( - model_builder, input_generator, label_input, ignore_grad_compare=True) + run_training_test_and_compare(model_builder, input_generator, label_input, ignore_grad_compare=True) + ########################################################################################## @@ -610,13 +585,9 @@ def backward(ctx, grad_output): class InplaceUpdateInputAsOutputRequireGradWithMarkDirtyModel(torch.nn.Module): def __init__(self, output_size): - super(InplaceUpdateInputAsOutputRequireGradWithMarkDirtyModel, - self).__init__() + super(InplaceUpdateInputAsOutputRequireGradWithMarkDirtyModel, self).__init__() self.inplace_op = InplaceUpdateInputAsOutputRequireGradWithMarkDirtyFunction.apply - self.bias = Parameter(torch.empty( - output_size, - device=torch.cuda.current_device(), - dtype=torch.float)) + self.bias = Parameter(torch.empty(output_size, device=torch.cuda.current_device(), dtype=torch.float)) with torch.no_grad(): self.bias.uniform_() @@ -659,10 +630,7 @@ class EvalTestModel(torch.nn.Module): def __init__(self, output_size): super(EvalTestModel, self).__init__() self.custom_fn = EvalTestFunction.apply - self.bias = Parameter(torch.empty( - output_size, - device=torch.cuda.current_device(), - dtype=torch.float)) + self.bias = Parameter(torch.empty(output_size, device=torch.cuda.current_device(), dtype=torch.float)) with torch.no_grad(): self.bias.uniform_() @@ -687,8 +655,10 @@ def input_generator(): run_evaluate_test_and_compare(model_builder, input_generator, label_input) -@pytest.mark.skipif(torch_version_lower_than("1.10.0"), - reason='PyTorch older than 1.10.0 has bugs for exporting multiple output custom function') +@pytest.mark.skipif( + torch_version_lower_than("1.10.0"), + reason="PyTorch older than 1.10.0 has bugs for exporting multiple output custom function", +) def test_TwoOutputFunction(): class TwoOutputFunction1(torch.autograd.Function): @staticmethod @@ -725,10 +695,7 @@ class TwoOutputModel(torch.nn.Module): def __init__(self, output_size): super(TwoOutputModel, self).__init__() self.fun = TwoOutputFunction1.apply - self.bias = Parameter(torch.empty( - output_size, - device=torch.cuda.current_device(), - dtype=torch.float)) + self.bias = Parameter(torch.empty(output_size, device=torch.cuda.current_device(), dtype=torch.float)) with torch.no_grad(): self.bias.uniform_() @@ -776,7 +743,7 @@ def forward(ctx, x, dim, device, use_ort): @staticmethod def backward(ctx, dv): - x, = ctx.saved_tensors + (x,) = ctx.saved_tensors y = x.detach().to(ctx.device) y.requires_grad = True g = None @@ -814,18 +781,20 @@ def get_inner_module_call_result(x, device, use_ort): x = torch.FloatTensor([1.0, -1.0]) # Test indirect ORTModule call from custom function - result_pth = get_inner_module_call_result(x.detach(), 'cuda:0', False) - result_ort = get_inner_module_call_result(x.detach(), 'cuda:0', True) + result_pth = get_inner_module_call_result(x.detach(), "cuda:0", False) + result_ort = get_inner_module_call_result(x.detach(), "cuda:0", True) compare_tensor_list(result_ort, result_pth) # Test indirect ORTModule call from custom function - result_ort = get_inner_module_call_result(x.detach(), 'cpu', True) - result_pth = get_inner_module_call_result(x.detach(), 'cpu', False) + result_ort = get_inner_module_call_result(x.detach(), "cpu", True) + result_pth = get_inner_module_call_result(x.detach(), "cpu", False) compare_tensor_list(result_ort, result_pth) -@pytest.mark.skipif(torch_version_lower_than("1.10.0"), - reason='PyTorch older than 1.10.0 has bugs for exporting multiple output custom function') +@pytest.mark.skipif( + torch_version_lower_than("1.10.0"), + reason="PyTorch older than 1.10.0 has bugs for exporting multiple output custom function", +) def test_Share_Input(): class TwoOutputFunction2(torch.autograd.Function): @staticmethod @@ -847,10 +816,7 @@ class TwoOutputModel(torch.nn.Module): def __init__(self, output_size): super(TwoOutputModel, self).__init__() self.fun = TwoOutputFunction2.apply - self.bias = Parameter(torch.empty( - output_size, - device=torch.cuda.current_device(), - dtype=torch.float)) + self.bias = Parameter(torch.empty(output_size, device=torch.cuda.current_device(), dtype=torch.float)) with torch.no_grad(): self.bias.uniform_() @@ -877,8 +843,7 @@ def input_generator_with_requires_grad(): # Test multi-input and multi-output custom function. run_training_test_and_compare(model_builder, input_generator, label_input) - run_training_test_and_compare( - model_builder, input_generator_with_requires_grad, label_input) + run_training_test_and_compare(model_builder, input_generator_with_requires_grad, label_input) def test_MultipleStream_InForwardFunction(): @@ -899,7 +864,7 @@ def forward(ctx, input): @staticmethod def backward(ctx, grad_output): - input, = ctx.saved_tensors + (input,) = ctx.saved_tensors return grad_output class MultipleStreamModel(torch.nn.Module): @@ -924,8 +889,9 @@ def input_generator(): label_input = torch.ones([output_size]) # Test multi-input and multi-output custom function. - run_training_test_and_compare(model_builder, input_generator, label_input, - expected_outputs=[torch.tensor([0.224, 0.272])]) + run_training_test_and_compare( + model_builder, input_generator, label_input, expected_outputs=[torch.tensor([0.224, 0.272])] + ) def test_NonDefaultStream_InForwardFunction1(): @@ -945,7 +911,7 @@ def forward(ctx, input): @staticmethod def backward(ctx, grad_output): - input, = ctx.saved_tensors + (input,) = ctx.saved_tensors return grad_output class MultipleStreamModel(torch.nn.Module): @@ -971,8 +937,9 @@ def input_generator(): label_input = torch.ones([output_size]) # Test multi-input and multi-output custom function. - run_training_test_and_compare(model_builder, input_generator, label_input, - expected_outputs=[torch.tensor([0.224, 0.272])]) + run_training_test_and_compare( + model_builder, input_generator, label_input, expected_outputs=[torch.tensor([0.224, 0.272])] + ) def test_NonDefaultStream_InForwardFunction2(): @@ -986,7 +953,7 @@ def forward(ctx, input): @staticmethod def backward(ctx, grad_output): - input, = ctx.saved_tensors + (input,) = ctx.saved_tensors return grad_output class MultipleStreamModel(torch.nn.Module): @@ -1017,8 +984,9 @@ def input_generator(): label_input = torch.ones([output_size]) # Test multi-input and multi-output custom function. - run_training_test_and_compare(model_builder, input_generator, label_input, - expected_outputs=[torch.tensor([0.224, 0.272])]) + run_training_test_and_compare( + model_builder, input_generator, label_input, expected_outputs=[torch.tensor([0.224, 0.272])] + ) def test_NonDefaultStreamInplaceUpdate_InForwardFunction(): @@ -1039,7 +1007,7 @@ def forward(ctx, input): @staticmethod def backward(ctx, grad_output): - input, = ctx.saved_tensors + (input,) = ctx.saved_tensors return grad_output class MultipleStreamModel(torch.nn.Module): @@ -1065,8 +1033,9 @@ def input_generator(): label_input = torch.ones([output_size]) # Test multi-input and multi-output custom function. - run_training_test_and_compare(model_builder, input_generator, label_input, - expected_outputs=[torch.tensor([0.224, 0.272])]) + run_training_test_and_compare( + model_builder, input_generator, label_input, expected_outputs=[torch.tensor([0.224, 0.272])] + ) def test_non_differentiable_autograd_function(): @@ -1094,12 +1063,12 @@ def forward(self, x): return z def run(): - m = Foo().to('cuda') - x = torch.rand((2, 2), dtype=torch.float).to('cuda') + m = Foo().to("cuda") + x = torch.rand((2, 2), dtype=torch.float).to("cuda") # Baseline. y_ref = m(x) - print('Ref:') + print("Ref:") print(y_ref) m = ORTModule(m) @@ -1112,7 +1081,7 @@ def run(): # Training mode. m.train() y_train = m(x) - print('Train:') + print("Train:") assert torch.allclose(y_ref, y_train) run() diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_autograd_dist.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_autograd_dist.py index 3d00fc23716b2..188e053e4711e 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_autograd_dist.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_autograd_dist.py @@ -30,9 +30,9 @@ def reduce(buffer): torch.distributed.all_reduce(buffer) address_for_output_torch_tensor = int(id(buffer)) if address_for_output_torch_tensor != address_for_torch_tensor: - raise ValueError( - "The output torch tensor should reuse the input torch tensor, but actually not.") + raise ValueError("The output torch tensor should reuse the input torch tensor, but actually not.") return buffer + ctx.save_for_backward(arg) ctx.mark_dirty(arg) return reduce(arg) @@ -46,10 +46,7 @@ class ReduceWithMarkDirtyModel(torch.nn.Module): def __init__(self, dim): super(ReduceWithMarkDirtyModel, self).__init__() self.reduce_op_ = ReduceWithMarkDirtyFunction.apply - self.bias = Parameter(torch.empty( - dim, - device=torch.cuda.current_device(), - dtype=torch.float)) + self.bias = Parameter(torch.empty(dim, device=torch.cuda.current_device(), dtype=torch.float)) # Always initialize bias to zero. with torch.no_grad(): @@ -72,7 +69,6 @@ def run_with_pytorch_on_gpu(model, args, rank, device): output.sum().backward() return output, [arg.grad for arg in cuda_args] - def run_with_ort_on_gpu(model, args, rank, device): model.to(device) model = ORTModule(model) @@ -85,14 +81,15 @@ def run_with_ort_on_gpu(model, args, rank, device): return output, [arg.grad for arg in cuda_args] try: - torch.cuda.set_device('cuda:' + str(rank)) - os.environ['MASTER_ADDR'] = '127.0.0.1' - os.environ['MASTER_PORT'] = '29500' - dist.init_process_group(backend='nccl', init_method='tcp://' + os.environ['MASTER_ADDR'] + ':23456', - world_size=size, rank=rank) + torch.cuda.set_device("cuda:" + str(rank)) + os.environ["MASTER_ADDR"] = "127.0.0.1" + os.environ["MASTER_PORT"] = "29500" + dist.init_process_group( + backend="nccl", init_method="tcp://" + os.environ["MASTER_ADDR"] + ":23456", world_size=size, rank=rank + ) dim = 32 - device = torch.device('cuda:' + str(rank)) + device = torch.device("cuda:" + str(rank)) x = torch.randn(dim, dtype=torch.float) x.requires_grad = True x_copy = copy.deepcopy(x) @@ -100,13 +97,11 @@ def run_with_ort_on_gpu(model, args, rank, device): torch.cuda.synchronize() - outputs, grads = run_with_pytorch_on_gpu( - m, [x], rank, device) + outputs, grads = run_with_pytorch_on_gpu(m, [x], rank, device) torch.cuda.synchronize() - outputs_ort, grads_ort = run_with_ort_on_gpu( - m, [x_copy], rank, device) + outputs_ort, grads_ort = run_with_ort_on_gpu(m, [x_copy], rank, device) torch.cuda.synchronize() @@ -119,17 +114,18 @@ def run_with_ort_on_gpu(model, args, rank, device): _test_helpers.compare_tensor_list(val_list_a, val_list_b) except Exception as e: print( - f"test_Distributed_ReduceWithMarkDirtyModel fail with rank {rank} with world size {size} with exception: \n{e}.") + f"test_Distributed_ReduceWithMarkDirtyModel fail with rank {rank} with world size {size} with exception: \n{e}." + ) raise e if __name__ == "__main__": size = 2 try: - mp.spawn(test_Distributed_ReduceWithMarkDirtyModel, - nprocs=size, args=(size,)) + mp.spawn(test_Distributed_ReduceWithMarkDirtyModel, nprocs=size, args=(size,)) except: import sys + sys.stdout.flush() sys.stderr.flush() raise diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_bert_classifier.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_bert_classifier.py index 7524c0f382b5f..8f1d57ff138a8 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_bert_classifier.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_bert_classifier.py @@ -19,6 +19,7 @@ import onnxruntime from onnxruntime.training.ortmodule import ORTModule, DebugOptions + def train(model, optimizer, scheduler, train_dataloader, epoch, device, args): # ======================================== # Training @@ -27,7 +28,7 @@ def train(model, optimizer, scheduler, train_dataloader, epoch, device, args): # https://github.com/huggingface/transformers/blob/5bfcd0485ece086ebcbed2d008813037968a9e58/examples/run_glue.py#L128 # Perform one full pass over the training set. - print('\n======== Epoch {:} / {:} with batch size {:} ========'.format(epoch + 1, args.epochs, args.batch_size)) + print("\n======== Epoch {:} / {:} with batch size {:} ========".format(epoch + 1, args.epochs, args.batch_size)) # Measure how long the training epoch takes. t0 = time.time() @@ -72,9 +73,7 @@ def train(model, optimizer, scheduler, train_dataloader, epoch, device, args): # The documentation for this `model` function is here: # https://huggingface.co/transformers/v2.2.0/model_doc/bert.html#transformers.BertForSequenceClassification - outputs = model(b_input_ids, - attention_mask=b_input_mask, - labels=b_labels) + outputs = model(b_input_ids, attention_mask=b_input_mask, labels=b_labels) # The call to `model` always returns a tuple, so we need to pull the # loss value out of the tuple. @@ -87,11 +86,14 @@ def train(model, optimizer, scheduler, train_dataloader, epoch, device, args): elapsed_time = curr_time - start_time # Report progress. - print(f'Batch {step:4} of {len(train_dataloader):4}. Execution time: {elapsed_time:.4f}. Loss: {loss.item():.4f}') + print( + f"Batch {step:4} of {len(train_dataloader):4}. Execution time: {elapsed_time:.4f}. Loss: {loss.item():.4f}" + ) start_time = curr_time if args.view_graphs: import torchviz + pytorch_backward_graph = torchviz.make_dot(outputs[0], params=dict(list(model.named_parameters()))) pytorch_backward_graph.view() @@ -124,6 +126,7 @@ def train(model, optimizer, scheduler, train_dataloader, epoch, device, args): print(" Training epoch took: {:.4f}s".format(epoch_time)) return epoch_time + def test(model, validation_dataloader, device, args): # ======================================== # Validation @@ -164,9 +167,7 @@ def test(model, validation_dataloader, device, args): # TODO: original sample had the last argument equal to None, but b_labels is because model was # exported using 3 inputs for training, so validation must follow. # Another approach would be checkpoint the trained model, re-export the model for validation with the checkpoint - outputs = model(b_input_ids, - attention_mask=b_input_mask, - labels=b_labels) + outputs = model(b_input_ids, attention_mask=b_input_mask, labels=b_labels) # Get the "logits" output by the model. The "logits" are the output # values prior to applying an activation function like the softmax. @@ -174,7 +175,7 @@ def test(model, validation_dataloader, device, args): # Move logits and labels to CPU logits = logits.detach().cpu().numpy() - label_ids = b_labels.to('cpu').numpy() + label_ids = b_labels.to("cpu").numpy() # Calculate the accuracy for this batch of test sentences. tmp_eval_accuracy = flat_accuracy(logits, label_ids) @@ -187,34 +188,40 @@ def test(model, validation_dataloader, device, args): # Report the final accuracy for this validation run. epoch_time = time.time() - t0 - accuracy = eval_accuracy/nb_eval_steps + accuracy = eval_accuracy / nb_eval_steps print(" Accuracy: {0:.2f}".format(accuracy)) print(" Validation took: {:.4f}s".format(epoch_time)) return epoch_time, accuracy + def load_dataset(args): # 2. Loading CoLA Dataset def _download_dataset(download_dir): if not os.path.exists(download_dir): # Download the file (if we haven't already) - print('Downloading dataset...') - url = 'https://nyu-mll.github.io/CoLA/cola_public_1.1.zip' - wget.download(url, './cola_public_1.1.zip') + print("Downloading dataset...") + url = "https://nyu-mll.github.io/CoLA/cola_public_1.1.zip" + wget.download(url, "./cola_public_1.1.zip") else: - print('Reusing cached dataset') + print("Reusing cached dataset") if not os.path.exists(args.data_dir): - _download_dataset('./cola_public_1.1.zip') + _download_dataset("./cola_public_1.1.zip") # Unzip it - print('Extracting dataset') - with zipfile.ZipFile('./cola_public_1.1.zip', 'r') as zip_ref: - zip_ref.extractall('./') + print("Extracting dataset") + with zipfile.ZipFile("./cola_public_1.1.zip", "r") as zip_ref: + zip_ref.extractall("./") else: - print('Reusing extracted dataset') + print("Reusing extracted dataset") # Load the dataset into a pandas dataframe. - df = pd.read_csv(os.path.join(args.data_dir, "in_domain_train.tsv"), delimiter='\t', header=None, names=['sentence_source', 'label', 'label_notes', 'sentence']) + df = pd.read_csv( + os.path.join(args.data_dir, "in_domain_train.tsv"), + delimiter="\t", + header=None, + names=["sentence_source", "label", "label_notes", "sentence"], + ) # Get the lists of sentences and their labels. sentences = df.sentence.values @@ -223,7 +230,7 @@ def _download_dataset(download_dir): # 3. Tokenization & Input Formatting # Load the BERT tokenizer. - tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True) + tokenizer = BertTokenizer.from_pretrained("bert-base-uncased", do_lower_case=True) # Set the max length of encoded sentence. # 64 is slightly larger than the maximum training sentence length of 47... @@ -238,13 +245,13 @@ def _download_dataset(download_dir): # (3) Append the `[SEP]` token to the end. # (4) Map tokens to their IDs. encoded_sent = tokenizer.encode( - sent, # Sentence to encode. - add_special_tokens = True, # Add '[CLS]' and '[SEP]' - ) + sent, # Sentence to encode. + add_special_tokens=True, # Add '[CLS]' and '[SEP]' + ) # Pad our input tokens with value 0. if len(encoded_sent) < MAX_LEN: - encoded_sent.extend([0]*(MAX_LEN-len(encoded_sent))) + encoded_sent.extend([0] * (MAX_LEN - len(encoded_sent))) # Truncate to MAX_LEN if len(encoded_sent) > MAX_LEN: @@ -269,11 +276,11 @@ def _download_dataset(download_dir): attention_masks.append(att_mask) # Use 90% for training and 10% for validation. - train_inputs, validation_inputs, train_labels, validation_labels = train_test_split(input_ids, labels, - random_state=2018, test_size=0.1) + train_inputs, validation_inputs, train_labels, validation_labels = train_test_split( + input_ids, labels, random_state=2018, test_size=0.1 + ) # Do the same for the masks. - train_masks, validation_masks, _, _ = train_test_split(attention_masks, labels, - random_state=2018, test_size=0.1) + train_masks, validation_masks, _, _ = train_test_split(attention_masks, labels, random_state=2018, test_size=0.1) # Convert all inputs and labels into torch tensors, the required datatype # for our model. @@ -298,65 +305,84 @@ def _download_dataset(download_dir): return train_dataloader, validation_dataloader + # Function to calculate the accuracy of our predictions vs labels def flat_accuracy(preds, labels): pred_flat = np.argmax(preds, axis=1).flatten() labels_flat = labels.flatten() return np.sum(pred_flat == labels_flat) / len(labels_flat) + def format_time(elapsed): - '''Takes a time in seconds and returns a string hh:mm:ss''' + """Takes a time in seconds and returns a string hh:mm:ss""" # Round to the nearest second. elapsed_rounded = int(round((elapsed))) # Format as hh:mm:ss return str(datetime.timedelta(seconds=elapsed_rounded)) + def main(): # 1. Basic setup - parser = argparse.ArgumentParser(description='PyTorch MNIST Example') - parser.add_argument('--pytorch-only', action='store_true', default=False, - help='disables ONNX Runtime training') - parser.add_argument('--batch-size', type=int, default=32, metavar='N', - help='input batch size for training (default: 32)') - parser.add_argument('--test-batch-size', type=int, default=64, metavar='N', - help='input batch size for testing (default: 64)') - parser.add_argument('--view-graphs', action='store_true', default=False, - help='views forward and backward graphs') - parser.add_argument('--export-onnx-graphs', action='store_true', default=False, - help='export ONNX graphs to current directory') - parser.add_argument('--no-cuda', action='store_true', default=False, - help='disables CUDA training') - parser.add_argument('--epochs', type=int, default=4, metavar='N', - help='number of epochs to train (default: 4)') - parser.add_argument('--seed', type=int, default=42, metavar='S', - help='random seed (default: 42)') - parser.add_argument('--log-interval', type=int, default=40, metavar='N', - help='how many batches to wait before logging training status (default: 40)') - parser.add_argument('--train-steps', type=int, default=-1, metavar='N', - help='number of steps to train. Set -1 to run through whole dataset (default: -1)') - parser.add_argument('--log-level', choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'], default='WARNING', - help='Log level (default: WARNING)') - parser.add_argument('--num-hidden-layers', type=int, default=1, metavar='H', - help='Number of hidden layers for the BERT model. A vanila BERT has 12 hidden layers (default: 1)') - parser.add_argument('--data-dir', type=str, default='./cola_public/raw', - help='Path to the bert data directory') + parser = argparse.ArgumentParser(description="PyTorch MNIST Example") + parser.add_argument("--pytorch-only", action="store_true", default=False, help="disables ONNX Runtime training") + parser.add_argument( + "--batch-size", type=int, default=32, metavar="N", help="input batch size for training (default: 32)" + ) + parser.add_argument( + "--test-batch-size", type=int, default=64, metavar="N", help="input batch size for testing (default: 64)" + ) + parser.add_argument("--view-graphs", action="store_true", default=False, help="views forward and backward graphs") + parser.add_argument( + "--export-onnx-graphs", action="store_true", default=False, help="export ONNX graphs to current directory" + ) + parser.add_argument("--no-cuda", action="store_true", default=False, help="disables CUDA training") + parser.add_argument("--epochs", type=int, default=4, metavar="N", help="number of epochs to train (default: 4)") + parser.add_argument("--seed", type=int, default=42, metavar="S", help="random seed (default: 42)") + parser.add_argument( + "--log-interval", + type=int, + default=40, + metavar="N", + help="how many batches to wait before logging training status (default: 40)", + ) + parser.add_argument( + "--train-steps", + type=int, + default=-1, + metavar="N", + help="number of steps to train. Set -1 to run through whole dataset (default: -1)", + ) + parser.add_argument( + "--log-level", + choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"], + default="WARNING", + help="Log level (default: WARNING)", + ) + parser.add_argument( + "--num-hidden-layers", + type=int, + default=1, + metavar="H", + help="Number of hidden layers for the BERT model. A vanila BERT has 12 hidden layers (default: 1)", + ) + parser.add_argument("--data-dir", type=str, default="./cola_public/raw", help="Path to the bert data directory") args = parser.parse_args() # Device (CPU vs CUDA) if torch.cuda.is_available() and not args.no_cuda: device = torch.device("cuda") - print('There are %d GPU(s) available.' % torch.cuda.device_count()) - print('We will use the GPU:', torch.cuda.get_device_name(0)) + print("There are %d GPU(s) available." % torch.cuda.device_count()) + print("We will use the GPU:", torch.cuda.get_device_name(0)) else: - print('No GPU available, using the CPU instead.') + print("No GPU available, using the CPU instead.") device = torch.device("cpu") # Set log level numeric_level = getattr(logging, args.log_level.upper(), None) if not isinstance(numeric_level, int): - raise ValueError('Invalid log level: %s' % args.log_level) + raise ValueError("Invalid log level: %s" % args.log_level) logging.basicConfig(level=numeric_level) # 2. Dataloader @@ -366,20 +392,20 @@ def main(): # Load BertForSequenceClassification, the pretrained BERT model with a single # linear classification layer on top. config = AutoConfig.from_pretrained( - "bert-base-uncased", - num_labels=2, - num_hidden_layers=args.num_hidden_layers, - output_attentions = False, # Whether the model returns attentions weights. - output_hidden_states = False, # Whether the model returns all hidden-states. + "bert-base-uncased", + num_labels=2, + num_hidden_layers=args.num_hidden_layers, + output_attentions=False, # Whether the model returns attentions weights. + output_hidden_states=False, # Whether the model returns all hidden-states. ) model = BertForSequenceClassification.from_pretrained( - "bert-base-uncased", # Use the 12-layer BERT model, with an uncased vocab. + "bert-base-uncased", # Use the 12-layer BERT model, with an uncased vocab. config=config, ) if not args.pytorch_only: # Just for future debugging - debug_options = DebugOptions(save_onnx=args.export_onnx_graphs, onnx_prefix='BertForSequenceClassification') + debug_options = DebugOptions(save_onnx=args.export_onnx_graphs, onnx_prefix="BertForSequenceClassification") model = ORTModule(model, debug_options) @@ -388,19 +414,20 @@ def main(): model.cuda() # Note: AdamW is a class from the huggingface library (as opposed to pytorch) - optimizer = AdamW(model.parameters(), - lr = 2e-5, # args.learning_rate - default is 5e-5, our notebook had 2e-5 - eps = 1e-8 # args.adam_epsilon - default is 1e-8. - ) + optimizer = AdamW( + model.parameters(), + lr=2e-5, # args.learning_rate - default is 5e-5, our notebook had 2e-5 + eps=1e-8, # args.adam_epsilon - default is 1e-8. + ) # Authors recommend between 2 and 4 epochs # Total number of training steps is number of batches * number of epochs. total_steps = len(train_dataloader) * args.epochs # Create the learning rate scheduler. - scheduler = get_linear_schedule_with_warmup(optimizer, - num_warmup_steps = 0, # Default value in run_glue.py - num_training_steps = total_steps) + scheduler = get_linear_schedule_with_warmup( + optimizer, num_warmup_steps=0, num_training_steps=total_steps # Default value in run_glue.py + ) # Seed random.seed(args.seed) np.random.seed(args.seed) @@ -420,11 +447,11 @@ def main(): assert validation_accuracy > 0.5 - print('\n======== Global stats ========') + print("\n======== Global stats ========") if not args.pytorch_only: estimated_export = 0 if args.epochs > 1: - estimated_export = epoch_0_training - (total_training_time - epoch_0_training)/(args.epochs-1) + estimated_export = epoch_0_training - (total_training_time - epoch_0_training) / (args.epochs - 1) print(" Estimated ONNX export took: {:.4f}s".format(estimated_export)) else: print(" Estimated ONNX export took: Estimate available when epochs > 1 only") @@ -432,5 +459,6 @@ def main(): print(" Accumulated training took: {:.4f}s".format(total_training_time)) print(" Accumulated validation took: {:.4f}s".format(total_test_time)) -if __name__ == '__main__': + +if __name__ == "__main__": main() diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_bert_classifier_autocast.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_bert_classifier_autocast.py index dcbe45ff190dd..42697766c9815 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_bert_classifier_autocast.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_bert_classifier_autocast.py @@ -19,6 +19,7 @@ import onnxruntime from onnxruntime.training.ortmodule import ORTModule, DebugOptions + def train(model, optimizer, scaler, scheduler, train_dataloader, epoch, device, args): # ======================================== # Training @@ -27,7 +28,7 @@ def train(model, optimizer, scaler, scheduler, train_dataloader, epoch, device, # https://github.com/huggingface/transformers/blob/5bfcd0485ece086ebcbed2d008813037968a9e58/examples/run_glue.py#L128 # Perform one full pass over the training set. - print('\n======== Epoch {:} / {:} with batch size {:} ========'.format(epoch + 1, args.epochs, args.batch_size)) + print("\n======== Epoch {:} / {:} with batch size {:} ========".format(epoch + 1, args.epochs, args.batch_size)) # Measure how long the training epoch takes. t0 = time.time() @@ -72,9 +73,7 @@ def train(model, optimizer, scaler, scheduler, train_dataloader, epoch, device, # The documentation for this `model` function is here: # https://huggingface.co/transformers/v2.2.0/model_doc/bert.html#transformers.BertForSequenceClassification with torch.cuda.amp.autocast(): - outputs = model(b_input_ids, - attention_mask=b_input_mask, - labels=b_labels) + outputs = model(b_input_ids, attention_mask=b_input_mask, labels=b_labels) # The call to `model` always returns a tuple, so we need to pull the # loss value out of the tuple. @@ -87,11 +86,14 @@ def train(model, optimizer, scaler, scheduler, train_dataloader, epoch, device, elapsed_time = curr_time - start_time # Report progress. - print(f'Batch {step:4} of {len(train_dataloader):4}. Execution time: {elapsed_time:.4f}. Loss: {loss.item():.4f}') + print( + f"Batch {step:4} of {len(train_dataloader):4}. Execution time: {elapsed_time:.4f}. Loss: {loss.item():.4f}" + ) start_time = curr_time if args.view_graphs: import torchviz + pytorch_backward_graph = torchviz.make_dot(outputs[0], params=dict(list(model.named_parameters()))) pytorch_backward_graph.view() @@ -127,6 +129,7 @@ def train(model, optimizer, scaler, scheduler, train_dataloader, epoch, device, print(" Training epoch took: {:.4f}s".format(epoch_time)) return epoch_time + def test(model, validation_dataloader, device, args): # ======================================== # Validation @@ -167,9 +170,7 @@ def test(model, validation_dataloader, device, args): # TODO: original sample had the last argument equal to None, but b_labels is because model was # exported using 3 inputs for training, so validation must follow. # Another approach would be checkpoint the trained model, re-export the model for validation with the checkpoint - outputs = model(b_input_ids, - attention_mask=b_input_mask, - labels=b_labels) + outputs = model(b_input_ids, attention_mask=b_input_mask, labels=b_labels) # Get the "logits" output by the model. The "logits" are the output # values prior to applying an activation function like the softmax. @@ -177,7 +178,7 @@ def test(model, validation_dataloader, device, args): # Move logits and labels to CPU logits = logits.detach().cpu().numpy() - label_ids = b_labels.to('cpu').numpy() + label_ids = b_labels.to("cpu").numpy() # Calculate the accuracy for this batch of test sentences. tmp_eval_accuracy = flat_accuracy(logits, label_ids) @@ -190,34 +191,40 @@ def test(model, validation_dataloader, device, args): # Report the final accuracy for this validation run. epoch_time = time.time() - t0 - accuracy = eval_accuracy/nb_eval_steps + accuracy = eval_accuracy / nb_eval_steps print(" Accuracy: {0:.2f}".format(accuracy)) print(" Validation took: {:.4f}s".format(epoch_time)) return epoch_time, accuracy + def load_dataset(args): # 2. Loading CoLA Dataset def _download_dataset(download_dir): if not os.path.exists(download_dir): # Download the file (if we haven't already) - print('Downloading dataset...') - url = 'https://nyu-mll.github.io/CoLA/cola_public_1.1.zip' - wget.download(url, './cola_public_1.1.zip') + print("Downloading dataset...") + url = "https://nyu-mll.github.io/CoLA/cola_public_1.1.zip" + wget.download(url, "./cola_public_1.1.zip") else: - print('Reusing cached dataset') + print("Reusing cached dataset") if not os.path.exists(args.data_dir): - _download_dataset('./cola_public_1.1.zip') + _download_dataset("./cola_public_1.1.zip") # Unzip it - print('Extracting dataset') - with zipfile.ZipFile('./cola_public_1.1.zip', 'r') as zip_ref: - zip_ref.extractall('./') + print("Extracting dataset") + with zipfile.ZipFile("./cola_public_1.1.zip", "r") as zip_ref: + zip_ref.extractall("./") else: - print('Reusing extracted dataset') + print("Reusing extracted dataset") # Load the dataset into a pandas dataframe. - df = pd.read_csv(os.path.join(args.data_dir, "in_domain_train.tsv"), delimiter='\t', header=None, names=['sentence_source', 'label', 'label_notes', 'sentence']) + df = pd.read_csv( + os.path.join(args.data_dir, "in_domain_train.tsv"), + delimiter="\t", + header=None, + names=["sentence_source", "label", "label_notes", "sentence"], + ) # Get the lists of sentences and their labels. sentences = df.sentence.values @@ -226,7 +233,7 @@ def _download_dataset(download_dir): # 3. Tokenization & Input Formatting # Load the BERT tokenizer. - tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True) + tokenizer = BertTokenizer.from_pretrained("bert-base-uncased", do_lower_case=True) # Set the max length of encoded sentence. # 64 is slightly larger than the maximum training sentence length of 47... @@ -241,13 +248,13 @@ def _download_dataset(download_dir): # (3) Append the `[SEP]` token to the end. # (4) Map tokens to their IDs. encoded_sent = tokenizer.encode( - sent, # Sentence to encode. - add_special_tokens = True, # Add '[CLS]' and '[SEP]' - ) + sent, # Sentence to encode. + add_special_tokens=True, # Add '[CLS]' and '[SEP]' + ) # Pad our input tokens with value 0. if len(encoded_sent) < MAX_LEN: - encoded_sent.extend([0]*(MAX_LEN-len(encoded_sent))) + encoded_sent.extend([0] * (MAX_LEN - len(encoded_sent))) # Truncate to MAX_LEN if len(encoded_sent) > MAX_LEN: @@ -272,11 +279,11 @@ def _download_dataset(download_dir): attention_masks.append(att_mask) # Use 90% for training and 10% for validation. - train_inputs, validation_inputs, train_labels, validation_labels = train_test_split(input_ids, labels, - random_state=2018, test_size=0.1) + train_inputs, validation_inputs, train_labels, validation_labels = train_test_split( + input_ids, labels, random_state=2018, test_size=0.1 + ) # Do the same for the masks. - train_masks, validation_masks, _, _ = train_test_split(attention_masks, labels, - random_state=2018, test_size=0.1) + train_masks, validation_masks, _, _ = train_test_split(attention_masks, labels, random_state=2018, test_size=0.1) # Convert all inputs and labels into torch tensors, the required datatype # for our model. @@ -301,69 +308,86 @@ def _download_dataset(download_dir): return train_dataloader, validation_dataloader + # Function to calculate the accuracy of our predictions vs labels def flat_accuracy(preds, labels): pred_flat = np.argmax(preds, axis=1).flatten() labels_flat = labels.flatten() return np.sum(pred_flat == labels_flat) / len(labels_flat) + def format_time(elapsed): - '''Takes a time in seconds and returns a string hh:mm:ss''' + """Takes a time in seconds and returns a string hh:mm:ss""" # Round to the nearest second. elapsed_rounded = int(round((elapsed))) # Format as hh:mm:ss return str(datetime.timedelta(seconds=elapsed_rounded)) + def main(): # 1. Basic setup - parser = argparse.ArgumentParser(description='PyTorch MNIST Example') - parser.add_argument('--pytorch-only', action='store_true', default=False, - help='disables ONNX Runtime training') - parser.add_argument('--batch-size', type=int, default=32, metavar='N', - help='input batch size for training (default: 32)') - parser.add_argument('--do-val', action='store_true', default=False, - help='disables validation') - parser.add_argument('--test-batch-size', type=int, default=64, metavar='N', - help='input batch size for testing (default: 64)') - parser.add_argument('--view-graphs', action='store_true', default=False, - help='views forward and backward graphs') - parser.add_argument('--export-onnx-graphs', action='store_true', default=False, - help='export ONNX graphs to current directory') - parser.add_argument('--no-cuda', action='store_true', default=False, - help='disables CUDA training') - parser.add_argument('--epochs', type=int, default=4, metavar='N', - help='number of epochs to train (default: 4)') - parser.add_argument('--seed', type=int, default=42, metavar='S', - help='random seed (default: 42)') - parser.add_argument('--log-interval', type=int, default=40, metavar='N', - help='how many batches to wait before logging training status (default: 40)') - parser.add_argument('--train-steps', type=int, default=-1, metavar='N', - help='number of steps to train. Set -1 to run through whole dataset (default: -1)') - parser.add_argument('--log-level', choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'], default='WARNING', - help='Log level (default: WARNING)') - parser.add_argument('--num-hidden-layers', type=int, default=1, metavar='H', - help='Number of hidden layers for the BERT model. A vanila BERT has 12 hidden layers (default: 1)') - parser.add_argument('--data-dir', type=str, default='./cola_public/raw', - help='Path to the bert data directory') - parser.add_argument('--grad-acc-steps', type=int, default=2, - help='Number of steps for accumulating gradients') + parser = argparse.ArgumentParser(description="PyTorch MNIST Example") + parser.add_argument("--pytorch-only", action="store_true", default=False, help="disables ONNX Runtime training") + parser.add_argument( + "--batch-size", type=int, default=32, metavar="N", help="input batch size for training (default: 32)" + ) + parser.add_argument("--do-val", action="store_true", default=False, help="disables validation") + parser.add_argument( + "--test-batch-size", type=int, default=64, metavar="N", help="input batch size for testing (default: 64)" + ) + parser.add_argument("--view-graphs", action="store_true", default=False, help="views forward and backward graphs") + parser.add_argument( + "--export-onnx-graphs", action="store_true", default=False, help="export ONNX graphs to current directory" + ) + parser.add_argument("--no-cuda", action="store_true", default=False, help="disables CUDA training") + parser.add_argument("--epochs", type=int, default=4, metavar="N", help="number of epochs to train (default: 4)") + parser.add_argument("--seed", type=int, default=42, metavar="S", help="random seed (default: 42)") + parser.add_argument( + "--log-interval", + type=int, + default=40, + metavar="N", + help="how many batches to wait before logging training status (default: 40)", + ) + parser.add_argument( + "--train-steps", + type=int, + default=-1, + metavar="N", + help="number of steps to train. Set -1 to run through whole dataset (default: -1)", + ) + parser.add_argument( + "--log-level", + choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"], + default="WARNING", + help="Log level (default: WARNING)", + ) + parser.add_argument( + "--num-hidden-layers", + type=int, + default=1, + metavar="H", + help="Number of hidden layers for the BERT model. A vanila BERT has 12 hidden layers (default: 1)", + ) + parser.add_argument("--data-dir", type=str, default="./cola_public/raw", help="Path to the bert data directory") + parser.add_argument("--grad-acc-steps", type=int, default=2, help="Number of steps for accumulating gradients") args = parser.parse_args() # Device (CPU vs CUDA) if torch.cuda.is_available() and not args.no_cuda: device = torch.device("cuda") - print('There are %d GPU(s) available.' % torch.cuda.device_count()) - print('We will use the GPU:', torch.cuda.get_device_name(0)) + print("There are %d GPU(s) available." % torch.cuda.device_count()) + print("We will use the GPU:", torch.cuda.get_device_name(0)) else: - print('No GPU available, using the CPU instead.') + print("No GPU available, using the CPU instead.") device = torch.device("cpu") # Set log level numeric_level = getattr(logging, args.log_level.upper(), None) if not isinstance(numeric_level, int): - raise ValueError('Invalid log level: %s' % args.log_level) + raise ValueError("Invalid log level: %s" % args.log_level) logging.basicConfig(level=numeric_level) # 2. Dataloader @@ -373,26 +397,29 @@ def main(): # Load BertForSequenceClassification, the pretrained BERT model with a single # linear classification layer on top. config = AutoConfig.from_pretrained( - "bert-base-uncased", - num_labels=2, - num_hidden_layers=args.num_hidden_layers, - output_attentions = False, # Whether the model returns attentions weights. - output_hidden_states = False, # Whether the model returns all hidden-states. + "bert-base-uncased", + num_labels=2, + num_hidden_layers=args.num_hidden_layers, + output_attentions=False, # Whether the model returns attentions weights. + output_hidden_states=False, # Whether the model returns all hidden-states. ) model = BertForSequenceClassification.from_pretrained( - "bert-base-uncased", # Use the 12-layer BERT model, with an uncased vocab. + "bert-base-uncased", # Use the 12-layer BERT model, with an uncased vocab. config=config, ) # Note: AdamW is a class from the huggingface library (as opposed to pytorch) - optimizer = torch.optim.AdamW(model.parameters(), - lr = 2e-2, # args.learning_rate - default is 5e-5, our notebook had 2e-5 - eps = 1e-8 # args.adam_epsilon - default is 1e-8. - ) + optimizer = torch.optim.AdamW( + model.parameters(), + lr=2e-2, # args.learning_rate - default is 5e-5, our notebook had 2e-5 + eps=1e-8, # args.adam_epsilon - default is 1e-8. + ) if not args.pytorch_only: # Just for future debugging - debug_options = DebugOptions(save_onnx=args.export_onnx_graphs, onnx_prefix='BertForSequenceClassificationAutoCast') + debug_options = DebugOptions( + save_onnx=args.export_onnx_graphs, onnx_prefix="BertForSequenceClassificationAutoCast" + ) model = ORTModule(model, debug_options) model._torch_module._execution_manager(is_training=True)._enable_grad_acc_optimization = True @@ -401,16 +428,14 @@ def main(): if torch.cuda.is_available() and not args.no_cuda: model.cuda() - - # Authors recommend between 2 and 4 epochs # Total number of training steps is number of batches * number of epochs. total_steps = len(train_dataloader) * args.epochs # Create the learning rate scheduler. - scheduler = get_linear_schedule_with_warmup(optimizer, - num_warmup_steps = 0, # Default value in run_glue.py - num_training_steps = total_steps) + scheduler = get_linear_schedule_with_warmup( + optimizer, num_warmup_steps=0, num_training_steps=total_steps # Default value in run_glue.py + ) scaler = torch.cuda.amp.GradScaler() # Seed @@ -434,11 +459,11 @@ def main(): if args.do_val: assert validation_accuracy > 0.5 - print('\n======== Global stats ========') + print("\n======== Global stats ========") if not args.pytorch_only: estimated_export = 0 if args.epochs > 1: - estimated_export = epoch_0_training - (total_training_time - epoch_0_training)/(args.epochs-1) + estimated_export = epoch_0_training - (total_training_time - epoch_0_training) / (args.epochs - 1) print(" Estimated ONNX export took: {:.4f}s".format(estimated_export)) else: print(" Estimated ONNX export took: Estimate available when epochs > 1 only") @@ -446,5 +471,6 @@ def main(): print(" Accumulated training took: {:.4f}s".format(total_training_time)) print(" Accumulated validation took: {:.4f}s".format(total_test_time)) -if __name__ == '__main__': + +if __name__ == "__main__": main() diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_deepspeed_pipeline_parallel.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_deepspeed_pipeline_parallel.py index d61863867f9fe..21ebdb52037d4 100755 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_deepspeed_pipeline_parallel.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_deepspeed_pipeline_parallel.py @@ -12,45 +12,36 @@ # USAGE: # pip install deepspeed # deepspeed orttraining_test_ortmodule_deepspeed_pipeline_parallel.py --deepspeed_config=orttraining_test_ortmodule_deepspeed_pipeline_parallel_config.json --pipeline-parallel-size 2 --steps=100 -# expected output : steps: 100 loss: 0.0585 iter time (s): 0.186 samples/sec: 53.694 +# expected output : steps: 100 loss: 0.0585 iter time (s): 0.186 samples/sec: 53.694 + class SampleData(torch.utils.data.Dataset): - def __init__(self,x,y): + def __init__(self, x, y): self.x = x self.y = y + def __len__(self): return x.size()[0] + def __getitem__(self, idx): - return self.x[idx],self.y[idx] + return self.x[idx], self.y[idx] + def get_args(): parser = argparse.ArgumentParser() - parser.add_argument('--local_rank', - type=int, - default=-1, - help='local rank passed from distributed launcher') - parser.add_argument('-s', - '--steps', - type=int, - default=100, - help='quit after this many steps') - parser.add_argument('-p', - '--pipeline-parallel-size', - type=int, - default=2, - help='pipeline parallelism') - parser.add_argument('--backend', - type=str, - default='nccl', - help='distributed backend') - parser.add_argument('--seed', type=int, default=0, help='PRNG seed') - parser.add_argument('--fp16',type=bool,default=False,help='fp16 run') - parser.add_argument('--run_without_ort',type=bool,default=False,help='onlydeepspeed run') - + parser.add_argument("--local_rank", type=int, default=-1, help="local rank passed from distributed launcher") + parser.add_argument("-s", "--steps", type=int, default=100, help="quit after this many steps") + parser.add_argument("-p", "--pipeline-parallel-size", type=int, default=2, help="pipeline parallelism") + parser.add_argument("--backend", type=str, default="nccl", help="distributed backend") + parser.add_argument("--seed", type=int, default=0, help="PRNG seed") + parser.add_argument("--fp16", type=bool, default=False, help="fp16 run") + parser.add_argument("--run_without_ort", type=bool, default=False, help="onlydeepspeed run") + parser = deepspeed.add_config_arguments(parser) args = parser.parse_args() return args + n = 10 d_in = 4 d_hidden = 8 @@ -60,9 +51,9 @@ def get_args(): device = torch.device("cuda", args.local_rank) if args.run_without_ort: - print('Running deepspeed pipeline parallel module without ORTModule') + print("Running deepspeed pipeline parallel module without ORTModule") else: - print('Running deepspeed pipeline parallel module with ORTModule') + print("Running deepspeed pipeline parallel module with ORTModule") dist.init_process_group(backend=args.backend) @@ -71,33 +62,40 @@ def get_args(): if args.run_without_ort: model = nn.Sequential( - nn.Linear(d_in, d_hidden), # Stage 1 - nn.ReLU(), # Stage 1 - nn.Linear(d_hidden, d_hidden), # Stage 1 - nn.ReLU(), # Stage 1 - nn.Linear(d_hidden, d_hidden), # Stage 2 - nn.ReLU(), # Stage 2 - nn.Linear(d_hidden, d_out) # Stage 2 + nn.Linear(d_in, d_hidden), # Stage 1 + nn.ReLU(), # Stage 1 + nn.Linear(d_hidden, d_hidden), # Stage 1 + nn.ReLU(), # Stage 1 + nn.Linear(d_hidden, d_hidden), # Stage 2 + nn.ReLU(), # Stage 2 + nn.Linear(d_hidden, d_out), # Stage 2 ) else: model = nn.Sequential( - ORTModule(nn.Linear(d_in, d_hidden).to(device)), # Stage 1 - nn.ReLU().to(device), # ORTModule(nn.ReLU().to(device)), Stage 1, TODO: ORTModule can wrap Relu once stateless model is supported. - ORTModule(nn.Linear(d_hidden, d_hidden).to(device)), # Stage 1 - nn.ReLU().to(device), # ORTModule(nn.ReLU().to(device)), Stage 1, TODO: ORTModule can wrap Relu once stateless model is supported. - ORTModule(nn.Linear(d_hidden, d_hidden).to(device)), # Stage 2 - nn.ReLU().to(device), # ORTModule(nn.ReLU().to(device)), Stage 2, TODO: ORTModule can wrap Relu once stateless model is supported. - ORTModule(nn.Linear(d_hidden, d_out).to(device)) # Stage 2 + ORTModule(nn.Linear(d_in, d_hidden).to(device)), # Stage 1 + nn.ReLU().to( + device + ), # ORTModule(nn.ReLU().to(device)), Stage 1, TODO: ORTModule can wrap Relu once stateless model is supported. + ORTModule(nn.Linear(d_hidden, d_hidden).to(device)), # Stage 1 + nn.ReLU().to( + device + ), # ORTModule(nn.ReLU().to(device)), Stage 1, TODO: ORTModule can wrap Relu once stateless model is supported. + ORTModule(nn.Linear(d_hidden, d_hidden).to(device)), # Stage 2 + nn.ReLU().to( + device + ), # ORTModule(nn.ReLU().to(device)), Stage 2, TODO: ORTModule can wrap Relu once stateless model is supported. + ORTModule(nn.Linear(d_hidden, d_out).to(device)), # Stage 2 ) -model = PipelineModule(layers=model, - loss_fn=torch.nn.CrossEntropyLoss(), - num_stages=args.pipeline_parallel_size, - partition_method='uniform', #'parameters', - activation_checkpoint_interval=0 - ) +model = PipelineModule( + layers=model, + loss_fn=torch.nn.CrossEntropyLoss(), + num_stages=args.pipeline_parallel_size, + partition_method="uniform", #'parameters', + activation_checkpoint_interval=0, +) params = [p for p in model.parameters() if p.requires_grad] @@ -107,16 +105,14 @@ def get_args(): x = x.half() # Output. y = torch.randint(0, d_out, (n,)) -ds = SampleData(x,y) +ds = SampleData(x, y) print("Initialize deepspeed") -model_engine, optimizer, _, _ = deepspeed.initialize(args=args, - model=model, - model_parameters=params, - training_data=ds #(x,y)# - ) +model_engine, optimizer, _, _ = deepspeed.initialize( + args=args, model=model, model_parameters=params, training_data=ds # (x,y)# +) for step in range(args.steps): loss = model_engine.train_batch() if step % 10 == 0: - print("step = ", step, ", loss = ",loss) + print("step = ", step, ", loss = ", loss) diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_deepspeed_zero_stage_1.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_deepspeed_zero_stage_1.py index fde4397216da9..037558663b428 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_deepspeed_zero_stage_1.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_deepspeed_zero_stage_1.py @@ -19,6 +19,7 @@ import deepspeed + class NeuralNet(torch.nn.Module): def __init__(self, input_size, hidden_size, num_classes): super(NeuralNet, self).__init__() @@ -35,8 +36,11 @@ def forward(self, input1): def train(args, model, device, optimizer, loss_fn, train_loader, epoch): - print('\n======== Epoch {:} / {:} with batch size {:} ========'.format( - epoch+1, args.epochs, model.train_batch_size())) + print( + "\n======== Epoch {:} / {:} with batch size {:} ========".format( + epoch + 1, args.epochs, model.train_batch_size() + ) + ) model.train() # Measure how long the training epoch takes. t0 = time.time() @@ -54,6 +58,7 @@ def train(args, model, device, optimizer, loss_fn, train_loader, epoch): if args.view_graphs: import torchviz + pytorch_backward_graph = torchviz.make_dot(probability, params=dict(list(model.named_parameters()))) pytorch_backward_graph.view() @@ -71,9 +76,15 @@ def train(args, model, device, optimizer, loss_fn, train_loader, epoch): if iteration % args.log_interval == 0: curr_time = time.time() elapsed_time = curr_time - start_time - print('[{:5}/{:5} ({:2.0f}%)]\tLoss: {:.6f}\tExecution time: {:.4f}'.format( - iteration * len(data), len(train_loader.dataset), - 100. * iteration / len(train_loader), loss, elapsed_time)) + print( + "[{:5}/{:5} ({:2.0f}%)]\tLoss: {:.6f}\tExecution time: {:.4f}".format( + iteration * len(data), + len(train_loader.dataset), + 100.0 * iteration / len(train_loader), + loss, + elapsed_time, + ) + ) start_time = curr_time # Calculate the average loss over the training data. @@ -103,61 +114,75 @@ def test(args, model, device, loss_fn, test_loader): pred = output.argmax(dim=1, keepdim=True) correct += pred.eq(target.view_as(pred)).sum().item() test_loss /= len(test_loader.dataset) - print('\nTest set: Batch size: {:}, Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format( - args.test_batch_size, test_loss, correct, len(test_loader.dataset), - 100. * correct / len(test_loader.dataset))) + print( + "\nTest set: Batch size: {:}, Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n".format( + args.test_batch_size, + test_loss, + correct, + len(test_loader.dataset), + 100.0 * correct / len(test_loader.dataset), + ) + ) # Report the final accuracy for this validation run. epoch_time = time.time() - t0 - print(" Accuracy: {0:.2f}".format(float(correct)/len(test_loader.dataset))) + print(" Accuracy: {0:.2f}".format(float(correct) / len(test_loader.dataset))) print(" Validation took: {:.4f}s".format(epoch_time)) return epoch_time + def my_loss(x, target, is_train=True): if is_train: return torch.nn.CrossEntropyLoss()(x, target) else: - return torch.nn.CrossEntropyLoss(reduction='sum')(x, target) + return torch.nn.CrossEntropyLoss(reduction="sum")(x, target) + def main(): # Training settings - parser = argparse.ArgumentParser(description='PyTorch MNIST Example') - parser.add_argument('--train-steps', type=int, default=-1, metavar='N', - help='number of steps to train. Set -1 to run through whole dataset (default: -1)') - parser.add_argument('--lr', type=float, default=0.01, metavar='LR', - help='learning rate (default: 0.01)') - parser.add_argument('--batch-size', type=int, default=32, metavar='N', - help='input batch size for training (default: 32)') - parser.add_argument('--test-batch-size', type=int, default=64, metavar='N', - help='input batch size for testing (default: 64)') - parser.add_argument('--no-cuda', action='store_true', default=False, - help='disables CUDA training') - parser.add_argument('--seed', type=int, default=42, metavar='S', - help='random seed (default: 42)') - parser.add_argument('--pytorch-only', action='store_true', default=False, - help='disables ONNX Runtime training') - parser.add_argument('--log-interval', type=int, default=300, metavar='N', - help='how many batches to wait before logging training status (default: 300)') - parser.add_argument('--view-graphs', action='store_true', default=False, - help='views forward and backward graphs') - parser.add_argument('--export-onnx-graphs', action='store_true', default=False, - help='export ONNX graphs to current directory') - parser.add_argument('--epochs', type=int, default=10, metavar='N', - help='number of epochs to train (default: 10)') - parser.add_argument('--log-level', choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'], default='WARNING', - help='Log level (default: WARNING)') - parser.add_argument('--data-dir', type=str, default='./mnist', - help='Path to the mnist data directory') + parser = argparse.ArgumentParser(description="PyTorch MNIST Example") + parser.add_argument( + "--train-steps", + type=int, + default=-1, + metavar="N", + help="number of steps to train. Set -1 to run through whole dataset (default: -1)", + ) + parser.add_argument("--lr", type=float, default=0.01, metavar="LR", help="learning rate (default: 0.01)") + parser.add_argument( + "--batch-size", type=int, default=32, metavar="N", help="input batch size for training (default: 32)" + ) + parser.add_argument( + "--test-batch-size", type=int, default=64, metavar="N", help="input batch size for testing (default: 64)" + ) + parser.add_argument("--no-cuda", action="store_true", default=False, help="disables CUDA training") + parser.add_argument("--seed", type=int, default=42, metavar="S", help="random seed (default: 42)") + parser.add_argument("--pytorch-only", action="store_true", default=False, help="disables ONNX Runtime training") + parser.add_argument( + "--log-interval", + type=int, + default=300, + metavar="N", + help="how many batches to wait before logging training status (default: 300)", + ) + parser.add_argument("--view-graphs", action="store_true", default=False, help="views forward and backward graphs") + parser.add_argument( + "--export-onnx-graphs", action="store_true", default=False, help="export ONNX graphs to current directory" + ) + parser.add_argument("--epochs", type=int, default=10, metavar="N", help="number of epochs to train (default: 10)") + parser.add_argument( + "--log-level", + choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"], + default="WARNING", + help="Log level (default: WARNING)", + ) + parser.add_argument("--data-dir", type=str, default="./mnist", help="Path to the mnist data directory") # DeepSpeed-related settings - parser.add_argument('--local_rank', - type=int, - required=True, - help='local rank passed from distributed launcher') + parser.add_argument("--local_rank", type=int, required=True, help="local rank passed from distributed launcher") parser = deepspeed.add_config_arguments(parser) - - args = parser.parse_args() + args = parser.parse_args() # Common setup torch.manual_seed(args.seed) @@ -171,48 +196,59 @@ def main(): ## Data loader - dist.init_process_group(backend='nccl') + dist.init_process_group(backend="nccl") if args.local_rank == 0: # download only once on rank 0 datasets.MNIST(args.data_dir, download=True) dist.barrier() - train_set = datasets.MNIST(args.data_dir, train=True, - transform=transforms.Compose([transforms.ToTensor(), - transforms.Normalize((0.1307,), (0.3081,))])) + train_set = datasets.MNIST( + args.data_dir, + train=True, + transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]), + ) test_loader = None if args.test_batch_size > 0: test_loader = torch.utils.data.DataLoader( - datasets.MNIST(args.data_dir, train=False, transform=transforms.Compose([ - transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])), - batch_size=args.test_batch_size, shuffle=True) + datasets.MNIST( + args.data_dir, + train=False, + transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]), + ), + batch_size=args.test_batch_size, + shuffle=True, + ) # Model architecture model = NeuralNet(input_size=784, hidden_size=500, num_classes=10).to(device) if not args.pytorch_only: - print('Training MNIST on ORTModule....') + print("Training MNIST on ORTModule....") # Set log level - log_level_mapping = {"DEBUG": LogLevel.VERBOSE, - "INFO": LogLevel.INFO, - "WARNING": LogLevel.WARNING, - "ERROR": LogLevel.ERROR, - "CRITICAL": LogLevel.FATAL} + log_level_mapping = { + "DEBUG": LogLevel.VERBOSE, + "INFO": LogLevel.INFO, + "WARNING": LogLevel.WARNING, + "ERROR": LogLevel.ERROR, + "CRITICAL": LogLevel.FATAL, + } log_level = log_level_mapping.get(args.log_level.upper(), None) if not isinstance(log_level, LogLevel): - raise ValueError('Invalid log level: %s' % args.log_level) - debug_options = DebugOptions(log_level=log_level, save_onnx=args.export_onnx_graphs, onnx_prefix='MNIST') + raise ValueError("Invalid log level: %s" % args.log_level) + debug_options = DebugOptions(log_level=log_level, save_onnx=args.export_onnx_graphs, onnx_prefix="MNIST") model = ORTModule(model, debug_options) else: - print('Training MNIST on vanilla PyTorch....') + print("Training MNIST on vanilla PyTorch....") model, optimizer, train_loader, _ = deepspeed.initialize( - args=args,model=model, + args=args, + model=model, model_parameters=[p for p in model.parameters() if p.requires_grad], - training_data=train_set) - + training_data=train_set, + ) + # Train loop total_training_time, total_test_time, epoch_0_training = 0, 0, 0 for epoch in range(0, args.epochs): @@ -222,11 +258,11 @@ def main(): if args.test_batch_size > 0: total_test_time += test(args, model, device, my_loss, test_loader) - print('\n======== Global stats ========') + print("\n======== Global stats ========") if not args.pytorch_only: estimated_export = 0 if args.epochs > 1: - estimated_export = epoch_0_training - (total_training_time - epoch_0_training)/(args.epochs-1) + estimated_export = epoch_0_training - (total_training_time - epoch_0_training) / (args.epochs - 1) print(" Estimated ONNX export took: {:.4f}s".format(estimated_export)) else: print(" Estimated ONNX export took: Estimate available when epochs > 1 only") @@ -235,5 +271,5 @@ def main(): print(" Accumulated validation took: {:.4f}s".format(total_test_time)) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_experimental_json_config.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_experimental_json_config.py index d2bcff7c0c3b8..8eebb767f6b53 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_experimental_json_config.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_experimental_json_config.py @@ -1,4 +1,3 @@ - import os import torch from onnxruntime.training import ortmodule @@ -20,16 +19,17 @@ def forward(self, input1): out = self.fc2(out) return out + def test_load_config_from_json_1(): - device = 'cuda' + device = "cuda" model = ortmodule.ORTModule(Net().to(device)) # load from json once. - path_to_json = os.path.join(os.getcwd(), 'orttraining_test_ortmodule_experimental_json_config_2.json') + path_to_json = os.path.join(os.getcwd(), "orttraining_test_ortmodule_experimental_json_config_2.json") load_from_json(model, path_to_json) # load from json another time - path_to_json = os.path.join(os.getcwd(), 'orttraining_test_ortmodule_experimental_json_config_1.json') + path_to_json = os.path.join(os.getcwd(), "orttraining_test_ortmodule_experimental_json_config_1.json") load_from_json(model, path_to_json) for training_mode in [True, False]: @@ -63,7 +63,7 @@ def test_load_config_from_json_1(): # test debug options assert ort_model_attributes._debug_options.save_onnx_models.save == True - assert ort_model_attributes._debug_options.save_onnx_models.name_prefix == 'my_model' + assert ort_model_attributes._debug_options.save_onnx_models.name_prefix == "my_model" assert ort_model_attributes._debug_options.logging.log_level.name == "VERBOSE" # test use memory aware gradient builder. @@ -75,16 +75,17 @@ def test_load_config_from_json_1(): # assert onnx opset version assert ortmodule.ONNX_OPSET_VERSION == 13 + def test_load_config_from_json_2(): - device = 'cuda' + device = "cuda" model = ortmodule.ORTModule(Net().to(device)) # load from json once. - path_to_json = os.path.join(os.getcwd(), 'orttraining_test_ortmodule_experimental_json_config_1.json') + path_to_json = os.path.join(os.getcwd(), "orttraining_test_ortmodule_experimental_json_config_1.json") load_from_json(model, path_to_json) # load from json another time - path_to_json = os.path.join(os.getcwd(), 'orttraining_test_ortmodule_experimental_json_config_2.json') + path_to_json = os.path.join(os.getcwd(), "orttraining_test_ortmodule_experimental_json_config_2.json") load_from_json(model, path_to_json) for training_mode in [True, False]: @@ -118,7 +119,7 @@ def test_load_config_from_json_2(): # test debug options assert ort_model_attributes._debug_options.save_onnx_models.save == True - assert ort_model_attributes._debug_options.save_onnx_models.name_prefix == 'my_other_model' + assert ort_model_attributes._debug_options.save_onnx_models.name_prefix == "my_other_model" assert ort_model_attributes._debug_options.logging.log_level.name == "INFO" # test use memory aware gradient builder. diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_fairscale_sharded_optimizer.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_fairscale_sharded_optimizer.py index f8c33742a02bb..e1a7dd591ec36 100755 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_fairscale_sharded_optimizer.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_fairscale_sharded_optimizer.py @@ -12,14 +12,14 @@ from onnxruntime.training.ortmodule import ORTModule, DebugOptions import numpy as np -# Usage : +# Usage : # pip install fairscale -# python3 orttraining_test_ortmodule_fairscale_sharded_optimizer.py --use_sharded_optimizer --use_ortmodule +# python3 orttraining_test_ortmodule_fairscale_sharded_optimizer.py --use_sharded_optimizer --use_ortmodule def dist_init(rank, world_size): - os.environ['MASTER_ADDR'] = 'localhost' - os.environ['MASTER_PORT'] = '12355' + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = "12355" # initialize the process group dist.init_process_group("nccl", rank=rank, world_size=world_size) @@ -39,47 +39,51 @@ def forward(self, input1): out = self.fc2(out) return out -def get_dataloader(args,rank,batch_size): + +def get_dataloader(args, rank, batch_size): # Data loading code train_dataset = torchvision.datasets.MNIST( - root=args.data_dir, - train=True, - transform=transforms.ToTensor(), - download=True - ) - + root=args.data_dir, train=True, transform=transforms.ToTensor(), download=True + ) + train_sampler = torch.utils.data.distributed.DistributedSampler( - train_dataset, - num_replicas=args.world_size, - rank=rank + train_dataset, num_replicas=args.world_size, rank=rank ) - + train_loader = torch.utils.data.DataLoader( dataset=train_dataset, batch_size=batch_size, - shuffle=False, + shuffle=False, num_workers=0, pin_memory=True, - sampler=train_sampler) - + sampler=train_sampler, + ) + test_loader = None if args.test_batch_size > 0: test_loader = torch.utils.data.DataLoader( - datasets.MNIST(args.data_dir, train=False, transform=transforms.Compose([ - transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])), - batch_size=args.test_batch_size, shuffle=True) + datasets.MNIST( + args.data_dir, + train=False, + transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]), + ), + batch_size=args.test_batch_size, + shuffle=True, + ) + + return train_loader, test_loader - return train_loader,test_loader def my_loss(x, target, is_train=True): if is_train: return torch.nn.CrossEntropyLoss()(x, target) else: - return torch.nn.CrossEntropyLoss(reduction='sum')(x, target) + return torch.nn.CrossEntropyLoss(reduction="sum")(x, target) + def train_step(args, model, device, optimizer, loss_fn, train_loader, epoch): - print('\n======== Epoch {:} / {:} with batch size {:} ========'.format(epoch+1, args.epochs, args.batch_size)) + print("\n======== Epoch {:} / {:} with batch size {:} ========".format(epoch + 1, args.epochs, args.batch_size)) model.train() # Measure how long the training epoch takes. t0 = time.time() @@ -99,6 +103,7 @@ def train_step(args, model, device, optimizer, loss_fn, train_loader, epoch): if args.view_graphs: import torchviz + pytorch_backward_graph = torchviz.make_dot(probability, params=dict(list(model.named_parameters()))) pytorch_backward_graph.view() @@ -116,9 +121,15 @@ def train_step(args, model, device, optimizer, loss_fn, train_loader, epoch): if iteration % args.log_interval == 0: curr_time = time.time() elapsed_time = curr_time - start_time - print('[{:5}/{:5} ({:2.0f}%)]\tLoss: {:.6f}\tExecution time: {:.4f}'.format( - iteration * len(data), len(train_loader.dataset), - 100. * iteration / len(train_loader), loss, elapsed_time)) + print( + "[{:5}/{:5} ({:2.0f}%)]\tLoss: {:.6f}\tExecution time: {:.4f}".format( + iteration * len(data), + len(train_loader.dataset), + 100.0 * iteration / len(train_loader), + loss, + elapsed_time, + ) + ) start_time = curr_time # Calculate the average loss over the training data. @@ -129,6 +140,7 @@ def train_step(args, model, device, optimizer, loss_fn, train_loader, epoch): print(" Training epoch took: {:.4f}s".format(epoch_time)) return epoch_time + def test(args, model, device, loss_fn, test_loader): model.eval() t0 = time.time() @@ -146,25 +158,28 @@ def test(args, model, device, loss_fn, test_loader): pred = output.argmax(dim=1, keepdim=True) correct += pred.eq(target.view_as(pred)).sum().item() test_loss /= len(test_loader.dataset) - print('\nTest set: Batch size: {:}, Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format( - args.test_batch_size, test_loss, correct, len(test_loader.dataset), - 100. * correct / len(test_loader.dataset))) + print( + "\nTest set: Batch size: {:}, Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n".format( + args.test_batch_size, + test_loss, + correct, + len(test_loader.dataset), + 100.0 * correct / len(test_loader.dataset), + ) + ) # Report the final accuracy for this validation run. epoch_time = time.time() - t0 - accuracy = float(correct)/len(test_loader.dataset) + accuracy = float(correct) / len(test_loader.dataset) print(" Accuracy: {0:.2f}".format(accuracy)) print(" Validation took: {:.4f}s".format(epoch_time)) return epoch_time, accuracy -def train( - rank: int, - args, - world_size: int, - epochs: int): + +def train(rank: int, args, world_size: int, epochs: int): # DDP init example - dist_init(rank, world_size) + dist_init(rank, world_size) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False # Setup @@ -179,24 +194,26 @@ def train( if args.use_ortmodule: print("Converting to ORTModule....") - debug_options = DebugOptions(save_onnx=args.export_onnx_graphs, onnx_prefix='NeuralNet') + debug_options = DebugOptions(save_onnx=args.export_onnx_graphs, onnx_prefix="NeuralNet") model = ORTModule(model, debug_options) - train_dataloader, test_dataloader = get_dataloader(args,rank,args.batch_size) + train_dataloader, test_dataloader = get_dataloader(args, rank, args.batch_size) loss_fn = my_loss - base_optimizer = torch.optim.SGD # pick any pytorch compliant optimizer here - base_optimizer_arguments = {} # pass any optimizer specific arguments here, or directly below when instantiating OSS + base_optimizer = torch.optim.SGD # pick any pytorch compliant optimizer here + base_optimizer_arguments = ( + {} + ) # pass any optimizer specific arguments here, or directly below when instantiating OSS if args.use_sharded_optimizer: # Wrap the optimizer in its state sharding brethren - optimizer = OSS(params=model.parameters(), optim=base_optimizer,lr = args.lr ) + optimizer = OSS(params=model.parameters(), optim=base_optimizer, lr=args.lr) # Wrap the model into ShardedDDP, which will reduce gradients to the proper ranks model = ShardedDDP(model, optimizer) else: device_ids = None if args.cpu else [rank] model = DDP(model, device_ids=device_ids, find_unused_parameters=False) # type: ignore - + optimizer = torch.optim.SGD(model.parameters(), lr=args.lr) # Any relevant training loop, nothing specific to OSS. For example: model.train() @@ -209,11 +226,11 @@ def train( test_time, validation_accuracy = test(args, model, rank, loss_fn, test_dataloader) total_test_time += test_time - print('\n======== Global stats ========') + print("\n======== Global stats ========") if args.use_ortmodule: estimated_export = 0 if args.epochs > 1: - estimated_export = epoch_0_training - (total_training_time - epoch_0_training)/(args.epochs-1) + estimated_export = epoch_0_training - (total_training_time - epoch_0_training) / (args.epochs - 1) print(" Estimated ONNX export took: {:.4f}s".format(estimated_export)) else: print(" Estimated ONNX export took: Estimate available when epochs > 1 only") @@ -221,7 +238,6 @@ def train( print(" Accumulated training took: {:.4f}s".format(total_training_time)) print(" Accumulated validation took: {:.4f}s".format(total_test_time)) - dist.destroy_process_group() @@ -233,28 +249,34 @@ def train( parser.add_argument("--world_size", action="store", default=2, type=int) parser.add_argument("--epochs", action="store", default=10, type=int) parser.add_argument("--batch_size", action="store", default=256, type=int) - parser.add_argument('--lr', type=float, default=0.01, metavar='LR', - help='learning rate (default: 0.01)') - parser.add_argument('--use_sharded_optimizer', action='store_true', default=False, - help='use sharded optim') - parser.add_argument('--train-steps', type=int, default=-1, metavar='N', - help='number of steps to train. Set -1 to run through whole dataset (default: -1)') - parser.add_argument('--view-graphs', action='store_true', default=False, - help='views forward and backward graphs') - parser.add_argument('--export-onnx-graphs', action='store_true', default=False, - help='export ONNX graphs to current directory') - parser.add_argument('--log-interval', type=int, default=300, metavar='N', - help='how many batches to wait before logging training status (default: 300)') + parser.add_argument("--lr", type=float, default=0.01, metavar="LR", help="learning rate (default: 0.01)") + parser.add_argument("--use_sharded_optimizer", action="store_true", default=False, help="use sharded optim") + parser.add_argument( + "--train-steps", + type=int, + default=-1, + metavar="N", + help="number of steps to train. Set -1 to run through whole dataset (default: -1)", + ) + parser.add_argument("--view-graphs", action="store_true", default=False, help="views forward and backward graphs") + parser.add_argument( + "--export-onnx-graphs", action="store_true", default=False, help="export ONNX graphs to current directory" + ) + parser.add_argument( + "--log-interval", + type=int, + default=300, + metavar="N", + help="how many batches to wait before logging training status (default: 300)", + ) parser.add_argument("--cpu", action="store_true", default=False) - parser.add_argument("--use_ortmodule", action="store_true", default=False, help = "use ortmodule") - parser.add_argument('--test-batch-size', type=int, default=64, metavar='N', - help='input batch size for testing (default: 64)') - parser.add_argument('--no-cuda', action='store_true', default=False, - help='disables CUDA training') - parser.add_argument('--seed', type=int, default=42, metavar='S', - help='random seed (default: 42)') - parser.add_argument('--data-dir', type=str, default='./mnist', - help='Path to the mnist data directory') + parser.add_argument("--use_ortmodule", action="store_true", default=False, help="use ortmodule") + parser.add_argument( + "--test-batch-size", type=int, default=64, metavar="N", help="input batch size for testing (default: 64)" + ) + parser.add_argument("--no-cuda", action="store_true", default=False, help="disables CUDA training") + parser.add_argument("--seed", type=int, default=42, metavar="S", help="random seed (default: 42)") + parser.add_argument("--data-dir", type=str, default="./mnist", help="Path to the mnist data directory") args = parser.parse_args() # Supposing that WORLD_SIZE and EPOCHS are somehow defined somewhere diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_fallback.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_fallback.py index e4d714c1c717b..dba407cca5c1b 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_fallback.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_fallback.py @@ -14,16 +14,19 @@ from onnxruntime.training.ortmodule import ORTModule, _fallback, ORTMODULE_TORCH_CPP_DIR from onnxruntime.training.ortmodule.torch_cpp_extensions import is_installed as is_torch_cpp_extensions_installed import _test_helpers -from _orttraining_ortmodule_models import (NeuralNetSinglePositionalArgument, - NeuralNetCustomClassOutput, - MyCustomClassInputNet, - MyCustomFunctionReluModel) +from _orttraining_ortmodule_models import ( + NeuralNetSinglePositionalArgument, + NeuralNetCustomClassOutput, + MyCustomClassInputNet, + MyCustomFunctionReluModel, +) # PyTorch model definitions for tests -@pytest.mark.parametrize("is_training,fallback_enabled,matching_policy,persist_fallback", - list(itertools.product([True, False], repeat=4))) +@pytest.mark.parametrize( + "is_training,fallback_enabled,matching_policy,persist_fallback", list(itertools.product([True, False], repeat=4)) +) def test_ortmodule_fallback_forward(is_training, fallback_enabled, matching_policy, persist_fallback): # is_training: True for torch.nn.Module training model, eval mode otherwise # fallback_enabled: True PyTorch executes the forward graph instead of ORT backend @@ -32,13 +35,13 @@ def test_ortmodule_fallback_forward(is_training, fallback_enabled, matching_poli if fallback_enabled: if matching_policy: - policy = 'FALLBACK_FORCE_TORCH_FORWARD' + policy = "FALLBACK_FORCE_TORCH_FORWARD" else: - policy = 'FALLBACK_UNSUPPORTED_DEVICE' + policy = "FALLBACK_UNSUPPORTED_DEVICE" else: - policy = 'FALLBACK_DISABLE' - os.environ['ORTMODULE_FALLBACK_POLICY'] = policy - os.environ['ORTMODULE_FALLBACK_RETRY'] = str(not persist_fallback) + policy = "FALLBACK_DISABLE" + os.environ["ORTMODULE_FALLBACK_POLICY"] = policy + os.environ["ORTMODULE_FALLBACK_RETRY"] = str(not persist_fallback) from dataclasses import dataclass @@ -65,8 +68,10 @@ def forward(self, point): if fallback_enabled: if matching_policy: if i > 0 and persist_fallback: - assert ort_model._torch_module._execution_manager( - is_training=is_training)._fallback_manager._exception is not None + assert ( + ort_model._torch_module._execution_manager(is_training=is_training)._fallback_manager._exception + is not None + ) ort_out = ort_model(inputs) pt_out = pt_model(inputs) assert ort_out == pt_out @@ -80,8 +85,9 @@ def forward(self, point): assert "ORTModule does not support the following model data type" in str(type_error.value) -@pytest.mark.parametrize("is_training,fallback_enabled,matching_policy,persist_fallback", - list(itertools.product([True, False], repeat=4))) +@pytest.mark.parametrize( + "is_training,fallback_enabled,matching_policy,persist_fallback", list(itertools.product([True, False], repeat=4)) +) def test_ortmodule_fallback_device__multiple(is_training, fallback_enabled, matching_policy, persist_fallback): # is_training: True for torch.nn.Module training model, eval mode otherwise # fallback_enabled: True PyTorch executes the forward graph instead of ORT backend @@ -90,24 +96,24 @@ def test_ortmodule_fallback_device__multiple(is_training, fallback_enabled, matc if fallback_enabled: if matching_policy: - policy = 'FALLBACK_UNSUPPORTED_DEVICE' + policy = "FALLBACK_UNSUPPORTED_DEVICE" else: - policy = 'FALLBACK_UNSUPPORTED_DATA' + policy = "FALLBACK_UNSUPPORTED_DATA" else: - policy = 'FALLBACK_DISABLE' - os.environ['ORTMODULE_FALLBACK_POLICY'] = policy - os.environ['ORTMODULE_FALLBACK_RETRY'] = str(not persist_fallback) + policy = "FALLBACK_DISABLE" + os.environ["ORTMODULE_FALLBACK_POLICY"] = policy + os.environ["ORTMODULE_FALLBACK_RETRY"] = str(not persist_fallback) class ManyDevicesNet(torch.nn.Module): def __init__(self): super().__init__() - self.net1 = torch.nn.Linear(10, 10).to('cuda:0') + self.net1 = torch.nn.Linear(10, 10).to("cuda:0") self.relu = torch.nn.ReLU() - self.net2 = torch.nn.Linear(10, 5).to('cpu') + self.net2 = torch.nn.Linear(10, 5).to("cpu") def forward(self, x): - x = self.relu(self.net1(x.to('cuda:0'))) - return self.net2(x.to('cpu')) + x = self.relu(self.net1(x.to("cuda:0"))) + return self.net2(x.to("cpu")) pt_model = ManyDevicesNet() inputs = torch.randn(20, 10) @@ -133,8 +139,9 @@ def forward(self, x): assert "ORTModule supports a single device per model" in str(type_error.value) -@pytest.mark.parametrize("is_training,fallback_enabled,matching_policy,persist_fallback", - list(itertools.product([True, False], repeat=4))) +@pytest.mark.parametrize( + "is_training,fallback_enabled,matching_policy,persist_fallback", list(itertools.product([True, False], repeat=4)) +) def test_ortmodule_fallback_device__mismatch(is_training, fallback_enabled, matching_policy, persist_fallback): # is_training: True for torch.nn.Module training model, eval mode otherwise # fallback_enabled: True PyTorch executes the forward graph instead of ORT backend @@ -143,16 +150,16 @@ def test_ortmodule_fallback_device__mismatch(is_training, fallback_enabled, matc if fallback_enabled: if matching_policy: - policy = 'FALLBACK_UNSUPPORTED_DEVICE' + policy = "FALLBACK_UNSUPPORTED_DEVICE" else: - policy = 'FALLBACK_UNSUPPORTED_DATA' + policy = "FALLBACK_UNSUPPORTED_DATA" else: - policy = 'FALLBACK_DISABLE' - os.environ['ORTMODULE_FALLBACK_POLICY'] = policy - os.environ['ORTMODULE_FALLBACK_RETRY'] = str(not persist_fallback) - os.environ['ORTMODULE_SKIPCHECK_POLICY'] = 'SKIP_CHECK_DISABLED' + policy = "FALLBACK_DISABLE" + os.environ["ORTMODULE_FALLBACK_POLICY"] = policy + os.environ["ORTMODULE_FALLBACK_RETRY"] = str(not persist_fallback) + os.environ["ORTMODULE_SKIPCHECK_POLICY"] = "SKIP_CHECK_DISABLED" - data_device = 'cuda' + data_device = "cuda" N, D_in, H, D_out = 64, 784, 500, 10 pt_model = NeuralNetSinglePositionalArgument(D_in, H, D_out) @@ -172,24 +179,34 @@ def test_ortmodule_fallback_device__mismatch(is_training, fallback_enabled, matc if matching_policy: with pytest.raises(RuntimeError) as e: ort_model(inputs) - assert \ - ("Tensor for argument #1 'self' is on CPU, but expected them to be on GPU (while checking arguments for addmm)" in str(e.value)) \ - or ("Expected all tensors to be on the same device, but found at least two devices, cpu and cuda:0!" in str(e.value)) + assert ( + "Tensor for argument #1 'self' is on CPU, but expected them to be on GPU (while checking arguments for addmm)" + in str(e.value) + ) or ( + "Expected all tensors to be on the same device, but found at least two devices, cpu and cuda:0!" + in str(e.value) + ) else: with pytest.raises(_fallback.ORTModuleDeviceException) as e: ort_model(inputs) - assert (f"Input argument to forward found on device {input_device}, " - f"but expected it to be on module device {ort_model_device}." in str(e.value)) + assert ( + f"Input argument to forward found on device {input_device}, " + f"but expected it to be on module device {ort_model_device}." in str(e.value) + ) else: with pytest.raises(_fallback.ORTModuleDeviceException) as e: ort_model(inputs) - assert (f"Input argument to forward found on device {input_device}, " - f"but expected it to be on module device {ort_model_device}." in str(e.value)) + assert ( + f"Input argument to forward found on device {input_device}, " + f"but expected it to be on module device {ort_model_device}." in str(e.value) + ) - del os.environ['ORTMODULE_SKIPCHECK_POLICY'] + del os.environ["ORTMODULE_SKIPCHECK_POLICY"] -@pytest.mark.parametrize("is_training,fallback_enabled,matching_policy,persist_fallback", - list(itertools.product([True, False], repeat=4))) + +@pytest.mark.parametrize( + "is_training,fallback_enabled,matching_policy,persist_fallback", list(itertools.product([True, False], repeat=4)) +) def test_ortmodule_fallback_output(is_training, fallback_enabled, matching_policy, persist_fallback): # is_training: True for torch.nn.Module training model, eval mode otherwise # fallback_enabled: True PyTorch executes the forward graph instead of ORT backend @@ -198,15 +215,15 @@ def test_ortmodule_fallback_output(is_training, fallback_enabled, matching_polic if fallback_enabled: if matching_policy: - policy = 'FALLBACK_UNSUPPORTED_DATA' + policy = "FALLBACK_UNSUPPORTED_DATA" else: - policy = 'FALLBACK_UNSUPPORTED_DEVICE' + policy = "FALLBACK_UNSUPPORTED_DEVICE" else: - policy = 'FALLBACK_DISABLE' - os.environ['ORTMODULE_FALLBACK_POLICY'] = policy - os.environ['ORTMODULE_FALLBACK_RETRY'] = str(not persist_fallback) + policy = "FALLBACK_DISABLE" + os.environ["ORTMODULE_FALLBACK_POLICY"] = policy + os.environ["ORTMODULE_FALLBACK_RETRY"] = str(not persist_fallback) - device = 'cuda' + device = "cuda" N, D_in, H, D_out = 64, 784, 500, 10 pt_model = NeuralNetCustomClassOutput(D_in, H, D_out).to(device) ort_model = ORTModule(copy.deepcopy(pt_model)) @@ -221,8 +238,10 @@ def test_ortmodule_fallback_output(is_training, fallback_enabled, matching_polic if fallback_enabled: if matching_policy: if i > 0 and persist_fallback: - assert ort_model._torch_module._execution_manager( - is_training=is_training)._fallback_manager._exception is not None + assert ( + ort_model._torch_module._execution_manager(is_training=is_training)._fallback_manager._exception + is not None + ) ort_out = ort_model(x, y, z) pt_out = pt_model(x, y, z) _test_helpers.assert_values_are_close(ort_out.out1, pt_out.out1, rtol=0, atol=0) @@ -231,15 +250,16 @@ def test_ortmodule_fallback_output(is_training, fallback_enabled, matching_polic else: with pytest.raises(_fallback.ORTModuleIOError) as runtime_error: ort_model(x, y, z) - assert 'ORTModule does not support the following model output type' in str(runtime_error.value) + assert "ORTModule does not support the following model output type" in str(runtime_error.value) else: with pytest.raises(_fallback.ORTModuleIOError) as runtime_error: ort_model(x, y, z) - assert 'ORTModule does not support the following model output type' in str(runtime_error.value) + assert "ORTModule does not support the following model output type" in str(runtime_error.value) -@pytest.mark.parametrize("is_training,fallback_enabled,matching_policy,persist_fallback", - list(itertools.product([True, False], repeat=4))) +@pytest.mark.parametrize( + "is_training,fallback_enabled,matching_policy,persist_fallback", list(itertools.product([True, False], repeat=4)) +) def test_ortmodule_fallback_input(is_training, fallback_enabled, matching_policy, persist_fallback): # is_training: True for torch.nn.Module training model, eval mode otherwise # fallback_enabled: True PyTorch executes the forward graph instead of ORT backend @@ -248,13 +268,13 @@ def test_ortmodule_fallback_input(is_training, fallback_enabled, matching_policy if fallback_enabled: if matching_policy: - policy = 'FALLBACK_UNSUPPORTED_DATA' + policy = "FALLBACK_UNSUPPORTED_DATA" else: - policy = 'FALLBACK_UNSUPPORTED_DEVICE' + policy = "FALLBACK_UNSUPPORTED_DEVICE" else: - policy = 'FALLBACK_DISABLE' - os.environ['ORTMODULE_FALLBACK_POLICY'] = policy - os.environ['ORTMODULE_FALLBACK_RETRY'] = str(not persist_fallback) + policy = "FALLBACK_DISABLE" + os.environ["ORTMODULE_FALLBACK_POLICY"] = policy + os.environ["ORTMODULE_FALLBACK_RETRY"] = str(not persist_fallback) pt_model = MyCustomClassInputNet() ort_model = ORTModule(copy.deepcopy(pt_model)) @@ -271,27 +291,34 @@ def __init__(self, x): if fallback_enabled: if matching_policy: if i > 0 and persist_fallback: - assert ort_model._torch_module._execution_manager( - is_training=is_training)._fallback_manager._exception is not None + assert ( + ort_model._torch_module._execution_manager(is_training=is_training)._fallback_manager._exception + is not None + ) ort_out = ort_model(inputs, CustomClass(1)) pt_out = pt_model(inputs, CustomClass(1)) _test_helpers.assert_values_are_close(ort_out, pt_out, rtol=0, atol=0) else: with pytest.raises(_fallback.ORTModuleIOError) as ex_info: _ = ort_model(torch.randn(1, 2), CustomClass(1)) - assert "ORTModule does not support the following model data"\ - " type .CustomClass'>" in str(ex_info.value) + ) else: with pytest.raises(_fallback.ORTModuleIOError) as ex_info: _ = ort_model(torch.randn(1, 2), CustomClass(1)) - assert "ORTModule does not support the following model data"\ - " type .CustomClass'>" in str(ex_info.value) + ) -@pytest.mark.parametrize("is_training,fallback_enabled,matching_policy,persist_fallback", - list(itertools.product([True, False], repeat=4))) +@pytest.mark.parametrize( + "is_training,fallback_enabled,matching_policy,persist_fallback", list(itertools.product([True, False], repeat=4)) +) def test_ortmodule_fallback_torch_model(is_training, fallback_enabled, matching_policy, persist_fallback): # is_training: True for torch.nn.Module training model, eval mode otherwise # fallback_enabled: True PyTorch executes the forward graph instead of ORT backend @@ -300,15 +327,15 @@ def test_ortmodule_fallback_torch_model(is_training, fallback_enabled, matching_ if fallback_enabled: if matching_policy: - policy = 'FALLBACK_UNSUPPORTED_TORCH_MODEL' + policy = "FALLBACK_UNSUPPORTED_TORCH_MODEL" else: - policy = 'FALLBACK_UNSUPPORTED_DEVICE' + policy = "FALLBACK_UNSUPPORTED_DEVICE" else: - policy = 'FALLBACK_DISABLE' - os.environ['ORTMODULE_FALLBACK_POLICY'] = policy - os.environ['ORTMODULE_FALLBACK_RETRY'] = str(not persist_fallback) + policy = "FALLBACK_DISABLE" + os.environ["ORTMODULE_FALLBACK_POLICY"] = policy + os.environ["ORTMODULE_FALLBACK_RETRY"] = str(not persist_fallback) - device = 'cuda' + device = "cuda" N, D_in, H, D_out = 64, 784, 500, 10 x = torch.randn(N, D_in, device=device) @@ -336,8 +363,9 @@ def test_ortmodule_fallback_torch_model(is_training, fallback_enabled, matching_ assert "ORTModule is not compatible with torch.nn.DataParallel" in str(ex_info.value) -@pytest.mark.parametrize("is_training,fallback_enabled,matching_policy,persist_fallback", - list(itertools.product([True, False], repeat=4))) +@pytest.mark.parametrize( + "is_training,fallback_enabled,matching_policy,persist_fallback", list(itertools.product([True, False], repeat=4)) +) def test_ortmodule_fallback_init__torch_version(is_training, fallback_enabled, matching_policy, persist_fallback): # is_training: True for torch.nn.Module training model, eval mode otherwise # fallback_enabled: True PyTorch executes the forward graph instead of ORT backend @@ -346,21 +374,22 @@ def test_ortmodule_fallback_init__torch_version(is_training, fallback_enabled, m from packaging import version from onnxruntime.training.ortmodule import MINIMUM_RUNTIME_PYTORCH_VERSION_STR - runtime_pytorch_version = version.parse(torch.__version__.split('+')[0]) + + runtime_pytorch_version = version.parse(torch.__version__.split("+")[0]) minimum_runtime_pytorch_version = version.parse(MINIMUM_RUNTIME_PYTORCH_VERSION_STR) if runtime_pytorch_version < minimum_runtime_pytorch_version: if fallback_enabled: if matching_policy: - policy = 'FALLBACK_BAD_INITIALIZATION' + policy = "FALLBACK_BAD_INITIALIZATION" else: - policy = 'FALLBACK_UNSUPPORTED_DEVICE' + policy = "FALLBACK_UNSUPPORTED_DEVICE" else: - policy = 'FALLBACK_DISABLE' - os.environ['ORTMODULE_FALLBACK_POLICY'] = policy - os.environ['ORTMODULE_FALLBACK_RETRY'] = str(not persist_fallback) + policy = "FALLBACK_DISABLE" + os.environ["ORTMODULE_FALLBACK_POLICY"] = policy + os.environ["ORTMODULE_FALLBACK_RETRY"] = str(not persist_fallback) - device = 'cuda' + device = "cuda" N, D_in, H, D_out = 64, 784, 500, 10 x = torch.randn(N, D_in, device=device) @@ -380,42 +409,51 @@ def test_ortmodule_fallback_init__torch_version(is_training, fallback_enabled, m with pytest.raises(_fallback.ORTModuleInitException) as ex_info: ort_model = ORTModule(pt_model) assert "ONNX Runtime ORTModule frontend requires PyTorch version greater or equal to" in str( - ex_info.value) + ex_info.value + ) else: with pytest.raises(_fallback.ORTModuleInitException) as ex_info: # Initialize with fallback policy because Exception will happen during __init__ ort_model = ORTModule(pt_model) - assert "ONNX Runtime ORTModule frontend requires PyTorch version greater or equal to" in str(ex_info.value) + assert "ONNX Runtime ORTModule frontend requires PyTorch version greater or equal to" in str( + ex_info.value + ) else: - warnings.warn('Skipping test_ortmodule_fallback_torch_version.' - f' It requires PyTorch prior to {MINIMUM_RUNTIME_PYTORCH_VERSION_STR}') - - -@pytest.mark.parametrize("is_training,fallback_enabled,matching_policy,persist_fallback", - list(itertools.product([True, False], repeat=4))) -def test_ortmodule_fallback_init__missing_cpp_extensions(is_training, fallback_enabled, matching_policy, - persist_fallback): + warnings.warn( + "Skipping test_ortmodule_fallback_torch_version." + f" It requires PyTorch prior to {MINIMUM_RUNTIME_PYTORCH_VERSION_STR}" + ) + + +@pytest.mark.parametrize( + "is_training,fallback_enabled,matching_policy,persist_fallback", list(itertools.product([True, False], repeat=4)) +) +def test_ortmodule_fallback_init__missing_cpp_extensions( + is_training, fallback_enabled, matching_policy, persist_fallback +): # is_training: True for torch.nn.Module training model, eval mode otherwise # fallback_enabled: True PyTorch executes the forward graph instead of ORT backend # matching_policy: True matches FALLBACK_UNSUPPORTED_TORCH_MODEL policy to ORTModuleDeviceException exception. # Otherwise, an incorrect policy (FALLBACK_UNSUPPORTED_DEVICE) is used to verify that the fallback does not happen if is_torch_cpp_extensions_installed(ORTMODULE_TORCH_CPP_DIR): - warnings.warn('Skipping test_ortmodule_fallback_init__missing_cpp_extensions.' - f' It requires PyTorch CPP extensions to be missing') + warnings.warn( + "Skipping test_ortmodule_fallback_init__missing_cpp_extensions." + f" It requires PyTorch CPP extensions to be missing" + ) else: if fallback_enabled: if matching_policy: - policy = 'FALLBACK_BAD_INITIALIZATION' + policy = "FALLBACK_BAD_INITIALIZATION" else: - policy = 'FALLBACK_UNSUPPORTED_DEVICE' + policy = "FALLBACK_UNSUPPORTED_DEVICE" else: - policy = 'FALLBACK_DISABLE' - os.environ['ORTMODULE_FALLBACK_POLICY'] = policy - os.environ['ORTMODULE_FALLBACK_RETRY'] = str(not persist_fallback) + policy = "FALLBACK_DISABLE" + os.environ["ORTMODULE_FALLBACK_POLICY"] = policy + os.environ["ORTMODULE_FALLBACK_RETRY"] = str(not persist_fallback) - device = 'cuda' + device = "cuda" N, D_in, H, D_out = 64, 784, 500, 10 x = torch.randn(N, D_in, device=device) @@ -442,10 +480,12 @@ def test_ortmodule_fallback_init__missing_cpp_extensions(is_training, fallback_e assert "ORTModule's extensions were not detected" in str(ex_info.value) -@pytest.mark.parametrize("is_training,fallback_enabled,matching_policy,persist_fallback", - list(itertools.product([True, False], repeat=4))) -def test_ortmodule_fallback_onnx_model__custom_autograd(is_training, fallback_enabled, matching_policy, - persist_fallback): +@pytest.mark.parametrize( + "is_training,fallback_enabled,matching_policy,persist_fallback", list(itertools.product([True, False], repeat=4)) +) +def test_ortmodule_fallback_onnx_model__custom_autograd( + is_training, fallback_enabled, matching_policy, persist_fallback +): # is_training: True for torch.nn.Module training model, eval mode otherwise # fallback_enabled: True PyTorch executes the forward graph instead of ORT backend # matching_policy: True matches FALLBACK_UNSUPPORTED_ONNX_MODEL policy to ORTModuleDeviceException exception. @@ -453,13 +493,13 @@ def test_ortmodule_fallback_onnx_model__custom_autograd(is_training, fallback_en if fallback_enabled: if matching_policy: - policy = 'FALLBACK_UNSUPPORTED_ONNX_MODEL' + policy = "FALLBACK_UNSUPPORTED_ONNX_MODEL" else: - policy = 'FALLBACK_UNSUPPORTED_DEVICE' + policy = "FALLBACK_UNSUPPORTED_DEVICE" else: - policy = 'FALLBACK_DISABLE' - os.environ['ORTMODULE_FALLBACK_POLICY'] = policy - os.environ['ORTMODULE_FALLBACK_RETRY'] = str(not persist_fallback) + policy = "FALLBACK_DISABLE" + os.environ["ORTMODULE_FALLBACK_POLICY"] = policy + os.environ["ORTMODULE_FALLBACK_RETRY"] = str(not persist_fallback) dtype = torch.float device = torch.device("cuda") @@ -479,8 +519,10 @@ def test_ortmodule_fallback_onnx_model__custom_autograd(is_training, fallback_en if fallback_enabled: if matching_policy: if i > 0 and persist_fallback: - assert ort_model._torch_module._execution_manager( - is_training=is_training)._fallback_manager._exception is not None + assert ( + ort_model._torch_module._execution_manager(is_training=is_training)._fallback_manager._exception + is not None + ) pt_out = pt_model(x.mm(w1)).mm(w2) ort_out = ort_model(x.mm(w1)).mm(w2) _test_helpers.assert_values_are_close(ort_out, pt_out, rtol=0, atol=0) @@ -495,8 +537,9 @@ def test_ortmodule_fallback_onnx_model__custom_autograd(is_training, fallback_en assert "There was an error while exporting the PyTorch model to ONNX" in str(ex_info.value) -@pytest.mark.parametrize("is_training,fallback_enabled,matching_policy,persist_fallback", - list(itertools.product([True, False], repeat=4))) +@pytest.mark.parametrize( + "is_training,fallback_enabled,matching_policy,persist_fallback", list(itertools.product([True, False], repeat=4)) +) def test_ortmodule_fallback_onnx_model__missing_op(is_training, fallback_enabled, matching_policy, persist_fallback): # is_training: True for torch.nn.Module training model, eval mode otherwise # fallback_enabled: True PyTorch executes the forward graph instead of ORT backend @@ -505,17 +548,18 @@ def test_ortmodule_fallback_onnx_model__missing_op(is_training, fallback_enabled if fallback_enabled: if matching_policy: - policy = 'FALLBACK_UNSUPPORTED_ONNX_MODEL' + policy = "FALLBACK_UNSUPPORTED_ONNX_MODEL" else: - policy = 'FALLBACK_UNSUPPORTED_DEVICE' + policy = "FALLBACK_UNSUPPORTED_DEVICE" else: - policy = 'FALLBACK_DISABLE' - os.environ['ORTMODULE_FALLBACK_POLICY'] = policy - os.environ['ORTMODULE_FALLBACK_RETRY'] = str(not persist_fallback) + policy = "FALLBACK_DISABLE" + os.environ["ORTMODULE_FALLBACK_POLICY"] = policy + os.environ["ORTMODULE_FALLBACK_RETRY"] = str(not persist_fallback) class CrossModule(torch.nn.Module): def forward(self, x, y): return torch.cross(x, y) + x = torch.randn(2, 3) y = torch.randn(2, 3) @@ -528,8 +572,10 @@ def forward(self, x, y): if fallback_enabled: if matching_policy: if i > 0 and persist_fallback: - assert ort_model._torch_module._execution_manager( - is_training=is_training)._fallback_manager._exception is not None + assert ( + ort_model._torch_module._execution_manager(is_training=is_training)._fallback_manager._exception + is not None + ) pt_out = pt_model(x, y) ort_out = ort_model(x, y) _test_helpers.assert_values_are_close(ort_out, pt_out, rtol=0, atol=0) @@ -544,17 +590,16 @@ def forward(self, x, y): assert "There was an error while exporting the PyTorch model to ONNX" in str(ex_info.value) -@pytest.mark.parametrize("is_training,persist_fallback", - list(itertools.product([True, False], repeat=2))) +@pytest.mark.parametrize("is_training,persist_fallback", list(itertools.product([True, False], repeat=2))) def test_ortmodule_fallback_warn_message(is_training, persist_fallback): # is_training: True for torch.nn.Module training model, eval mode otherwise - policy = 'FALLBACK_UNSUPPORTED_DEVICE' - os.environ['ORTMODULE_FALLBACK_POLICY'] = policy - os.environ['ORTMODULE_FALLBACK_RETRY'] = str(not persist_fallback) - os.environ['ORTMODULE_SKIPCHECK_POLICY'] = 'SKIP_CHECK_DISABLED' + policy = "FALLBACK_UNSUPPORTED_DEVICE" + os.environ["ORTMODULE_FALLBACK_POLICY"] = policy + os.environ["ORTMODULE_FALLBACK_RETRY"] = str(not persist_fallback) + os.environ["ORTMODULE_SKIPCHECK_POLICY"] = "SKIP_CHECK_DISABLED" - data_device = 'cuda' + data_device = "cuda" N, D_in, H, D_out = 64, 784, 500, 10 pt_model = NeuralNetSinglePositionalArgument(D_in, H, D_out) @@ -573,45 +618,38 @@ def test_ortmodule_fallback_warn_message(is_training, persist_fallback): ort_model(inputs) assert "Fallback to PyTorch due to exception" in str(warning_record[0].message.args[0]) - del os.environ['ORTMODULE_SKIPCHECK_POLICY'] - + del os.environ["ORTMODULE_SKIPCHECK_POLICY"] -@pytest.mark.parametrize("is_training,persist_fallback", - list(itertools.product([True, False], repeat=2))) +@pytest.mark.parametrize("is_training,persist_fallback", list(itertools.product([True, False], repeat=2))) def test_ortmodule_fallback_non_contiguous_tensors(is_training, persist_fallback): # is_training: True for torch.nn.Module training model, eval mode otherwise # Validate fix for issue: https://github.com/pytorch/ort/issues/92 - policy = 'FALLBACK_UNSUPPORTED_DEVICE' - os.environ['ORTMODULE_FALLBACK_POLICY'] = policy - os.environ['ORTMODULE_FALLBACK_RETRY'] = str(not persist_fallback) - os.environ['ORTMODULE_SKIPCHECK_POLICY'] = 'SKIP_CHECK_DISABLED' + policy = "FALLBACK_UNSUPPORTED_DEVICE" + os.environ["ORTMODULE_FALLBACK_POLICY"] = policy + os.environ["ORTMODULE_FALLBACK_RETRY"] = str(not persist_fallback) + os.environ["ORTMODULE_SKIPCHECK_POLICY"] = "SKIP_CHECK_DISABLED" class PositionalEncoding(torch.nn.Module): - def __init__(self, d_model, dropout=0.1, max_len=5000): super().__init__() self.dropout = torch.nn.Dropout(p=dropout) position = torch.arange(max_len).unsqueeze(1) - div_term = (torch.exp(torch.arange(0, d_model, 2) * - (-math.log(10000.0) / d_model))) + div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)) pe = torch.zeros(max_len, 1, d_model) pe[:, 0, 0::2] = torch.sin(position * div_term) pe[:, 0, 1::2] = torch.cos(position * div_term) - self.register_buffer('pe', pe) + self.register_buffer("pe", pe) def forward(self, x): - x = x + self.pe[:x.size(0)] - return self.dropout(x) - + x = x + self.pe[: x.size(0)] + return self.dropout(x) class TransformerModel(torch.nn.Module): - - def __init__(self, ntoken, d_model, nhead, d_hid, - nlayers, dropout=0.5): + def __init__(self, ntoken, d_model, nhead, d_hid, nlayers, dropout=0.5): super().__init__() - self.model_type = 'Transformer' + self.model_type = "Transformer" encoder_layers = torch.nn.TransformerEncoderLayer(d_model, nhead, d_hid, dropout) self.transformer_encoder = torch.nn.TransformerEncoder(encoder_layers, nlayers) self.pos_encoder = PositionalEncoding(d_model, dropout) @@ -633,20 +671,17 @@ def forward(self, src, src_mask): output = self.decoder(output) return output - def generate_square_subsequent_mask(sz): - return torch.triu(torch.ones(sz, sz) * float('-inf'), diagonal=1) - + return torch.triu(torch.ones(sz, sz) * float("-inf"), diagonal=1) def get_batch(source, i): seq_len = min(bptt, len(source) - 1 - i) - data = source[i:i+seq_len] - target = source[i+1:i+1+seq_len].reshape(-1) + data = source[i : i + seq_len] + target = source[i + 1 : i + 1 + seq_len].reshape(-1) return data, target - criterion = torch.nn.CrossEntropyLoss() - device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") train_data = np.random.randint(1, 12455, 1000) ends = np.random.randint(2, 20, 100).cumsum() ends = ends[ends < train_data.shape[0] - 2] @@ -687,4 +722,4 @@ def get_batch(source, i): assert n_iter > 0 - del os.environ['ORTMODULE_SKIPCHECK_POLICY'] + del os.environ["ORTMODULE_SKIPCHECK_POLICY"] diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_poc.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_poc.py index 4727dcbb44c54..bb94a6c514977 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_poc.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_poc.py @@ -24,7 +24,7 @@ def forward(self, input1): def train(args, model, device, optimizer, loss_fn, train_loader, epoch): - print('\n======== Epoch {:} / {:} with batch size {:} ========'.format(epoch+1, args.epochs, args.batch_size)) + print("\n======== Epoch {:} / {:} with batch size {:} ========".format(epoch + 1, args.epochs, args.batch_size)) model.train() # Measure how long the training epoch takes. t0 = time.time() @@ -44,6 +44,7 @@ def train(args, model, device, optimizer, loss_fn, train_loader, epoch): if args.view_graphs: import torchviz + pytorch_backward_graph = torchviz.make_dot(probability, params=dict(list(model.named_parameters()))) pytorch_backward_graph.view() @@ -61,9 +62,15 @@ def train(args, model, device, optimizer, loss_fn, train_loader, epoch): if iteration % args.log_interval == 0: curr_time = time.time() elapsed_time = curr_time - start_time - print('[{:5}/{:5} ({:2.0f}%)]\tLoss: {:.6f}\tExecution time: {:.4f}'.format( - iteration * len(data), len(train_loader.dataset), - 100. * iteration / len(train_loader), loss, elapsed_time)) + print( + "[{:5}/{:5} ({:2.0f}%)]\tLoss: {:.6f}\tExecution time: {:.4f}".format( + iteration * len(data), + len(train_loader.dataset), + 100.0 * iteration / len(train_loader), + loss, + elapsed_time, + ) + ) start_time = curr_time # Calculate the average loss over the training data. @@ -93,56 +100,73 @@ def test(args, model, device, loss_fn, test_loader): pred = output.argmax(dim=1, keepdim=True) correct += pred.eq(target.view_as(pred)).sum().item() test_loss /= len(test_loader.dataset) - print('\nTest set: Batch size: {:}, Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format( - args.test_batch_size, test_loss, correct, len(test_loader.dataset), - 100. * correct / len(test_loader.dataset))) + print( + "\nTest set: Batch size: {:}, Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n".format( + args.test_batch_size, + test_loss, + correct, + len(test_loader.dataset), + 100.0 * correct / len(test_loader.dataset), + ) + ) # Report the final accuracy for this validation run. epoch_time = time.time() - t0 - accuracy = float(correct)/len(test_loader.dataset) + accuracy = float(correct) / len(test_loader.dataset) print(" Accuracy: {0:.2f}".format(accuracy)) print(" Validation took: {:.4f}s".format(epoch_time)) return epoch_time, accuracy + def my_loss(x, target, is_train=True): if is_train: return torch.nn.CrossEntropyLoss()(x, target) else: - return torch.nn.CrossEntropyLoss(reduction='sum')(x, target) + return torch.nn.CrossEntropyLoss(reduction="sum")(x, target) + def main(): # Training settings - parser = argparse.ArgumentParser(description='PyTorch MNIST Example') - parser.add_argument('--train-steps', type=int, default=-1, metavar='N', - help='number of steps to train. Set -1 to run through whole dataset (default: -1)') - parser.add_argument('--lr', type=float, default=0.01, metavar='LR', - help='learning rate (default: 0.01)') - parser.add_argument('--batch-size', type=int, default=32, metavar='N', - help='input batch size for training (default: 32)') - parser.add_argument('--test-batch-size', type=int, default=64, metavar='N', - help='input batch size for testing (default: 64)') - parser.add_argument('--no-cuda', action='store_true', default=False, - help='disables CUDA training') - parser.add_argument('--seed', type=int, default=42, metavar='S', - help='random seed (default: 42)') - parser.add_argument('--pytorch-only', action='store_true', default=False, - help='disables ONNX Runtime training') - parser.add_argument('--log-interval', type=int, default=300, metavar='N', - help='how many batches to wait before logging training status (default: 300)') - parser.add_argument('--view-graphs', action='store_true', default=False, - help='views forward and backward graphs') - parser.add_argument('--export-onnx-graphs', action='store_true', default=False, - help='export ONNX graphs to current directory') - parser.add_argument('--epochs', type=int, default=5, metavar='N', - help='number of epochs to train (default: 10)') - parser.add_argument('--log-level', choices=['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'], default='WARNING', - help='Log level (default: WARNING)') - parser.add_argument('--data-dir', type=str, default='./mnist', - help='Path to the mnist data directory') + parser = argparse.ArgumentParser(description="PyTorch MNIST Example") + parser.add_argument( + "--train-steps", + type=int, + default=-1, + metavar="N", + help="number of steps to train. Set -1 to run through whole dataset (default: -1)", + ) + parser.add_argument("--lr", type=float, default=0.01, metavar="LR", help="learning rate (default: 0.01)") + parser.add_argument( + "--batch-size", type=int, default=32, metavar="N", help="input batch size for training (default: 32)" + ) + parser.add_argument( + "--test-batch-size", type=int, default=64, metavar="N", help="input batch size for testing (default: 64)" + ) + parser.add_argument("--no-cuda", action="store_true", default=False, help="disables CUDA training") + parser.add_argument("--seed", type=int, default=42, metavar="S", help="random seed (default: 42)") + parser.add_argument("--pytorch-only", action="store_true", default=False, help="disables ONNX Runtime training") + parser.add_argument( + "--log-interval", + type=int, + default=300, + metavar="N", + help="how many batches to wait before logging training status (default: 300)", + ) + parser.add_argument("--view-graphs", action="store_true", default=False, help="views forward and backward graphs") + parser.add_argument( + "--export-onnx-graphs", action="store_true", default=False, help="export ONNX graphs to current directory" + ) + parser.add_argument("--epochs", type=int, default=5, metavar="N", help="number of epochs to train (default: 10)") + parser.add_argument( + "--log-level", + choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"], + default="WARNING", + help="Log level (default: WARNING)", + ) + parser.add_argument("--data-dir", type=str, default="./mnist", help="Path to the mnist data directory") args = parser.parse_args() - # Common setup torch.manual_seed(args.seed) onnxruntime.set_seed(args.seed) @@ -153,35 +177,45 @@ def main(): device = "cpu" ## Data loader - train_loader = torch.utils.data.DataLoader(datasets.MNIST(args.data_dir, train=True, download=True, - transform=transforms.Compose([transforms.ToTensor(), - transforms.Normalize((0.1307,), (0.3081,))])), - batch_size=args.batch_size, - shuffle=True) + train_loader = torch.utils.data.DataLoader( + datasets.MNIST( + args.data_dir, + train=True, + download=True, + transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]), + ), + batch_size=args.batch_size, + shuffle=True, + ) test_loader = None if args.test_batch_size > 0: test_loader = torch.utils.data.DataLoader( - datasets.MNIST(args.data_dir, train=False, transform=transforms.Compose([ - transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])), - batch_size=args.test_batch_size, shuffle=True) + datasets.MNIST( + args.data_dir, + train=False, + transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]), + ), + batch_size=args.test_batch_size, + shuffle=True, + ) # Model architecture model = NeuralNet(input_size=784, hidden_size=500, num_classes=10).to(device) if not args.pytorch_only: - print('Training MNIST on ORTModule....') + print("Training MNIST on ORTModule....") # Just for future debugging - debug_options = DebugOptions(save_onnx=args.export_onnx_graphs, onnx_prefix='MNIST') + debug_options = DebugOptions(save_onnx=args.export_onnx_graphs, onnx_prefix="MNIST") model = ORTModule(model, debug_options) # Set log level numeric_level = getattr(logging, args.log_level.upper(), None) if not isinstance(numeric_level, int): - raise ValueError('Invalid log level: %s' % args.log_level) + raise ValueError("Invalid log level: %s" % args.log_level) logging.basicConfig(level=numeric_level) else: - print('Training MNIST on vanilla PyTorch....') + print("Training MNIST on vanilla PyTorch....") optimizer = torch.optim.SGD(model.parameters(), lr=args.lr) # Train loop @@ -196,11 +230,11 @@ def main(): assert validation_accuracy > 0.92 - print('\n======== Global stats ========') + print("\n======== Global stats ========") if not args.pytorch_only: estimated_export = 0 if args.epochs > 1: - estimated_export = epoch_0_training - (total_training_time - epoch_0_training)/(args.epochs-1) + estimated_export = epoch_0_training - (total_training_time - epoch_0_training) / (args.epochs - 1) print(" Estimated ONNX export took: {:.4f}s".format(estimated_export)) else: print(" Estimated ONNX export took: Estimate available when epochs > 1 only") @@ -209,5 +243,5 @@ def main(): print(" Accumulated validation took: {:.4f}s".format(total_test_time)) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_pytorch_ddp.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_pytorch_ddp.py index 73e3afea585e8..93426659991fe 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_pytorch_ddp.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_pytorch_ddp.py @@ -18,12 +18,13 @@ def setup(rank, world_size): - os.environ['MASTER_ADDR'] = 'localhost' - os.environ['MASTER_PORT'] = '12355' + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = "12355" # initialize the process group dist.init_process_group("gloo", rank=rank, world_size=world_size) + def cleanup(): dist.destroy_process_group() @@ -48,9 +49,9 @@ def demo_basic(rank, world_size, use_ort_module): model = ToyModel().to(rank) if use_ort_module: model = ORTModule(model) - print(f" Rank {rank} uses ORTModule."); + print(f" Rank {rank} uses ORTModule.") else: - print(f" Rank {rank} uses Pytorch's nn.Module."); + print(f" Rank {rank} uses Pytorch's nn.Module.") ddp_model = DDP(model, device_ids=[rank]) @@ -74,23 +75,26 @@ def demo_basic(rank, world_size, use_ort_module): loss_history.append(torch.unsqueeze(loss, 0)) loss_history = torch.cat(loss_history).cpu() - expected_loss_history = torch.FloatTensor([1.4909229278564453, 1.432194471359253, 1.39592707157135, 1.367714762687683, 1.3445055484771729]) + expected_loss_history = torch.FloatTensor( + [1.4909229278564453, 1.432194471359253, 1.39592707157135, 1.367714762687683, 1.3445055484771729] + ) assert torch.allclose(expected_loss_history, loss_history) cleanup() + def demo_checkpoint(rank, world_size, use_ort_module): torch.manual_seed(rank) print(f"Running DDP checkpoint example on rank {rank}.") setup(rank, world_size) if use_ort_module: - print(f" Rank {rank} uses ORTModule."); + print(f" Rank {rank} uses ORTModule.") model = ToyModel().to(rank) model = ORTModule(model) else: - print(f" Rank {rank} uses Pytorch's nn.Module."); + print(f" Rank {rank} uses Pytorch's nn.Module.") model = ToyModel().to(rank) ddp_model = DDP(model, device_ids=[rank]) @@ -109,9 +113,8 @@ def demo_checkpoint(rank, world_size, use_ort_module): # 0 saves it. dist.barrier() # configure map_location properly - map_location = {'cuda:%d' % 0: 'cuda:%d' % rank} - ddp_model.load_state_dict( - torch.load(CHECKPOINT_PATH, map_location=map_location)) + map_location = {"cuda:%d" % 0: "cuda:%d" % rank} + ddp_model.load_state_dict(torch.load(CHECKPOINT_PATH, map_location=map_location)) optimizer.zero_grad() outputs = ddp_model(torch.randn(20, 10)) @@ -143,17 +146,17 @@ def demo_checkpoint(rank, world_size, use_ort_module): cleanup() + def run_demo(demo_fn, world_size, use_ort_module): - mp.spawn(demo_fn, - args=(world_size, use_ort_module), - nprocs=world_size, - join=True) + mp.spawn(demo_fn, args=(world_size, use_ort_module), nprocs=world_size, join=True) + def parse_args(): parser = argparse.ArgumentParser() - parser.add_argument('--use_ort_module', action='store_true') + parser.add_argument("--use_ort_module", action="store_true") return parser.parse_args() + if __name__ == "__main__": args = parse_args() run_demo(demo_basic, 4, args.use_ort_module) diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_torch_lightning_basic.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_torch_lightning_basic.py index 7c8e1e5a552fb..9f8f273837d85 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_torch_lightning_basic.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_torch_lightning_basic.py @@ -14,20 +14,11 @@ class LitAutoEncoder(pl.LightningModule): - def __init__(self, lr, use_ortmodule=True): super().__init__() self.lr = lr - self.encoder = nn.Sequential( - nn.Linear(28*28, 64), - nn.ReLU(), - nn.Linear(64, 3) - ) - self.decoder = nn.Sequential( - nn.Linear(3, 64), - nn.ReLU(), - nn.Linear(64, 28*28) - ) + self.encoder = nn.Sequential(nn.Linear(28 * 28, 64), nn.ReLU(), nn.Linear(64, 3)) + self.decoder = nn.Sequential(nn.Linear(3, 64), nn.ReLU(), nn.Linear(64, 28 * 28)) if use_ortmodule: self.encoder = ORTModule(self.encoder) self.decoder = ORTModule(self.decoder) @@ -48,7 +39,7 @@ def training_step(self, batch, batch_idx): x_hat = self.decoder(z) loss = F.mse_loss(x_hat, x) # Logging to TensorBoard by default - self.log('train_loss', loss) + self.log("train_loss", loss) return loss def configure_optimizers(self): @@ -58,23 +49,23 @@ def configure_optimizers(self): def main(): # Training settings - parser = argparse.ArgumentParser(description='PyTorch MNIST Example') - parser.add_argument('--train-steps', type=int, default=-1, metavar='N', - help='number of steps to train. Set -1 to run through whole dataset (default: -1)') - parser.add_argument('--lr', type=float, default=0.01, metavar='LR', - help='learning rate (default: 0.01)') - parser.add_argument('--batch-size', type=int, default=32, metavar='N', - help='input batch size for training (default: 32)') - parser.add_argument('--no-cuda', action='store_true', default=False, - help='disables CUDA training') - parser.add_argument('--seed', type=int, default=42, metavar='S', - help='random seed (default: 42)') - parser.add_argument('--pytorch-only', action='store_true', default=False, - help='disables ONNX Runtime training') - parser.add_argument('--epochs', type=int, default=5, metavar='N', - help='number of epochs to train (default: 10)') - parser.add_argument('--data-dir', type=str, default='./mnist', - help='Path to the mnist data directory') + parser = argparse.ArgumentParser(description="PyTorch MNIST Example") + parser.add_argument( + "--train-steps", + type=int, + default=-1, + metavar="N", + help="number of steps to train. Set -1 to run through whole dataset (default: -1)", + ) + parser.add_argument("--lr", type=float, default=0.01, metavar="LR", help="learning rate (default: 0.01)") + parser.add_argument( + "--batch-size", type=int, default=32, metavar="N", help="input batch size for training (default: 32)" + ) + parser.add_argument("--no-cuda", action="store_true", default=False, help="disables CUDA training") + parser.add_argument("--seed", type=int, default=42, metavar="S", help="random seed (default: 42)") + parser.add_argument("--pytorch-only", action="store_true", default=False, help="disables ONNX Runtime training") + parser.add_argument("--epochs", type=int, default=5, metavar="N", help="number of epochs to train (default: 10)") + parser.add_argument("--data-dir", type=str, default="./mnist", help="Path to the mnist data directory") args = parser.parse_args() @@ -96,15 +87,15 @@ def main(): # Train loop kwargs = {} - if device == 'cuda': - kwargs.update({'gpus': 1}) + if device == "cuda": + kwargs.update({"gpus": 1}) if args.train_steps > 0: - kwargs.update({'max_steps': args.train_steps}) + kwargs.update({"max_steps": args.train_steps}) if args.epochs > 0: - kwargs.update({'max_epochs': args.epochs}) + kwargs.update({"max_epochs": args.epochs}) trainer = pl.Trainer(**kwargs) trainer.fit(autoencoder, train_loader) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/orttraining/orttraining/test/python/orttraining_test_orttrainer_bert_toy_onnx.py b/orttraining/orttraining/test/python/orttraining_test_orttrainer_bert_toy_onnx.py index b62a63e9be203..531085f21ce61 100644 --- a/orttraining/orttraining/test/python/orttraining_test_orttrainer_bert_toy_onnx.py +++ b/orttraining/orttraining/test/python/orttraining_test_orttrainer_bert_toy_onnx.py @@ -10,13 +10,22 @@ import torch import onnxruntime -from onnxruntime.capi.ort_trainer import IODescription as Legacy_IODescription,\ - ModelDescription as Legacy_ModelDescription,\ - LossScaler as Legacy_LossScaler,\ - ORTTrainer as Legacy_ORTTrainer -from onnxruntime.training import _utils, amp, checkpoint, optim, orttrainer, TrainStepInfo,\ - model_desc_validation as md_val,\ - orttrainer_options as orttrainer_options +from onnxruntime.capi.ort_trainer import ( + IODescription as Legacy_IODescription, + ModelDescription as Legacy_ModelDescription, + LossScaler as Legacy_LossScaler, + ORTTrainer as Legacy_ORTTrainer, +) +from onnxruntime.training import ( + _utils, + amp, + checkpoint, + optim, + orttrainer, + TrainStepInfo, + model_desc_validation as md_val, + orttrainer_options as orttrainer_options, +) import _test_commons, _test_helpers @@ -25,17 +34,17 @@ ############################################################################### -def generate_random_input_from_model_desc(desc, seed=1, device = "cuda:0"): - '''Generates a sample input for the BERT model using the model desc''' +def generate_random_input_from_model_desc(desc, seed=1, device="cuda:0"): + """Generates a sample input for the BERT model using the model desc""" torch.manual_seed(seed) onnxruntime.set_seed(seed) dtype = torch.int64 vocab_size = 30528 num_classes = [vocab_size, 2, 2, vocab_size, 2] - dims = {"batch_size":16, "seq_len":1} + dims = {"batch_size": 16, "seq_len": 1} sample_input = [] - for index, input in enumerate(desc['inputs']): + for index, input in enumerate(desc["inputs"]): size = [] for s in input[1]: if isinstance(s, (int)): @@ -45,45 +54,92 @@ def generate_random_input_from_model_desc(desc, seed=1, device = "cuda:0"): sample_input.append(torch.randint(0, num_classes[index], tuple(size), dtype=dtype).to(device)) return sample_input + # EXPERIMENTAL HELPER FUNCTIONS + def bert_model_description(dynamic_shape=True): - '''Creates the model description dictionary with static dimensions''' + """Creates the model description dictionary with static dimensions""" if dynamic_shape: - model_desc = {'inputs': [('input_ids', ['batch_size', 'seq_len']), - ('segment_ids', ['batch_size', 'seq_len'],), - ('input_mask', ['batch_size', 'seq_len'],), - ('masked_lm_labels', ['batch_size', 'seq_len'],), - ('next_sentence_labels', ['batch_size', ],)], - 'outputs': [('loss', [], True)]} + model_desc = { + "inputs": [ + ("input_ids", ["batch_size", "seq_len"]), + ( + "segment_ids", + ["batch_size", "seq_len"], + ), + ( + "input_mask", + ["batch_size", "seq_len"], + ), + ( + "masked_lm_labels", + ["batch_size", "seq_len"], + ), + ( + "next_sentence_labels", + [ + "batch_size", + ], + ), + ], + "outputs": [("loss", [], True)], + } else: batch_size = 16 seq_len = 1 - model_desc = {'inputs': [('input_ids', [batch_size, seq_len]), - ('segment_ids', [batch_size, seq_len],), - ('input_mask', [batch_size, seq_len],), - ('masked_lm_labels', [batch_size, seq_len],), - ('next_sentence_labels', [batch_size, ],)], - 'outputs': [('loss', [], True)]} + model_desc = { + "inputs": [ + ("input_ids", [batch_size, seq_len]), + ( + "segment_ids", + [batch_size, seq_len], + ), + ( + "input_mask", + [batch_size, seq_len], + ), + ( + "masked_lm_labels", + [batch_size, seq_len], + ), + ( + "next_sentence_labels", + [ + batch_size, + ], + ), + ], + "outputs": [("loss", [], True)], + } return model_desc def optimizer_parameters(model): - '''A method to assign different hyper parameters for different model parameter groups''' + """A method to assign different hyper parameters for different model parameter groups""" no_decay_keys = ["bias", "gamma", "beta", "LayerNorm"] no_decay_param_group = [] for initializer in model.graph.initializer: if any(key in initializer.name for key in no_decay_keys): no_decay_param_group.append(initializer.name) - params = [{'params': no_decay_param_group, "alpha": 0.9, "beta": 0.999, "lambda_coef": 0.0, "epsilon": 1e-6, "do_bias_correction":False}] + params = [ + { + "params": no_decay_param_group, + "alpha": 0.9, + "beta": 0.999, + "lambda_coef": 0.0, + "epsilon": 1e-6, + "do_bias_correction": False, + } + ] return params def load_bert_onnx_model(): - bert_onnx_model_path = os.path.join('testdata', "bert_toy_postprocessed.onnx") + bert_onnx_model_path = os.path.join("testdata", "bert_toy_postprocessed.onnx") model = onnx.load(bert_onnx_model_path) return model @@ -101,9 +157,11 @@ def update(self, train_step_info): self.loss_scale *= 0.9 return self.loss_scale + # LEGACY HELPER FUNCTIONS -class LegacyCustomLossScaler(): + +class LegacyCustomLossScaler: def __init__(self, loss_scale=float(1 << 16)): self._initial_loss_scale = loss_scale self.loss_scale_ = loss_scale @@ -115,27 +173,41 @@ def update_loss_scale(self, is_all_finite): self.loss_scale_ *= 0.9 -def legacy_model_params(lr, device = torch.device("cuda", 0)): +def legacy_model_params(lr, device=torch.device("cuda", 0)): legacy_model_desc = legacy_bert_model_description() learning_rate_description = legacy_ort_trainer_learning_rate_description() learning_rate = torch.tensor([lr]).to(device) return (legacy_model_desc, learning_rate_description, learning_rate) + def legacy_ort_trainer_learning_rate_description(): - return Legacy_IODescription('Learning_Rate', [1, ], torch.float32) + return Legacy_IODescription( + "Learning_Rate", + [ + 1, + ], + torch.float32, + ) def legacy_bert_model_description(): vocab_size = 30528 - input_ids_desc = Legacy_IODescription('input_ids', ['batch', 'max_seq_len_in_batch']) - segment_ids_desc = Legacy_IODescription('segment_ids', ['batch', 'max_seq_len_in_batch']) - input_mask_desc = Legacy_IODescription('input_mask', ['batch', 'max_seq_len_in_batch']) - masked_lm_labels_desc = Legacy_IODescription('masked_lm_labels', ['batch', 'max_seq_len_in_batch']) - next_sentence_labels_desc = Legacy_IODescription('next_sentence_labels', ['batch', ]) - loss_desc = Legacy_IODescription('loss', []) - - return Legacy_ModelDescription([input_ids_desc, segment_ids_desc, input_mask_desc, masked_lm_labels_desc, - next_sentence_labels_desc], [loss_desc]) + input_ids_desc = Legacy_IODescription("input_ids", ["batch", "max_seq_len_in_batch"]) + segment_ids_desc = Legacy_IODescription("segment_ids", ["batch", "max_seq_len_in_batch"]) + input_mask_desc = Legacy_IODescription("input_mask", ["batch", "max_seq_len_in_batch"]) + masked_lm_labels_desc = Legacy_IODescription("masked_lm_labels", ["batch", "max_seq_len_in_batch"]) + next_sentence_labels_desc = Legacy_IODescription( + "next_sentence_labels", + [ + "batch", + ], + ) + loss_desc = Legacy_IODescription("loss", []) + + return Legacy_ModelDescription( + [input_ids_desc, segment_ids_desc, input_mask_desc, masked_lm_labels_desc, next_sentence_labels_desc], + [loss_desc], + ) def legacy_optim_params_a(name): @@ -143,7 +215,7 @@ def legacy_optim_params_a(name): def legacy_optim_params_b(name): - params = ['bert.embeddings.LayerNorm.bias', 'bert.embeddings.LayerNorm.weight'] + params = ["bert.embeddings.LayerNorm.bias", "bert.embeddings.LayerNorm.weight"] if name in params: return {"alpha": 0.9, "beta": 0.999, "lambda": 0.0, "epsilon": 1e-6, "do_bias_correction": False} return {"alpha": 0.9, "beta": 0.999, "lambda": 0.01, "epsilon": 1e-6, "do_bias_correction": False} @@ -151,25 +223,23 @@ def legacy_optim_params_b(name): def legacy_optim_params_c(name): params_group = optimizer_parameters(load_bert_onnx_model()) - if name in params_group[0]['params']: + if name in params_group[0]["params"]: return {"alpha": 0.9, "beta": 0.999, "lambda": 0.0, "epsilon": 1e-6, "do_bias_correction": False} return {"alpha": 0.9, "beta": 0.999, "lambda": 0.01, "epsilon": 1e-6, "do_bias_correction": False} + ############################################################################### # Testing starts here ######################################################### ############################################################################### -@pytest.mark.parametrize("dynamic_shape", [ - (True), - (False) -]) +@pytest.mark.parametrize("dynamic_shape", [(True), (False)]) def testToyBERTModelBasicTraining(dynamic_shape): model_desc = bert_model_description(dynamic_shape) model = load_bert_onnx_model() optim_config = optim.LambConfig() - opts = orttrainer.ORTTrainerOptions({}) + opts = orttrainer.ORTTrainerOptions({}) trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, options=opts) for i in range(10): @@ -178,14 +248,14 @@ def testToyBERTModelBasicTraining(dynamic_shape): assert output.shape == torch.Size([]) -@pytest.mark.parametrize("expected_losses", [ - ([11.041123, 10.986166, 11.101636, 11.013366, 11.03775 , - 11.041175, 10.957118, 11.069563, 11.040824, 11.16437]) -]) +@pytest.mark.parametrize( + "expected_losses", + [([11.041123, 10.986166, 11.101636, 11.013366, 11.03775, 11.041175, 10.957118, 11.069563, 11.040824, 11.16437])], +) def testToyBERTDeterministicCheck(expected_losses): # Common setup train_steps = 10 - device = 'cuda' + device = "cuda" seed = 1 rtol = 1e-3 torch.manual_seed(seed) @@ -196,14 +266,14 @@ def testToyBERTDeterministicCheck(expected_losses): model = load_bert_onnx_model() params = optimizer_parameters(model) optim_config = optim.LambConfig() - opts = orttrainer.ORTTrainerOptions({ - 'debug' : { - 'deterministic_compute': True - }, - 'device': { - 'id': device, - }, - }) + opts = orttrainer.ORTTrainerOptions( + { + "debug": {"deterministic_compute": True}, + "device": { + "id": device, + }, + } + ) trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, options=opts) # Train @@ -216,47 +286,148 @@ def testToyBERTDeterministicCheck(expected_losses): _test_helpers.assert_model_outputs(experimental_losses, expected_losses, rtol=rtol) -@pytest.mark.parametrize("initial_lr, lr_scheduler, expected_learning_rates, expected_losses", [ - (1.0, optim.lr_scheduler.ConstantWarmupLRScheduler,\ - [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], - [10.988012313842773, 10.99213981628418, 120.79301452636719, 36.11647033691406, 95.83200073242188,\ - 221.2766571044922, 208.40316772460938, 279.5332946777344, 402.46380615234375, 325.79254150390625]), - (0.5, optim.lr_scheduler.ConstantWarmupLRScheduler,\ - [0.0, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5], - [10.988012313842773, 10.99213981628418, 52.69743347167969, 19.741533279418945, 83.88340759277344,\ - 126.39848327636719, 91.53898620605469, 63.62016296386719, 102.21206665039062, 180.1424560546875]), - (1.0, optim.lr_scheduler.CosineWarmupLRScheduler,\ - [0.0, 0.9931806517013612, 0.9397368756032445, 0.8386407858128706, 0.7008477123264848, 0.5412896727361662,\ - 0.37725725642960045, 0.22652592093878665, 0.10542974530180327, 0.02709137914968268], - [10.988012313842773, 10.99213981628418, 120.6441650390625, 32.152557373046875, 89.63705444335938,\ - 138.8782196044922, 117.57748413085938, 148.01927185058594, 229.60403442382812, 110.2930908203125]), - (1.0, optim.lr_scheduler.LinearWarmupLRScheduler,\ - [0.0, 0.9473684210526315, 0.8421052631578947, 0.7368421052631579, 0.631578947368421, 0.5263157894736842,\ - 0.42105263157894735, 0.3157894736842105, 0.21052631578947367, 0.10526315789473684], - [10.988012313842773, 10.99213981628418, 112.89633178710938, 31.114538192749023, 80.94029235839844,\ - 131.34490966796875, 111.4329605102539, 133.74252319335938, 219.37344360351562, 109.67041015625]), - (1.0, optim.lr_scheduler.PolyWarmupLRScheduler,\ - [0.0, 0.9473684263157895, 0.8421052789473684, 0.7368421315789474, 0.6315789842105263, 0.5263158368421054, - 0.42105268947368424, 0.31578954210526317, 0.21052639473684212, 0.10526324736842106], - [10.988012313842773, 10.99213981628418, 112.89633178710938, 31.114538192749023, 80.9402847290039,\ - 131.3447265625, 111.43253326416016, 133.7415008544922, 219.37147521972656, 109.66986083984375]) -]) +@pytest.mark.parametrize( + "initial_lr, lr_scheduler, expected_learning_rates, expected_losses", + [ + ( + 1.0, + optim.lr_scheduler.ConstantWarmupLRScheduler, + [0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], + [ + 10.988012313842773, + 10.99213981628418, + 120.79301452636719, + 36.11647033691406, + 95.83200073242188, + 221.2766571044922, + 208.40316772460938, + 279.5332946777344, + 402.46380615234375, + 325.79254150390625, + ], + ), + ( + 0.5, + optim.lr_scheduler.ConstantWarmupLRScheduler, + [0.0, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5], + [ + 10.988012313842773, + 10.99213981628418, + 52.69743347167969, + 19.741533279418945, + 83.88340759277344, + 126.39848327636719, + 91.53898620605469, + 63.62016296386719, + 102.21206665039062, + 180.1424560546875, + ], + ), + ( + 1.0, + optim.lr_scheduler.CosineWarmupLRScheduler, + [ + 0.0, + 0.9931806517013612, + 0.9397368756032445, + 0.8386407858128706, + 0.7008477123264848, + 0.5412896727361662, + 0.37725725642960045, + 0.22652592093878665, + 0.10542974530180327, + 0.02709137914968268, + ], + [ + 10.988012313842773, + 10.99213981628418, + 120.6441650390625, + 32.152557373046875, + 89.63705444335938, + 138.8782196044922, + 117.57748413085938, + 148.01927185058594, + 229.60403442382812, + 110.2930908203125, + ], + ), + ( + 1.0, + optim.lr_scheduler.LinearWarmupLRScheduler, + [ + 0.0, + 0.9473684210526315, + 0.8421052631578947, + 0.7368421052631579, + 0.631578947368421, + 0.5263157894736842, + 0.42105263157894735, + 0.3157894736842105, + 0.21052631578947367, + 0.10526315789473684, + ], + [ + 10.988012313842773, + 10.99213981628418, + 112.89633178710938, + 31.114538192749023, + 80.94029235839844, + 131.34490966796875, + 111.4329605102539, + 133.74252319335938, + 219.37344360351562, + 109.67041015625, + ], + ), + ( + 1.0, + optim.lr_scheduler.PolyWarmupLRScheduler, + [ + 0.0, + 0.9473684263157895, + 0.8421052789473684, + 0.7368421315789474, + 0.6315789842105263, + 0.5263158368421054, + 0.42105268947368424, + 0.31578954210526317, + 0.21052639473684212, + 0.10526324736842106, + ], + [ + 10.988012313842773, + 10.99213981628418, + 112.89633178710938, + 31.114538192749023, + 80.9402847290039, + 131.3447265625, + 111.43253326416016, + 133.7415008544922, + 219.37147521972656, + 109.66986083984375, + ], + ), + ], +) def testToyBERTModelLRScheduler(initial_lr, lr_scheduler, expected_learning_rates, expected_losses): - return # TODO: re-enable after nondeterminism on backend is fixed + return # TODO: re-enable after nondeterminism on backend is fixed # Common setup - device = 'cuda' + device = "cuda" total_steps = 10 seed = 1 warmup = 0.05 cycles = 0.5 - power = 1. + power = 1.0 lr_end = 1e-7 rtol = 1e-3 torch.manual_seed(seed) onnxruntime.set_seed(seed) # Setup LR Schedulers - if lr_scheduler == optim.lr_scheduler.ConstantWarmupLRScheduler or lr_scheduler == optim.lr_scheduler.LinearWarmupLRScheduler: + if ( + lr_scheduler == optim.lr_scheduler.ConstantWarmupLRScheduler + or lr_scheduler == optim.lr_scheduler.LinearWarmupLRScheduler + ): lr_scheduler = lr_scheduler(total_steps=total_steps, warmup=warmup) elif lr_scheduler == optim.lr_scheduler.CosineWarmupLRScheduler: lr_scheduler = lr_scheduler(total_steps=total_steps, warmup=warmup, cycles=cycles) @@ -269,15 +440,15 @@ def testToyBERTModelLRScheduler(initial_lr, lr_scheduler, expected_learning_rate model_desc = bert_model_description() model = load_bert_onnx_model() optim_config = optim.AdamConfig(lr=initial_lr) - opts = orttrainer.ORTTrainerOptions({ - 'debug' : { - 'deterministic_compute': True - }, - 'device': { - 'id': device, - }, - 'lr_scheduler' : lr_scheduler - }) + opts = orttrainer.ORTTrainerOptions( + { + "debug": {"deterministic_compute": True}, + "device": { + "id": device, + }, + "lr_scheduler": lr_scheduler, + } + ) trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, options=opts) # Train @@ -293,18 +464,60 @@ def testToyBERTModelLRScheduler(initial_lr, lr_scheduler, expected_learning_rate _test_helpers.assert_model_outputs(losses, expected_losses, rtol=rtol) -@pytest.mark.parametrize("loss_scaler, expected_losses", [ - (None, [11.041126, 10.986309, 11.101673, 11.013394, 11.037781, - 11.041253, 10.957072, 11.069506, 11.040807, 11.164349]), - (amp.DynamicLossScaler(), [11.041126, 10.986309, 11.101673, 11.013394, - 11.037781, 11.041253, 10.957072, 11.069506, 11.040807, 11.164349]), - (CustomLossScaler(), [11.041126, 10.986309, 11.101645, 11.013412, - 11.037757, 11.041273, 10.957077, 11.069525, 11.040765, 11.164298]) -]) +@pytest.mark.parametrize( + "loss_scaler, expected_losses", + [ + ( + None, + [ + 11.041126, + 10.986309, + 11.101673, + 11.013394, + 11.037781, + 11.041253, + 10.957072, + 11.069506, + 11.040807, + 11.164349, + ], + ), + ( + amp.DynamicLossScaler(), + [ + 11.041126, + 10.986309, + 11.101673, + 11.013394, + 11.037781, + 11.041253, + 10.957072, + 11.069506, + 11.040807, + 11.164349, + ], + ), + ( + CustomLossScaler(), + [ + 11.041126, + 10.986309, + 11.101645, + 11.013412, + 11.037757, + 11.041273, + 10.957077, + 11.069525, + 11.040765, + 11.164298, + ], + ), + ], +) def testToyBERTModelMixedPrecisionLossScaler(loss_scaler, expected_losses): # Common setup total_steps = 10 - device = 'cuda' + device = "cuda" seed = 1 rtol = 1e-3 torch.manual_seed(seed) @@ -314,18 +527,15 @@ def testToyBERTModelMixedPrecisionLossScaler(loss_scaler, expected_losses): model_desc = bert_model_description() model = load_bert_onnx_model() optim_config = optim.LambConfig() - opts = orttrainer.ORTTrainerOptions({ - 'debug' : { - 'deterministic_compute': True - }, - 'device': { - 'id': device, - }, - 'mixed_precision': { - 'enabled': True, - 'loss_scaler': loss_scaler + opts = orttrainer.ORTTrainerOptions( + { + "debug": {"deterministic_compute": True}, + "device": { + "id": device, + }, + "mixed_precision": {"enabled": True, "loss_scaler": loss_scaler}, } - }) + ) trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, options=opts) # Train @@ -338,14 +548,56 @@ def testToyBERTModelMixedPrecisionLossScaler(loss_scaler, expected_losses): _test_helpers.assert_model_outputs(losses, expected_losses, rtol=rtol) -@pytest.mark.parametrize("gradient_accumulation_steps, expected_losses", [ - (1, [11.041123, 10.986166, 11.101636, 11.013366, 11.03775, - 11.041175, 10.957118, 11.069563, 11.040824, 11.16437]), - (4, [11.041123, 10.982856, 11.105512, 11.006721, 11.03358, - 11.05058, 10.955864, 11.059035, 11.037753, 11.162649]), - (7, [11.041123, 10.982856, 11.105512, 11.006721, 11.036314, - 11.055109, 10.960751, 11.05809 , 11.038856, 11.159635]) -]) +@pytest.mark.parametrize( + "gradient_accumulation_steps, expected_losses", + [ + ( + 1, + [ + 11.041123, + 10.986166, + 11.101636, + 11.013366, + 11.03775, + 11.041175, + 10.957118, + 11.069563, + 11.040824, + 11.16437, + ], + ), + ( + 4, + [ + 11.041123, + 10.982856, + 11.105512, + 11.006721, + 11.03358, + 11.05058, + 10.955864, + 11.059035, + 11.037753, + 11.162649, + ], + ), + ( + 7, + [ + 11.041123, + 10.982856, + 11.105512, + 11.006721, + 11.036314, + 11.055109, + 10.960751, + 11.05809, + 11.038856, + 11.159635, + ], + ), + ], +) def testToyBERTModelGradientAccumulation(gradient_accumulation_steps, expected_losses): # Common setup total_steps = 10 @@ -359,17 +611,15 @@ def testToyBERTModelGradientAccumulation(gradient_accumulation_steps, expected_l model_desc = bert_model_description() model = load_bert_onnx_model() optim_config = optim.LambConfig() - opts = orttrainer.ORTTrainerOptions({ - 'debug' : { - 'deterministic_compute': True - }, - 'device': { - 'id': device, - }, - 'batch' : { - 'gradient_accumulation_steps' : gradient_accumulation_steps - }, - }) + opts = orttrainer.ORTTrainerOptions( + { + "debug": {"deterministic_compute": True}, + "device": { + "id": device, + }, + "batch": {"gradient_accumulation_steps": gradient_accumulation_steps}, + } + ) trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, options=opts) # Train @@ -388,7 +638,7 @@ def testToyBertCheckpointBasic(): torch.manual_seed(seed) onnxruntime.set_seed(seed) optim_config = optim.LambConfig() - opts = orttrainer.ORTTrainerOptions({'debug' : {'deterministic_compute': True}}) + opts = orttrainer.ORTTrainerOptions({"debug": {"deterministic_compute": True}}) # Create ORTTrainer and save initial state in a dict model = load_bert_onnx_model() @@ -399,15 +649,15 @@ def testToyBertCheckpointBasic(): ## All initializers must be present in the state_dict ## when the specified model for ORTTRainer is an ONNX model for param in trainer._onnx_model.graph.initializer: - assert param.name in sd['model']['full_precision'] + assert param.name in sd["model"]["full_precision"] ## Modify one of the state values and load into ORTTrainer - sd['model']['full_precision']['bert.encoder.layer.0.attention.output.LayerNorm.weight'] += 10 + sd["model"]["full_precision"]["bert.encoder.layer.0.attention.output.LayerNorm.weight"] += 10 trainer.load_state_dict(sd) ## Save a checkpoint - ckpt_dir = 'testdata' - trainer.save_checkpoint(os.path.join(ckpt_dir, 'bert_toy_save_test.ortcp')) + ckpt_dir = "testdata" + trainer.save_checkpoint(os.path.join(ckpt_dir, "bert_toy_save_test.ortcp")) del trainer del model @@ -415,7 +665,7 @@ def testToyBertCheckpointBasic(): model2 = load_bert_onnx_model() model_desc2 = bert_model_description() trainer2 = orttrainer.ORTTrainer(model2, model_desc2, optim_config, options=opts) - trainer2.load_checkpoint(os.path.join(ckpt_dir, 'bert_toy_save_test.ortcp')) + trainer2.load_checkpoint(os.path.join(ckpt_dir, "bert_toy_save_test.ortcp")) loaded_sd = trainer2.state_dict() # Assert whether original state and the one loaded from checkpoint matches @@ -428,8 +678,12 @@ def testToyBertCheckpointFrozenWeights(): total_steps = 10 torch.manual_seed(seed) onnxruntime.set_seed(seed) - opts = orttrainer.ORTTrainerOptions({'debug' : {'deterministic_compute': True}, - 'utils' : {'frozen_weights' : ['bert.encoder.layer.0.attention.self.value.weight']}}) + opts = orttrainer.ORTTrainerOptions( + { + "debug": {"deterministic_compute": True}, + "utils": {"frozen_weights": ["bert.encoder.layer.0.attention.self.value.weight"]}, + } + ) # Create ORTTrainer and save initial state in a dict model = load_bert_onnx_model() @@ -461,26 +715,34 @@ def testToyBertCheckpointFrozenWeights(): loaded_state_dict = trainer2.state_dict() _test_commons.assert_all_states_close_ort(state_dict, loaded_state_dict) -@pytest.mark.parametrize("optimizer, mixedprecision_enabled", [ - (optim.LambConfig(), False), - (optim.AdamConfig(), False), - (optim.LambConfig(), True), - (optim.AdamConfig(), True), -]) + +@pytest.mark.parametrize( + "optimizer, mixedprecision_enabled", + [ + (optim.LambConfig(), False), + (optim.AdamConfig(), False), + (optim.LambConfig(), True), + (optim.AdamConfig(), True), + ], +) def testToyBertLoadOptimState(optimizer, mixedprecision_enabled): # Common setup rtol = 1e-03 - device = 'cuda' + device = "cuda" seed = 1 torch.manual_seed(seed) onnxruntime.set_seed(seed) optim_config = optimizer - opts = orttrainer.ORTTrainerOptions({'debug' : {'deterministic_compute': True}, - 'device' : {'id' : device}, - 'mixed_precision': { - 'enabled': mixedprecision_enabled, - }, - 'distributed' : {'allreduce_post_accumulation' : True}}) + opts = orttrainer.ORTTrainerOptions( + { + "debug": {"deterministic_compute": True}, + "device": {"id": device}, + "mixed_precision": { + "enabled": mixedprecision_enabled, + }, + "distributed": {"allreduce_post_accumulation": True}, + } + ) # Create ORTTrainer and save initial state in a dict model = load_bert_onnx_model() @@ -488,29 +750,81 @@ def testToyBertLoadOptimState(optimizer, mixedprecision_enabled): dummy_init_state = _test_commons.generate_dummy_optim_state(model, optimizer) trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, options=opts) trainer.load_state_dict(dummy_init_state) - + # Expected values - input_ids = torch.tensor([[26598],[21379],[19922],[ 5219],[ 5644],[20559],[23777],[25672],[22969],[16824],[16822],[635],[27399],[20647],[18519],[15546]], device=device) - segment_ids = torch.tensor([[0],[1],[0],[1],[0],[0],[1],[0],[0],[1],[1],[0],[0],[1],[1],[1]], device=device) - input_mask = torch.tensor([[0],[0],[0],[0],[1],[1],[1],[0],[1],[1],[0],[0],[0],[1],[0],[0]], device=device) - masked_lm_labels = torch.tensor([[25496],[16184],[11005],[16228],[14884],[21660],[ 8678],[23083],[ 4027],[ 8397],[11921],[ 1333],[26482],[ 1666],[17925],[27978]], device=device) + input_ids = torch.tensor( + [ + [26598], + [21379], + [19922], + [5219], + [5644], + [20559], + [23777], + [25672], + [22969], + [16824], + [16822], + [635], + [27399], + [20647], + [18519], + [15546], + ], + device=device, + ) + segment_ids = torch.tensor( + [[0], [1], [0], [1], [0], [0], [1], [0], [0], [1], [1], [0], [0], [1], [1], [1]], device=device + ) + input_mask = torch.tensor( + [[0], [0], [0], [0], [1], [1], [1], [0], [1], [1], [0], [0], [0], [1], [0], [0]], device=device + ) + masked_lm_labels = torch.tensor( + [ + [25496], + [16184], + [11005], + [16228], + [14884], + [21660], + [8678], + [23083], + [4027], + [8397], + [11921], + [1333], + [26482], + [1666], + [17925], + [27978], + ], + device=device, + ) next_sentence_labels = torch.tensor([0, 1, 0, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 1, 1, 0], device=device) # Actual values _ = trainer.eval_step(input_ids, segment_ids, input_mask, masked_lm_labels, next_sentence_labels) - + actual_state_dict = trainer.state_dict() - del actual_state_dict['model'] + del actual_state_dict["model"] _test_commons.assert_all_states_close_ort(actual_state_dict, dummy_init_state) -@pytest.mark.parametrize("model_params", [ - (['bert.embeddings.LayerNorm.bias']), - (['bert.embeddings.LayerNorm.bias', - 'bert.embeddings.LayerNorm.weight', - 'bert.encoder.layer.0.attention.output.LayerNorm.bias']), -]) + +@pytest.mark.parametrize( + "model_params", + [ + (["bert.embeddings.LayerNorm.bias"]), + ( + [ + "bert.embeddings.LayerNorm.bias", + "bert.embeddings.LayerNorm.weight", + "bert.encoder.layer.0.attention.output.LayerNorm.bias", + ] + ), + ], +) def testORTTrainerFrozenWeights(model_params): - device = 'cuda' + device = "cuda" total_steps = 10 seed = 1 @@ -521,14 +835,12 @@ def testORTTrainerFrozenWeights(model_params): optim_config = optim.LambConfig() # Setup ORTTrainer WITHOUT frozen weights opts_dict = { - 'debug' : { - 'deterministic_compute': True - }, - 'device': { - 'id': device, + "debug": {"deterministic_compute": True}, + "device": { + "id": device, }, } - opts = orttrainer.ORTTrainerOptions(opts_dict) + opts = orttrainer.ORTTrainerOptions(opts_dict) torch.manual_seed(seed) onnxruntime.set_seed(seed) @@ -544,8 +856,8 @@ def testORTTrainerFrozenWeights(model_params): assert all([param in session_state for param in model_params]) # Setup ORTTrainer WITH frozen weights - opts_dict.update({'utils' : {'frozen_weights' : model_params}}) - opts = orttrainer.ORTTrainerOptions(opts_dict) + opts_dict.update({"utils": {"frozen_weights": model_params}}) + opts = orttrainer.ORTTrainerOptions(opts_dict) trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, options=opts) for i in range(total_steps): @@ -557,9 +869,10 @@ def testORTTrainerFrozenWeights(model_params): session_state = trainer._training_session.get_state() assert not any([param in session_state for param in model_params]) + def testToyBERTSaveAsONNX(): - device = 'cuda' - onnx_file_name = '_____temp_toy_bert_onnx_model.onnx' + device = "cuda" + onnx_file_name = "_____temp_toy_bert_onnx_model.onnx" if os.path.exists(onnx_file_name): os.remove(onnx_file_name) assert not os.path.exists(onnx_file_name) @@ -569,14 +882,14 @@ def testToyBERTSaveAsONNX(): model = load_bert_onnx_model() optim_config = optim.LambConfig() - opts = orttrainer.ORTTrainerOptions({ - 'debug' : { - 'deterministic_compute': True - }, - 'device': { - 'id': device, - }, - }) + opts = orttrainer.ORTTrainerOptions( + { + "debug": {"deterministic_compute": True}, + "device": { + "id": device, + }, + } + ) trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, options=opts) @@ -591,26 +904,33 @@ def testToyBERTSaveAsONNX(): # Create a new trainer from persisted ONNX model and compare with original ONNX model trainer_from_onnx = orttrainer.ORTTrainer(reload_onnx_model, model_desc, optim_config, options=opts) assert trainer_from_onnx._onnx_model is not None - assert (id(trainer_from_onnx._onnx_model) != id(trainer._onnx_model)) - for initializer, loaded_initializer in zip(trainer._onnx_model.graph.initializer, trainer_from_onnx._onnx_model.graph.initializer): + assert id(trainer_from_onnx._onnx_model) != id(trainer._onnx_model) + for initializer, loaded_initializer in zip( + trainer._onnx_model.graph.initializer, trainer_from_onnx._onnx_model.graph.initializer + ): assert initializer.name == loaded_initializer.name - assert (onnx.helper.printable_graph(trainer_from_onnx._onnx_model.graph) == onnx.helper.printable_graph(trainer._onnx_model.graph)) + assert onnx.helper.printable_graph(trainer_from_onnx._onnx_model.graph) == onnx.helper.printable_graph( + trainer._onnx_model.graph + ) _test_helpers.assert_onnx_weights(trainer, trainer_from_onnx) ############################################################################### # Temporary tests comparing Legacy vs Experimental ORTTrainer APIs ############ ############################################################################### -@pytest.mark.parametrize("optimizer_config", [ - (optim.AdamConfig), -# (optim.LambConfig), # TODO: re-enable after nondeterminism on backend is fixed - (optim.SGDConfig) -]) +@pytest.mark.parametrize( + "optimizer_config", + [ + (optim.AdamConfig), + # (optim.LambConfig), # TODO: re-enable after nondeterminism on backend is fixed + (optim.SGDConfig), + ], +) def testToyBERTModelLegacyExperimentalBasicTraining(optimizer_config): # Common setup train_steps = 512 - device = 'cuda' + device = "cuda" seed = 1 torch.manual_seed(seed) onnxruntime.set_seed(seed) @@ -618,14 +938,14 @@ def testToyBERTModelLegacyExperimentalBasicTraining(optimizer_config): # EXPERIMENTAL API model_desc = bert_model_description() model = load_bert_onnx_model() - opts = orttrainer.ORTTrainerOptions({ - 'debug' : { - 'deterministic_compute': True - }, - 'device': { - 'id': device, - }, - }) + opts = orttrainer.ORTTrainerOptions( + { + "debug": {"deterministic_compute": True}, + "device": { + "id": device, + }, + } + ) optim_config = optimizer_config(lr=0.01) trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, options=opts) experimental_losses = [] @@ -638,22 +958,27 @@ def testToyBERTModelLegacyExperimentalBasicTraining(optimizer_config): onnxruntime.set_seed(seed) if optimizer_config == optim.AdamConfig: - legacy_optimizer = 'AdamOptimizer' + legacy_optimizer = "AdamOptimizer" elif optimizer_config == optim.LambConfig: - legacy_optimizer = 'LambOptimizer' + legacy_optimizer = "LambOptimizer" elif optimizer_config == optim.SGDConfig: - legacy_optimizer = 'SGDOptimizer' + legacy_optimizer = "SGDOptimizer" else: raise RuntimeError("Invalid optimizer_config") device = torch.device(device) model = load_bert_onnx_model() legacy_model_desc, learning_rate_description, learning_rate = legacy_model_params(lr=optim_config.lr) - legacy_trainer = Legacy_ORTTrainer(model, None, legacy_model_desc, legacy_optimizer, - None, - learning_rate_description, - device, - _use_deterministic_compute=True) + legacy_trainer = Legacy_ORTTrainer( + model, + None, + legacy_model_desc, + legacy_optimizer, + None, + learning_rate_description, + device, + _use_deterministic_compute=True, + ) legacy_losses = [] for i in range(train_steps): sample_input = generate_random_input_from_model_desc(model_desc, i) @@ -664,13 +989,16 @@ def testToyBERTModelLegacyExperimentalBasicTraining(optimizer_config): _test_helpers.assert_model_outputs(experimental_losses, legacy_losses, True) -@pytest.mark.parametrize("initial_lr, lr_scheduler, legacy_lr_scheduler", [ - (1.0, optim.lr_scheduler.ConstantWarmupLRScheduler, _test_commons.legacy_constant_lr_scheduler), - (0.5, optim.lr_scheduler.ConstantWarmupLRScheduler, _test_commons.legacy_constant_lr_scheduler), - (1.0, optim.lr_scheduler.CosineWarmupLRScheduler, _test_commons.legacy_cosine_lr_scheduler), - (1.0, optim.lr_scheduler.LinearWarmupLRScheduler, _test_commons.legacy_linear_lr_scheduler), - (1.0, optim.lr_scheduler.PolyWarmupLRScheduler, _test_commons.legacy_poly_lr_scheduler), -]) +@pytest.mark.parametrize( + "initial_lr, lr_scheduler, legacy_lr_scheduler", + [ + (1.0, optim.lr_scheduler.ConstantWarmupLRScheduler, _test_commons.legacy_constant_lr_scheduler), + (0.5, optim.lr_scheduler.ConstantWarmupLRScheduler, _test_commons.legacy_constant_lr_scheduler), + (1.0, optim.lr_scheduler.CosineWarmupLRScheduler, _test_commons.legacy_cosine_lr_scheduler), + (1.0, optim.lr_scheduler.LinearWarmupLRScheduler, _test_commons.legacy_linear_lr_scheduler), + (1.0, optim.lr_scheduler.PolyWarmupLRScheduler, _test_commons.legacy_poly_lr_scheduler), + ], +) def testToyBERTModelLegacyExperimentalLRScheduler(initial_lr, lr_scheduler, legacy_lr_scheduler): ############################################################################ # These tests require hard-coded values for 'total_steps' and 'initial_lr' # @@ -678,23 +1006,40 @@ def testToyBERTModelLegacyExperimentalLRScheduler(initial_lr, lr_scheduler, lega # Common setup total_steps = 128 - device = 'cuda' + device = "cuda" seed = 1 warmup = 0.05 cycles = 0.5 - power = 1. + power = 1.0 lr_end = 1e-7 # Setup both Experimental and Legacy LR Schedulers before the experimental loop - if legacy_lr_scheduler == _test_commons.legacy_constant_lr_scheduler or legacy_lr_scheduler == _test_commons.legacy_linear_lr_scheduler: - legacy_lr_scheduler = partial(legacy_lr_scheduler, initial_lr=initial_lr, total_steps=total_steps, warmup=warmup) + if ( + legacy_lr_scheduler == _test_commons.legacy_constant_lr_scheduler + or legacy_lr_scheduler == _test_commons.legacy_linear_lr_scheduler + ): + legacy_lr_scheduler = partial( + legacy_lr_scheduler, initial_lr=initial_lr, total_steps=total_steps, warmup=warmup + ) elif legacy_lr_scheduler == _test_commons.legacy_cosine_lr_scheduler: - legacy_lr_scheduler = partial(legacy_lr_scheduler, initial_lr=initial_lr, total_steps=total_steps, warmup=warmup, cycles=cycles) + legacy_lr_scheduler = partial( + legacy_lr_scheduler, initial_lr=initial_lr, total_steps=total_steps, warmup=warmup, cycles=cycles + ) elif legacy_lr_scheduler == _test_commons.legacy_poly_lr_scheduler: - legacy_lr_scheduler = partial(legacy_lr_scheduler, initial_lr=initial_lr, total_steps=total_steps, warmup=warmup, power=power, lr_end=lr_end) + legacy_lr_scheduler = partial( + legacy_lr_scheduler, + initial_lr=initial_lr, + total_steps=total_steps, + warmup=warmup, + power=power, + lr_end=lr_end, + ) else: raise RuntimeError("Invalid legacy_lr_scheduler") - if lr_scheduler == optim.lr_scheduler.ConstantWarmupLRScheduler or lr_scheduler == optim.lr_scheduler.LinearWarmupLRScheduler: + if ( + lr_scheduler == optim.lr_scheduler.ConstantWarmupLRScheduler + or lr_scheduler == optim.lr_scheduler.LinearWarmupLRScheduler + ): lr_scheduler = lr_scheduler(total_steps=total_steps, warmup=warmup) elif lr_scheduler == optim.lr_scheduler.CosineWarmupLRScheduler: lr_scheduler = lr_scheduler(total_steps=total_steps, warmup=warmup, cycles=cycles) @@ -703,22 +1048,21 @@ def testToyBERTModelLegacyExperimentalLRScheduler(initial_lr, lr_scheduler, lega else: raise RuntimeError("Invalid lr_scheduler") - # EXPERIMENTAL API model_desc = bert_model_description() model = load_bert_onnx_model() torch.manual_seed(seed) onnxruntime.set_seed(seed) optim_config = optim.AdamConfig(lr=initial_lr) - opts = orttrainer.ORTTrainerOptions({ - 'debug' : { - 'deterministic_compute': True - }, - 'device': { - 'id': device, - }, - 'lr_scheduler' : lr_scheduler - }) + opts = orttrainer.ORTTrainerOptions( + { + "debug": {"deterministic_compute": True}, + "device": { + "id": device, + }, + "lr_scheduler": lr_scheduler, + } + ) trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, options=opts) experimental_losses = [] for i in range(total_steps): @@ -732,12 +1076,17 @@ def testToyBERTModelLegacyExperimentalLRScheduler(initial_lr, lr_scheduler, lega device = torch.device(device) model = load_bert_onnx_model() legacy_model_desc, learning_rate_description, learning_rate = legacy_model_params(initial_lr) - legacy_trainer = Legacy_ORTTrainer(model, None, legacy_model_desc, "AdamOptimizer", - None, - learning_rate_description, - device, - _use_deterministic_compute=True, - get_lr_this_step=legacy_lr_scheduler) + legacy_trainer = Legacy_ORTTrainer( + model, + None, + legacy_model_desc, + "AdamOptimizer", + None, + learning_rate_description, + device, + _use_deterministic_compute=True, + get_lr_this_step=legacy_lr_scheduler, + ) legacy_losses = [] for i in range(total_steps): sample_input = generate_random_input_from_model_desc(model_desc, i) @@ -748,11 +1097,14 @@ def testToyBERTModelLegacyExperimentalLRScheduler(initial_lr, lr_scheduler, lega _test_helpers.assert_model_outputs(experimental_losses, legacy_losses) -@pytest.mark.parametrize("loss_scaler, legacy_loss_scaler", [ - (None, Legacy_LossScaler("ort_test_input_loss_scaler", True)), - (amp.DynamicLossScaler(), Legacy_LossScaler("ort_test_input_loss_scaler", True)), - (CustomLossScaler(), LegacyCustomLossScaler()) -]) +@pytest.mark.parametrize( + "loss_scaler, legacy_loss_scaler", + [ + (None, Legacy_LossScaler("ort_test_input_loss_scaler", True)), + (amp.DynamicLossScaler(), Legacy_LossScaler("ort_test_input_loss_scaler", True)), + (CustomLossScaler(), LegacyCustomLossScaler()), + ], +) def testToyBERTModelMixedPrecisionLossScalerLegacyExperimental(loss_scaler, legacy_loss_scaler): # Common setup total_steps = 128 @@ -765,18 +1117,15 @@ def testToyBERTModelMixedPrecisionLossScalerLegacyExperimental(loss_scaler, lega model_desc = bert_model_description() model = load_bert_onnx_model() optim_config = optim.AdamConfig(lr=0.001) - opts = orttrainer.ORTTrainerOptions({ - 'debug' : { - 'deterministic_compute': True - }, - 'device': { - 'id': device, - }, - 'mixed_precision': { - 'enabled': True, - 'loss_scaler': loss_scaler + opts = orttrainer.ORTTrainerOptions( + { + "debug": {"deterministic_compute": True}, + "device": { + "id": device, + }, + "mixed_precision": {"enabled": True, "loss_scaler": loss_scaler}, } - }) + ) trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, options=opts) experimental_losses = [] for i in range(total_steps): @@ -789,13 +1138,18 @@ def testToyBERTModelMixedPrecisionLossScalerLegacyExperimental(loss_scaler, lega device = torch.device(device) model = load_bert_onnx_model() legacy_model_desc, learning_rate_description, learning_rate = legacy_model_params(optim_config.lr) - legacy_trainer = Legacy_ORTTrainer(model, None, legacy_model_desc, "AdamOptimizer", - None, - learning_rate_description, - device, - _use_deterministic_compute=True, - use_mixed_precision=True, - loss_scaler = legacy_loss_scaler) + legacy_trainer = Legacy_ORTTrainer( + model, + None, + legacy_model_desc, + "AdamOptimizer", + None, + learning_rate_description, + device, + _use_deterministic_compute=True, + use_mixed_precision=True, + loss_scaler=legacy_loss_scaler, + ) legacy_losses = [] for i in range(total_steps): sample_input = generate_random_input_from_model_desc(model_desc, i) @@ -806,11 +1160,7 @@ def testToyBERTModelMixedPrecisionLossScalerLegacyExperimental(loss_scaler, lega _test_helpers.assert_model_outputs(experimental_losses, legacy_losses) -@pytest.mark.parametrize("gradient_accumulation_steps", [ - (1), - (4), - (7) -]) +@pytest.mark.parametrize("gradient_accumulation_steps", [(1), (4), (7)]) def testToyBERTModelGradientAccumulationLegacyExperimental(gradient_accumulation_steps): # Common setup total_steps = 128 @@ -823,17 +1173,15 @@ def testToyBERTModelGradientAccumulationLegacyExperimental(gradient_accumulation model_desc = bert_model_description() model = load_bert_onnx_model() optim_config = optim.AdamConfig() - opts = orttrainer.ORTTrainerOptions({ - 'debug' : { - 'deterministic_compute': True - }, - 'device': { - 'id': device, - }, - 'batch' : { - 'gradient_accumulation_steps' : gradient_accumulation_steps - }, - }) + opts = orttrainer.ORTTrainerOptions( + { + "debug": {"deterministic_compute": True}, + "device": { + "id": device, + }, + "batch": {"gradient_accumulation_steps": gradient_accumulation_steps}, + } + ) trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, options=opts) experimental_losses = [] for i in range(total_steps): @@ -847,12 +1195,17 @@ def testToyBERTModelGradientAccumulationLegacyExperimental(gradient_accumulation device = torch.device(device) model = load_bert_onnx_model() legacy_model_desc, learning_rate_description, learning_rate = legacy_model_params(optim_config.lr) - legacy_trainer = Legacy_ORTTrainer(model, None, legacy_model_desc, "AdamOptimizer", - None, - learning_rate_description, - device, - _use_deterministic_compute = True, - gradient_accumulation_steps = gradient_accumulation_steps) + legacy_trainer = Legacy_ORTTrainer( + model, + None, + legacy_model_desc, + "AdamOptimizer", + None, + learning_rate_description, + device, + _use_deterministic_compute=True, + gradient_accumulation_steps=gradient_accumulation_steps, + ) legacy_losses = [] for i in range(total_steps): sample_input = generate_random_input_from_model_desc(model_desc, i) @@ -862,15 +1215,30 @@ def testToyBERTModelGradientAccumulationLegacyExperimental(gradient_accumulation # Check results _test_helpers.assert_model_outputs(experimental_losses, legacy_losses) -@pytest.mark.parametrize("params, legacy_optim_map", [ - # Change the hyper parameters for all parameters - ([], legacy_optim_params_a), - # Change the hyperparameters for a subset of hardcoded parameters - ([{'params':['bert.embeddings.LayerNorm.bias', 'bert.embeddings.LayerNorm.weight'], "alpha": 0.9, - "beta": 0.999, "lambda_coef": 0.0, "epsilon": 1e-6, "do_bias_correction":False}], legacy_optim_params_b), - # Change the hyperparameters for a generated set of paramers - (optimizer_parameters(load_bert_onnx_model()), legacy_optim_params_c) -]) + +@pytest.mark.parametrize( + "params, legacy_optim_map", + [ + # Change the hyper parameters for all parameters + ([], legacy_optim_params_a), + # Change the hyperparameters for a subset of hardcoded parameters + ( + [ + { + "params": ["bert.embeddings.LayerNorm.bias", "bert.embeddings.LayerNorm.weight"], + "alpha": 0.9, + "beta": 0.999, + "lambda_coef": 0.0, + "epsilon": 1e-6, + "do_bias_correction": False, + } + ], + legacy_optim_params_b, + ), + # Change the hyperparameters for a generated set of paramers + (optimizer_parameters(load_bert_onnx_model()), legacy_optim_params_c), + ], +) def testToyBERTModelLegacyExperimentalCustomOptimParameters(params, legacy_optim_map): # Common setup total_steps = 128 @@ -883,15 +1251,17 @@ def testToyBERTModelLegacyExperimentalCustomOptimParameters(params, legacy_optim model_desc = bert_model_description() model = load_bert_onnx_model() - optim_config = optim.AdamConfig(params, alpha= 0.9, beta= 0.999, lambda_coef= 0.01, epsilon= 1e-6, do_bias_correction=False) - opts = orttrainer.ORTTrainerOptions({ - 'debug' : { - 'deterministic_compute': True - }, - 'device': { - 'id': device, - }, - }) + optim_config = optim.AdamConfig( + params, alpha=0.9, beta=0.999, lambda_coef=0.01, epsilon=1e-6, do_bias_correction=False + ) + opts = orttrainer.ORTTrainerOptions( + { + "debug": {"deterministic_compute": True}, + "device": { + "id": device, + }, + } + ) trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, options=opts) experimental_losses = [] @@ -906,16 +1276,21 @@ def testToyBERTModelLegacyExperimentalCustomOptimParameters(params, legacy_optim model = load_bert_onnx_model() legacy_model_desc, learning_rate_description, learning_rate = legacy_model_params(trainer.optim_config.lr) - legacy_trainer = Legacy_ORTTrainer(model, None, legacy_model_desc, "AdamOptimizer", - legacy_optim_map, - learning_rate_description, - device, - _use_deterministic_compute=True) + legacy_trainer = Legacy_ORTTrainer( + model, + None, + legacy_model_desc, + "AdamOptimizer", + legacy_optim_map, + learning_rate_description, + device, + _use_deterministic_compute=True, + ) legacy_losses = [] for i in range(total_steps): sample_input = generate_random_input_from_model_desc(model_desc, i) legacy_sample_input = [*sample_input, learning_rate] legacy_losses.append(legacy_trainer.train_step(legacy_sample_input).cpu().item()) - # Check results + # Check results _test_helpers.assert_model_outputs(experimental_losses, legacy_losses) diff --git a/orttraining/orttraining/test/python/orttraining_test_orttrainer_checkpoint_functions.py b/orttraining/orttraining/test/python/orttraining_test_orttrainer_checkpoint_functions.py index 35e3d18c0457a..99606d923e1d2 100644 --- a/orttraining/orttraining/test/python/orttraining_test_orttrainer_checkpoint_functions.py +++ b/orttraining/orttraining/test/python/orttraining_test_orttrainer_checkpoint_functions.py @@ -8,32 +8,30 @@ # Helper functions + def _create_trainer(zero_enabled=False): """Cerates a simple ORTTrainer for ORTTrainer functional tests""" - device = 'cuda' + device = "cuda" optim_config = optim.LambConfig(lr=0.1) - opts = { - 'device' : {'id' : device}, - 'debug' : {'deterministic_compute': True} - } + opts = {"device": {"id": device}, "debug": {"deterministic_compute": True}} if zero_enabled: - opts['distributed'] = { - 'world_rank' : 0, - 'world_size' : 1, - 'horizontal_parallel_size' : 1, - 'data_parallel_size' : 1, - 'allreduce_post_accumulation' : True, - 'deepspeed_zero_optimization': - { - 'stage': 1 - } - } + opts["distributed"] = { + "world_rank": 0, + "world_size": 1, + "horizontal_parallel_size": 1, + "data_parallel_size": 1, + "allreduce_post_accumulation": True, + "deepspeed_zero_optimization": {"stage": 1}, + } model, model_desc, loss_fn, batcher_fn, train_data, _, _ = _load_pytorch_transformer_model(device) - trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=loss_fn, options=orttrainer.ORTTrainerOptions(opts)) + trainer = orttrainer.ORTTrainer( + model, model_desc, optim_config, loss_fn=loss_fn, options=orttrainer.ORTTrainerOptions(opts) + ) return trainer + class _training_session_mock(object): """Mock object for the ORTTrainer _training_session member""" @@ -51,6 +49,7 @@ def get_optimizer_state(self): def get_partition_info_map(self): return self.partition_info + def _get_load_state_dict_strict_error_arguments(): """Return a list of tuples that can be used as parameters for test_load_state_dict_errors_when_model_key_missing @@ -60,61 +59,65 @@ def _get_load_state_dict_strict_error_arguments(): """ training_session_state_dict = { - 'model': { - 'full_precision': { - 'a': np.arange(5), - 'b': np.arange(7) - } + "model": {"full_precision": {"a": np.arange(5), "b": np.arange(7)}}, + "optimizer": { + "a": {"Moment_1": np.arange(5), "Moment_2": np.arange(7)}, + "shared_optimizer_state": {"step": np.arange(5)}, }, - 'optimizer': { - 'a': { - 'Moment_1': np.arange(5), - 'Moment_2': np.arange(7) - }, - 'shared_optimizer_state': { - 'step': np.arange(5) - } - } } # input state dictionaries - precision_key_missing = {'model': {}, 'optimizer': {}} - precision_key_unexpected = {'model': {'full_precision': {}, 'mixed_precision': {}}, 'optimizer': {}} - model_state_key_missing = {'model': {'full_precision': {}}, 'optimizer': {}} - model_state_key_unexpected = {'model': {'full_precision': {'a': 2, 'b': 3, 'c': 4}}, 'optimizer': {}} - optimizer_model_state_key_missing = {'model': {'full_precision': {'a': 2, 'b': 3}}, 'optimizer': {}} - optimizer_model_state_key_unexpected = {'model': {'full_precision': {'a': 2, 'b': 3}}, 'optimizer': \ - {'a': {}, 'shared_optimizer_state': {}, 'b': {}}} - optimizer_state_key_missing = {'model': {'full_precision': {'a': 2, 'b': 3}}, 'optimizer': \ - {'a': {}, 'shared_optimizer_state': {'step': np.arange(5)}}} - optimizer_state_key_unexpected = {'model': {'full_precision': {'a': 2, 'b': 3}}, 'optimizer': \ - {'a': {'Moment_1': np.arange(5), 'Moment_2': np.arange(7)}, 'shared_optimizer_state': {'step': np.arange(5), 'another_step': np.arange(1)}}} + precision_key_missing = {"model": {}, "optimizer": {}} + precision_key_unexpected = {"model": {"full_precision": {}, "mixed_precision": {}}, "optimizer": {}} + model_state_key_missing = {"model": {"full_precision": {}}, "optimizer": {}} + model_state_key_unexpected = {"model": {"full_precision": {"a": 2, "b": 3, "c": 4}}, "optimizer": {}} + optimizer_model_state_key_missing = {"model": {"full_precision": {"a": 2, "b": 3}}, "optimizer": {}} + optimizer_model_state_key_unexpected = { + "model": {"full_precision": {"a": 2, "b": 3}}, + "optimizer": {"a": {}, "shared_optimizer_state": {}, "b": {}}, + } + optimizer_state_key_missing = { + "model": {"full_precision": {"a": 2, "b": 3}}, + "optimizer": {"a": {}, "shared_optimizer_state": {"step": np.arange(5)}}, + } + optimizer_state_key_unexpected = { + "model": {"full_precision": {"a": 2, "b": 3}}, + "optimizer": { + "a": {"Moment_1": np.arange(5), "Moment_2": np.arange(7)}, + "shared_optimizer_state": {"step": np.arange(5), "another_step": np.arange(1)}, + }, + } input_arguments = [ - (training_session_state_dict, precision_key_missing, ['full_precision']), - (training_session_state_dict, precision_key_unexpected, ['mixed_precision']), - (training_session_state_dict, model_state_key_missing, ['a', 'b']), - (training_session_state_dict, model_state_key_unexpected, ['c']), - (training_session_state_dict, optimizer_model_state_key_missing, ['a', 'shared_optimizer_state']), - (training_session_state_dict, optimizer_model_state_key_unexpected, ['b']), - (training_session_state_dict, optimizer_state_key_missing, ['Moment_1', 'Moment_2']), - (training_session_state_dict, optimizer_state_key_unexpected, ['another_step']) + (training_session_state_dict, precision_key_missing, ["full_precision"]), + (training_session_state_dict, precision_key_unexpected, ["mixed_precision"]), + (training_session_state_dict, model_state_key_missing, ["a", "b"]), + (training_session_state_dict, model_state_key_unexpected, ["c"]), + (training_session_state_dict, optimizer_model_state_key_missing, ["a", "shared_optimizer_state"]), + (training_session_state_dict, optimizer_model_state_key_unexpected, ["b"]), + (training_session_state_dict, optimizer_state_key_missing, ["Moment_1", "Moment_2"]), + (training_session_state_dict, optimizer_state_key_unexpected, ["another_step"]), ] return input_arguments + # Tests + def test_empty_state_dict_when_training_session_uninitialized(): trainer = _create_trainer() with pytest.warns(UserWarning) as user_warning: state_dict = trainer.state_dict() assert len(state_dict.keys()) == 0 - assert user_warning[0].message.args[0] == "ONNX Runtime training session is not initialized yet. " \ - "Please run train_step or eval_step at least once before calling ORTTrainer.state_dict()." + assert ( + user_warning[0].message.args[0] == "ONNX Runtime training session is not initialized yet. " + "Please run train_step or eval_step at least once before calling ORTTrainer.state_dict()." + ) -@patch('onnx.ModelProto') + +@patch("onnx.ModelProto") def test_training_session_provides_empty_model_states(onnx_model_mock): trainer = _create_trainer() training_session_mock = _training_session_mock({}, {}, {}) @@ -122,69 +125,58 @@ def test_training_session_provides_empty_model_states(onnx_model_mock): trainer._onnx_model = onnx_model_mock() state_dict = trainer.state_dict() - assert len(state_dict['model'].keys()) == 0 + assert len(state_dict["model"].keys()) == 0 + -@patch('onnx.ModelProto') +@patch("onnx.ModelProto") def test_training_session_provides_model_states(onnx_model_mock): trainer = _create_trainer() - model_states = { - 'full_precision': { - 'a': np.arange(5), - 'b': np.arange(7) - } - } + model_states = {"full_precision": {"a": np.arange(5), "b": np.arange(7)}} training_session_mock = _training_session_mock(model_states, {}, {}) trainer._training_session = training_session_mock trainer._onnx_model = onnx_model_mock() state_dict = trainer.state_dict() - assert (state_dict['model']['full_precision']['a'] == np.arange(5)).all() - assert (state_dict['model']['full_precision']['b'] == np.arange(7)).all() + assert (state_dict["model"]["full_precision"]["a"] == np.arange(5)).all() + assert (state_dict["model"]["full_precision"]["b"] == np.arange(7)).all() + -@patch('onnx.ModelProto') +@patch("onnx.ModelProto") def test_training_session_provides_model_states_pytorch_format(onnx_model_mock): trainer = _create_trainer() - model_states = { - 'full_precision': { - 'a': np.arange(5), - 'b': np.arange(7) - } - } + model_states = {"full_precision": {"a": np.arange(5), "b": np.arange(7)}} training_session_mock = _training_session_mock(model_states, {}, {}) trainer._training_session = training_session_mock trainer._onnx_model = onnx_model_mock() state_dict = trainer.state_dict(pytorch_format=True) - assert torch.all(torch.eq(state_dict['a'], torch.tensor(np.arange(5)))) - assert torch.all(torch.eq(state_dict['b'], torch.tensor(np.arange(7)))) + assert torch.all(torch.eq(state_dict["a"], torch.tensor(np.arange(5)))) + assert torch.all(torch.eq(state_dict["b"], torch.tensor(np.arange(7)))) + -@patch('onnx.ModelProto') +@patch("onnx.ModelProto") def test_onnx_graph_provides_frozen_model_states(onnx_model_mock): trainer = _create_trainer() - model_states = { - 'full_precision': { - 'a': np.arange(5), - 'b': np.arange(7) - } - } + model_states = {"full_precision": {"a": np.arange(5), "b": np.arange(7)}} training_session_mock = _training_session_mock(model_states, {}, {}) trainer._training_session = training_session_mock trainer._onnx_model = onnx_model_mock() - trainer.options.utils.frozen_weights = ['a_frozen_weight', 'a_float16_weight'] + trainer.options.utils.frozen_weights = ["a_frozen_weight", "a_float16_weight"] trainer._onnx_model.graph.initializer = [ - onnx.numpy_helper.from_array(np.array([1, 2, 3], dtype=np.float32), 'a_frozen_weight'), - onnx.numpy_helper.from_array(np.array([4, 5, 6], dtype=np.float32), 'a_non_fronzen_weight'), - onnx.numpy_helper.from_array(np.array([7, 8, 9], dtype=np.float16), 'a_float16_weight') + onnx.numpy_helper.from_array(np.array([1, 2, 3], dtype=np.float32), "a_frozen_weight"), + onnx.numpy_helper.from_array(np.array([4, 5, 6], dtype=np.float32), "a_non_fronzen_weight"), + onnx.numpy_helper.from_array(np.array([7, 8, 9], dtype=np.float16), "a_float16_weight"), ] state_dict = trainer.state_dict() - assert (state_dict['model']['full_precision']['a'] == np.arange(5)).all() - assert (state_dict['model']['full_precision']['b'] == np.arange(7)).all() - assert (state_dict['model']['full_precision']['a_frozen_weight'] == np.array([1, 2, 3], dtype=np.float32)).all() - assert 'a_non_fronzen_weight' not in state_dict['model']['full_precision'] - assert (state_dict['model']['full_precision']['a_float16_weight'] == np.array([7, 8, 9], dtype=np.float32)).all() + assert (state_dict["model"]["full_precision"]["a"] == np.arange(5)).all() + assert (state_dict["model"]["full_precision"]["b"] == np.arange(7)).all() + assert (state_dict["model"]["full_precision"]["a_frozen_weight"] == np.array([1, 2, 3], dtype=np.float32)).all() + assert "a_non_fronzen_weight" not in state_dict["model"]["full_precision"] + assert (state_dict["model"]["full_precision"]["a_float16_weight"] == np.array([7, 8, 9], dtype=np.float32)).all() + -@patch('onnx.ModelProto') +@patch("onnx.ModelProto") def test_training_session_provides_empty_optimizer_states(onnx_model_mock): trainer = _create_trainer() training_session_mock = _training_session_mock({}, {}, {}) @@ -192,55 +184,43 @@ def test_training_session_provides_empty_optimizer_states(onnx_model_mock): trainer._onnx_model = onnx_model_mock() state_dict = trainer.state_dict() - assert len(state_dict['optimizer'].keys()) == 0 + assert len(state_dict["optimizer"].keys()) == 0 -@patch('onnx.ModelProto') + +@patch("onnx.ModelProto") def test_training_session_provides_optimizer_states(onnx_model_mock): trainer = _create_trainer() optimizer_states = { - 'model_weight': { - 'Moment_1': np.arange(5), - 'Moment_2': np.arange(7) - }, - 'shared_optimizer_state': { - 'step': np.arange(1) - } + "model_weight": {"Moment_1": np.arange(5), "Moment_2": np.arange(7)}, + "shared_optimizer_state": {"step": np.arange(1)}, } training_session_mock = _training_session_mock({}, optimizer_states, {}) trainer._training_session = training_session_mock trainer._onnx_model = onnx_model_mock() state_dict = trainer.state_dict() - assert (state_dict['optimizer']['model_weight']['Moment_1'] == np.arange(5)).all() - assert (state_dict['optimizer']['model_weight']['Moment_2'] == np.arange(7)).all() - assert (state_dict['optimizer']['shared_optimizer_state']['step'] == np.arange(1)).all() + assert (state_dict["optimizer"]["model_weight"]["Moment_1"] == np.arange(5)).all() + assert (state_dict["optimizer"]["model_weight"]["Moment_2"] == np.arange(7)).all() + assert (state_dict["optimizer"]["shared_optimizer_state"]["step"] == np.arange(1)).all() -@patch('onnx.ModelProto') + +@patch("onnx.ModelProto") def test_training_session_provides_optimizer_states_pytorch_format(onnx_model_mock): trainer = _create_trainer() - model_states = { - 'full_precision': { - 'a': np.arange(5), - 'b': np.arange(7) - } - } + model_states = {"full_precision": {"a": np.arange(5), "b": np.arange(7)}} optimizer_states = { - 'model_weight': { - 'Moment_1': np.arange(5), - 'Moment_2': np.arange(7) - }, - 'shared_optimizer_state': { - 'step': np.arange(1) - } + "model_weight": {"Moment_1": np.arange(5), "Moment_2": np.arange(7)}, + "shared_optimizer_state": {"step": np.arange(1)}, } training_session_mock = _training_session_mock(model_states, optimizer_states, {}) trainer._training_session = training_session_mock trainer._onnx_model = onnx_model_mock() state_dict = trainer.state_dict(pytorch_format=True) - assert 'optimizer' not in state_dict + assert "optimizer" not in state_dict + -@patch('onnx.ModelProto') +@patch("onnx.ModelProto") def test_training_session_provides_empty_partition_info_map(onnx_model_mock): trainer = _create_trainer(zero_enabled=True) training_session_mock = _training_session_mock({}, {}, {}) @@ -248,137 +228,92 @@ def test_training_session_provides_empty_partition_info_map(onnx_model_mock): trainer._onnx_model = onnx_model_mock() state_dict = trainer.state_dict() - assert len(state_dict['partition_info'].keys()) == 0 + assert len(state_dict["partition_info"].keys()) == 0 -@patch('onnx.ModelProto') + +@patch("onnx.ModelProto") def test_training_session_provides_partition_info_map(onnx_model_mock): trainer = _create_trainer(zero_enabled=True) - partition_info = { - 'a': { - 'original_dim': [1, 2, 3] - } - } + partition_info = {"a": {"original_dim": [1, 2, 3]}} training_session_mock = _training_session_mock({}, {}, partition_info) trainer._training_session = training_session_mock trainer._onnx_model = onnx_model_mock() state_dict = trainer.state_dict() - assert state_dict['partition_info']['a']['original_dim'] == [1, 2, 3] + assert state_dict["partition_info"]["a"]["original_dim"] == [1, 2, 3] -@patch('onnx.ModelProto') + +@patch("onnx.ModelProto") def test_training_session_provides_all_states(onnx_model_mock): trainer = _create_trainer(zero_enabled=True) - model_states = { - 'full_precision': { - 'a': np.arange(5), - 'b': np.arange(7) - } - } + model_states = {"full_precision": {"a": np.arange(5), "b": np.arange(7)}} optimizer_states = { - 'model_weight': { - 'Moment_1': np.arange(5), - 'Moment_2': np.arange(7) - }, - 'shared_optimizer_state': { - 'step': np.arange(1) - } - } - partition_info = { - 'a': { - 'original_dim': [1, 2, 3] - } + "model_weight": {"Moment_1": np.arange(5), "Moment_2": np.arange(7)}, + "shared_optimizer_state": {"step": np.arange(1)}, } + partition_info = {"a": {"original_dim": [1, 2, 3]}} training_session_mock = _training_session_mock(model_states, optimizer_states, partition_info) trainer._training_session = training_session_mock trainer._onnx_model = onnx_model_mock() state_dict = trainer.state_dict() - assert (state_dict['model']['full_precision']['a'] == np.arange(5)).all() - assert (state_dict['model']['full_precision']['b'] == np.arange(7)).all() - assert (state_dict['optimizer']['model_weight']['Moment_1'] == np.arange(5)).all() - assert (state_dict['optimizer']['model_weight']['Moment_2'] == np.arange(7)).all() - assert (state_dict['optimizer']['shared_optimizer_state']['step'] == np.arange(1)).all() - assert state_dict['partition_info']['a']['original_dim'] == [1, 2, 3] + assert (state_dict["model"]["full_precision"]["a"] == np.arange(5)).all() + assert (state_dict["model"]["full_precision"]["b"] == np.arange(7)).all() + assert (state_dict["optimizer"]["model_weight"]["Moment_1"] == np.arange(5)).all() + assert (state_dict["optimizer"]["model_weight"]["Moment_2"] == np.arange(7)).all() + assert (state_dict["optimizer"]["shared_optimizer_state"]["step"] == np.arange(1)).all() + assert state_dict["partition_info"]["a"]["original_dim"] == [1, 2, 3] + def test_load_state_dict_holds_when_training_session_not_initialized(): trainer = _create_trainer() state_dict = { - 'model': { - 'full_precision': { - 'a': np.arange(5), - 'b': np.arange(7) - } + "model": {"full_precision": {"a": np.arange(5), "b": np.arange(7)}}, + "optimizer": { + "a": {"Moment_1": np.arange(5), "Moment_2": np.arange(7)}, + "shared_optimizer_state": {"step": np.arange(5)}, }, - 'optimizer': { - 'a': { - 'Moment_1': np.arange(5), - 'Moment_2': np.arange(7) - }, - 'shared_optimizer_state': { - 'step': np.arange(5) - } - } } assert not trainer._load_state_dict state_dict = trainer.load_state_dict(state_dict) assert trainer._load_state_dict -@pytest.mark.parametrize("state_dict, input_state_dict, error_key", [ - ({ - 'model':{}, - 'optimizer':{} - }, - { - 'model':{}, - 'optimizer':{}, - 'trainer_options': { - 'optimizer_name': 'LambOptimizer' - } - }, - 'train_step_info'), - ({ - 'optimizer':{}, - 'train_step_info': { - 'optimization_step': 0, - 'step': 0 - } - }, - { - 'optimizer':{}, - 'trainer_options': { - 'optimizer_name': 'LambOptimizer' - }, - 'train_step_info': { - 'optimization_step': 0, - 'step': 0 - } - }, - 'model'), - ({ - 'model':{}, - 'train_step_info': { - 'optimization_step': 0, - 'step': 0 - } - }, - { - 'model':{}, - 'trainer_options': { - 'optimizer_name': 'LambOptimizer' - }, - 'train_step_info': { - 'optimization_step': 0, - 'step': 0 - } - }, - 'optimizer')]) + +@pytest.mark.parametrize( + "state_dict, input_state_dict, error_key", + [ + ( + {"model": {}, "optimizer": {}}, + {"model": {}, "optimizer": {}, "trainer_options": {"optimizer_name": "LambOptimizer"}}, + "train_step_info", + ), + ( + {"optimizer": {}, "train_step_info": {"optimization_step": 0, "step": 0}}, + { + "optimizer": {}, + "trainer_options": {"optimizer_name": "LambOptimizer"}, + "train_step_info": {"optimization_step": 0, "step": 0}, + }, + "model", + ), + ( + {"model": {}, "train_step_info": {"optimization_step": 0, "step": 0}}, + { + "model": {}, + "trainer_options": {"optimizer_name": "LambOptimizer"}, + "train_step_info": {"optimization_step": 0, "step": 0}, + }, + "optimizer", + ), + ], +) def test_load_state_dict_warns_when_model_optimizer_key_missing(state_dict, input_state_dict, error_key): trainer = _create_trainer() trainer._training_session = _training_session_mock({}, {}, {}) trainer.state_dict = Mock(return_value=state_dict) trainer._update_onnx_model_initializers = Mock() trainer._init_session = Mock() - with patch('onnx.ModelProto') as onnx_model_mock: + with patch("onnx.ModelProto") as onnx_model_mock: trainer._onnx_model = onnx_model_mock() trainer._onnx_model.graph.initializer = [] with pytest.warns(UserWarning) as user_warning: @@ -386,6 +321,7 @@ def test_load_state_dict_warns_when_model_optimizer_key_missing(state_dict, inpu assert user_warning[0].message.args[0] == "Missing key: {} in state_dict".format(error_key) + @pytest.mark.parametrize("state_dict, input_state_dict, error_keys", _get_load_state_dict_strict_error_arguments()) def test_load_state_dict_errors_when_state_dict_mismatch(state_dict, input_state_dict, error_keys): trainer = _create_trainer() @@ -396,53 +332,32 @@ def test_load_state_dict_errors_when_state_dict_mismatch(state_dict, input_state assert any(key in str(runtime_error.value) for key in error_keys) -@patch('onnx.ModelProto') + +@patch("onnx.ModelProto") def test_load_state_dict_loads_the_states_and_inits_training_session(onnx_model_mock): trainer = _create_trainer() training_session_state_dict = { - 'model': { - 'full_precision': { - 'a': np.arange(5), - 'b': np.arange(7) - } + "model": {"full_precision": {"a": np.arange(5), "b": np.arange(7)}}, + "optimizer": { + "a": {"Moment_1": np.arange(5), "Moment_2": np.arange(7)}, + "shared_optimizer_state": {"step": np.arange(1)}, }, - 'optimizer': { - 'a': { - 'Moment_1': np.arange(5), - 'Moment_2': np.arange(7) - }, - 'shared_optimizer_state': { - 'step': np.arange(1) - } - } } input_state_dict = { - 'model': { - 'full_precision': { - 'a': np.array([1, 2]), - 'b': np.array([3, 4]) - } - }, - 'optimizer': { - 'a': { - 'Moment_1': np.array([5, 6]), - 'Moment_2': np.array([7, 8]) - }, - 'shared_optimizer_state': { - 'step': np.array([9]) - } + "model": {"full_precision": {"a": np.array([1, 2]), "b": np.array([3, 4])}}, + "optimizer": { + "a": {"Moment_1": np.array([5, 6]), "Moment_2": np.array([7, 8])}, + "shared_optimizer_state": {"step": np.array([9])}, }, - 'trainer_options': { - 'optimizer_name': 'LambOptimizer' - } + "trainer_options": {"optimizer_name": "LambOptimizer"}, } trainer._training_session = _training_session_mock({}, {}, {}) trainer.state_dict = Mock(return_value=training_session_state_dict) trainer._onnx_model = onnx_model_mock() trainer._onnx_model.graph.initializer = [ - onnx.numpy_helper.from_array(np.arange(20, dtype=np.float32), 'a'), - onnx.numpy_helper.from_array(np.arange(25, dtype=np.float32), 'b') + onnx.numpy_helper.from_array(np.arange(20, dtype=np.float32), "a"), + onnx.numpy_helper.from_array(np.arange(25, dtype=np.float32), "b"), ] trainer._update_onnx_model_initializers = Mock() trainer._init_session = Mock() @@ -452,388 +367,354 @@ def test_load_state_dict_loads_the_states_and_inits_training_session(onnx_model_ loaded_initializers, _ = trainer._update_onnx_model_initializers.call_args state_dict_to_load, _ = trainer._init_session.call_args - assert 'a' in loaded_initializers[0] - assert (loaded_initializers[0]['a'] == np.array([1, 2])).all() - assert 'b' in loaded_initializers[0] - assert (loaded_initializers[0]['b'] == np.array([3, 4])).all() + assert "a" in loaded_initializers[0] + assert (loaded_initializers[0]["a"] == np.array([1, 2])).all() + assert "b" in loaded_initializers[0] + assert (loaded_initializers[0]["b"] == np.array([3, 4])).all() + + assert (state_dict_to_load[0]["a"]["Moment_1"] == np.array([5, 6])).all() + assert (state_dict_to_load[0]["a"]["Moment_2"] == np.array([7, 8])).all() + assert (state_dict_to_load[0]["shared_optimizer_state"]["step"] == np.array([9])).all() - assert (state_dict_to_load[0]['a']['Moment_1'] == np.array([5, 6])).all() - assert (state_dict_to_load[0]['a']['Moment_2'] == np.array([7, 8])).all() - assert (state_dict_to_load[0]['shared_optimizer_state']['step'] == np.array([9])).all() -@patch('onnxruntime.training._checkpoint_storage.save') +@patch("onnxruntime.training._checkpoint_storage.save") def test_save_checkpoint_calls_checkpoint_storage_save(save_mock): trainer = _create_trainer() - state_dict = { - 'model': {}, - 'optimizer': {} - } + state_dict = {"model": {}, "optimizer": {}} trainer.state_dict = Mock(return_value=state_dict) - trainer.save_checkpoint('abc') + trainer.save_checkpoint("abc") save_args, _ = save_mock.call_args - assert 'model' in save_args[0] - assert not bool(save_args[0]['model']) - assert 'optimizer' in save_args[0] - assert not bool(save_args[0]['optimizer']) - assert save_args[1] == 'abc' + assert "model" in save_args[0] + assert not bool(save_args[0]["model"]) + assert "optimizer" in save_args[0] + assert not bool(save_args[0]["optimizer"]) + assert save_args[1] == "abc" + -@patch('onnxruntime.training._checkpoint_storage.save') +@patch("onnxruntime.training._checkpoint_storage.save") def test_save_checkpoint_exclude_optimizer_states(save_mock): trainer = _create_trainer() - state_dict = { - 'model': {}, - 'optimizer': {} - } + state_dict = {"model": {}, "optimizer": {}} trainer.state_dict = Mock(return_value=state_dict) - trainer.save_checkpoint('abc', include_optimizer_states=False) + trainer.save_checkpoint("abc", include_optimizer_states=False) save_args, _ = save_mock.call_args - assert 'model' in save_args[0] - assert not bool(save_args[0]['model']) - assert 'optimizer' not in save_args[0] - assert save_args[1] == 'abc' + assert "model" in save_args[0] + assert not bool(save_args[0]["model"]) + assert "optimizer" not in save_args[0] + assert save_args[1] == "abc" + -@patch('onnxruntime.training._checkpoint_storage.save') +@patch("onnxruntime.training._checkpoint_storage.save") def test_save_checkpoint_user_dict(save_mock): trainer = _create_trainer() - state_dict = { - 'model': {}, - 'optimizer': {} - } + state_dict = {"model": {}, "optimizer": {}} trainer.state_dict = Mock(return_value=state_dict) - trainer.save_checkpoint('abc', user_dict={'abc': np.arange(4)}) + trainer.save_checkpoint("abc", user_dict={"abc": np.arange(4)}) save_args, _ = save_mock.call_args - assert 'user_dict' in save_args[0] - assert save_args[0]['user_dict'] == _checkpoint_storage.to_serialized_hex({'abc': np.arange(4)}) + assert "user_dict" in save_args[0] + assert save_args[0]["user_dict"] == _checkpoint_storage.to_serialized_hex({"abc": np.arange(4)}) + -@patch('onnxruntime.training._checkpoint_storage.load') -@patch('onnxruntime.training.checkpoint.aggregate_checkpoints') +@patch("onnxruntime.training._checkpoint_storage.load") +@patch("onnxruntime.training.checkpoint.aggregate_checkpoints") def test_load_checkpoint(aggregate_checkpoints_mock, load_mock): trainer = _create_trainer() trainer_options = { - 'mixed_precision': np.bool_(False), - 'world_rank': np.int64(0), - 'world_size': np.int64(1), - 'horizontal_parallel_size' : np.int64(1), - 'data_parallel_size' : np.int64(1), - 'zero_stage': np.int64(0) + "mixed_precision": np.bool_(False), + "world_rank": np.int64(0), + "world_size": np.int64(1), + "horizontal_parallel_size": np.int64(1), + "data_parallel_size": np.int64(1), + "zero_stage": np.int64(0), } state_dict = { - 'model': {}, - 'optimizer': {}, - 'trainer_options': { - 'mixed_precision': np.bool_(False), - 'world_rank': np.int64(0), - 'world_size': np.int64(1), - 'horizontal_parallel_size' : np.int64(1), - 'data_parallel_size' : np.int64(1), - 'zero_stage': np.int64(0) - } + "model": {}, + "optimizer": {}, + "trainer_options": { + "mixed_precision": np.bool_(False), + "world_rank": np.int64(0), + "world_size": np.int64(1), + "horizontal_parallel_size": np.int64(1), + "data_parallel_size": np.int64(1), + "zero_stage": np.int64(0), + }, } trainer.load_state_dict = Mock() load_mock.side_effect = [trainer_options, state_dict] - trainer.load_checkpoint('abc') + trainer.load_checkpoint("abc") args_list = load_mock.call_args_list load_args, load_kwargs = args_list[0] - assert load_args[0] == 'abc' - assert load_kwargs['key'] == 'trainer_options' + assert load_args[0] == "abc" + assert load_kwargs["key"] == "trainer_options" load_args, load_kwargs = args_list[1] - assert load_args[0] == 'abc' - assert 'key' not in load_kwargs + assert load_args[0] == "abc" + assert "key" not in load_kwargs assert not aggregate_checkpoints_mock.called -@patch('onnxruntime.training._checkpoint_storage.load') -@patch('onnxruntime.training.checkpoint.aggregate_checkpoints') -@pytest.mark.parametrize("trainer_options", [ - { - 'mixed_precision': np.bool_(False), - 'world_rank': np.int64(0), - 'world_size': np.int64(4), - 'horizontal_parallel_size' : np.int64(1), - 'data_parallel_size' : np.int64(4), - 'zero_stage': np.int64(1) - }, - { - 'mixed_precision': np.bool_(True), - 'world_rank': np.int64(0), - 'world_size': np.int64(1), - 'horizontal_parallel_size' : np.int64(1), - 'data_parallel_size' : np.int64(1), - 'zero_stage': np.int64(1) - }, - { - 'mixed_precision': np.bool_(True), - 'world_rank': np.int64(0), - 'world_size': np.int64(1), - 'horizontal_parallel_size' : np.int64(1), - 'data_parallel_size' : np.int64(1), - 'zero_stage': np.int64(1) - } -]) + +@patch("onnxruntime.training._checkpoint_storage.load") +@patch("onnxruntime.training.checkpoint.aggregate_checkpoints") +@pytest.mark.parametrize( + "trainer_options", + [ + { + "mixed_precision": np.bool_(False), + "world_rank": np.int64(0), + "world_size": np.int64(4), + "horizontal_parallel_size": np.int64(1), + "data_parallel_size": np.int64(4), + "zero_stage": np.int64(1), + }, + { + "mixed_precision": np.bool_(True), + "world_rank": np.int64(0), + "world_size": np.int64(1), + "horizontal_parallel_size": np.int64(1), + "data_parallel_size": np.int64(1), + "zero_stage": np.int64(1), + }, + { + "mixed_precision": np.bool_(True), + "world_rank": np.int64(0), + "world_size": np.int64(1), + "horizontal_parallel_size": np.int64(1), + "data_parallel_size": np.int64(1), + "zero_stage": np.int64(1), + }, + ], +) def test_load_checkpoint_aggregation_required_zero_enabled(aggregate_checkpoints_mock, load_mock, trainer_options): trainer = _create_trainer() trainer.load_state_dict = Mock() load_mock.side_effect = [trainer_options] - trainer.load_checkpoint('abc') + trainer.load_checkpoint("abc") args_list = load_mock.call_args_list load_args, load_kwargs = args_list[0] - assert load_args[0] == 'abc' - assert load_kwargs['key'] == 'trainer_options' + assert load_args[0] == "abc" + assert load_kwargs["key"] == "trainer_options" assert aggregate_checkpoints_mock.called call_args, _ = aggregate_checkpoints_mock.call_args - assert call_args[0] == tuple(['abc']) + assert call_args[0] == tuple(["abc"]) + -@patch('onnxruntime.training._checkpoint_storage.load') -@patch('onnxruntime.training.checkpoint.aggregate_checkpoints') +@patch("onnxruntime.training._checkpoint_storage.load") +@patch("onnxruntime.training.checkpoint.aggregate_checkpoints") def test_load_checkpoint_user_dict(aggregate_checkpoints_mock, load_mock): trainer = _create_trainer() trainer_options = { - 'mixed_precision': np.bool_(False), - 'world_rank': np.int64(0), - 'world_size': np.int64(1), - 'horizontal_parallel_size': np.int64(1), - 'data_parallel_size': np.int64(1), - 'zero_stage': np.int64(0) + "mixed_precision": np.bool_(False), + "world_rank": np.int64(0), + "world_size": np.int64(1), + "horizontal_parallel_size": np.int64(1), + "data_parallel_size": np.int64(1), + "zero_stage": np.int64(0), } state_dict = { - 'model': {}, - 'optimizer': {}, - 'trainer_options': { - 'mixed_precision': np.bool_(False), - 'world_rank': np.int64(0), - 'world_size': np.int64(1), - 'horizontal_parallel_size': np.int64(1), - 'data_parallel_size': np.int64(1), - 'zero_stage': np.int64(0) + "model": {}, + "optimizer": {}, + "trainer_options": { + "mixed_precision": np.bool_(False), + "world_rank": np.int64(0), + "world_size": np.int64(1), + "horizontal_parallel_size": np.int64(1), + "data_parallel_size": np.int64(1), + "zero_stage": np.int64(0), }, - 'user_dict': _checkpoint_storage.to_serialized_hex({'array': torch.tensor(np.arange(5))}) + "user_dict": _checkpoint_storage.to_serialized_hex({"array": torch.tensor(np.arange(5))}), } trainer.load_state_dict = Mock() load_mock.side_effect = [trainer_options, state_dict] - user_dict = trainer.load_checkpoint('abc') + user_dict = trainer.load_checkpoint("abc") - assert torch.all(torch.eq(user_dict['array'], torch.tensor(np.arange(5)))) + assert torch.all(torch.eq(user_dict["array"], torch.tensor(np.arange(5)))) -@patch('onnxruntime.training._checkpoint_storage.load') + +@patch("onnxruntime.training._checkpoint_storage.load") def test_checkpoint_aggregation(load_mock): trainer_options1 = { - 'mixed_precision': np.bool_(False), - 'world_rank': np.int64(0), - 'world_size': np.int64(2), - 'horizontal_parallel_size' : np.int64(1), - 'data_parallel_size' : np.int64(2), - 'zero_stage': np.int64(1), - 'optimizer_name': b'Adam' + "mixed_precision": np.bool_(False), + "world_rank": np.int64(0), + "world_size": np.int64(2), + "horizontal_parallel_size": np.int64(1), + "data_parallel_size": np.int64(2), + "zero_stage": np.int64(1), + "optimizer_name": b"Adam", } trainer_options2 = { - 'mixed_precision': np.bool_(False), - 'world_rank': np.int64(1), - 'world_size': np.int64(2), - 'horizontal_parallel_size' : np.int64(1), - 'data_parallel_size' : np.int64(2), - 'zero_stage': np.int64(1), - 'optimizer_name': b'Adam' + "mixed_precision": np.bool_(False), + "world_rank": np.int64(1), + "world_size": np.int64(2), + "horizontal_parallel_size": np.int64(1), + "data_parallel_size": np.int64(2), + "zero_stage": np.int64(1), + "optimizer_name": b"Adam", } state_dict1 = { - 'model': { - 'full_precision': { - 'optimizer_sharded': np.array([1, 2, 3]), - 'non_sharded': np.array([11, 22, 33]) - } - }, - 'optimizer': { - 'optimizer_sharded': { - 'Moment_1': np.array([9, 8, 7]), - 'Moment_2': np.array([99, 88, 77]), - 'Step': np.array([5]) + "model": {"full_precision": {"optimizer_sharded": np.array([1, 2, 3]), "non_sharded": np.array([11, 22, 33])}}, + "optimizer": { + "optimizer_sharded": { + "Moment_1": np.array([9, 8, 7]), + "Moment_2": np.array([99, 88, 77]), + "Step": np.array([5]), + }, + "non_sharded": { + "Moment_1": np.array([666, 555, 444]), + "Moment_2": np.array([6666, 5555, 4444]), + "Step": np.array([55]), }, - 'non_sharded': { - 'Moment_1': np.array([666, 555, 444]), - 'Moment_2': np.array([6666, 5555, 4444]), - 'Step': np.array([55]) - } }, - 'trainer_options': { - 'mixed_precision': np.bool_(False), - 'world_rank': np.int64(0), - 'world_size': np.int64(1), - 'horizontal_parallel_size' : np.int64(1), - 'data_parallel_size' : np.int64(1), - 'zero_stage': np.int64(0), - 'optimizer_name': b'Adam' + "trainer_options": { + "mixed_precision": np.bool_(False), + "world_rank": np.int64(0), + "world_size": np.int64(1), + "horizontal_parallel_size": np.int64(1), + "data_parallel_size": np.int64(1), + "zero_stage": np.int64(0), + "optimizer_name": b"Adam", }, - 'partition_info': { - 'optimizer_sharded': {'original_dim': np.array([2, 3])} - } + "partition_info": {"optimizer_sharded": {"original_dim": np.array([2, 3])}}, } state_dict2 = { - 'model': { - 'full_precision': { - 'optimizer_sharded': np.array([1, 2, 3]), - 'non_sharded': np.array([11, 22, 33]) - } - }, - 'optimizer': { - 'optimizer_sharded': { - 'Moment_1': np.array([6, 5, 4]), - 'Moment_2': np.array([66, 55, 44]), - 'Step': np.array([5]) + "model": {"full_precision": {"optimizer_sharded": np.array([1, 2, 3]), "non_sharded": np.array([11, 22, 33])}}, + "optimizer": { + "optimizer_sharded": { + "Moment_1": np.array([6, 5, 4]), + "Moment_2": np.array([66, 55, 44]), + "Step": np.array([5]), + }, + "non_sharded": { + "Moment_1": np.array([666, 555, 444]), + "Moment_2": np.array([6666, 5555, 4444]), + "Step": np.array([55]), }, - 'non_sharded': { - 'Moment_1': np.array([666, 555, 444]), - 'Moment_2': np.array([6666, 5555, 4444]), - 'Step': np.array([55]) - } }, - 'trainer_options': { - 'mixed_precision': np.bool_(False), - 'world_rank': np.int64(1), - 'world_size': np.int64(1), - 'horizontal_parallel_size' : np.int64(1), - 'data_parallel_size' : np.int64(1), - 'zero_stage': np.int64(0), - 'optimizer_name': b'Adam' + "trainer_options": { + "mixed_precision": np.bool_(False), + "world_rank": np.int64(1), + "world_size": np.int64(1), + "horizontal_parallel_size": np.int64(1), + "data_parallel_size": np.int64(1), + "zero_stage": np.int64(0), + "optimizer_name": b"Adam", }, - 'partition_info': { - 'optimizer_sharded': {'original_dim': np.array([2, 3])} - } + "partition_info": {"optimizer_sharded": {"original_dim": np.array([2, 3])}}, } load_mock.side_effect = [trainer_options1, trainer_options2, trainer_options1, state_dict1, state_dict2] - state_dict = checkpoint.aggregate_checkpoints(['abc', 'def'], pytorch_format=False) - - assert (state_dict['model']['full_precision']['optimizer_sharded'] == np.array([1, 2, 3])).all() - assert (state_dict['model']['full_precision']['non_sharded'] == np.array([11, 22, 33])).all() - assert (state_dict['optimizer']['optimizer_sharded']['Moment_1'] == np.array([[9, 8, 7], [6, 5, 4]])).all() - assert (state_dict['optimizer']['optimizer_sharded']['Moment_2'] == np.array([[99, 88, 77], [66, 55, 44]])).all() - assert (state_dict['optimizer']['optimizer_sharded']['Step'] == np.array([5])).all() - assert (state_dict['optimizer']['non_sharded']['Moment_1'] == np.array([666, 555, 444])).all() - assert (state_dict['optimizer']['non_sharded']['Moment_2'] == np.array([6666, 5555, 4444])).all() - assert (state_dict['optimizer']['non_sharded']['Step'] == np.array([55])).all() - - assert state_dict['trainer_options']['mixed_precision'] == False - assert state_dict['trainer_options']['world_rank'] == 0 - assert state_dict['trainer_options']['world_size'] == 1 - assert state_dict['trainer_options']['horizontal_parallel_size'] == 1 - assert state_dict['trainer_options']['data_parallel_size'] == 1 - assert state_dict['trainer_options']['zero_stage'] == 0 - assert state_dict['trainer_options']['optimizer_name'] == b'Adam' - -@patch('onnxruntime.training._checkpoint_storage.load') + state_dict = checkpoint.aggregate_checkpoints(["abc", "def"], pytorch_format=False) + + assert (state_dict["model"]["full_precision"]["optimizer_sharded"] == np.array([1, 2, 3])).all() + assert (state_dict["model"]["full_precision"]["non_sharded"] == np.array([11, 22, 33])).all() + assert (state_dict["optimizer"]["optimizer_sharded"]["Moment_1"] == np.array([[9, 8, 7], [6, 5, 4]])).all() + assert (state_dict["optimizer"]["optimizer_sharded"]["Moment_2"] == np.array([[99, 88, 77], [66, 55, 44]])).all() + assert (state_dict["optimizer"]["optimizer_sharded"]["Step"] == np.array([5])).all() + assert (state_dict["optimizer"]["non_sharded"]["Moment_1"] == np.array([666, 555, 444])).all() + assert (state_dict["optimizer"]["non_sharded"]["Moment_2"] == np.array([6666, 5555, 4444])).all() + assert (state_dict["optimizer"]["non_sharded"]["Step"] == np.array([55])).all() + + assert state_dict["trainer_options"]["mixed_precision"] == False + assert state_dict["trainer_options"]["world_rank"] == 0 + assert state_dict["trainer_options"]["world_size"] == 1 + assert state_dict["trainer_options"]["horizontal_parallel_size"] == 1 + assert state_dict["trainer_options"]["data_parallel_size"] == 1 + assert state_dict["trainer_options"]["zero_stage"] == 0 + assert state_dict["trainer_options"]["optimizer_name"] == b"Adam" + + +@patch("onnxruntime.training._checkpoint_storage.load") def test_checkpoint_aggregation_mixed_precision(load_mock): trainer_options1 = { - 'mixed_precision': np.bool_(True), - 'world_rank': np.int64(0), - 'world_size': np.int64(2), - 'horizontal_parallel_size': np.int64(1), - 'data_parallel_size': np.int64(2), - 'zero_stage': np.int64(1), - 'optimizer_name': b'Adam' + "mixed_precision": np.bool_(True), + "world_rank": np.int64(0), + "world_size": np.int64(2), + "horizontal_parallel_size": np.int64(1), + "data_parallel_size": np.int64(2), + "zero_stage": np.int64(1), + "optimizer_name": b"Adam", } trainer_options2 = { - 'mixed_precision': np.bool_(True), - 'world_rank': np.int64(1), - 'world_size': np.int64(2), - 'horizontal_parallel_size': np.int64(1), - 'data_parallel_size': np.int64(2), - 'zero_stage': np.int64(1), - 'optimizer_name': b'Adam' + "mixed_precision": np.bool_(True), + "world_rank": np.int64(1), + "world_size": np.int64(2), + "horizontal_parallel_size": np.int64(1), + "data_parallel_size": np.int64(2), + "zero_stage": np.int64(1), + "optimizer_name": b"Adam", } state_dict1 = { - 'model': { - 'full_precision': { - 'sharded': np.array([1, 2, 3]), - 'non_sharded': np.array([11, 22, 33]) - } - }, - 'optimizer': { - 'sharded': { - 'Moment_1': np.array([9, 8, 7]), - 'Moment_2': np.array([99, 88, 77]), - 'Step': np.array([5]) + "model": {"full_precision": {"sharded": np.array([1, 2, 3]), "non_sharded": np.array([11, 22, 33])}}, + "optimizer": { + "sharded": {"Moment_1": np.array([9, 8, 7]), "Moment_2": np.array([99, 88, 77]), "Step": np.array([5])}, + "non_sharded": { + "Moment_1": np.array([666, 555, 444]), + "Moment_2": np.array([6666, 5555, 4444]), + "Step": np.array([55]), }, - 'non_sharded': { - 'Moment_1': np.array([666, 555, 444]), - 'Moment_2': np.array([6666, 5555, 4444]), - 'Step': np.array([55]) - } }, - 'trainer_options': { - 'mixed_precision': np.bool_(True), - 'world_rank': np.int64(0), - 'world_size': np.int64(1), - 'horizontal_parallel_size': np.int64(1), - 'data_parallel_size': np.int64(1), - 'zero_stage': np.int64(0), - 'optimizer_name': b'Adam' + "trainer_options": { + "mixed_precision": np.bool_(True), + "world_rank": np.int64(0), + "world_size": np.int64(1), + "horizontal_parallel_size": np.int64(1), + "data_parallel_size": np.int64(1), + "zero_stage": np.int64(0), + "optimizer_name": b"Adam", }, - 'partition_info': { - 'sharded': {'original_dim': np.array([2, 3])} - } + "partition_info": {"sharded": {"original_dim": np.array([2, 3])}}, } state_dict2 = { - 'model': { - 'full_precision': { - 'sharded': np.array([4, 5, 6]), - 'non_sharded': np.array([11, 22, 33]) - } - }, - 'optimizer': { - 'sharded': { - 'Moment_1': np.array([6, 5, 4]), - 'Moment_2': np.array([66, 55, 44]), - 'Step': np.array([5]) + "model": {"full_precision": {"sharded": np.array([4, 5, 6]), "non_sharded": np.array([11, 22, 33])}}, + "optimizer": { + "sharded": {"Moment_1": np.array([6, 5, 4]), "Moment_2": np.array([66, 55, 44]), "Step": np.array([5])}, + "non_sharded": { + "Moment_1": np.array([666, 555, 444]), + "Moment_2": np.array([6666, 5555, 4444]), + "Step": np.array([55]), }, - 'non_sharded': { - 'Moment_1': np.array([666, 555, 444]), - 'Moment_2': np.array([6666, 5555, 4444]), - 'Step': np.array([55]) - } }, - 'trainer_options': { - 'mixed_precision': np.bool_(True), - 'world_rank': np.int64(1), - 'world_size': np.int64(1), - 'horizontal_parallel_size': np.int64(1), - 'data_parallel_size': np.int64(1), - 'zero_stage': np.int64(0), - 'optimizer_name': b'Adam' + "trainer_options": { + "mixed_precision": np.bool_(True), + "world_rank": np.int64(1), + "world_size": np.int64(1), + "horizontal_parallel_size": np.int64(1), + "data_parallel_size": np.int64(1), + "zero_stage": np.int64(0), + "optimizer_name": b"Adam", }, - 'partition_info': { - 'sharded': {'original_dim': np.array([2, 3])} - } + "partition_info": {"sharded": {"original_dim": np.array([2, 3])}}, } load_mock.side_effect = [trainer_options1, trainer_options2, trainer_options1, state_dict1, state_dict2] - state_dict = checkpoint.aggregate_checkpoints(['abc', 'def'], pytorch_format=False) - - assert (state_dict['model']['full_precision']['sharded'] == np.array([[1, 2, 3], [4, 5, 6]])).all() - assert (state_dict['model']['full_precision']['non_sharded'] == np.array([11, 22, 33])).all() - assert (state_dict['optimizer']['sharded']['Moment_1'] == np.array([[9, 8, 7], [6, 5, 4]])).all() - assert (state_dict['optimizer']['sharded']['Moment_2'] == np.array([[99, 88, 77], [66, 55, 44]])).all() - assert (state_dict['optimizer']['sharded']['Step'] == np.array([5])).all() - assert (state_dict['optimizer']['non_sharded']['Moment_1'] == np.array([666, 555, 444])).all() - assert (state_dict['optimizer']['non_sharded']['Moment_2'] == np.array([6666, 5555, 4444])).all() - assert (state_dict['optimizer']['non_sharded']['Step'] == np.array([55])).all() - - assert state_dict['trainer_options']['mixed_precision'] == True - assert state_dict['trainer_options']['world_rank'] == 0 - assert state_dict['trainer_options']['world_size'] == 1 - assert state_dict['trainer_options']['horizontal_parallel_size'] == 1 - assert state_dict['trainer_options']['data_parallel_size'] == 1 - assert state_dict['trainer_options']['zero_stage'] == 0 - assert state_dict['trainer_options']['optimizer_name'] == b'Adam' + state_dict = checkpoint.aggregate_checkpoints(["abc", "def"], pytorch_format=False) + + assert (state_dict["model"]["full_precision"]["sharded"] == np.array([[1, 2, 3], [4, 5, 6]])).all() + assert (state_dict["model"]["full_precision"]["non_sharded"] == np.array([11, 22, 33])).all() + assert (state_dict["optimizer"]["sharded"]["Moment_1"] == np.array([[9, 8, 7], [6, 5, 4]])).all() + assert (state_dict["optimizer"]["sharded"]["Moment_2"] == np.array([[99, 88, 77], [66, 55, 44]])).all() + assert (state_dict["optimizer"]["sharded"]["Step"] == np.array([5])).all() + assert (state_dict["optimizer"]["non_sharded"]["Moment_1"] == np.array([666, 555, 444])).all() + assert (state_dict["optimizer"]["non_sharded"]["Moment_2"] == np.array([6666, 5555, 4444])).all() + assert (state_dict["optimizer"]["non_sharded"]["Step"] == np.array([55])).all() + + assert state_dict["trainer_options"]["mixed_precision"] == True + assert state_dict["trainer_options"]["world_rank"] == 0 + assert state_dict["trainer_options"]["world_size"] == 1 + assert state_dict["trainer_options"]["horizontal_parallel_size"] == 1 + assert state_dict["trainer_options"]["data_parallel_size"] == 1 + assert state_dict["trainer_options"]["zero_stage"] == 0 + assert state_dict["trainer_options"]["optimizer_name"] == b"Adam" diff --git a/orttraining/orttraining/test/python/orttraining_test_orttrainer_frontend.py b/orttraining/orttraining/test/python/orttraining_test_orttrainer_frontend.py index 26f3ae91c0ed0..5dd9e1368420d 100644 --- a/orttraining/orttraining/test/python/orttraining_test_orttrainer_frontend.py +++ b/orttraining/orttraining/test/python/orttraining_test_orttrainer_frontend.py @@ -11,14 +11,23 @@ import torch.nn.functional as F from onnxruntime import set_seed -from onnxruntime.capi.ort_trainer import IODescription as Legacy_IODescription,\ - ModelDescription as Legacy_ModelDescription,\ - LossScaler as Legacy_LossScaler,\ - ORTTrainer as Legacy_ORTTrainer -from onnxruntime.training import _utils, amp, checkpoint, optim, orttrainer, TrainStepInfo,\ - model_desc_validation as md_val,\ - orttrainer_options as orttrainer_options -import _test_commons,_test_helpers +from onnxruntime.capi.ort_trainer import ( + IODescription as Legacy_IODescription, + ModelDescription as Legacy_ModelDescription, + LossScaler as Legacy_LossScaler, + ORTTrainer as Legacy_ORTTrainer, +) +from onnxruntime.training import ( + _utils, + amp, + checkpoint, + optim, + orttrainer, + TrainStepInfo, + model_desc_validation as md_val, + orttrainer_options as orttrainer_options, +) +import _test_commons, _test_helpers from onnxruntime import SessionOptions from onnxruntime.training import PropagateCastOpsStrategy @@ -26,128 +35,141 @@ # Testing starts here ######################################################### ############################################################################### -pytorch_110 = StrictVersion('.'.join(torch.__version__.split('.')[:2])) >= StrictVersion('1.10.0') +pytorch_110 = StrictVersion(".".join(torch.__version__.split(".")[:2])) >= StrictVersion("1.10.0") def get_model_opset(model_onnx): for op in model_onnx.opset_import: - if op.domain == '': + if op.domain == "": return op.version return None -@pytest.mark.parametrize("test_input", [ - ({}), - ({'batch': {}, - 'device': {}, - 'distributed': {}, - 'mixed_precision': {}, - 'utils': {}, - '_internal_use': {}}) -]) +@pytest.mark.parametrize( + "test_input", + [({}), ({"batch": {}, "device": {}, "distributed": {}, "mixed_precision": {}, "utils": {}, "_internal_use": {}})], +) def testORTTrainerOptionsDefaultValues(test_input): - ''' Test different ways of using default values for incomplete input''' + """Test different ways of using default values for incomplete input""" expected_values = { - 'batch': { - 'gradient_accumulation_steps': 1 - }, - 'device': { - 'id': 'cuda', - 'mem_limit': 0 - }, - 'distributed': { - 'world_rank': 0, - 'world_size': 1, - 'local_rank': 0, - 'data_parallel_size': 1, - 'horizontal_parallel_size': 1, - 'pipeline_parallel' : { - 'pipeline_parallel_size': 1, - 'num_pipeline_micro_batches':1, - 'pipeline_cut_info_string': '', - 'sliced_schema' : {}, - 'sliced_axes' : {}, - 'sliced_tensor_names': [] + "batch": {"gradient_accumulation_steps": 1}, + "device": {"id": "cuda", "mem_limit": 0}, + "distributed": { + "world_rank": 0, + "world_size": 1, + "local_rank": 0, + "data_parallel_size": 1, + "horizontal_parallel_size": 1, + "pipeline_parallel": { + "pipeline_parallel_size": 1, + "num_pipeline_micro_batches": 1, + "pipeline_cut_info_string": "", + "sliced_schema": {}, + "sliced_axes": {}, + "sliced_tensor_names": [], }, - 'allreduce_post_accumulation': False, - 'data_parallel_size': 1, - 'horizontal_parallel_size':1, - 'deepspeed_zero_optimization': { - 'stage' : 0, + "allreduce_post_accumulation": False, + "data_parallel_size": 1, + "horizontal_parallel_size": 1, + "deepspeed_zero_optimization": { + "stage": 0, }, - 'enable_adasum': False, + "enable_adasum": False, }, - 'lr_scheduler': None, - 'mixed_precision': { - 'enabled': False, - 'loss_scaler': None + "lr_scheduler": None, + "mixed_precision": {"enabled": False, "loss_scaler": None}, + "graph_transformer": { + "attn_dropout_recompute": False, + "gelu_recompute": False, + "transformer_layer_recompute": False, + "number_recompute_layers": 0, + "allow_layer_norm_mod_precision": False, + "propagate_cast_ops_config": {"strategy": PropagateCastOpsStrategy.FLOOD_FILL, "level": 1, "allow": []}, }, - 'graph_transformer': { - 'attn_dropout_recompute': False, - 'gelu_recompute': False, - 'transformer_layer_recompute': False, - 'number_recompute_layers': 0, - 'allow_layer_norm_mod_precision': False, - 'propagate_cast_ops_config': { - 'strategy': PropagateCastOpsStrategy.FLOOD_FILL, - 'level': 1, - 'allow': [] - } + "utils": { + "frozen_weights": [], + "grad_norm_clip": True, + "memory_efficient_gradient": False, + "run_symbolic_shape_infer": False, }, - 'utils': { - 'frozen_weights': [], - 'grad_norm_clip': True, - 'memory_efficient_gradient': False, - 'run_symbolic_shape_infer': False - }, - 'debug': { - 'deterministic_compute': False, - 'check_model_export': False, - 'graph_save_paths' : { - 'model_after_graph_transforms_path': '', - 'model_with_gradient_graph_path': '', - 'model_with_training_graph_path': '', - 'model_with_training_graph_after_optimization_path': '' - } + "debug": { + "deterministic_compute": False, + "check_model_export": False, + "graph_save_paths": { + "model_after_graph_transforms_path": "", + "model_with_gradient_graph_path": "", + "model_with_training_graph_path": "", + "model_with_training_graph_after_optimization_path": "", + }, }, - '_internal_use': { - 'enable_internal_postprocess': True, - 'extra_postprocess': None, - 'onnx_opset_version' : 14, - 'enable_onnx_contrib_ops': True, + "_internal_use": { + "enable_internal_postprocess": True, + "extra_postprocess": None, + "onnx_opset_version": 14, + "enable_onnx_contrib_ops": True, }, - 'provider_options':{}, - 'session_options': None, + "provider_options": {}, + "session_options": None, } actual_values = orttrainer_options.ORTTrainerOptions(test_input) assert actual_values._validated_opts == expected_values -@pytest.mark.parametrize("input,error_msg", [ - ({'mixed_precision': {'enabled': 1}},\ - "Invalid options: {'mixed_precision': [{'enabled': ['must be of boolean type']}]}") -]) +@pytest.mark.parametrize( + "input,error_msg", + [ + ( + {"mixed_precision": {"enabled": 1}}, + "Invalid options: {'mixed_precision': [{'enabled': ['must be of boolean type']}]}", + ) + ], +) def testORTTrainerOptionsInvalidMixedPrecisionEnabledSchema(input, error_msg): - '''Test an invalid input based on schema validation error message''' + """Test an invalid input based on schema validation error message""" with pytest.raises(ValueError) as e: orttrainer_options.ORTTrainerOptions(input) assert str(e.value) == error_msg -@pytest.mark.parametrize("input_dict,input_dtype,output_dtype", [ - ({'inputs': [('in0', [])], - 'outputs': [('out0', []), ('out1', [])]},(torch.int,),(torch.float,torch.int32,)), - ({'inputs': [('in0', ['batch', 2, 3])], - 'outputs': [('out0', [], True)]}, (torch.int8,), (torch.int16,)), - ({'inputs': [('in0', []), ('in1', [1]), ('in2', [1, 2]), ('in3', [1000, 'dyn_ax1']), ('in4', ['dyn_ax1', 'dyn_ax2', 'dyn_ax3'])], - 'outputs': [('out0', [], True), ('out1', [1], False), ('out2', [1, 'dyn_ax1', 3])]}, - (torch.float,torch.uint8,torch.bool,torch.double,torch.half,), (torch.float,torch.float,torch.int64)) -]) +@pytest.mark.parametrize( + "input_dict,input_dtype,output_dtype", + [ + ( + {"inputs": [("in0", [])], "outputs": [("out0", []), ("out1", [])]}, + (torch.int,), + ( + torch.float, + torch.int32, + ), + ), + ({"inputs": [("in0", ["batch", 2, 3])], "outputs": [("out0", [], True)]}, (torch.int8,), (torch.int16,)), + ( + { + "inputs": [ + ("in0", []), + ("in1", [1]), + ("in2", [1, 2]), + ("in3", [1000, "dyn_ax1"]), + ("in4", ["dyn_ax1", "dyn_ax2", "dyn_ax3"]), + ], + "outputs": [("out0", [], True), ("out1", [1], False), ("out2", [1, "dyn_ax1", 3])], + }, + ( + torch.float, + torch.uint8, + torch.bool, + torch.double, + torch.half, + ), + (torch.float, torch.float, torch.int64), + ), + ], +) def testORTTrainerModelDescValidSchemas(input_dict, input_dtype, output_dtype): - r''' Test different ways of using default values for incomplete input''' + r"""Test different ways of using default values for incomplete input""" model_description = md_val._ORTTrainerModelDesc(input_dict) @@ -160,14 +182,14 @@ def testORTTrainerModelDescValidSchemas(input_dict, input_dtype, output_dtype): for idx, i_desc in enumerate(model_description.inputs): assert isinstance(i_desc, model_description._InputDescription) assert len(i_desc) == 2 - assert input_dict['inputs'][idx][0] == i_desc.name - assert input_dict['inputs'][idx][1] == i_desc.shape + assert input_dict["inputs"][idx][0] == i_desc.name + assert input_dict["inputs"][idx][1] == i_desc.shape for idx, o_desc in enumerate(model_description.outputs): assert isinstance(o_desc, model_description._OutputDescription) assert len(o_desc) == 3 - assert input_dict['outputs'][idx][0] == o_desc.name - assert input_dict['outputs'][idx][1] == o_desc.shape - is_loss = input_dict['outputs'][idx][2] if len(input_dict['outputs'][idx]) == 3 else False + assert input_dict["outputs"][idx][0] == o_desc.name + assert input_dict["outputs"][idx][1] == o_desc.shape + is_loss = input_dict["outputs"][idx][2] if len(input_dict["outputs"][idx]) == 3 else False assert is_loss == o_desc.is_loss # Set all_finite name and check its description @@ -197,34 +219,44 @@ def testORTTrainerModelDescValidSchemas(input_dict, input_dtype, output_dtype): assert output_dtype[idx] == o_desc.dtype -@pytest.mark.parametrize("input_dict,error_msg", [ - ({'inputs': [(True, [])], - 'outputs': [(True, [])]}, - "Invalid model_desc: {'inputs': [{0: ['the first element of the tuple (aka name) must be a string']}], " - "'outputs': [{0: ['the first element of the tuple (aka name) must be a string']}]}"), - ({'inputs': [('in1', None)], - 'outputs': [('out1', None)]}, - "Invalid model_desc: {'inputs': [{0: ['the second element of the tuple (aka shape) must be a list']}], " - "'outputs': [{0: ['the second element of the tuple (aka shape) must be a list']}]}"), - ({'inputs': [('in1', [])], - 'outputs': [('out1', [], None)]}, - "Invalid model_desc: {'outputs': [{0: ['the third element of the tuple (aka is_loss) must be a boolean']}]}"), - ({'inputs': [('in1', [True])], - 'outputs': [('out1', [True])]}, - "Invalid model_desc: {'inputs': [{0: ['each shape must be either a string or integer']}], " - "'outputs': [{0: ['each shape must be either a string or integer']}]}"), - ({'inputs': [('in1', [])], - 'outputs': [('out1', [], True), ('out2', [], True)]}, - "Invalid model_desc: {'outputs': [{1: ['only one is_loss can bet set to True']}]}"), - ({'inputz': [('in1', [])], - 'outputs': [('out1', [], True)]}, - "Invalid model_desc: {'inputs': ['required field'], 'inputz': ['unknown field']}"), - ({'inputs': [('in1', [])], - 'outputz': [('out1', [], True)]}, - "Invalid model_desc: {'outputs': ['required field'], 'outputz': ['unknown field']}"), -]) +@pytest.mark.parametrize( + "input_dict,error_msg", + [ + ( + {"inputs": [(True, [])], "outputs": [(True, [])]}, + "Invalid model_desc: {'inputs': [{0: ['the first element of the tuple (aka name) must be a string']}], " + "'outputs': [{0: ['the first element of the tuple (aka name) must be a string']}]}", + ), + ( + {"inputs": [("in1", None)], "outputs": [("out1", None)]}, + "Invalid model_desc: {'inputs': [{0: ['the second element of the tuple (aka shape) must be a list']}], " + "'outputs': [{0: ['the second element of the tuple (aka shape) must be a list']}]}", + ), + ( + {"inputs": [("in1", [])], "outputs": [("out1", [], None)]}, + "Invalid model_desc: {'outputs': [{0: ['the third element of the tuple (aka is_loss) must be a boolean']}]}", + ), + ( + {"inputs": [("in1", [True])], "outputs": [("out1", [True])]}, + "Invalid model_desc: {'inputs': [{0: ['each shape must be either a string or integer']}], " + "'outputs': [{0: ['each shape must be either a string or integer']}]}", + ), + ( + {"inputs": [("in1", [])], "outputs": [("out1", [], True), ("out2", [], True)]}, + "Invalid model_desc: {'outputs': [{1: ['only one is_loss can bet set to True']}]}", + ), + ( + {"inputz": [("in1", [])], "outputs": [("out1", [], True)]}, + "Invalid model_desc: {'inputs': ['required field'], 'inputz': ['unknown field']}", + ), + ( + {"inputs": [("in1", [])], "outputz": [("out1", [], True)]}, + "Invalid model_desc: {'outputs': ['required field'], 'outputz': ['unknown field']}", + ), + ], +) def testORTTrainerModelDescInvalidSchemas(input_dict, error_msg): - r''' Test different ways of using default values for incomplete input''' + r"""Test different ways of using default values for incomplete input""" with pytest.raises(ValueError) as e: md_val._ORTTrainerModelDesc(input_dict) assert str(e.value) == error_msg @@ -236,13 +268,10 @@ def testDynamicLossScaler(): # Initial state train_step_info = orttrainer.TrainStepInfo(optim.LambConfig()) - assert_allclose(default_scaler.loss_scale, float(1 << 16), - rtol=rtol, err_msg="loss scale mismatch") + assert_allclose(default_scaler.loss_scale, float(1 << 16), rtol=rtol, err_msg="loss scale mismatch") assert default_scaler.up_scale_window == 2000 - assert_allclose(default_scaler.min_loss_scale, 1.0, - rtol=rtol, err_msg="min loss scale mismatch") - assert_allclose(default_scaler.max_loss_scale, float( - 1 << 24), rtol=rtol, err_msg="max loss scale mismatch") + assert_allclose(default_scaler.min_loss_scale, 1.0, rtol=rtol, err_msg="min loss scale mismatch") + assert_allclose(default_scaler.max_loss_scale, float(1 << 24), rtol=rtol, err_msg="max loss scale mismatch") # Performing 9*2000 updates to cover all branches of LossScaler.update(train_step_info.all_finite=True) loss_scale = float(1 << 16) @@ -252,80 +281,66 @@ def testDynamicLossScaler(): for i in range(1, 2000): new_loss_scale = default_scaler.update(train_step_info) assert default_scaler._stable_steps_count == i - assert_allclose(new_loss_scale, loss_scale, - rtol=rtol, err_msg=f"loss scale mismatch at update {i}") + assert_allclose(new_loss_scale, loss_scale, rtol=rtol, err_msg=f"loss scale mismatch at update {i}") # 2000th update without overflow doubles the loss and zero stable steps until max_loss_scale is reached new_loss_scale = default_scaler.update(train_step_info) if cycles <= 8: loss_scale *= 2 assert default_scaler._stable_steps_count == 0 - assert_allclose(new_loss_scale, loss_scale, - rtol=rtol, err_msg="loss scale mismatch") + assert_allclose(new_loss_scale, loss_scale, rtol=rtol, err_msg="loss scale mismatch") # After 8 cycles, loss scale should be float(1 << 16)*(2**8) - assert_allclose(new_loss_scale, float(1 << 16) - * (2**8), rtol=rtol, err_msg="loss scale mismatch") + assert_allclose(new_loss_scale, float(1 << 16) * (2**8), rtol=rtol, err_msg="loss scale mismatch") # After 9 cycles, loss scale reaches max_loss_scale and it is not doubled from that point on - loss_scale = float(1 << 16)*(2**8) + loss_scale = float(1 << 16) * (2**8) for count in range(1, 2050): new_loss_scale = default_scaler.update(train_step_info) assert default_scaler._stable_steps_count == (count % 2000) - assert_allclose(new_loss_scale, loss_scale, - rtol=rtol, err_msg="loss scale mismatch") + assert_allclose(new_loss_scale, loss_scale, rtol=rtol, err_msg="loss scale mismatch") # Setting train_step_info.all_finite = False to test down scaling train_step_info.all_finite = False # Performing 24 updates to half the loss scale each time - loss_scale = float(1 << 16)*(2**8) + loss_scale = float(1 << 16) * (2**8) for count in range(1, 25): new_loss_scale = default_scaler.update(train_step_info) loss_scale /= 2 assert default_scaler._stable_steps_count == 0 - assert_allclose(new_loss_scale, loss_scale, - rtol=rtol, err_msg="loss scale mismatch") + assert_allclose(new_loss_scale, loss_scale, rtol=rtol, err_msg="loss scale mismatch") # After 24 updates with gradient overflow, loss scale is 1.0 - assert_allclose(new_loss_scale, 1., - rtol=rtol, err_msg="loss scale mismatch") + assert_allclose(new_loss_scale, 1.0, rtol=rtol, err_msg="loss scale mismatch") # After 25 updates, min_loss_scale is reached and loss scale is not halfed from that point on for count in range(1, 5): new_loss_scale = default_scaler.update(train_step_info) assert default_scaler._stable_steps_count == 0 - assert_allclose(new_loss_scale, loss_scale, - rtol=rtol, err_msg="loss scale mismatch") + assert_allclose(new_loss_scale, loss_scale, rtol=rtol, err_msg="loss scale mismatch") def testDynamicLossScalerCustomValues(): rtol = 1e-7 - scaler = amp.loss_scaler.DynamicLossScaler(automatic_update=False, - loss_scale=3, - up_scale_window=7, - min_loss_scale=5, - max_loss_scale=10) + scaler = amp.loss_scaler.DynamicLossScaler( + automatic_update=False, loss_scale=3, up_scale_window=7, min_loss_scale=5, max_loss_scale=10 + ) assert scaler.automatic_update == False - assert_allclose(scaler.loss_scale, 3, rtol=rtol, - err_msg="loss scale mismatch") - assert_allclose(scaler.min_loss_scale, 5, rtol=rtol, - err_msg="min loss scale mismatch") - assert_allclose(scaler.max_loss_scale, 10, rtol=rtol, - err_msg="max loss scale mismatch") + assert_allclose(scaler.loss_scale, 3, rtol=rtol, err_msg="loss scale mismatch") + assert_allclose(scaler.min_loss_scale, 5, rtol=rtol, err_msg="min loss scale mismatch") + assert_allclose(scaler.max_loss_scale, 10, rtol=rtol, err_msg="max loss scale mismatch") assert scaler.up_scale_window == 7 def testTrainStepInfo(): - '''Test valid initializations of TrainStepInfo''' + """Test valid initializations of TrainStepInfo""" optimizer_config = optim.LambConfig() - fetches=['out1','out2'] - step_info = orttrainer.TrainStepInfo(optimizer_config=optimizer_config, - all_finite=False, - fetches=fetches, - optimization_step=123, - step=456) + fetches = ["out1", "out2"] + step_info = orttrainer.TrainStepInfo( + optimizer_config=optimizer_config, all_finite=False, fetches=fetches, optimization_step=123, step=456 + ) assert step_info.optimizer_config == optimizer_config assert step_info.all_finite == False assert step_info.fetches == fetches @@ -340,12 +355,15 @@ def testTrainStepInfo(): assert step_info.step == 0 -@pytest.mark.parametrize("invalid_input", [ - (-1), - ('Hello'), -]) +@pytest.mark.parametrize( + "invalid_input", + [ + (-1), + ("Hello"), + ], +) def testTrainStepInfoInvalidInput(invalid_input): - '''Test invalid initialization of TrainStepInfo''' + """Test invalid initialization of TrainStepInfo""" optimizer_config = optim.LambConfig() with pytest.raises(AssertionError): orttrainer.TrainStepInfo(optimizer_config=invalid_input) @@ -363,64 +381,67 @@ def testTrainStepInfoInvalidInput(invalid_input): orttrainer.TrainStepInfo(optimizer_config, step=invalid_input) -@pytest.mark.parametrize("optim_name,lr,alpha,default_alpha", [ - ('AdamOptimizer', .1, .2, None), - ('LambOptimizer', .2, .3, None), - ('SGDOptimizer', .3, .4, None), - ('SGDOptimizer', .3, .4, .5) -]) +@pytest.mark.parametrize( + "optim_name,lr,alpha,default_alpha", + [ + ("AdamOptimizer", 0.1, 0.2, None), + ("LambOptimizer", 0.2, 0.3, None), + ("SGDOptimizer", 0.3, 0.4, None), + ("SGDOptimizer", 0.3, 0.4, 0.5), + ], +) def testOptimizerConfig(optim_name, lr, alpha, default_alpha): - '''Test initialization of _OptimizerConfig''' - defaults = {'lr': lr, 'alpha': alpha} - params = [{'params': ['fc1.weight', 'fc2.weight']}] + """Test initialization of _OptimizerConfig""" + defaults = {"lr": lr, "alpha": alpha} + params = [{"params": ["fc1.weight", "fc2.weight"]}] if default_alpha is not None: - params[0].update({'alpha': default_alpha}) + params[0].update({"alpha": default_alpha}) else: - params[0].update({'alpha': alpha}) - cfg = optim.config._OptimizerConfig( - name=optim_name, params=params, defaults=defaults) + params[0].update({"alpha": alpha}) + cfg = optim.config._OptimizerConfig(name=optim_name, params=params, defaults=defaults) assert cfg.name == optim_name rtol = 1e-07 - assert_allclose(defaults['lr'], - cfg.lr, rtol=rtol, err_msg="lr mismatch") + assert_allclose(defaults["lr"], cfg.lr, rtol=rtol, err_msg="lr mismatch") # 1:1 mapping between defaults and params's hyper parameters for param in params: for k, _ in param.items(): - if k != 'params': + if k != "params": assert k in cfg.defaults, "hyper parameter {k} not present in one of the parameter params" for k, _ in cfg.defaults.items(): for param in cfg.params: assert k in param, "hyper parameter {k} not present in one of the parameter params" -@pytest.mark.parametrize("optim_name,defaults,params", [ - ('AdamOptimizer', {'lr': -1}, []), # invalid lr - ('FooOptimizer', {'lr': 0.001}, []), # invalid name - ('SGDOptimizer', [], []), # invalid type(defaults) - (optim.AdamConfig, {'lr': 0.003}, []), # invalid type(name) - ('AdamOptimizer', {'lr': None}, []), # missing 'lr' hyper parameter - ('SGDOptimizer', {'lr': 0.004}, {}), # invalid type(params) - # invalid type(params[i]) - ('AdamOptimizer', {'lr': 0.005, 'alpha': 2}, [[]]), - # missing 'params' at 'params' - ('AdamOptimizer', {'lr': 0.005, 'alpha': 2}, [{'alpha': 1}]), - # missing 'alpha' at 'defaults' - ('AdamOptimizer', {'lr': 0.005}, [{'params': 'param1', 'alpha': 1}]), -]) +@pytest.mark.parametrize( + "optim_name,defaults,params", + [ + ("AdamOptimizer", {"lr": -1}, []), # invalid lr + ("FooOptimizer", {"lr": 0.001}, []), # invalid name + ("SGDOptimizer", [], []), # invalid type(defaults) + (optim.AdamConfig, {"lr": 0.003}, []), # invalid type(name) + ("AdamOptimizer", {"lr": None}, []), # missing 'lr' hyper parameter + ("SGDOptimizer", {"lr": 0.004}, {}), # invalid type(params) + # invalid type(params[i]) + ("AdamOptimizer", {"lr": 0.005, "alpha": 2}, [[]]), + # missing 'params' at 'params' + ("AdamOptimizer", {"lr": 0.005, "alpha": 2}, [{"alpha": 1}]), + # missing 'alpha' at 'defaults' + ("AdamOptimizer", {"lr": 0.005}, [{"params": "param1", "alpha": 1}]), + ], +) def testOptimizerConfigInvalidInputs(optim_name, defaults, params): - '''Test invalid initialization of _OptimizerConfig''' + """Test invalid initialization of _OptimizerConfig""" with pytest.raises(AssertionError): - optim.config._OptimizerConfig( - name=optim_name, params=params, defaults=defaults) + optim.config._OptimizerConfig(name=optim_name, params=params, defaults=defaults) def testOptimizerConfigSGD(): - '''Test initialization of SGD''' + """Test initialization of SGD""" cfg = optim.SGDConfig() - assert cfg.name == 'SGDOptimizer' + assert cfg.name == "SGDOptimizer" rtol = 1e-07 assert_allclose(0.001, cfg.lr, rtol=rtol, err_msg="lr mismatch") @@ -430,23 +451,22 @@ def testOptimizerConfigSGD(): # SGD does not support params with pytest.raises(AssertionError) as e: - params = [{'params': ['layer1.weight'], 'lr': 0.1}] + params = [{"params": ["layer1.weight"], "lr": 0.1}] optim.SGDConfig(params=params, lr=0.002) assert_allclose(0.002, cfg.lr, rtol=rtol, err_msg="lr mismatch") assert str(e.value) == "'params' must be an empty list for SGD optimizer" def testOptimizerConfigAdam(): - '''Test initialization of Adam''' + """Test initialization of Adam""" cfg = optim.AdamConfig() - assert cfg.name == 'AdamOptimizer' + assert cfg.name == "AdamOptimizer" rtol = 1e-7 assert_allclose(0.001, cfg.lr, rtol=rtol, err_msg="lr mismatch") assert_allclose(0.9, cfg.alpha, rtol=rtol, err_msg="alpha mismatch") assert_allclose(0.999, cfg.beta, rtol=rtol, err_msg="beta mismatch") - assert_allclose(0.0, cfg.lambda_coef, rtol=rtol, - err_msg="lambda_coef mismatch") + assert_allclose(0.0, cfg.lambda_coef, rtol=rtol, err_msg="lambda_coef mismatch") assert_allclose(1e-8, cfg.epsilon, rtol=rtol, err_msg="epsilon mismatch") assert_allclose(1.0, cfg.max_norm_clip, rtol=rtol, err_msg="max_norm_clip mismatch") assert cfg.do_bias_correction == True, "lambda_coef mismatch" @@ -454,54 +474,46 @@ def testOptimizerConfigAdam(): def testOptimizerConfigLamb(): - '''Test initialization of Lamb''' + """Test initialization of Lamb""" cfg = optim.LambConfig() - assert cfg.name == 'LambOptimizer' + assert cfg.name == "LambOptimizer" rtol = 1e-7 assert_allclose(0.001, cfg.lr, rtol=rtol, err_msg="lr mismatch") assert_allclose(0.9, cfg.alpha, rtol=rtol, err_msg="alpha mismatch") assert_allclose(0.999, cfg.beta, rtol=rtol, err_msg="beta mismatch") - assert_allclose(0.0, cfg.lambda_coef, rtol=rtol, - err_msg="lambda_coef mismatch") - assert cfg.ratio_min == float('-inf'), "ratio_min mismatch" - assert cfg.ratio_max == float('inf'), "ratio_max mismatch" + assert_allclose(0.0, cfg.lambda_coef, rtol=rtol, err_msg="lambda_coef mismatch") + assert cfg.ratio_min == float("-inf"), "ratio_min mismatch" + assert cfg.ratio_max == float("inf"), "ratio_max mismatch" assert_allclose(1e-6, cfg.epsilon, rtol=rtol, err_msg="epsilon mismatch") assert_allclose(1.0, cfg.max_norm_clip, rtol=rtol, err_msg="max_norm_clip mismatch") assert cfg.do_bias_correction == False, "do_bias_correction mismatch" -@pytest.mark.parametrize("optim_name", [ - ('Adam'), - ('Lamb') -]) +@pytest.mark.parametrize("optim_name", [("Adam"), ("Lamb")]) def testOptimizerConfigParams(optim_name): rtol = 1e-7 - params = [{'params': ['layer1.weight'], 'alpha': 0.1}] - if optim_name == 'Adam': + params = [{"params": ["layer1.weight"], "alpha": 0.1}] + if optim_name == "Adam": cfg = optim.AdamConfig(params=params, alpha=0.2) - elif optim_name == 'Lamb': + elif optim_name == "Lamb": cfg = optim.LambConfig(params=params, alpha=0.2) else: - raise ValueError('invalid input') + raise ValueError("invalid input") assert len(cfg.params) == 1, "params should have length 1" - assert_allclose(cfg.params[0]['alpha'], 0.1, - rtol=rtol, err_msg="invalid lr on params[0]") + assert_allclose(cfg.params[0]["alpha"], 0.1, rtol=rtol, err_msg="invalid lr on params[0]") -@pytest.mark.parametrize("optim_name", [ - ('Adam'), - ('Lamb') -]) +@pytest.mark.parametrize("optim_name", [("Adam"), ("Lamb")]) def testOptimizerConfigInvalidParams(optim_name): # lr is not supported within params with pytest.raises(AssertionError) as e: - params = [{'params': ['layer1.weight'], 'lr': 0.1}] - if optim_name == 'Adam': + params = [{"params": ["layer1.weight"], "lr": 0.1}] + if optim_name == "Adam": optim.AdamConfig(params=params, lr=0.2) - elif optim_name == 'Lamb': + elif optim_name == "Lamb": optim.LambConfig(params=params, lr=0.2) else: - raise ValueError('invalid input') + raise ValueError("invalid input") assert str(e.value) == "'lr' is not supported inside params" @@ -509,26 +521,50 @@ def testLinearLRSchedulerCreation(): total_steps = 10 warmup = 0.05 - lr_scheduler = optim.lr_scheduler.LinearWarmupLRScheduler(total_steps, - warmup) + lr_scheduler = optim.lr_scheduler.LinearWarmupLRScheduler(total_steps, warmup) # Initial state assert lr_scheduler.total_steps == total_steps assert lr_scheduler.warmup == warmup -@pytest.mark.parametrize("lr_scheduler,expected_values", [ - (optim.lr_scheduler.ConstantWarmupLRScheduler, - [0.0, 0.2, 0.4, 0.6, 0.8, 1.0, 1.0, 1.0, 1.0, 1.0]), - (optim.lr_scheduler.CosineWarmupLRScheduler, - [0.0, 0.9763960957919413, 0.9059835861602854, 0.7956724530494887, 0.6563036824392345,\ - 0.5015739416158049, 0.34668951940611276, 0.2068719061737831, 0.09586187986225325, 0.0245691111902418]), - (optim.lr_scheduler.LinearWarmupLRScheduler, - [0.0, 0.2, 0.4, 0.6, 0.8, 1.0, 0.8, 0.6, 0.4, 0.2]), - (optim.lr_scheduler.PolyWarmupLRScheduler, - [0.0, 0.9509018036072144, 0.9008016032064128, 0.8507014028056112, 0.8006012024048097,\ - 0.750501002004008, 0.7004008016032064, 0.6503006012024048, 0.6002004008016032, 0.5501002004008015]) -]) +@pytest.mark.parametrize( + "lr_scheduler,expected_values", + [ + (optim.lr_scheduler.ConstantWarmupLRScheduler, [0.0, 0.2, 0.4, 0.6, 0.8, 1.0, 1.0, 1.0, 1.0, 1.0]), + ( + optim.lr_scheduler.CosineWarmupLRScheduler, + [ + 0.0, + 0.9763960957919413, + 0.9059835861602854, + 0.7956724530494887, + 0.6563036824392345, + 0.5015739416158049, + 0.34668951940611276, + 0.2068719061737831, + 0.09586187986225325, + 0.0245691111902418, + ], + ), + (optim.lr_scheduler.LinearWarmupLRScheduler, [0.0, 0.2, 0.4, 0.6, 0.8, 1.0, 0.8, 0.6, 0.4, 0.2]), + ( + optim.lr_scheduler.PolyWarmupLRScheduler, + [ + 0.0, + 0.9509018036072144, + 0.9008016032064128, + 0.8507014028056112, + 0.8006012024048097, + 0.750501002004008, + 0.7004008016032064, + 0.6503006012024048, + 0.6002004008016032, + 0.5501002004008015, + ], + ), + ], +) def testLRSchedulerUpdateImpl(lr_scheduler, expected_values): # Test tolerance rtol = 1e-03 @@ -548,60 +584,91 @@ def testLRSchedulerUpdateImpl(lr_scheduler, expected_values): lr_scheduler._step(train_step_info) lr_list = lr_scheduler.get_last_lr() assert len(lr_list) == 1 - assert_allclose(lr_list[0], - expected_values[optimization_step], rtol=rtol, err_msg="lr mismatch") + assert_allclose(lr_list[0], expected_values[optimization_step], rtol=rtol, err_msg="lr mismatch") + def testInstantiateORTTrainerOptions(): session_options = SessionOptions() session_options.enable_mem_pattern = False - provider_options = {'EP1': {'key':'val'}} - opts = {'session_options' : session_options, - 'provider_options' : provider_options} + provider_options = {"EP1": {"key": "val"}} + opts = {"session_options": session_options, "provider_options": provider_options} opts = orttrainer.ORTTrainerOptions(opts) - assert(opts.session_options.enable_mem_pattern is False) - assert(opts._validated_opts['provider_options']['EP1']['key'] == 'val') - -@pytest.mark.parametrize("step_fn, lr_scheduler, expected_lr_values, device", [ - ('train_step', None, None, 'cuda'), - ('eval_step', None, None, 'cpu'), - ('train_step', optim.lr_scheduler.ConstantWarmupLRScheduler, - [0.0, 0.2, 0.4, 0.6, 0.8, 1.0, 1.0, 1.0, 1.0, 1.0], 'cpu'), - ('train_step', optim.lr_scheduler.CosineWarmupLRScheduler, - [0.0, 0.2, 0.4, 0.6, 0.8, 1.0, 0.9045084971874737, 0.6545084971874737, 0.34549150281252633, 0.09549150281252633], - 'cuda'), - ('train_step', optim.lr_scheduler.LinearWarmupLRScheduler, - [0.0, 0.2, 0.4, 0.6, 0.8, 1.0, 0.8, 0.6, 0.4, 0.2], 'cpu'), - ('train_step', optim.lr_scheduler.PolyWarmupLRScheduler, - [0.0, 0.2, 0.4, 0.6, 0.8, 1.0, 0.80000002, 0.60000004, 0.40000006000000005, 0.20000007999999997], 'cuda') -]) + assert opts.session_options.enable_mem_pattern is False + assert opts._validated_opts["provider_options"]["EP1"]["key"] == "val" + + +@pytest.mark.parametrize( + "step_fn, lr_scheduler, expected_lr_values, device", + [ + ("train_step", None, None, "cuda"), + ("eval_step", None, None, "cpu"), + ( + "train_step", + optim.lr_scheduler.ConstantWarmupLRScheduler, + [0.0, 0.2, 0.4, 0.6, 0.8, 1.0, 1.0, 1.0, 1.0, 1.0], + "cpu", + ), + ( + "train_step", + optim.lr_scheduler.CosineWarmupLRScheduler, + [ + 0.0, + 0.2, + 0.4, + 0.6, + 0.8, + 1.0, + 0.9045084971874737, + 0.6545084971874737, + 0.34549150281252633, + 0.09549150281252633, + ], + "cuda", + ), + ( + "train_step", + optim.lr_scheduler.LinearWarmupLRScheduler, + [0.0, 0.2, 0.4, 0.6, 0.8, 1.0, 0.8, 0.6, 0.4, 0.2], + "cpu", + ), + ( + "train_step", + optim.lr_scheduler.PolyWarmupLRScheduler, + [0.0, 0.2, 0.4, 0.6, 0.8, 1.0, 0.80000002, 0.60000004, 0.40000006000000005, 0.20000007999999997], + "cuda", + ), + ], +) def testInstantiateORTTrainer(step_fn, lr_scheduler, expected_lr_values, device): total_steps = 1 - initial_lr = 1. + initial_lr = 1.0 rtol = 1e-3 # PyTorch Transformer model as example - opts = {'device' : {'id' : device}} + opts = {"device": {"id": device}} if lr_scheduler: total_steps = 10 - opts.update({'lr_scheduler' : lr_scheduler(total_steps=total_steps, warmup=0.5)}) + opts.update({"lr_scheduler": lr_scheduler(total_steps=total_steps, warmup=0.5)}) opts = orttrainer.ORTTrainerOptions(opts) optim_config = optim.LambConfig(lr=initial_lr) - model, model_desc, my_loss, batcher_fn, train_data, val_data, _ = _test_commons._load_pytorch_transformer_model(device) + model, model_desc, my_loss, batcher_fn, train_data, val_data, _ = _test_commons._load_pytorch_transformer_model( + device + ) trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=my_loss, options=opts) # Run a train or evaluation step - if step_fn == 'eval_step': + if step_fn == "eval_step": data, targets = batcher_fn(val_data, 0) - elif step_fn == 'train_step': + elif step_fn == "train_step": data, targets = batcher_fn(train_data, 0) else: - raise ValueError('Invalid step_fn') + raise ValueError("Invalid step_fn") # Export model to ONNX - if step_fn == 'eval_step': + if step_fn == "eval_step": step_fn = trainer.eval_step output = trainer.eval_step(data, targets) - elif step_fn == 'train_step': + elif step_fn == "train_step": step_fn = trainer.train_step for i in range(total_steps): output = trainer.train_step(data, targets) @@ -609,7 +676,7 @@ def testInstantiateORTTrainer(step_fn, lr_scheduler, expected_lr_values, device) lr_list = trainer.options.lr_scheduler.get_last_lr() assert_allclose(lr_list[0], expected_lr_values[i], rtol=rtol, err_msg="lr mismatch") else: - raise ValueError('Invalid step_fn') + raise ValueError("Invalid step_fn") assert trainer._onnx_model is not None # Check output shape after train/eval step @@ -629,7 +696,8 @@ def testInstantiateORTTrainer(step_fn, lr_scheduler, expected_lr_values, device) for dim_idx, dim in enumerate(trainer._onnx_model.graph.input[i].type.tensor_type.shape.dim): assert input_dim[dim_idx] == dim.dim_value assert input_type == _utils.dtype_onnx_to_torch( - trainer._onnx_model.graph.input[i].type.tensor_type.elem_type) + trainer._onnx_model.graph.input[i].type.tensor_type.elem_type + ) opset = get_model_opset(trainer._onnx_model) @@ -644,10 +712,11 @@ def testInstantiateORTTrainer(step_fn, lr_scheduler, expected_lr_values, device) if opset != 14: assert output_dim[dim_idx] == dim.dim_value assert output_type == _utils.dtype_onnx_to_torch( - trainer._onnx_model.graph.output[i].type.tensor_type.elem_type) + trainer._onnx_model.graph.output[i].type.tensor_type.elem_type + ) # Save current model as ONNX as a file - file_name = os.path.join('_____temp_onnx_model.onnx') + file_name = os.path.join("_____temp_onnx_model.onnx") trainer.save_as_onnx(file_name) assert os.path.exists(file_name) with open(file_name, "rb") as f: @@ -659,28 +728,21 @@ def testInstantiateORTTrainer(step_fn, lr_scheduler, expected_lr_values, device) trainer_from_onnx = orttrainer.ORTTrainer(reload_onnx_model, model_desc, optim_config) step_fn(data, targets) assert trainer_from_onnx._onnx_model is not None - assert (id(trainer_from_onnx._onnx_model) != id(trainer._onnx_model)) - assert (trainer_from_onnx._onnx_model == trainer._onnx_model) - assert (trainer_from_onnx._onnx_model.graph == trainer._onnx_model.graph) - assert (onnx.helper.printable_graph(trainer_from_onnx._onnx_model.graph) == onnx.helper.printable_graph(trainer._onnx_model.graph)) + assert id(trainer_from_onnx._onnx_model) != id(trainer._onnx_model) + assert trainer_from_onnx._onnx_model == trainer._onnx_model + assert trainer_from_onnx._onnx_model.graph == trainer._onnx_model.graph + assert onnx.helper.printable_graph(trainer_from_onnx._onnx_model.graph) == onnx.helper.printable_graph( + trainer._onnx_model.graph + ) -@pytest.mark.parametrize("seed, device", [ - (0, 'cpu'), - (24, 'cuda') -]) +@pytest.mark.parametrize("seed, device", [(0, "cpu"), (24, "cuda")]) def testORTDeterministicCompute(seed, device): # Common setup optim_config = optim.LambConfig() - opts = orttrainer.ORTTrainerOptions({ - 'debug' : { - 'deterministic_compute': True - }, - 'device' : { - 'id' : device, - 'mem_limit' : 10*1024*1024 - } - }) + opts = orttrainer.ORTTrainerOptions( + {"debug": {"deterministic_compute": True}, "device": {"id": device, "mem_limit": 10 * 1024 * 1024}} + ) # Setup for the first ORTTRainer run torch.manual_seed(seed) @@ -704,12 +766,15 @@ def testORTDeterministicCompute(seed, device): _test_helpers.assert_onnx_weights(first_trainer, second_trainer) -@pytest.mark.parametrize("seed,device,expected_loss,fetches", [ - (321, 'cuda', [10.5774, 10.4403, 10.4175, 10.2886, 10.2760], False), - (321, 'cuda', [10.5774, 10.4403, 10.4175, 10.2886, 10.2760], True), -]) +@pytest.mark.parametrize( + "seed,device,expected_loss,fetches", + [ + (321, "cuda", [10.5774, 10.4403, 10.4175, 10.2886, 10.2760], False), + (321, "cuda", [10.5774, 10.4403, 10.4175, 10.2886, 10.2760], True), + ], +) def testORTTrainerMixedPrecisionLossScaler(seed, device, expected_loss, fetches): - return # TODO: re-enable after nondeterminism on backend is fixed. update numbers + return # TODO: re-enable after nondeterminism on backend is fixed. update numbers rtol = 1e-3 total_steps = len(expected_loss) @@ -718,12 +783,16 @@ def testORTTrainerMixedPrecisionLossScaler(seed, device, expected_loss, fetches) # Setup ORTTrainer loss_scaler = amp.DynamicLossScaler() - options = orttrainer.ORTTrainerOptions({'device' : {'id' : device}, - 'mixed_precision' : { - 'enabled' : True, - 'loss_scaler' : loss_scaler}, - 'debug' : {'deterministic_compute' : True}}) - model, model_desc, my_loss, batcher_fn, train_data, val_data, _ = _test_commons._load_pytorch_transformer_model(device) + options = orttrainer.ORTTrainerOptions( + { + "device": {"id": device}, + "mixed_precision": {"enabled": True, "loss_scaler": loss_scaler}, + "debug": {"deterministic_compute": True}, + } + ) + model, model_desc, my_loss, batcher_fn, train_data, val_data, _ = _test_commons._load_pytorch_transformer_model( + device + ) optim_config = optim.LambConfig(lr=0.001) trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=my_loss, options=options) @@ -732,7 +801,7 @@ def testORTTrainerMixedPrecisionLossScaler(seed, device, expected_loss, fetches) for i in range(total_steps): data, targets = batcher_fn(train_data, i) if fetches: - trainer._train_step_info.fetches=['loss'] + trainer._train_step_info.fetches = ["loss"] loss = trainer.train_step(data, targets) else: loss, _ = trainer.train_step(data, targets) @@ -741,9 +810,9 @@ def testORTTrainerMixedPrecisionLossScaler(seed, device, expected_loss, fetches) # Eval once just to test fetches in action val_data, val_targets = batcher_fn(val_data, 0) if fetches: - trainer._train_step_info.fetches=['loss'] + trainer._train_step_info.fetches = ["loss"] loss = trainer.eval_step(val_data, val_targets) - trainer._train_step_info.fetches=[] + trainer._train_step_info.fetches = [] loss, _ = trainer.eval_step(val_data, val_targets) # Compare loss to ground truth computed from current ORTTrainer API @@ -753,45 +822,57 @@ def testORTTrainerMixedPrecisionLossScaler(seed, device, expected_loss, fetches) def _recompute_data(): device_capability_major = torch.cuda.get_device_capability()[0] - if device_capability_major == 7: # V100 for Dev machine - expected_loss = {12: [10.5598, 10.4591, 10.3477, 10.2726, 10.1945], - 14: [10.54088, 10.498755, 10.386827, 10.338747, 10.262459]} + if device_capability_major == 7: # V100 for Dev machine + expected_loss = { + 12: [10.5598, 10.4591, 10.3477, 10.2726, 10.1945], + 14: [10.54088, 10.498755, 10.386827, 10.338747, 10.262459], + } return [ - (False, False, False, 0, expected_loss), # no recompute - (True, False, False, 0, expected_loss), # attn_dropout recompute - (False, True, False, 0, expected_loss), # gelu recompute - (False, False, True, 0, expected_loss), # transformer_layer recompute - (False, False, True, 1, expected_loss), # transformer_layer recompute with 1 layer + (False, False, False, 0, expected_loss), # no recompute + (True, False, False, 0, expected_loss), # attn_dropout recompute + (False, True, False, 0, expected_loss), # gelu recompute + (False, False, True, 0, expected_loss), # transformer_layer recompute + (False, False, True, 1, expected_loss), # transformer_layer recompute with 1 layer ] elif device_capability_major == 5: # M60 for CI machines - expected_loss = {12: [10.5445, 10.4389, 10.3480, 10.2627, 10.2113], - 14: [10.5445, 10.4389, 10.3480, 10.2627, 10.2113]} + expected_loss = { + 12: [10.5445, 10.4389, 10.3480, 10.2627, 10.2113], + 14: [10.5445, 10.4389, 10.3480, 10.2627, 10.2113], + } return [ - (False, False, False, 0, expected_loss), # no recompute - (True, False, False, 0, expected_loss), # attn_dropout recompute - (False, True, False, 0, expected_loss), # gelu recompute - (False, False, True, 0, expected_loss), # transformer_layer recompute - (False, False, True, 1, expected_loss), # transformer_layer recompute with 1 layer + (False, False, False, 0, expected_loss), # no recompute + (True, False, False, 0, expected_loss), # attn_dropout recompute + (False, True, False, 0, expected_loss), # gelu recompute + (False, False, True, 0, expected_loss), # transformer_layer recompute + (False, False, True, 1, expected_loss), # transformer_layer recompute with 1 layer ] + + @pytest.mark.parametrize("attn_dropout, gelu, transformer_layer, number_layers, expected_loss", _recompute_data()) def testORTTrainerRecompute(attn_dropout, gelu, transformer_layer, number_layers, expected_loss): seed = 321 - device = 'cuda' + device = "cuda" rtol = 1e-3 total_steps = len(expected_loss[12]) torch.manual_seed(seed) set_seed(seed) # Setup ORTTrainer - options = orttrainer.ORTTrainerOptions({'device' : {'id' : device}, - 'graph_transformer' : { - 'attn_dropout_recompute': attn_dropout, - 'gelu_recompute': gelu, - 'transformer_layer_recompute': transformer_layer, - 'number_recompute_layers': number_layers - }, - 'debug' : {'deterministic_compute' : True}}) - model, model_desc, my_loss, batcher_fn, train_data, val_data, _ = _test_commons._load_pytorch_transformer_model(device) + options = orttrainer.ORTTrainerOptions( + { + "device": {"id": device}, + "graph_transformer": { + "attn_dropout_recompute": attn_dropout, + "gelu_recompute": gelu, + "transformer_layer_recompute": transformer_layer, + "number_recompute_layers": number_layers, + }, + "debug": {"deterministic_compute": True}, + } + ) + model, model_desc, my_loss, batcher_fn, train_data, val_data, _ = _test_commons._load_pytorch_transformer_model( + device + ) optim_config = optim.LambConfig(lr=0.001) trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=my_loss, options=options) @@ -808,26 +889,105 @@ def testORTTrainerRecompute(attn_dropout, gelu, transformer_layer, number_layers _test_helpers.assert_model_outputs(expected_loss[opset], actual_loss, True, rtol=rtol) -@pytest.mark.parametrize("seed,device,gradient_accumulation_steps,total_steps,expected_loss", [ - (0, 'cuda', 1, 12, [10.5368022919, 10.4146203995, 10.3635568619, 10.2650547028, 10.2284049988, 10.1304626465,\ - 10.0853414536, 9.9987659454, 9.9472427368, 9.8832416534, 9.8223171234, 9.8222122192]), - (42, 'cuda', 3, 12, [10.6455879211, 10.6247081757, 10.6361322403, 10.5187482834, 10.5345087051, 10.5487670898,\ - 10.4833698273, 10.4600019455, 10.4535751343, 10.3774127960, 10.4144191742, 10.3757553101]), - (123, 'cuda', 7, 12, [10.5353469849, 10.5261383057, 10.5240392685, 10.5013713837, 10.5678377151, 10.5452117920,\ - 10.5184345245, 10.4271221161, 10.4458627701, 10.4864749908, 10.4416503906, 10.4467563629]), - (321, 'cuda', 12, 12, [10.5773944855, 10.5428829193, 10.5974750519, 10.5416746140, 10.6009902954, 10.5684127808,\ - 10.5759754181, 10.5636739731, 10.5613927841, 10.5825119019, 10.6031589508, 10.6199369431]), -]) +@pytest.mark.parametrize( + "seed,device,gradient_accumulation_steps,total_steps,expected_loss", + [ + ( + 0, + "cuda", + 1, + 12, + [ + 10.5368022919, + 10.4146203995, + 10.3635568619, + 10.2650547028, + 10.2284049988, + 10.1304626465, + 10.0853414536, + 9.9987659454, + 9.9472427368, + 9.8832416534, + 9.8223171234, + 9.8222122192, + ], + ), + ( + 42, + "cuda", + 3, + 12, + [ + 10.6455879211, + 10.6247081757, + 10.6361322403, + 10.5187482834, + 10.5345087051, + 10.5487670898, + 10.4833698273, + 10.4600019455, + 10.4535751343, + 10.3774127960, + 10.4144191742, + 10.3757553101, + ], + ), + ( + 123, + "cuda", + 7, + 12, + [ + 10.5353469849, + 10.5261383057, + 10.5240392685, + 10.5013713837, + 10.5678377151, + 10.5452117920, + 10.5184345245, + 10.4271221161, + 10.4458627701, + 10.4864749908, + 10.4416503906, + 10.4467563629, + ], + ), + ( + 321, + "cuda", + 12, + 12, + [ + 10.5773944855, + 10.5428829193, + 10.5974750519, + 10.5416746140, + 10.6009902954, + 10.5684127808, + 10.5759754181, + 10.5636739731, + 10.5613927841, + 10.5825119019, + 10.6031589508, + 10.6199369431, + ], + ), + ], +) def testORTTrainerGradientAccumulation(seed, device, gradient_accumulation_steps, total_steps, expected_loss): - return # TODO: re-enable after nondeterminism on backend is fixed. update numbers + return # TODO: re-enable after nondeterminism on backend is fixed. update numbers rtol = 1e-3 torch.manual_seed(seed) set_seed(seed) # Setup ORTTrainer - options = orttrainer.ORTTrainerOptions({'device' : {'id' : device}, - 'batch' : {'gradient_accumulation_steps' : gradient_accumulation_steps}, - 'debug' : {'deterministic_compute' : True}}) + options = orttrainer.ORTTrainerOptions( + { + "device": {"id": device}, + "batch": {"gradient_accumulation_steps": gradient_accumulation_steps}, + "debug": {"deterministic_compute": True}, + } + ) model, model_desc, my_loss, batcher_fn, train_data, _, _ = _test_commons._load_pytorch_transformer_model(device) optim_config = optim.LambConfig(lr=0.001) trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=my_loss, options=options) @@ -843,18 +1003,22 @@ def testORTTrainerGradientAccumulation(seed, device, gradient_accumulation_steps _test_helpers.assert_model_outputs(expected_loss, actual_loss, rtol=rtol) -@pytest.mark.parametrize("dynamic_axes", [ - (True), - (False), -]) +@pytest.mark.parametrize( + "dynamic_axes", + [ + (True), + (False), + ], +) def testORTTrainerDynamicShape(dynamic_axes): # Common setup - device = 'cuda' + device = "cuda" # Setup ORTTrainer options = orttrainer.ORTTrainerOptions({}) - model, model_desc, my_loss, batcher_fn,\ - train_data, _, _ = _test_commons._load_pytorch_transformer_model(device, dynamic_axes=dynamic_axes) + model, model_desc, my_loss, batcher_fn, train_data, _, _ = _test_commons._load_pytorch_transformer_model( + device, dynamic_axes=dynamic_axes + ) optim_config = optim.LambConfig(lr=0.001) trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=my_loss, options=options) @@ -864,25 +1028,27 @@ def testORTTrainerDynamicShape(dynamic_axes): data, targets = batcher_fn(train_data, i) if dynamic_axes: # Forcing batches with different sizes to exercise dynamic shapes - data = data[:-(i+1)] - targets = targets[:-(i+1)*data.size(1)] + data = data[: -(i + 1)] + targets = targets[: -(i + 1) * data.size(1)] _, _ = trainer.train_step(data, targets) assert trainer._onnx_model is not None -@pytest.mark.parametrize('enable_onnx_contrib_ops', [ - (True), - (False), -]) +@pytest.mark.parametrize( + "enable_onnx_contrib_ops", + [ + (True), + (False), + ], +) def testORTTrainerInternalUseContribOps(enable_onnx_contrib_ops): # Common setup - device = 'cuda' + device = "cuda" # Setup ORTTrainer options = orttrainer.ORTTrainerOptions({"_internal_use": {"enable_onnx_contrib_ops": enable_onnx_contrib_ops}}) - model, model_desc, my_loss, batcher_fn,\ - train_data, _, _ = _test_commons._load_pytorch_transformer_model(device) + model, model_desc, my_loss, batcher_fn, train_data, _, _ = _test_commons._load_pytorch_transformer_model(device) optim_config = optim.LambConfig(lr=0.001) trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=my_loss, options=options) @@ -895,22 +1061,28 @@ def testORTTrainerInternalUseContribOps(enable_onnx_contrib_ops): _, _ = trainer.train_step(data, targets) -@pytest.mark.parametrize("model_params", [ - (['decoder.weight', - 'transformer_encoder.layers.0.linear1.bias', - 'transformer_encoder.layers.0.linear2.weight', - 'transformer_encoder.layers.1.self_attn.out_proj.weight', - 'transformer_encoder.layers.1.self_attn.out_proj.bias']), -]) +@pytest.mark.parametrize( + "model_params", + [ + ( + [ + "decoder.weight", + "transformer_encoder.layers.0.linear1.bias", + "transformer_encoder.layers.0.linear2.weight", + "transformer_encoder.layers.1.self_attn.out_proj.weight", + "transformer_encoder.layers.1.self_attn.out_proj.bias", + ] + ), + ], +) def testORTTrainerFrozenWeights(model_params): # Common setup - device = 'cuda' + device = "cuda" total_steps = 10 # Setup ORTTrainer WITHOUT frozen weights options = orttrainer.ORTTrainerOptions({}) - model, model_desc, my_loss, batcher_fn,\ - train_data, _, _ = _test_commons._load_pytorch_transformer_model(device) + model, model_desc, my_loss, batcher_fn, train_data, _, _ = _test_commons._load_pytorch_transformer_model(device) optim_config = optim.LambConfig(lr=0.001) trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=my_loss, options=options) for i in range(total_steps): @@ -922,9 +1094,8 @@ def testORTTrainerFrozenWeights(model_params): session_state = trainer._training_session.get_state() assert all([param in session_state for param in model_params]) - # Setup ORTTrainer WITH frozen weights - options = orttrainer.ORTTrainerOptions({'utils' : {'frozen_weights' : model_params}}) + options = orttrainer.ORTTrainerOptions({"utils": {"frozen_weights": model_params}}) model, _, _, _, _, _, _ = _test_commons._load_pytorch_transformer_model(device) trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=my_loss, options=options) for i in range(total_steps): @@ -937,30 +1108,44 @@ def testORTTrainerFrozenWeights(model_params): assert not all([param in session_state for param in model_params]) -@pytest.mark.parametrize("loss_scaler, optimizer_config, gradient_accumulation_steps", [ - (None, optim.AdamConfig(), 1), - (None, optim.LambConfig(), 1), - (None, optim.SGDConfig(), 1), - (amp.DynamicLossScaler(), optim.AdamConfig(), 1), - (amp.DynamicLossScaler(), optim.LambConfig(), 5), - #(amp.DynamicLossScaler(), optim.SGDConfig(), 1), # SGD doesnt support fp16 -]) +@pytest.mark.parametrize( + "loss_scaler, optimizer_config, gradient_accumulation_steps", + [ + (None, optim.AdamConfig(), 1), + (None, optim.LambConfig(), 1), + (None, optim.SGDConfig(), 1), + (amp.DynamicLossScaler(), optim.AdamConfig(), 1), + (amp.DynamicLossScaler(), optim.LambConfig(), 5), + # (amp.DynamicLossScaler(), optim.SGDConfig(), 1), # SGD doesnt support fp16 + ], +) def testORTTrainerStateDictWrapModelLossFn(loss_scaler, optimizer_config, gradient_accumulation_steps): # Common setup seed = 1 + class LinearModel(torch.nn.Module): def __init__(self): super().__init__() self.linear = torch.nn.Linear(2, 4) + def forward(self, y=None, x=None): if y is not None: return self.linear(x) + y else: return self.linear(x) + torch.ones(2, 4) - model_desc = {'inputs' : [('x', [2, 2]), - ('label', [2, ])], - 'outputs' : [('loss', [], True), - ('output', [2, 4])]} + + model_desc = { + "inputs": [ + ("x", [2, 2]), + ( + "label", + [ + 2, + ], + ), + ], + "outputs": [("loss", [], True), ("output", [2, 4])], + } # Dummy data data1 = torch.randn(2, 2) @@ -969,18 +1154,22 @@ def forward(self, y=None, x=None): label2 = torch.tensor([0, 1], dtype=torch.int64) # Setup training based on test parameters - opts = {'debug' : {'deterministic_compute': True}, - 'batch' : { 'gradient_accumulation_steps' : gradient_accumulation_steps}} + opts = { + "debug": {"deterministic_compute": True}, + "batch": {"gradient_accumulation_steps": gradient_accumulation_steps}, + } if loss_scaler: - opts['mixed_precision'] = { 'enabled': True, 'loss_scaler': loss_scaler} - opts = orttrainer.ORTTrainerOptions(opts) + opts["mixed_precision"] = {"enabled": True, "loss_scaler": loss_scaler} + opts = orttrainer.ORTTrainerOptions(opts) # Training session 1 torch.manual_seed(seed) set_seed(seed) pt_model = LinearModel() + def loss_fn(x, label): return F.nll_loss(F.log_softmax(x, dim=1), label) + trainer = orttrainer.ORTTrainer(pt_model, model_desc, optimizer_config, loss_fn=loss_fn, options=opts) # Check state_dict keys before train. Must be empty @@ -990,7 +1179,7 @@ def loss_fn(x, label): # Train once and check initial state trainer.train_step(x=data1, label=label1) state_dict = trainer.state_dict() - assert all([weight in state_dict['model']['full_precision'].keys() for weight in ['linear.bias', 'linear.weight']]) + assert all([weight in state_dict["model"]["full_precision"].keys() for weight in ["linear.bias", "linear.weight"]]) # Initialize training session 2 from state of Training 1 torch.manual_seed(seed) @@ -1012,7 +1201,9 @@ def loss_fn(x, label): def testORTTrainerNonPickableModel(): # Common setup import threading + seed = 1 + class UnpickableModel(torch.nn.Module): def __init__(self): super().__init__() @@ -1026,38 +1217,47 @@ def forward(self, y=None, x=None): else: return self.linear(x) + torch.ones(2, 4) - model_desc = {'inputs' : [('x', [2, 2]), - ('label', [2, ])], - 'outputs' : [('loss', [], True), - ('output', [2, 4])]} + model_desc = { + "inputs": [ + ("x", [2, 2]), + ( + "label", + [ + 2, + ], + ), + ], + "outputs": [("loss", [], True), ("output", [2, 4])], + } # Dummy data data = torch.randn(2, 2) label = torch.tensor([0, 1], dtype=torch.int64) # Setup training based on test parameters - opts = orttrainer.ORTTrainerOptions({'debug' : {'deterministic_compute': True}}) + opts = orttrainer.ORTTrainerOptions({"debug": {"deterministic_compute": True}}) # Training session torch.manual_seed(seed) set_seed(seed) pt_model = UnpickableModel() + def loss_fn(x, label): return F.nll_loss(F.log_softmax(x, dim=1), label) + optim_config = optim.AdamConfig() trainer = orttrainer.ORTTrainer(pt_model, model_desc, optim_config, loss_fn=loss_fn, options=opts) # Train must succeed despite warning _, _ = trainer.train_step(data, label) + ############################################################################### # Temporary tests comparing Legacy vs Experimental ORTTrainer APIs ############ ############################################################################### -@pytest.mark.parametrize("seed,device", [ - (1234, 'cuda') -]) +@pytest.mark.parametrize("seed,device", [(1234, "cuda")]) def testORTTrainerLegacyAndExperimentalWeightsCheck(seed, device): # Common data rtol = 1e-7 @@ -1067,14 +1267,12 @@ def testORTTrainerLegacyAndExperimentalWeightsCheck(seed, device): torch.manual_seed(seed) set_seed(seed) optim_config = optim.LambConfig() - opts = orttrainer.ORTTrainerOptions({ - 'device' : { - 'id' : device - }, - 'debug' : { - 'deterministic_compute': True - }, - }) + opts = orttrainer.ORTTrainerOptions( + { + "device": {"id": device}, + "debug": {"deterministic_compute": True}, + } + ) model, model_desc, my_loss, batcher_fn, train_data, _, _ = _test_commons._load_pytorch_transformer_model(device) trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=my_loss, options=opts) # Training loop @@ -1086,8 +1284,9 @@ def testORTTrainerLegacyAndExperimentalWeightsCheck(seed, device): torch.manual_seed(seed) set_seed(seed) model, (model_desc, lr_desc), _, _, _, _, _ = _test_commons._load_pytorch_transformer_model(device, legacy_api=True) - legacy_trainer = Legacy_ORTTrainer(model, my_loss, model_desc, "LambOptimizer", None, lr_desc, - device, _use_deterministic_compute=True) + legacy_trainer = Legacy_ORTTrainer( + model, my_loss, model_desc, "LambOptimizer", None, lr_desc, device, _use_deterministic_compute=True + ) # Training loop for i in range(total_steps): data, targets = batcher_fn(train_data, i) @@ -1097,9 +1296,12 @@ def testORTTrainerLegacyAndExperimentalWeightsCheck(seed, device): _test_helpers.assert_legacy_onnx_weights(trainer, legacy_trainer, rtol=rtol) -@pytest.mark.parametrize("seed,device", [ - (321, 'cuda'), -]) +@pytest.mark.parametrize( + "seed,device", + [ + (321, "cuda"), + ], +) def testORTTrainerLegacyAndExperimentalPrecisionLossScaler(seed, device): # Common data total_steps = 128 @@ -1108,11 +1310,15 @@ def testORTTrainerLegacyAndExperimentalPrecisionLossScaler(seed, device): torch.manual_seed(seed) set_seed(seed) loss_scaler = amp.DynamicLossScaler() - options = orttrainer.ORTTrainerOptions({'device' : {'id' : device}, - 'mixed_precision' : { - 'enabled' : True, - 'loss_scaler' : loss_scaler}, - 'debug' : {'deterministic_compute' : True,}}) + options = orttrainer.ORTTrainerOptions( + { + "device": {"id": device}, + "mixed_precision": {"enabled": True, "loss_scaler": loss_scaler}, + "debug": { + "deterministic_compute": True, + }, + } + ) model, model_desc, my_loss, batcher_fn, train_data, _, _ = _test_commons._load_pytorch_transformer_model(device) optim_config = optim.LambConfig(lr=0.001) trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=my_loss, options=options) @@ -1129,12 +1335,19 @@ def testORTTrainerLegacyAndExperimentalPrecisionLossScaler(seed, device): torch.manual_seed(seed) set_seed(seed) model, (model_desc, lr_desc), _, _, _, _, _ = _test_commons._load_pytorch_transformer_model(device, legacy_api=True) - loss_scaler = Legacy_LossScaler('ort_test_input_loss_scalar', True) - legacy_trainer = Legacy_ORTTrainer(model, my_loss, model_desc, "LambOptimizer", - None, lr_desc, device=device, - _use_deterministic_compute=True, - use_mixed_precision=True, - loss_scaler=loss_scaler) + loss_scaler = Legacy_LossScaler("ort_test_input_loss_scalar", True) + legacy_trainer = Legacy_ORTTrainer( + model, + my_loss, + model_desc, + "LambOptimizer", + None, + lr_desc, + device=device, + _use_deterministic_compute=True, + use_mixed_precision=True, + loss_scaler=loss_scaler, + ) # Training loop legacy_loss = [] legacy_preds_dtype = [] @@ -1150,12 +1363,15 @@ def testORTTrainerLegacyAndExperimentalPrecisionLossScaler(seed, device): _test_helpers.assert_model_outputs(legacy_loss, experimental_loss) -@pytest.mark.parametrize("seed,device,gradient_accumulation_steps,total_steps", [ - (0, 'cuda', 1, 12), - (42, 'cuda', 3, 12), - (123, 'cuda', 7, 12), - (321, 'cuda', 12, 12), -]) +@pytest.mark.parametrize( + "seed,device,gradient_accumulation_steps,total_steps", + [ + (0, "cuda", 1, 12), + (42, "cuda", 3, 12), + (123, "cuda", 7, 12), + (321, "cuda", 12, 12), + ], +) def testORTTrainerLegacyAndExperimentalGradientAccumulation(seed, device, gradient_accumulation_steps, total_steps): # Common data torch.set_printoptions(precision=10) @@ -1163,9 +1379,13 @@ def testORTTrainerLegacyAndExperimentalGradientAccumulation(seed, device, gradie # Setup experimental API torch.manual_seed(seed) set_seed(seed) - options = orttrainer.ORTTrainerOptions({'device' : {'id' : device}, - 'batch' : {'gradient_accumulation_steps' : gradient_accumulation_steps}, - 'debug' : {'deterministic_compute' : True}}) + options = orttrainer.ORTTrainerOptions( + { + "device": {"id": device}, + "batch": {"gradient_accumulation_steps": gradient_accumulation_steps}, + "debug": {"deterministic_compute": True}, + } + ) model, model_desc, my_loss, batcher_fn, train_data, _, _ = _test_commons._load_pytorch_transformer_model(device) optim_config = optim.LambConfig(lr=0.001) trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=my_loss, options=options) @@ -1180,10 +1400,17 @@ def testORTTrainerLegacyAndExperimentalGradientAccumulation(seed, device, gradie torch.manual_seed(seed) set_seed(seed) model, (model_desc, lr_desc), _, _, _, _, _ = _test_commons._load_pytorch_transformer_model(device, legacy_api=True) - legacy_trainer = Legacy_ORTTrainer(model, my_loss, model_desc, "LambOptimizer", - None, lr_desc, device=device, - _use_deterministic_compute=True, - gradient_accumulation_steps=gradient_accumulation_steps) + legacy_trainer = Legacy_ORTTrainer( + model, + my_loss, + model_desc, + "LambOptimizer", + None, + lr_desc, + device=device, + _use_deterministic_compute=True, + gradient_accumulation_steps=gradient_accumulation_steps, + ) # Training loop legacy_loss = [] for i in range(total_steps): @@ -1195,34 +1422,112 @@ def testORTTrainerLegacyAndExperimentalGradientAccumulation(seed, device, gradie _test_helpers.assert_model_outputs(legacy_loss, experimental_loss) -@pytest.mark.parametrize("seed,device,optimizer_config,lr_scheduler, get_lr_this_step", [ - (0, 'cuda', optim.AdamConfig, optim.lr_scheduler.ConstantWarmupLRScheduler, _test_commons.legacy_constant_lr_scheduler), - (0, 'cuda', optim.LambConfig, optim.lr_scheduler.ConstantWarmupLRScheduler, _test_commons.legacy_constant_lr_scheduler), - (0, 'cuda', optim.SGDConfig, optim.lr_scheduler.ConstantWarmupLRScheduler, _test_commons.legacy_constant_lr_scheduler), - (42, 'cuda', optim.AdamConfig, optim.lr_scheduler.LinearWarmupLRScheduler, _test_commons.legacy_linear_lr_scheduler), - (42, 'cuda', optim.LambConfig, optim.lr_scheduler.LinearWarmupLRScheduler, _test_commons.legacy_linear_lr_scheduler), - (42, 'cuda', optim.SGDConfig, optim.lr_scheduler.LinearWarmupLRScheduler, _test_commons.legacy_linear_lr_scheduler), - (123, 'cuda', optim.AdamConfig, optim.lr_scheduler.CosineWarmupLRScheduler, _test_commons.legacy_cosine_lr_scheduler), - (123, 'cuda', optim.LambConfig, optim.lr_scheduler.CosineWarmupLRScheduler, _test_commons.legacy_cosine_lr_scheduler), - (123, 'cuda', optim.SGDConfig, optim.lr_scheduler.CosineWarmupLRScheduler, _test_commons.legacy_cosine_lr_scheduler), - (321, 'cuda', optim.AdamConfig, optim.lr_scheduler.PolyWarmupLRScheduler, _test_commons.legacy_poly_lr_scheduler), - (321, 'cuda', optim.LambConfig, optim.lr_scheduler.PolyWarmupLRScheduler, _test_commons.legacy_poly_lr_scheduler), - (321, 'cuda', optim.SGDConfig, optim.lr_scheduler.PolyWarmupLRScheduler, _test_commons.legacy_poly_lr_scheduler), -]) +@pytest.mark.parametrize( + "seed,device,optimizer_config,lr_scheduler, get_lr_this_step", + [ + ( + 0, + "cuda", + optim.AdamConfig, + optim.lr_scheduler.ConstantWarmupLRScheduler, + _test_commons.legacy_constant_lr_scheduler, + ), + ( + 0, + "cuda", + optim.LambConfig, + optim.lr_scheduler.ConstantWarmupLRScheduler, + _test_commons.legacy_constant_lr_scheduler, + ), + ( + 0, + "cuda", + optim.SGDConfig, + optim.lr_scheduler.ConstantWarmupLRScheduler, + _test_commons.legacy_constant_lr_scheduler, + ), + ( + 42, + "cuda", + optim.AdamConfig, + optim.lr_scheduler.LinearWarmupLRScheduler, + _test_commons.legacy_linear_lr_scheduler, + ), + ( + 42, + "cuda", + optim.LambConfig, + optim.lr_scheduler.LinearWarmupLRScheduler, + _test_commons.legacy_linear_lr_scheduler, + ), + ( + 42, + "cuda", + optim.SGDConfig, + optim.lr_scheduler.LinearWarmupLRScheduler, + _test_commons.legacy_linear_lr_scheduler, + ), + ( + 123, + "cuda", + optim.AdamConfig, + optim.lr_scheduler.CosineWarmupLRScheduler, + _test_commons.legacy_cosine_lr_scheduler, + ), + ( + 123, + "cuda", + optim.LambConfig, + optim.lr_scheduler.CosineWarmupLRScheduler, + _test_commons.legacy_cosine_lr_scheduler, + ), + ( + 123, + "cuda", + optim.SGDConfig, + optim.lr_scheduler.CosineWarmupLRScheduler, + _test_commons.legacy_cosine_lr_scheduler, + ), + ( + 321, + "cuda", + optim.AdamConfig, + optim.lr_scheduler.PolyWarmupLRScheduler, + _test_commons.legacy_poly_lr_scheduler, + ), + ( + 321, + "cuda", + optim.LambConfig, + optim.lr_scheduler.PolyWarmupLRScheduler, + _test_commons.legacy_poly_lr_scheduler, + ), + ( + 321, + "cuda", + optim.SGDConfig, + optim.lr_scheduler.PolyWarmupLRScheduler, + _test_commons.legacy_poly_lr_scheduler, + ), + ], +) def testORTTrainerLegacyAndExperimentalLRScheduler(seed, device, optimizer_config, lr_scheduler, get_lr_this_step): # Common data total_steps = 10 lr = 0.001 warmup = 0.5 cycles = 0.5 - power = 1. + power = 1.0 lr_end = 1e-7 torch.set_printoptions(precision=10) # Setup experimental API torch.manual_seed(seed) set_seed(seed) - if lr_scheduler == optim.lr_scheduler.ConstantWarmupLRScheduler or lr_scheduler == optim.lr_scheduler.LinearWarmupLRScheduler: + if ( + lr_scheduler == optim.lr_scheduler.ConstantWarmupLRScheduler + or lr_scheduler == optim.lr_scheduler.LinearWarmupLRScheduler + ): lr_scheduler = lr_scheduler(total_steps=total_steps, warmup=warmup) elif lr_scheduler == optim.lr_scheduler.CosineWarmupLRScheduler: lr_scheduler = lr_scheduler(total_steps=total_steps, warmup=warmup, cycles=cycles) @@ -1231,9 +1536,9 @@ def testORTTrainerLegacyAndExperimentalLRScheduler(seed, device, optimizer_confi else: raise RuntimeError("Invalid lr_scheduler") - options = orttrainer.ORTTrainerOptions({'device' : {'id' : device}, - 'debug' : {'deterministic_compute' : True}, - 'lr_scheduler' : lr_scheduler}) + options = orttrainer.ORTTrainerOptions( + {"device": {"id": device}, "debug": {"deterministic_compute": True}, "lr_scheduler": lr_scheduler} + ) model, model_desc, my_loss, batcher_fn, train_data, _, _ = _test_commons._load_pytorch_transformer_model(device) optim_config = optimizer_config(lr=lr) trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=my_loss, options=options) @@ -1249,28 +1554,42 @@ def testORTTrainerLegacyAndExperimentalLRScheduler(seed, device, optimizer_confi set_seed(seed) if optimizer_config == optim.AdamConfig: - legacy_optimizer_config = 'AdamOptimizer' + legacy_optimizer_config = "AdamOptimizer" elif optimizer_config == optim.LambConfig: - legacy_optimizer_config = 'LambOptimizer' + legacy_optimizer_config = "LambOptimizer" elif optimizer_config == optim.SGDConfig: - legacy_optimizer_config = 'SGDOptimizer' + legacy_optimizer_config = "SGDOptimizer" else: raise RuntimeError("Invalid optimizer_config") - if get_lr_this_step == _test_commons.legacy_constant_lr_scheduler or get_lr_this_step == _test_commons.legacy_linear_lr_scheduler: + if ( + get_lr_this_step == _test_commons.legacy_constant_lr_scheduler + or get_lr_this_step == _test_commons.legacy_linear_lr_scheduler + ): get_lr_this_step = partial(get_lr_this_step, initial_lr=lr, total_steps=total_steps, warmup=warmup) elif get_lr_this_step == _test_commons.legacy_cosine_lr_scheduler: - get_lr_this_step = partial(get_lr_this_step, initial_lr=lr, total_steps=total_steps, warmup=warmup, cycles=cycles) + get_lr_this_step = partial( + get_lr_this_step, initial_lr=lr, total_steps=total_steps, warmup=warmup, cycles=cycles + ) elif get_lr_this_step == _test_commons.legacy_poly_lr_scheduler: - get_lr_this_step = partial(get_lr_this_step, initial_lr=lr, total_steps=total_steps, warmup=warmup, power=power, lr_end=lr_end) + get_lr_this_step = partial( + get_lr_this_step, initial_lr=lr, total_steps=total_steps, warmup=warmup, power=power, lr_end=lr_end + ) else: raise RuntimeError("Invalid get_lr_this_step") model, (model_desc, lr_desc), _, _, _, _, _ = _test_commons._load_pytorch_transformer_model(device, legacy_api=True) - legacy_trainer = Legacy_ORTTrainer(model, my_loss, model_desc, legacy_optimizer_config, - None, lr_desc, device=device, - _use_deterministic_compute=True, - get_lr_this_step=get_lr_this_step) + legacy_trainer = Legacy_ORTTrainer( + model, + my_loss, + model_desc, + legacy_optimizer_config, + None, + lr_desc, + device=device, + _use_deterministic_compute=True, + get_lr_this_step=get_lr_this_step, + ) # Training loop legacy_loss = [] for i in range(total_steps): @@ -1283,7 +1602,9 @@ def testORTTrainerLegacyAndExperimentalLRScheduler(seed, device, optimizer_confi def testLossScalerLegacyAndExperimentalFullCycle(): - info = orttrainer.TrainStepInfo(optimizer_config=optim.LambConfig(lr=0.001), all_finite=True, fetches=[], optimization_step=0, step=0) + info = orttrainer.TrainStepInfo( + optimizer_config=optim.LambConfig(lr=0.001), all_finite=True, fetches=[], optimization_step=0, step=0 + ) new_ls = amp.DynamicLossScaler() old_ls = Legacy_LossScaler("ort_test_input_loss_scaler", True) @@ -1358,6 +1679,7 @@ def testLossScalerLegacyAndExperimentalRandomAllFinite(): assert_allclose(new_ls.max_loss_scale, old_ls.max_loss_scale_) import random + out = [] for _ in range(1, 64): train_step_info.all_finite = bool(random.getrandbits(1)) @@ -1369,18 +1691,18 @@ def testLossScalerLegacyAndExperimentalRandomAllFinite(): out.append(new_loss_scale) assert new_loss_scale > 1e-7 + def testORTTrainerRunSymbolicShapeInfer(): # Common data seed = 0 total_steps = 12 - device = 'cuda' + device = "cuda" torch.set_printoptions(precision=10) # Setup without symbolic shape inference torch.manual_seed(seed) set_seed(seed) - options = orttrainer.ORTTrainerOptions({'device' : {'id' : device}, - 'debug' : {'deterministic_compute' : True}}) + options = orttrainer.ORTTrainerOptions({"device": {"id": device}, "debug": {"deterministic_compute": True}}) model, model_desc, my_loss, batcher_fn, train_data, _, _ = _test_commons._load_pytorch_transformer_model(device) optim_config = optim.LambConfig(lr=0.001) trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=my_loss, options=options) @@ -1409,10 +1731,17 @@ def testORTTrainerRunSymbolicShapeInfer(): torch.manual_seed(seed) set_seed(seed) model, (model_desc, lr_desc), _, _, _, _, _ = _test_commons._load_pytorch_transformer_model(device, legacy_api=True) - legacy_trainer = Legacy_ORTTrainer(model, my_loss, model_desc, "LambOptimizer", - None, lr_desc, device=device, - run_symbolic_shape_infer=True, - _use_deterministic_compute=True) + legacy_trainer = Legacy_ORTTrainer( + model, + my_loss, + model_desc, + "LambOptimizer", + None, + lr_desc, + device=device, + run_symbolic_shape_infer=True, + _use_deterministic_compute=True, + ) # Training loop legacy_loss = [] for i in range(total_steps): @@ -1424,37 +1753,51 @@ def testORTTrainerRunSymbolicShapeInfer(): _test_helpers.assert_model_outputs(new_loss, expected_loss) _test_helpers.assert_model_outputs(legacy_loss, expected_loss) -@pytest.mark.parametrize("test_input", [ - ({ - 'distributed': {'enable_adasum': True}, - }) -]) + +@pytest.mark.parametrize( + "test_input", + [ + ( + { + "distributed": {"enable_adasum": True}, + } + ) + ], +) def testORTTrainerOptionsEnabledAdasumFlag(test_input): - ''' Test the enabled_adasum flag values when set enabled''' + """Test the enabled_adasum flag values when set enabled""" actual_values = orttrainer_options.ORTTrainerOptions(test_input) assert actual_values.distributed.enable_adasum == True -@pytest.mark.parametrize("test_input", [ - ({ - 'distributed': {'enable_adasum': False}, - }) -]) + +@pytest.mark.parametrize( + "test_input", + [ + ( + { + "distributed": {"enable_adasum": False}, + } + ) + ], +) def testORTTrainerOptionsDisabledAdasumFlag(test_input): - ''' Test the enabled_adasum flag values when set disabled''' + """Test the enabled_adasum flag values when set disabled""" actual_values = orttrainer_options.ORTTrainerOptions(test_input) assert actual_values.distributed.enable_adasum == False + def testORTTrainerUnusedInput(): class UnusedInputModel(torch.nn.Module): def __init__(self): super(UnusedInputModel, self).__init__() + def forward(self, x, y): return torch.mean(x) model = UnusedInputModel() - model_desc = {'inputs': [('x', [1]), ('y', [1])], 'outputs': [('loss', [], True)]} + model_desc = {"inputs": [("x", [1]), ("y", [1])], "outputs": [("loss", [], True)]} optim_config = optim.LambConfig(lr=0.001) trainer = orttrainer.ORTTrainer(model, model_desc, optim_config) @@ -1464,46 +1807,44 @@ def forward(self, x, y): except RuntimeError: pytest.fail("RuntimeError doing train_step with an unused input.") -@pytest.mark.parametrize("debug_files", [ - {'model_after_graph_transforms_path': 'transformed.onnx', - 'model_with_gradient_graph_path': 'transformed_grad.onnx', - 'model_with_training_graph_path': 'training.onnx', - 'model_with_training_graph_after_optimization_path': 'training_optimized.onnx' - }, - {'model_after_graph_transforms_path': 'transformed.onnx', - 'model_with_training_graph_path': '' - }, - ]) + +@pytest.mark.parametrize( + "debug_files", + [ + { + "model_after_graph_transforms_path": "transformed.onnx", + "model_with_gradient_graph_path": "transformed_grad.onnx", + "model_with_training_graph_path": "training.onnx", + "model_with_training_graph_after_optimization_path": "training_optimized.onnx", + }, + {"model_after_graph_transforms_path": "transformed.onnx", "model_with_training_graph_path": ""}, + ], +) def testTrainingGraphExport(debug_files): - device = 'cuda' + device = "cuda" model, model_desc, my_loss, batcher_fn, train_data, _, _ = _test_commons._load_pytorch_transformer_model(device) with tempfile.TemporaryDirectory() as tempdir: debug_paths = {} - for k,v in debug_files.items(): + for k, v in debug_files.items(): debug_paths[k] = os.path.join(tempdir, v) - opts = orttrainer.ORTTrainerOptions( - { - "device": {"id": device}, - "debug": {"graph_save_paths": debug_paths} - } - ) + opts = orttrainer.ORTTrainerOptions({"device": {"id": device}, "debug": {"graph_save_paths": debug_paths}}) optim_config = optim.AdamConfig() trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=my_loss, options=opts) data, targets = batcher_fn(train_data, 0) trainer.train_step(data, targets) - for k,v in debug_files.items(): + for k, v in debug_files.items(): path = debug_paths[k] if len(v) > 0: assert os.path.isfile(path) saved_graph = onnx.load(path).graph - if k == 'model_with_training_graph_path': + if k == "model_with_training_graph_path": assert any("AdamOptimizer" in n.op_type for n in saved_graph.node) - elif k == 'model_with_gradient_graph_path': + elif k == "model_with_gradient_graph_path": assert any("Grad" in n.name for n in saved_graph.node) - elif k == 'model_after_graph_transforms_path': + elif k == "model_after_graph_transforms_path": assert any("LayerNormalization" in n.op_type for n in saved_graph.node) - elif k == 'model_with_training_graph_after_optimization_path': + elif k == "model_with_training_graph_after_optimization_path": assert any("FusedMatMul" in n.op_type for n in saved_graph.node) # remove saved file os.remove(path) @@ -1513,62 +1854,326 @@ def testTrainingGraphExport(debug_files): def _adam_max_norm_clip_data(): device_capability_major = torch.cuda.get_device_capability()[0] - if device_capability_major == 7: # V100 for Dev machine + if device_capability_major == 7: # V100 for Dev machine return [ - (0, 'cuda', 1.0, 1, 12, { - 12: [10.592951, 10.067989, 9.619152, 9.245731, 8.881137, - 8.578644, 8.280573, 8.063023, 7.797933, 7.486215, 7.233806, 7.011791], - 14: [10.584141, 10.068119, 9.581743, 9.191472, 8.880169, 8.5352, - 8.311425, 8.061202, 7.773032, 7.523009, 7.258711, 7.02805]}), - (0, 'cuda', 0.1, 1, 12, { - 12: [10.592951, 10.068722, 9.620503, 9.247791, 8.883972, - 8.582286, 8.285027, 8.068308, 7.803638, 7.492318, 7.240352, 7.018665], - 14: [10.584141, 10.068845, 9.583107, 9.193537, 8.882966, 8.538839, - 8.315872, 8.066408, 7.778978, 7.529708, 7.265849, 7.035439]}), - (42, 'cuda', 1.0, 1, 12, { - 12: [10.647908, 10.144501, 9.672352, 9.306980, 8.956026, - 8.602655, 8.351079, 8.088144, 7.867220, 7.564082, 7.289846, 7.073726], - 14: [10.697515, 10.229034, 9.765422, 9.428294, 9.080612, 8.715208, - 8.459574, 8.169073, 7.940211, 7.654147, 7.390446, 7.166227]}), - (42, 'cuda', 0.1, 1, 12, { - 12: [10.647908, 10.145191, 9.673690, 9.309031, 8.959020, - 8.606632, 8.355836, 8.093478, 7.873327, 7.570731, 7.296772, 7.0809422], - 14: [10.697515, 10.22967, 9.766556, 9.430037, 9.083106, 8.718601, - 8.463726, 8.17396, 7.945755, 7.660188, 7.396963, 7.172944]}) + ( + 0, + "cuda", + 1.0, + 1, + 12, + { + 12: [ + 10.592951, + 10.067989, + 9.619152, + 9.245731, + 8.881137, + 8.578644, + 8.280573, + 8.063023, + 7.797933, + 7.486215, + 7.233806, + 7.011791, + ], + 14: [ + 10.584141, + 10.068119, + 9.581743, + 9.191472, + 8.880169, + 8.5352, + 8.311425, + 8.061202, + 7.773032, + 7.523009, + 7.258711, + 7.02805, + ], + }, + ), + ( + 0, + "cuda", + 0.1, + 1, + 12, + { + 12: [ + 10.592951, + 10.068722, + 9.620503, + 9.247791, + 8.883972, + 8.582286, + 8.285027, + 8.068308, + 7.803638, + 7.492318, + 7.240352, + 7.018665, + ], + 14: [ + 10.584141, + 10.068845, + 9.583107, + 9.193537, + 8.882966, + 8.538839, + 8.315872, + 8.066408, + 7.778978, + 7.529708, + 7.265849, + 7.035439, + ], + }, + ), + ( + 42, + "cuda", + 1.0, + 1, + 12, + { + 12: [ + 10.647908, + 10.144501, + 9.672352, + 9.306980, + 8.956026, + 8.602655, + 8.351079, + 8.088144, + 7.867220, + 7.564082, + 7.289846, + 7.073726, + ], + 14: [ + 10.697515, + 10.229034, + 9.765422, + 9.428294, + 9.080612, + 8.715208, + 8.459574, + 8.169073, + 7.940211, + 7.654147, + 7.390446, + 7.166227, + ], + }, + ), + ( + 42, + "cuda", + 0.1, + 1, + 12, + { + 12: [ + 10.647908, + 10.145191, + 9.673690, + 9.309031, + 8.959020, + 8.606632, + 8.355836, + 8.093478, + 7.873327, + 7.570731, + 7.296772, + 7.0809422, + ], + 14: [ + 10.697515, + 10.22967, + 9.766556, + 9.430037, + 9.083106, + 8.718601, + 8.463726, + 8.17396, + 7.945755, + 7.660188, + 7.396963, + 7.172944, + ], + }, + ), ] elif device_capability_major == 5: # M60 for CI machines (Python Packaging Pipeline) return [ - (0, 'cuda', 1.0, 1, 12, { - 12: [10.618382, 10.08292 , 9.603334, 9.258133, 8.917768, 8.591574, - 8.318401, 8.042292, 7.783608, 7.50226 , 7.236041, 7.035602], - 14: [10.618382, 10.08292 , 9.603334, 9.258133, 8.917768, 8.591574, - 8.318401, 8.042292, 7.783608, 7.50226 , 7.236041, 7.035602]}), - (0, 'cuda', 0.1, 1, 12, { - 12: [10.618382, 10.083632, 9.604639, 9.260109, 8.920504, 8.595082, - 8.322799, 8.047493, 7.78929 , 7.508382, 7.242587, 7.042367], - 14: [10.618382, 10.083632, 9.604639, 9.260109, 8.920504, 8.595082, - 8.322799, 8.047493, 7.78929 , 7.508382, 7.242587, 7.042367]}), - (42, 'cuda', 1.0, 1, 12, { - 12: [10.68639 , 10.102986, 9.647681, 9.293091, 8.958928, 8.625297, - 8.351107, 8.079577, 7.840723, 7.543044, 7.284141, 7.072688], - 14: [10.68639 , 10.102986, 9.647681, 9.293091, 8.958928, 8.625297, - 8.351107, 8.079577, 7.840723, 7.543044, 7.284141, 7.072688]}), - (42, 'cuda', 0.1, 1, 12, { - 12: [10.68639 , 10.103672, 9.649025, 9.295167, 8.961777, 8.629059, - 8.355571, 8.084871, 7.846589, 7.549438, 7.290722, 7.079446], - 14: [10.697515, 10.22967, 9.766556, 9.430037, 9.083106, 8.718601, - 8.463726, 8.17396, 7.945755, 7.660188, 7.396963, 7.172944]}), + ( + 0, + "cuda", + 1.0, + 1, + 12, + { + 12: [ + 10.618382, + 10.08292, + 9.603334, + 9.258133, + 8.917768, + 8.591574, + 8.318401, + 8.042292, + 7.783608, + 7.50226, + 7.236041, + 7.035602, + ], + 14: [ + 10.618382, + 10.08292, + 9.603334, + 9.258133, + 8.917768, + 8.591574, + 8.318401, + 8.042292, + 7.783608, + 7.50226, + 7.236041, + 7.035602, + ], + }, + ), + ( + 0, + "cuda", + 0.1, + 1, + 12, + { + 12: [ + 10.618382, + 10.083632, + 9.604639, + 9.260109, + 8.920504, + 8.595082, + 8.322799, + 8.047493, + 7.78929, + 7.508382, + 7.242587, + 7.042367, + ], + 14: [ + 10.618382, + 10.083632, + 9.604639, + 9.260109, + 8.920504, + 8.595082, + 8.322799, + 8.047493, + 7.78929, + 7.508382, + 7.242587, + 7.042367, + ], + }, + ), + ( + 42, + "cuda", + 1.0, + 1, + 12, + { + 12: [ + 10.68639, + 10.102986, + 9.647681, + 9.293091, + 8.958928, + 8.625297, + 8.351107, + 8.079577, + 7.840723, + 7.543044, + 7.284141, + 7.072688, + ], + 14: [ + 10.68639, + 10.102986, + 9.647681, + 9.293091, + 8.958928, + 8.625297, + 8.351107, + 8.079577, + 7.840723, + 7.543044, + 7.284141, + 7.072688, + ], + }, + ), + ( + 42, + "cuda", + 0.1, + 1, + 12, + { + 12: [ + 10.68639, + 10.103672, + 9.649025, + 9.295167, + 8.961777, + 8.629059, + 8.355571, + 8.084871, + 7.846589, + 7.549438, + 7.290722, + 7.079446, + ], + 14: [ + 10.697515, + 10.22967, + 9.766556, + 9.430037, + 9.083106, + 8.718601, + 8.463726, + 8.17396, + 7.945755, + 7.660188, + 7.396963, + 7.172944, + ], + }, + ), ] -@pytest.mark.parametrize("seed,device,max_norm_clip,gradient_accumulation_steps,total_steps,expected_loss", _adam_max_norm_clip_data()) + + +@pytest.mark.parametrize( + "seed,device,max_norm_clip,gradient_accumulation_steps,total_steps,expected_loss", _adam_max_norm_clip_data() +) def testORTTrainerAdamMaxNormClip(seed, device, max_norm_clip, gradient_accumulation_steps, total_steps, expected_loss): rtol = 1e-5 torch.manual_seed(seed) set_seed(seed) # Setup ORTTrainer - options = orttrainer.ORTTrainerOptions({'device' : {'id' : device}, - 'batch' : {'gradient_accumulation_steps' : gradient_accumulation_steps}, - 'debug' : {'deterministic_compute' : True}}) + options = orttrainer.ORTTrainerOptions( + { + "device": {"id": device}, + "batch": {"gradient_accumulation_steps": gradient_accumulation_steps}, + "debug": {"deterministic_compute": True}, + } + ) model, model_desc, my_loss, batcher_fn, train_data, _, _ = _test_commons._load_pytorch_transformer_model(device) optim_config = optim.AdamConfig(lr=0.001, max_norm_clip=max_norm_clip) trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=my_loss, options=options) @@ -1588,58 +2193,274 @@ def testORTTrainerAdamMaxNormClip(seed, device, max_norm_clip, gradient_accumula def _lamb_max_norm_clip_data(): device_capability_major = torch.cuda.get_device_capability()[0] - if device_capability_major == 7: # V100 for Dev machine + if device_capability_major == 7: # V100 for Dev machine return [ - (0, 'cuda', 1.0, 1, 12, { - 12: [10.592951, 10.487728, 10.422251, 10.350913, 10.244248, 10.213003, - 10.129222, 10.095112, 10.035983, 9.974586, 9.909771, 9.874278], - 14: [10.584141, 10.497192, 10.389251, 10.286045, 10.231354, 10.17018, - 10.066779, 10.048138, 9.958029, 9.8908, 9.82965, 9.755484]}), - (0, 'cuda', 0.1, 1, 12, { - 12: [10.592951, 10.452503, 10.349832, 10.245314, 10.106587, 10.046009, - 9.934781, 9.875164, 9.792067, 9.704592, 9.617104, 9.563070], - 14: [10.584141, 10.461154, 10.315399, 10.178979, 10.092329, 9.999928, - 9.869949, 9.824564, 9.707565, 9.61643, 9.532847, 9.439593]}), - (42, 'cuda', 1.0, 1, 12, { - 12: [10.647908, 10.566276, 10.476154, 10.406275, 10.311079, 10.240053, - 10.196469, 10.113955, 10.117376, 10.013077, 9.930301, 9.893368], - 14: [10.697515, 10.631279, 10.528757, 10.496689, 10.411219, 10.322109, - 10.297314, 10.215549, 10.149698, 10.087336, 10.010884, 9.934544]}), - (42, 'cuda', 0.1, 1, 12, { - 12: [10.647908, 10.531957, 10.405246, 10.302971, 10.176583, 10.075583, - 10.005772, 9.897825, 9.875748, 9.748932, 9.642885, 9.586762], - 14: [10.697515, 10.596729, 10.457815, 10.393475, 10.277581, 10.158909, - 10.108126, 10.000326, 9.912526, 9.826057, 9.727899, 9.633768]}) + ( + 0, + "cuda", + 1.0, + 1, + 12, + { + 12: [ + 10.592951, + 10.487728, + 10.422251, + 10.350913, + 10.244248, + 10.213003, + 10.129222, + 10.095112, + 10.035983, + 9.974586, + 9.909771, + 9.874278, + ], + 14: [ + 10.584141, + 10.497192, + 10.389251, + 10.286045, + 10.231354, + 10.17018, + 10.066779, + 10.048138, + 9.958029, + 9.8908, + 9.82965, + 9.755484, + ], + }, + ), + ( + 0, + "cuda", + 0.1, + 1, + 12, + { + 12: [ + 10.592951, + 10.452503, + 10.349832, + 10.245314, + 10.106587, + 10.046009, + 9.934781, + 9.875164, + 9.792067, + 9.704592, + 9.617104, + 9.563070, + ], + 14: [ + 10.584141, + 10.461154, + 10.315399, + 10.178979, + 10.092329, + 9.999928, + 9.869949, + 9.824564, + 9.707565, + 9.61643, + 9.532847, + 9.439593, + ], + }, + ), + ( + 42, + "cuda", + 1.0, + 1, + 12, + { + 12: [ + 10.647908, + 10.566276, + 10.476154, + 10.406275, + 10.311079, + 10.240053, + 10.196469, + 10.113955, + 10.117376, + 10.013077, + 9.930301, + 9.893368, + ], + 14: [ + 10.697515, + 10.631279, + 10.528757, + 10.496689, + 10.411219, + 10.322109, + 10.297314, + 10.215549, + 10.149698, + 10.087336, + 10.010884, + 9.934544, + ], + }, + ), + ( + 42, + "cuda", + 0.1, + 1, + 12, + { + 12: [ + 10.647908, + 10.531957, + 10.405246, + 10.302971, + 10.176583, + 10.075583, + 10.005772, + 9.897825, + 9.875748, + 9.748932, + 9.642885, + 9.586762, + ], + 14: [ + 10.697515, + 10.596729, + 10.457815, + 10.393475, + 10.277581, + 10.158909, + 10.108126, + 10.000326, + 9.912526, + 9.826057, + 9.727899, + 9.633768, + ], + }, + ), ] elif device_capability_major == 5: # M60 for CI machines (Python Packaging Pipeline) return [ - (0, 'cuda', 1.0, 1, 12, { - 12: [10.618382, 10.50222, 10.403347, 10.35298, 10.288447, 10.237399, - 10.184225, 10.089048, 10.008952, 9.972644, 9.897674, 9.84524], - 14: [0, 0, 0, 4, 0, 0, 0, 4, 0, 0, 0, 4]}), - (0, 'cuda', 0.1, 1, 12, { - 12: [10.618382, 10.466732, 10.330871, 10.24715 , 10.150972, 10.069127, - 9.98974 , 9.870169, 9.763693, 9.704323, 9.605957, 9.533117], - 14: [1, 0, 0, 4, 0, 0, 0, 4, 0, 0, 0, 4]}), - (42, 'cuda', 1.0, 1, 12, { - 12: [10.68639 , 10.511692, 10.447308, 10.405255, 10.334866, 10.261473, - 10.169422, 10.107138, 10.069889, 9.97798, 9.928105, 9.896435], - 14: [2, 0, 0, 4, 0, 0, 0, 4, 0, 0, 0, 4]}), - (42, 'cuda', 0.1, 1, 12, { - 12: [10.68639 , 10.477489, 10.376671, 10.301725, 10.200718, 10.098477, - 9.97995 , 9.890104, 9.828899, 9.713555, 9.639567, 9.589856], - 14: [3, 0, 0, 4, 0, 0, 0, 4, 0, 0, 0, 4]}), + ( + 0, + "cuda", + 1.0, + 1, + 12, + { + 12: [ + 10.618382, + 10.50222, + 10.403347, + 10.35298, + 10.288447, + 10.237399, + 10.184225, + 10.089048, + 10.008952, + 9.972644, + 9.897674, + 9.84524, + ], + 14: [0, 0, 0, 4, 0, 0, 0, 4, 0, 0, 0, 4], + }, + ), + ( + 0, + "cuda", + 0.1, + 1, + 12, + { + 12: [ + 10.618382, + 10.466732, + 10.330871, + 10.24715, + 10.150972, + 10.069127, + 9.98974, + 9.870169, + 9.763693, + 9.704323, + 9.605957, + 9.533117, + ], + 14: [1, 0, 0, 4, 0, 0, 0, 4, 0, 0, 0, 4], + }, + ), + ( + 42, + "cuda", + 1.0, + 1, + 12, + { + 12: [ + 10.68639, + 10.511692, + 10.447308, + 10.405255, + 10.334866, + 10.261473, + 10.169422, + 10.107138, + 10.069889, + 9.97798, + 9.928105, + 9.896435, + ], + 14: [2, 0, 0, 4, 0, 0, 0, 4, 0, 0, 0, 4], + }, + ), + ( + 42, + "cuda", + 0.1, + 1, + 12, + { + 12: [ + 10.68639, + 10.477489, + 10.376671, + 10.301725, + 10.200718, + 10.098477, + 9.97995, + 9.890104, + 9.828899, + 9.713555, + 9.639567, + 9.589856, + ], + 14: [3, 0, 0, 4, 0, 0, 0, 4, 0, 0, 0, 4], + }, + ), ] -@pytest.mark.parametrize("seed,device,max_norm_clip, gradient_accumulation_steps,total_steps,expected_loss", _lamb_max_norm_clip_data()) + + +@pytest.mark.parametrize( + "seed,device,max_norm_clip, gradient_accumulation_steps,total_steps,expected_loss", _lamb_max_norm_clip_data() +) def testORTTrainerLambMaxNormClip(seed, device, max_norm_clip, gradient_accumulation_steps, total_steps, expected_loss): rtol = 1e-3 torch.manual_seed(seed) set_seed(seed) # Setup ORTTrainer - options = orttrainer.ORTTrainerOptions({'device' : {'id' : device}, - 'batch' : {'gradient_accumulation_steps' : gradient_accumulation_steps}, - 'debug' : {'deterministic_compute' : True}}) + options = orttrainer.ORTTrainerOptions( + { + "device": {"id": device}, + "batch": {"gradient_accumulation_steps": gradient_accumulation_steps}, + "debug": {"deterministic_compute": True}, + } + ) model, model_desc, my_loss, batcher_fn, train_data, _, _ = _test_commons._load_pytorch_transformer_model(device) optim_config = optim.LambConfig(lr=0.001, max_norm_clip=max_norm_clip) trainer = orttrainer.ORTTrainer(model, model_desc, optim_config, loss_fn=my_loss, options=options) diff --git a/orttraining/orttraining/test/python/orttraining_test_ortvalue.py b/orttraining/orttraining/test/python/orttraining_test_ortvalue.py index cabeea1c076e1..164409939447f 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortvalue.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortvalue.py @@ -24,14 +24,14 @@ def C_ort_from_dlpack(dlpack): - if hasattr(C, 'ort_from_dlpack'): + if hasattr(C, "ort_from_dlpack"): return C.ort_from_dlpack(dlpack) return from_dlpack(dlpack) def _torch_tensor_from_dl_pack(dlpack, ortvalue, device): - torch_tensor = from_dlpack(dlpack) if device.type != 'ort' else C_ort_from_dlpack(dlpack) - return torch_tensor.to(torch.bool) if ortvalue.data_type() == 'tensor(bool)' else torch_tensor + torch_tensor = from_dlpack(dlpack) if device.type != "ort" else C_ort_from_dlpack(dlpack) + return torch_tensor.to(torch.bool) if ortvalue.data_type() == "tensor(bool)" else torch_tensor def _ortvalue_to_torch_tensor(ortvalue, device): @@ -44,7 +44,6 @@ def _ortvalues_to_torch_tensor(vect, device): class TestOrtValue(unittest.TestCase): - def testOrtValueDlPack_float32(self): numpy_arr_input = np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], dtype=np.float32) ortvalue = onnxrt.OrtValue.ortvalue_from_numpy(numpy_arr_input) @@ -94,7 +93,8 @@ def testOrtValueDlPack_bool(self): def testOrtValueVector_float32(self): narrays = [ np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], dtype=np.float32), - np.array([[6.0, 7.0], [8.0, 9.0], [1.0, 6.0]], dtype=np.float32)] + np.array([[6.0, 7.0], [8.0, 9.0], [1.0, 6.0]], dtype=np.float32), + ] vect = OrtValueVector() for a in narrays: ortvalue = onnxrt.OrtValue.ortvalue_from_numpy(a) @@ -109,7 +109,8 @@ def testOrtValueVector_float32(self): def testOrtValueVector_bool(self): narrays = [ np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], dtype=np.bool_), - np.array([[6.0, 7.0], [8.0, 9.0], [1.0, 6.0]], dtype=np.bool_)] + np.array([[6.0, 7.0], [8.0, 9.0], [1.0, 6.0]], dtype=np.bool_), + ] vect = OrtValueVector() for a in narrays: ortvalue = onnxrt.OrtValue.ortvalue_from_numpy(a) @@ -123,7 +124,8 @@ def testOrtValueVector_bool(self): def OrtValueVectorDlPackOrtValue(self, my_to_tensor, tensor_type, device, dtype=np.float32): narrays = [ np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], dtype=dtype), - np.array([[6.0, 7.0], [8.0, 9.0], [1.0, 6.0]], dtype=dtype)] + np.array([[6.0, 7.0], [8.0, 9.0], [1.0, 6.0]], dtype=dtype), + ] vect = OrtValueVector() ptr = [] for a in narrays: @@ -149,7 +151,7 @@ def OrtValueVectorDlPackOrtValue(self, my_to_tensor, tensor_type, device, dtype= ptr2 = [] for av1, v2 in zip(narrays, converted_values): ptr2.append(v2.data_ptr()) - if hasattr(v2, 'cpu'): + if hasattr(v2, "cpu"): av2 = v2.cpu().numpy() else: av2 = v2.numpy() @@ -159,45 +161,50 @@ def OrtValueVectorDlPackOrtValue(self, my_to_tensor, tensor_type, device, dtype= def testOrtValueVectorDlPackOrtValue_cpu(self): def my_to_tensor(dlpack_structure): return C_OrtValue.from_dlpack(dlpack_structure, False) - self.OrtValueVectorDlPackOrtValue(my_to_tensor, C_OrtValue, 'cpu') + + self.OrtValueVectorDlPackOrtValue(my_to_tensor, C_OrtValue, "cpu") def testOrtValueVectorDlPackTorch_cpu(self): def my_to_tensor(dlpack_structure): return from_dlpack(dlpack_structure) - self.OrtValueVectorDlPackOrtValue(my_to_tensor, torch.Tensor, 'cpu') + + self.OrtValueVectorDlPackOrtValue(my_to_tensor, torch.Tensor, "cpu") def testOrtValueVectorDlPack_Torch_cpu(self): def my_to_tensor(dlpack_structure): return _from_dlpack(dlpack_structure) - self.OrtValueVectorDlPackOrtValue(my_to_tensor, torch.Tensor, 'cpu') + + self.OrtValueVectorDlPackOrtValue(my_to_tensor, torch.Tensor, "cpu") def testOrtValueVectorDlPackNone_cpu(self): - self.OrtValueVectorDlPackOrtValue(None, None, 'cpu') + self.OrtValueVectorDlPackOrtValue(None, None, "cpu") @unittest.skipIf(not has_cuda, reason="No CUDA availabled.") def testOrtValueVectorDlPackOrtValue_cuda(self): def my_to_tensor(dlpack_structure): return C_OrtValue.from_dlpack(dlpack_structure, False) - self.OrtValueVectorDlPackOrtValue(my_to_tensor, C_OrtValue, 'cuda') + + self.OrtValueVectorDlPackOrtValue(my_to_tensor, C_OrtValue, "cuda") @unittest.skipIf(not has_cuda, reason="No CUDA availabled.") def testOrtValueVectorDlPackTorch_cuda(self): def my_to_tensor(dlpack_structure): return from_dlpack(dlpack_structure) - self.OrtValueVectorDlPackOrtValue(my_to_tensor, torch.Tensor, 'cuda') + + self.OrtValueVectorDlPackOrtValue(my_to_tensor, torch.Tensor, "cuda") @unittest.skipIf(not has_cuda, reason="No CUDA availabled.") def testOrtValueVectorDlPack_Torch_cuda(self): def my_to_tensor(dlpack_structure): return _from_dlpack(dlpack_structure) - self.OrtValueVectorDlPackOrtValue(my_to_tensor, torch.Tensor, 'cuda') + + self.OrtValueVectorDlPackOrtValue(my_to_tensor, torch.Tensor, "cuda") @unittest.skipIf(not has_cuda, reason="No CUDA availabled.") def testOrtValueVectorDlPackNone_cuda(self): - self.OrtValueVectorDlPackOrtValue(None, None, 'cuda') + self.OrtValueVectorDlPackOrtValue(None, None, "cuda") def test_ortmodule_dlpack(self): - class NeuralNetTanh(torch.nn.Module): def __init__(self, input_size, hidden_size, num_classes): super(NeuralNetTanh, self).__init__() @@ -221,7 +228,7 @@ def run_step(model, x): ort_model = ORTModule(copy.deepcopy(pt_model)) for step in range(10): - pt_x = torch.randn(N, D_in, device='cpu', requires_grad=True) + pt_x = torch.randn(N, D_in, device="cpu", requires_grad=True) ort_x = copy.deepcopy(pt_x) ort_prediction, ort_loss = run_step(ort_model, ort_x) pt_prediction, pt_loss = run_step(pt_model, pt_x) @@ -241,7 +248,7 @@ def forward(self, condition, x1, x2): out2 = torch.tensor(out1).to(torch.bool) return out1, out2 - device = 'cpu' + device = "cpu" N, D_in, D_out = 8, 16, 2 model = NeuralNetBoolInputOutput(D_in, D_out).to(device) model = ORTModule(model) @@ -256,16 +263,16 @@ def forward(self, condition, x1, x2): def _ortvalues_to_torch_tensor_ortvaluevector(self, device, tensor_type, new_impl, dtype=np.float32): narrays = [ np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], dtype=dtype), - np.array([[6.0, 7.0], [8.0, 9.0], [1.0, 6.0]], dtype=dtype)] + np.array([[6.0, 7.0], [8.0, 9.0], [1.0, 6.0]], dtype=dtype), + ] vect = OrtValueVector() ptr = [] for a in narrays: - ortvalue = onnxrt.OrtValue.ortvalue_from_numpy( - a, device.type if device.type != 'ort' else 'cpu') + ortvalue = onnxrt.OrtValue.ortvalue_from_numpy(a, device.type if device.type != "ort" else "cpu") vect.push_back(ortvalue._ortvalue) ptr.append(ortvalue.data_ptr()) self.assertEqual(len(vect), 2) - if new_impl == 'list': + if new_impl == "list": tensors = _utils._ortvalues_to_torch_tensor_list(list(vect), device) elif new_impl: tensors = _utils._ortvalues_to_torch_tensor(vect, device) @@ -273,54 +280,54 @@ def _ortvalues_to_torch_tensor_ortvaluevector(self, device, tensor_type, new_imp tensors = _ortvalues_to_torch_tensor(vect, device) self.assertEqual(len(tensors), len(vect)) for t in tensors: - assert(isinstance(t, torch.Tensor)) + assert isinstance(t, torch.Tensor) self.assertEqual(ptr, [t.data_ptr() for t in tensors]) assert all(map(lambda v: isinstance(v, tensor_type), tensors)) def test_ortvalues_to_torch_tensor_ortvaluevector_cpu_new(self): - device = torch.device('cpu') + device = torch.device("cpu") self._ortvalues_to_torch_tensor_ortvaluevector(device, torch.Tensor, True) def test_ortvalues_to_torch_tensor_ortvaluevector_cpu_list(self): - device = torch.device('cpu') - self._ortvalues_to_torch_tensor_ortvaluevector(device, torch.Tensor, 'list') + device = torch.device("cpu") + self._ortvalues_to_torch_tensor_ortvaluevector(device, torch.Tensor, "list") def test_ortvalues_to_torch_tensor_ortvaluevector_cpu_old(self): - device = torch.device('cpu') + device = torch.device("cpu") self._ortvalues_to_torch_tensor_ortvaluevector(device, torch.Tensor, False) def test_ortvalues_to_torch_tensor_ortvaluevector_ort_new(self): - device = torch.device('ort') - if hasattr(C, 'ort_from_dlpack'): + device = torch.device("ort") + if hasattr(C, "ort_from_dlpack"): self._ortvalues_to_torch_tensor_ortvaluevector(device, torch.Tensor, True) else: with self.assertRaises(AttributeError): self._ortvalues_to_torch_tensor_ortvaluevector(device, torch.Tensor, True) def test_ortvalues_to_torch_tensor_ortvaluevector_ort_old(self): - device = torch.device('ort') + device = torch.device("ort") self._ortvalues_to_torch_tensor_ortvaluevector(device, torch.Tensor, False) @unittest.skipIf(not has_cuda, reason="No CUDA availabled.") def test_ortvalues_to_torch_tensor_ortvaluevector_cuda_new(self): - device = torch.device('cuda:0') + device = torch.device("cuda:0") self._ortvalues_to_torch_tensor_ortvaluevector(device, torch.Tensor, True) @unittest.skipIf(not has_cuda, reason="No CUDA availabled.") def test_ortvalues_to_torch_tensor_ortvaluevector_cuda_old(self): - device = torch.device('cuda:0') + device = torch.device("cuda:0") self._ortvalues_to_torch_tensor_ortvaluevector(device, torch.Tensor, False) def _ortvalues_to_torch_tensor_list(self, device, tensor_type, new_impl): narrays = [ np.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], dtype=np.float32), - np.array([[6.0, 7.0], [8.0, 9.0], [1.0, 6.0]], dtype=np.float32)] + np.array([[6.0, 7.0], [8.0, 9.0], [1.0, 6.0]], dtype=np.float32), + ] vect = C.OrtValueVector() vect.reserve(len(narrays)) ptr = [] for a in narrays: - ortvalue = onnxrt.OrtValue.ortvalue_from_numpy( - a, device.type if device.type != 'ort' else 'cpu') + ortvalue = onnxrt.OrtValue.ortvalue_from_numpy(a, device.type if device.type != "ort" else "cpu") vect.push_back(ortvalue._ortvalue) ptr.append(ortvalue.data_ptr()) self.assertEqual(len(vect), 2) @@ -333,33 +340,33 @@ def _ortvalues_to_torch_tensor_list(self, device, tensor_type, new_impl): assert all(map(lambda v: isinstance(v, tensor_type), tensors)) def test_ortvalues_to_torch_tensor_list_cpu_new(self): - device = torch.device('cpu') + device = torch.device("cpu") self._ortvalues_to_torch_tensor_list(device, torch.Tensor, True) def test_ortvalues_to_torch_tensor_list_cpu_old(self): - device = torch.device('cpu') + device = torch.device("cpu") self._ortvalues_to_torch_tensor_list(device, torch.Tensor, False) def test_ortvalues_to_torch_tensor_list_ort_new(self): - device = torch.device('ort') - if hasattr(C, 'ort_from_dlpack'): + device = torch.device("ort") + if hasattr(C, "ort_from_dlpack"): self._ortvalues_to_torch_tensor_list(device, torch.Tensor, True) else: with self.assertRaises(AttributeError): self._ortvalues_to_torch_tensor_list(device, torch.Tensor, True) def test_ortvalues_to_torch_tensor_list_ort_old(self): - device = torch.device('ort') + device = torch.device("ort") self._ortvalues_to_torch_tensor_list(device, torch.Tensor, False) @unittest.skipIf(not has_cuda, reason="No CUDA availabled.") def test_ortvalues_to_torch_tensor_list_cuda_new(self): - device = torch.device('cuda:0') + device = torch.device("cuda:0") self._ortvalues_to_torch_tensor_list(device, torch.Tensor, True) @unittest.skipIf(not has_cuda, reason="No CUDA availabled.") def test_ortvalues_to_torch_tensor_list_cuda_old(self): - device = torch.device('cuda:0') + device = torch.device("cuda:0") self._ortvalues_to_torch_tensor_list(device, torch.Tensor, False) def test_element_type(self): diff --git a/orttraining/orttraining/test/python/orttraining_test_sampler.py b/orttraining/orttraining/test/python/orttraining_test_sampler.py index 32346c7b614b6..c47b721b7d100 100644 --- a/orttraining/orttraining/test/python/orttraining_test_sampler.py +++ b/orttraining/orttraining/test/python/orttraining_test_sampler.py @@ -6,6 +6,7 @@ from onnxruntime.training.utils.data import sampler import random + class MyDataset(torch.utils.data.Dataset): def __init__(self, samples): self.samples = samples @@ -18,25 +19,18 @@ def __len__(self): def test_load_balancing_data_sampler_balances_load(): - samples_and_complexities = \ - [(torch.FloatTensor([val]), torch.randint(0, 100, (1,)).item()) for val in range(100)] + samples_and_complexities = [(torch.FloatTensor([val]), torch.randint(0, 100, (1,)).item()) for val in range(100)] dataset = MyDataset(samples_and_complexities) def complexity_fn(sample): return sample[1] data_sampler0 = sampler.LoadBalancingDistributedSampler( - dataset, - complexity_fn=complexity_fn, - world_size=2, - rank=0, - shuffle=False) + dataset, complexity_fn=complexity_fn, world_size=2, rank=0, shuffle=False + ) data_sampler1 = sampler.LoadBalancingDistributedSampler( - dataset, - complexity_fn=complexity_fn, - world_size=2, - rank=1, - shuffle=False) + dataset, complexity_fn=complexity_fn, world_size=2, rank=1, shuffle=False + ) largest_complexity = -1 for index in data_sampler0: @@ -48,6 +42,7 @@ def complexity_fn(sample): assert samples_and_complexities[index][1] >= largest_complexity largest_complexity = samples_and_complexities[index][1] + def test_load_balancing_data_sampler_shuffles_and_balances_load(): complexities = [] for i in range(50): @@ -56,8 +51,7 @@ def test_load_balancing_data_sampler_shuffles_and_balances_load(): complexities.append(c) random.shuffle(complexities) - samples = \ - [torch.FloatTensor([val]) for val in range(100)] + samples = [torch.FloatTensor([val]) for val in range(100)] samples_and_complexities = list(zip(samples, complexities)) dataset = MyDataset(samples_and_complexities) @@ -65,25 +59,18 @@ def complexity_fn(sample): return sample[1] data_sampler0 = sampler.LoadBalancingDistributedSampler( - dataset, - complexity_fn=complexity_fn, - world_size=2, - rank=0, - shuffle=True) + dataset, complexity_fn=complexity_fn, world_size=2, rank=0, shuffle=True + ) data_sampler1 = sampler.LoadBalancingDistributedSampler( - dataset, - complexity_fn=complexity_fn, - world_size=2, - rank=1, - shuffle=True) + dataset, complexity_fn=complexity_fn, world_size=2, rank=1, shuffle=True + ) for index0, index1 in zip(data_sampler0, data_sampler1): - assert samples_and_complexities[index0][1] == \ - samples_and_complexities[index1][1] + assert samples_and_complexities[index0][1] == samples_and_complexities[index1][1] + def test_load_balancing_data_sampler_sorts_in_groups(): - samples_and_complexities = \ - [(torch.FloatTensor([val]), torch.randint(0, 100, (1,)).item()) for val in range(100)] + samples_and_complexities = [(torch.FloatTensor([val]), torch.randint(0, 100, (1,)).item()) for val in range(100)] dataset = MyDataset(samples_and_complexities) def complexity_fn(sample): @@ -92,23 +79,21 @@ def complexity_fn(sample): group_size = 8 samples_and_complexities_sorted = samples_and_complexities.copy() for begin_index in range(0, len(samples_and_complexities), group_size): - end_index = min(begin_index+group_size, len(samples_and_complexities)) - samples_and_complexities_sorted[begin_index:end_index] = sorted(samples_and_complexities_sorted[begin_index:end_index], key=lambda x: x[1]) + end_index = min(begin_index + group_size, len(samples_and_complexities)) + samples_and_complexities_sorted[begin_index:end_index] = sorted( + samples_and_complexities_sorted[begin_index:end_index], key=lambda x: x[1] + ) data_sampler = sampler.LoadBalancingDistributedSampler( - dataset, - complexity_fn=complexity_fn, - world_size=1, - rank=0, - shuffle=False, - group_size=8) + dataset, complexity_fn=complexity_fn, world_size=1, rank=0, shuffle=False, group_size=8 + ) for index, sorted_sample in zip(data_sampler, samples_and_complexities_sorted): assert samples_and_complexities[index][1] == sorted_sample[1] + def test_load_balancing_data_sampler_sorts_and_shuffles_in_groups(): - samples_and_complexities = \ - [(torch.FloatTensor([val]), torch.randint(0, 100, (1,)).item()) for val in range(100)] + samples_and_complexities = [(torch.FloatTensor([val]), torch.randint(0, 100, (1,)).item()) for val in range(100)] dataset = MyDataset(samples_and_complexities) def complexity_fn(sample): @@ -117,63 +102,55 @@ def complexity_fn(sample): group_size = 8 samples_and_complexities_sorted = samples_and_complexities.copy() for begin_index in range(0, len(samples_and_complexities), group_size): - end_index = min(begin_index+group_size, len(samples_and_complexities)) - samples_and_complexities_sorted[begin_index:end_index] = \ - sorted(samples_and_complexities_sorted[begin_index:end_index], key=lambda x: x[1]) + end_index = min(begin_index + group_size, len(samples_and_complexities)) + samples_and_complexities_sorted[begin_index:end_index] = sorted( + samples_and_complexities_sorted[begin_index:end_index], key=lambda x: x[1] + ) samples_and_complexities_sorted_and_shuffled = samples_and_complexities_sorted.copy() - shuffled_group_order = torch.randperm(( - len(samples_and_complexities)+group_size-1) // group_size, - generator=torch.Generator().manual_seed(0)).tolist() + shuffled_group_order = torch.randperm( + (len(samples_and_complexities) + group_size - 1) // group_size, generator=torch.Generator().manual_seed(0) + ).tolist() end = 0 for group_index in shuffled_group_order: - original_begin = group_index*group_size - original_end = min(original_begin+group_size, len(samples_and_complexities)) + original_begin = group_index * group_size + original_end = min(original_begin + group_size, len(samples_and_complexities)) begin = end - end = begin + (original_end-original_begin) - samples_and_complexities_sorted_and_shuffled[begin:end] = \ - samples_and_complexities_sorted[original_begin:original_end] + end = begin + (original_end - original_begin) + samples_and_complexities_sorted_and_shuffled[begin:end] = samples_and_complexities_sorted[ + original_begin:original_end + ] data_sampler = sampler.LoadBalancingDistributedSampler( - dataset, - complexity_fn=complexity_fn, - world_size=1, - rank=0, - shuffle=True, - group_size=8) + dataset, complexity_fn=complexity_fn, world_size=1, rank=0, shuffle=True, group_size=8 + ) for index, sorted_and_shuffled_sample in zip(data_sampler, samples_and_complexities_sorted_and_shuffled): assert samples_and_complexities[index][1] == sorted_and_shuffled_sample[1] + def test_load_balancing_batch_sampler_uses_data_sampler(): - samples_and_complexities = \ - [(torch.FloatTensor([val]), torch.randint(0, 100, (1,)).item()) for val in range(100)] + samples_and_complexities = [(torch.FloatTensor([val]), torch.randint(0, 100, (1,)).item()) for val in range(100)] dataset = MyDataset(samples_and_complexities) def complexity_fn(sample): return sample[1] data_sampler = sampler.LoadBalancingDistributedSampler( - dataset, - complexity_fn=complexity_fn, - world_size=1, - rank=0, - shuffle=False) + dataset, complexity_fn=complexity_fn, world_size=1, rank=0, shuffle=False + ) batch_size = 12 + def batch_fn(indices): nonlocal batch_size batches = [] for batch_index_begin in range(0, len(indices), batch_size): - batch_index_end = min(batch_index_begin+batch_size, len(indices)) + batch_index_end = min(batch_index_begin + batch_size, len(indices)) batches.append(indices[batch_index_begin:batch_index_end]) return batches - batch_sampler = sampler.LoadBalancingDistributedBatchSampler( - data_sampler, - batch_fn - ) + batch_sampler = sampler.LoadBalancingDistributedBatchSampler(data_sampler, batch_fn) for batch in batch_sampler: - assert len(batch) == batch_size or \ - len(batch) == len(samples_and_complexities) % batch_size + assert len(batch) == batch_size or len(batch) == len(samples_and_complexities) % batch_size diff --git a/orttraining/orttraining/test/python/orttraining_test_transformers.py b/orttraining/orttraining/test/python/orttraining_test_transformers.py index 856ee59132ae0..e43d3599d68ca 100644 --- a/orttraining/orttraining/test/python/orttraining_test_transformers.py +++ b/orttraining/orttraining/test/python/orttraining_test_transformers.py @@ -9,7 +9,7 @@ import random import numpy as np from numpy.testing import assert_allclose -from transformers import (BertConfig, BertForPreTraining, BertModel) +from transformers import BertConfig, BertForPreTraining, BertModel from orttraining_test_data_loader import ids_tensor, BatchArgsOption from orttraining_test_utils import run_test, get_lr @@ -19,35 +19,35 @@ import torch -class BertModelTest(unittest.TestCase): +class BertModelTest(unittest.TestCase): class BertModelTester(object): - - def __init__(self, - parent, - batch_size=13, - seq_length=7, - is_training=True, - use_input_mask=True, - use_token_type_ids=True, - use_labels=True, - vocab_size=99, - hidden_size=32, - num_hidden_layers=5, - num_attention_heads=4, - intermediate_size=37, - hidden_act="gelu", - hidden_dropout_prob=0.1, - attention_probs_dropout_prob=0.1, - max_position_embeddings=512, - type_vocab_size=16, - type_sequence_label_size=2, - initializer_range=0.02, - num_labels=3, - num_choices=4, - scope=None, - device='cpu', - ): + def __init__( + self, + parent, + batch_size=13, + seq_length=7, + is_training=True, + use_input_mask=True, + use_token_type_ids=True, + use_labels=True, + vocab_size=99, + hidden_size=32, + num_hidden_layers=5, + num_attention_heads=4, + intermediate_size=37, + hidden_act="gelu", + hidden_dropout_prob=0.1, + attention_probs_dropout_prob=0.1, + max_position_embeddings=512, + type_vocab_size=16, + type_sequence_label_size=2, + initializer_range=0.02, + num_labels=3, + num_choices=4, + scope=None, + device="cpu", + ): self.parent = parent self.batch_size = batch_size self.seq_length = seq_length @@ -74,38 +74,99 @@ def __init__(self, # 1. superset of bert input/output descs # see BertPreTrainedModel doc - self.input_ids_desc = IODescription('input_ids', ['batch', 'max_seq_len_in_batch'], torch.int64, num_classes=self.vocab_size) - self.attention_mask_desc = IODescription('attention_mask', ['batch', 'max_seq_len_in_batch'], torch.int64, num_classes=2) - self.token_type_ids_desc = IODescription('token_type_ids', ['batch', 'max_seq_len_in_batch'], torch.int64, num_classes=2) - self.position_ids_desc = IODescription('position_ids', ['batch', 'max_seq_len_in_batch'], torch.int64, num_classes=self.max_position_embeddings) - self.head_mask_desc = IODescription('head_mask', [self.num_hidden_layers, self.num_attention_heads], torch.int64, num_classes=2) - self.inputs_embeds_desc = IODescription('inputs_embeds', ['batch', 'max_seq_len_in_batch', self.hidden_size], torch.float32) - - self.encoder_hidden_states_desc = IODescription('encoder_hidden_states', ['batch', 'max_seq_len_in_batch', self.hidden_size], torch.float32) - self.encoder_attention_mask_desc = IODescription('encoder_attention_mask', ['batch', 'max_seq_len_in_batch'], torch.float32) + self.input_ids_desc = IODescription( + "input_ids", ["batch", "max_seq_len_in_batch"], torch.int64, num_classes=self.vocab_size + ) + self.attention_mask_desc = IODescription( + "attention_mask", ["batch", "max_seq_len_in_batch"], torch.int64, num_classes=2 + ) + self.token_type_ids_desc = IODescription( + "token_type_ids", ["batch", "max_seq_len_in_batch"], torch.int64, num_classes=2 + ) + self.position_ids_desc = IODescription( + "position_ids", ["batch", "max_seq_len_in_batch"], torch.int64, num_classes=self.max_position_embeddings + ) + self.head_mask_desc = IODescription( + "head_mask", [self.num_hidden_layers, self.num_attention_heads], torch.int64, num_classes=2 + ) + self.inputs_embeds_desc = IODescription( + "inputs_embeds", ["batch", "max_seq_len_in_batch", self.hidden_size], torch.float32 + ) + + self.encoder_hidden_states_desc = IODescription( + "encoder_hidden_states", ["batch", "max_seq_len_in_batch", self.hidden_size], torch.float32 + ) + self.encoder_attention_mask_desc = IODescription( + "encoder_attention_mask", ["batch", "max_seq_len_in_batch"], torch.float32 + ) # see BertForPreTraining doc - self.masked_lm_labels_desc = IODescription('masked_lm_labels', ['batch', 'max_seq_len_in_batch'], torch.int64, num_classes=self.vocab_size) - self.next_sentence_label_desc = IODescription('next_sentence_label', ['batch',], torch.int64, num_classes=2) + self.masked_lm_labels_desc = IODescription( + "masked_lm_labels", ["batch", "max_seq_len_in_batch"], torch.int64, num_classes=self.vocab_size + ) + self.next_sentence_label_desc = IODescription( + "next_sentence_label", + [ + "batch", + ], + torch.int64, + num_classes=2, + ) # outputs - self.loss_desc = IODescription('loss', [1,], torch.float32) - self.prediction_scores_desc = IODescription('prediction_scores', ['batch', 'max_seq_len_in_batch', self.vocab_size], torch.float32) - - self.seq_relationship_scores_desc = IODescription('seq_relationship_scores', ['batch', 2], torch.float32) # IODescription('seq_relationship_scores', ['batch', 'max_seq_len_in_batch', 2], torch.float32) - self.hidden_states_desc = IODescription('hidden_states', [self.num_hidden_layers, 'batch', 'max_seq_len_in_batch', self.hidden_size], torch.float32) - self.attentions_desc = IODescription('attentions', [self.num_hidden_layers, 'batch', self.num_attention_heads, 'max_seq_len_in_batch', 'max_seq_len_in_batch'], torch.float32) - self.last_hidden_state_desc = IODescription('last_hidden_state', ['batch', 'max_seq_len_in_batch', self.hidden_size], torch.float32) - self.pooler_output_desc = IODescription('pooler_output', ['batch', self.hidden_size], torch.float32) + self.loss_desc = IODescription( + "loss", + [ + 1, + ], + torch.float32, + ) + self.prediction_scores_desc = IODescription( + "prediction_scores", ["batch", "max_seq_len_in_batch", self.vocab_size], torch.float32 + ) + + self.seq_relationship_scores_desc = IODescription( + "seq_relationship_scores", ["batch", 2], torch.float32 + ) # IODescription('seq_relationship_scores', ['batch', 'max_seq_len_in_batch', 2], torch.float32) + self.hidden_states_desc = IODescription( + "hidden_states", + [self.num_hidden_layers, "batch", "max_seq_len_in_batch", self.hidden_size], + torch.float32, + ) + self.attentions_desc = IODescription( + "attentions", + [ + self.num_hidden_layers, + "batch", + self.num_attention_heads, + "max_seq_len_in_batch", + "max_seq_len_in_batch", + ], + torch.float32, + ) + self.last_hidden_state_desc = IODescription( + "last_hidden_state", ["batch", "max_seq_len_in_batch", self.hidden_size], torch.float32 + ) + self.pooler_output_desc = IODescription("pooler_output", ["batch", self.hidden_size], torch.float32) def BertForPreTraining_descs(self): return ModelDescription( - [self.input_ids_desc, self.attention_mask_desc, self.token_type_ids_desc, self.masked_lm_labels_desc, self.next_sentence_label_desc], + [ + self.input_ids_desc, + self.attention_mask_desc, + self.token_type_ids_desc, + self.masked_lm_labels_desc, + self.next_sentence_label_desc, + ], # returns loss_desc if both masked_lm_labels_desc, next_sentence_label are provided # hidden_states_desc, attentions_desc shall be included according to config.output_attentions, config.output_hidden_states - [self.loss_desc, self.prediction_scores_desc, self.seq_relationship_scores_desc, - #hidden_states_desc, attentions_desc - ]) + [ + self.loss_desc, + self.prediction_scores_desc, + self.seq_relationship_scores_desc, + # hidden_states_desc, attentions_desc + ], + ) def prepare_config_and_inputs(self): input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size).to(self.device) @@ -139,17 +200,27 @@ def prepare_config_and_inputs(self): max_position_embeddings=self.max_position_embeddings, type_vocab_size=self.type_vocab_size, is_decoder=False, - initializer_range=self.initializer_range) + initializer_range=self.initializer_range, + ) return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels - def create_and_check_bert_for_pretraining(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels, + def create_and_check_bert_for_pretraining( + self, + config, + input_ids, + token_type_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, option_fp16, option_allreduce_post_accumulation, option_gradient_accumulation_steps, option_split_batch, option_use_internal_get_lr_this_step=[True], - option_use_internal_loss_scaler=[True]): + option_use_internal_loss_scaler=[True], + ): seed = 42 random.seed(seed) np.random.seed(seed) @@ -159,24 +230,47 @@ def create_and_check_bert_for_pretraining(self, config, input_ids, token_type_id model = BertForPreTraining(config=config) model.eval() - loss, prediction_scores, seq_relationship_score = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, - masked_lm_labels=token_labels, next_sentence_label=sequence_labels) - model_desc = ModelDescription([self.input_ids_desc, self.attention_mask_desc, self.token_type_ids_desc, - self.masked_lm_labels_desc, self.next_sentence_label_desc], - [self.loss_desc, self.prediction_scores_desc, self.seq_relationship_scores_desc]) + loss, prediction_scores, seq_relationship_score = model( + input_ids, + attention_mask=input_mask, + token_type_ids=token_type_ids, + masked_lm_labels=token_labels, + next_sentence_label=sequence_labels, + ) + model_desc = ModelDescription( + [ + self.input_ids_desc, + self.attention_mask_desc, + self.token_type_ids_desc, + self.masked_lm_labels_desc, + self.next_sentence_label_desc, + ], + [self.loss_desc, self.prediction_scores_desc, self.seq_relationship_scores_desc], + ) from collections import namedtuple - MyArgs = namedtuple("MyArgs", - "local_rank world_size max_steps learning_rate warmup_proportion batch_size seq_len") + + MyArgs = namedtuple( + "MyArgs", "local_rank world_size max_steps learning_rate warmup_proportion batch_size seq_len" + ) dataset_len = 100 epochs = 8 - max_steps = epochs * dataset_len - args = MyArgs(local_rank=0, world_size=1, max_steps=max_steps, learning_rate=0.00001, warmup_proportion=0.01, batch_size=13, seq_len=7) + max_steps = epochs * dataset_len + args = MyArgs( + local_rank=0, + world_size=1, + max_steps=max_steps, + learning_rate=0.00001, + warmup_proportion=0.01, + batch_size=13, + seq_len=7, + ) def get_lr_this_step(global_step): return get_lr(args, global_step) - loss_scaler = LossScaler('loss_scale_input_name', True, up_scale_window=2000) + + loss_scaler = LossScaler("loss_scale_input_name", True, up_scale_window=2000) for fp16 in option_fp16: for allreduce_post_accumulation in option_allreduce_post_accumulation: @@ -194,16 +288,27 @@ def get_lr_this_step(global_step): torch.cuda.manual_seed_all(seed) onnxruntime.set_seed(seed) - old_api_loss_ort, old_api_prediction_scores_ort, old_api_seq_relationship_score_ort =\ - run_test( - model, model_desc, self.device, args, gradient_accumulation_steps, fp16, - allreduce_post_accumulation, - get_lr_this_step, use_internal_get_lr_this_step, - loss_scaler, use_internal_loss_scaler, - split_batch, - dataset_len, - epochs, - use_new_api=False) + ( + old_api_loss_ort, + old_api_prediction_scores_ort, + old_api_seq_relationship_score_ort, + ) = run_test( + model, + model_desc, + self.device, + args, + gradient_accumulation_steps, + fp16, + allreduce_post_accumulation, + get_lr_this_step, + use_internal_get_lr_this_step, + loss_scaler, + use_internal_loss_scaler, + split_batch, + dataset_len, + epochs, + use_new_api=False, + ) random.seed(seed) np.random.seed(seed) @@ -211,21 +316,33 @@ def get_lr_this_step(global_step): torch.cuda.manual_seed_all(seed) onnxruntime.set_seed(seed) if use_internal_get_lr_this_step and use_internal_loss_scaler: - new_api_loss_ort, new_api_prediction_scores_ort, new_api_seq_relationship_score_ort =\ - run_test( - model, model_desc, self.device, args, gradient_accumulation_steps, fp16, - allreduce_post_accumulation, - get_lr_this_step, use_internal_get_lr_this_step, - loss_scaler, use_internal_loss_scaler, - split_batch, - dataset_len, - epochs, - use_new_api=True) - + ( + new_api_loss_ort, + new_api_prediction_scores_ort, + new_api_seq_relationship_score_ort, + ) = run_test( + model, + model_desc, + self.device, + args, + gradient_accumulation_steps, + fp16, + allreduce_post_accumulation, + get_lr_this_step, + use_internal_get_lr_this_step, + loss_scaler, + use_internal_loss_scaler, + split_batch, + dataset_len, + epochs, + use_new_api=True, + ) + assert_allclose(old_api_loss_ort, new_api_loss_ort) assert_allclose(old_api_prediction_scores_ort, new_api_prediction_scores_ort) - assert_allclose(old_api_seq_relationship_score_ort, new_api_seq_relationship_score_ort) - + assert_allclose( + old_api_seq_relationship_score_ort, new_api_seq_relationship_score_ort + ) def setUp(self): self.model_tester = BertModelTest.BertModelTester(self) @@ -244,7 +361,8 @@ def test_for_pretraining_mixed_precision(self): option_fp16, option_allreduce_post_accumulation, option_gradient_accumulation_steps, - option_split_batch) + option_split_batch + ) def test_for_pretraining_mixed_precision_with_gradient_accumulation(self): # It would be better to test both with/without mixed precision and allreduce_post_accumulation. @@ -260,7 +378,8 @@ def test_for_pretraining_mixed_precision_with_gradient_accumulation(self): option_fp16, option_allreduce_post_accumulation, option_gradient_accumulation_steps, - option_split_batch) + option_split_batch + ) def test_for_pretraining_full_precision_all(self): # This test is not stable because it create and run ORTSession multiple times. @@ -277,7 +396,8 @@ def test_for_pretraining_full_precision_all(self): option_fp16, option_allreduce_post_accumulation, option_gradient_accumulation_steps, - option_split_batch) + option_split_batch + ) def test_for_pretraining_full_precision_list_input(self): option_fp16 = [False] @@ -290,7 +410,8 @@ def test_for_pretraining_full_precision_list_input(self): option_fp16, option_allreduce_post_accumulation, option_gradient_accumulation_steps, - option_split_batch) + option_split_batch + ) def test_for_pretraining_full_precision_dict_input(self): option_fp16 = [False] @@ -303,7 +424,8 @@ def test_for_pretraining_full_precision_dict_input(self): option_fp16, option_allreduce_post_accumulation, option_gradient_accumulation_steps, - option_split_batch) + option_split_batch + ) def test_for_pretraining_full_precision_list_and_dict_input(self): option_fp16 = [False] @@ -316,7 +438,8 @@ def test_for_pretraining_full_precision_list_and_dict_input(self): option_fp16, option_allreduce_post_accumulation, option_gradient_accumulation_steps, - option_split_batch) + option_split_batch + ) def test_for_pretraining_full_precision_grad_accumulation_list_input(self): option_fp16 = [False] @@ -329,7 +452,8 @@ def test_for_pretraining_full_precision_grad_accumulation_list_input(self): option_fp16, option_allreduce_post_accumulation, option_gradient_accumulation_steps, - option_split_batch) + option_split_batch + ) def test_for_pretraining_full_precision_grad_accumulation_dict_input(self): option_fp16 = [False] @@ -342,7 +466,8 @@ def test_for_pretraining_full_precision_grad_accumulation_dict_input(self): option_fp16, option_allreduce_post_accumulation, option_gradient_accumulation_steps, - option_split_batch) + option_split_batch + ) def test_for_pretraining_full_precision_grad_accumulation_list_and_dict_input(self): option_fp16 = [False] @@ -355,7 +480,9 @@ def test_for_pretraining_full_precision_grad_accumulation_list_and_dict_input(se option_fp16, option_allreduce_post_accumulation, option_gradient_accumulation_steps, - option_split_batch) + option_split_batch + ) + if __name__ == "__main__": unittest.main() diff --git a/orttraining/orttraining/test/python/orttraining_test_utils.py b/orttraining/orttraining/test/python/orttraining_test_utils.py index 763e16cc47dfb..0ec501b0d72d4 100644 --- a/orttraining/orttraining/test/python/orttraining_test_utils.py +++ b/orttraining/orttraining/test/python/orttraining_test_utils.py @@ -5,46 +5,58 @@ from orttraining_test_data_loader import create_ort_test_dataloader, BatchArgsOption, split_batch from orttraining_test_bert_postprocess import postprocess_model -from onnxruntime.training import _utils, amp, optim, orttrainer, TrainStepInfo,\ - model_desc_validation as md_val,\ - orttrainer_options as orttrainer_options +from onnxruntime.training import ( + _utils, + amp, + optim, + orttrainer, + TrainStepInfo, + model_desc_validation as md_val, + orttrainer_options as orttrainer_options, +) from onnxruntime.training.optim import _LRScheduler + def warmup_cosine(x, warmup=0.002): if x < warmup: - return x/warmup + return x / warmup return 0.5 * (1.0 + torch.cos(math.pi * x)) + def warmup_constant(x, warmup=0.002): if x < warmup: - return x/warmup + return x / warmup return 1.0 + def warmup_linear(x, warmup=0.002): if x < warmup: - return x/warmup - return max((x - 1. )/ (warmup - 1.), 0.) + return x / warmup + return max((x - 1.0) / (warmup - 1.0), 0.0) + def warmup_poly(x, warmup=0.002, degree=0.5): if x < warmup: - return x/warmup - return (1.0 - x)**degree + return x / warmup + return (1.0 - x) ** degree SCHEDULES = { - 'warmup_cosine':warmup_cosine, - 'warmup_constant':warmup_constant, - 'warmup_linear':warmup_linear, - 'warmup_poly':warmup_poly, + "warmup_cosine": warmup_cosine, + "warmup_constant": warmup_constant, + "warmup_linear": warmup_linear, + "warmup_poly": warmup_poly, } -def get_lr(args, training_steps, schedule='warmup_poly'): + +def get_lr(args, training_steps, schedule="warmup_poly"): if args.max_steps == -1: return args.learning_rate schedule_fct = SCHEDULES[schedule] return args.learning_rate * schedule_fct(training_steps / args.max_steps, args.warmup_proportion) + def map_optimizer_attributes(name): no_decay_keys = ["bias", "gamma", "beta", "LayerNorm"] no_decay = any(no_decay_key in name for no_decay_key in no_decay_keys) @@ -53,6 +65,7 @@ def map_optimizer_attributes(name): else: return {"alpha": 0.9, "beta": 0.999, "lambda": 0.0, "epsilon": 1e-6} + class WrapLRScheduler(_LRScheduler): def __init__(self, get_lr_this_step): super().__init__() @@ -62,70 +75,133 @@ def get_lr(self, train_step_info): return [self.get_lr_this_step(train_step_info.optimization_step)] -def run_test(model, model_desc, device, args, gradient_accumulation_steps, fp16, - allreduce_post_accumulation, get_lr_this_step, use_internal_get_lr_this_step, loss_scaler, use_internal_loss_scaler, - batch_args_option, dataset_len, epochs, use_new_api): +def run_test( + model, + model_desc, + device, + args, + gradient_accumulation_steps, + fp16, + allreduce_post_accumulation, + get_lr_this_step, + use_internal_get_lr_this_step, + loss_scaler, + use_internal_loss_scaler, + batch_args_option, + dataset_len, + epochs, + use_new_api, +): dataloader = create_ort_test_dataloader(model_desc.inputs_, args.batch_size, args.seq_len, dataset_len, device) if use_new_api: - assert use_internal_loss_scaler, 'new api should always use internal loss scaler' + assert use_internal_loss_scaler, "new api should always use internal loss scaler" new_api_lr_scheduler = WrapLRScheduler(get_lr_this_step) new_api_loss_scaler = amp.DynamicLossScaler() if fp16 else None - options = orttrainer.ORTTrainerOptions({'batch' : { - 'gradient_accumulation_steps' : gradient_accumulation_steps}, - 'device': {'id': device}, - 'mixed_precision': { - 'enabled': fp16, - 'loss_scaler': new_api_loss_scaler}, - 'debug': {'deterministic_compute': True, }, - 'utils': { - 'grad_norm_clip': True}, - 'distributed': {'allreduce_post_accumulation': True}, - 'lr_scheduler': new_api_lr_scheduler - }) + options = orttrainer.ORTTrainerOptions( + { + "batch": {"gradient_accumulation_steps": gradient_accumulation_steps}, + "device": {"id": device}, + "mixed_precision": {"enabled": fp16, "loss_scaler": new_api_loss_scaler}, + "debug": { + "deterministic_compute": True, + }, + "utils": {"grad_norm_clip": True}, + "distributed": {"allreduce_post_accumulation": True}, + "lr_scheduler": new_api_lr_scheduler, + } + ) param_optimizer = list(model.named_parameters()) - params = [{ - 'params': [n for n, p in param_optimizer if "bias" in n or "LayerNorm.weight" in n], - "alpha": 0.9, "beta": 0.999, "lambda": 0.0, "epsilon": 1e-6}, { - 'params': [n for n, p in param_optimizer if not ("bias" in n or "LayerNorm.weight" in n)], - "alpha": 0.9, "beta": 0.999, "lambda": 0.0, "epsilon": 1e-6} - ] + params = [ + { + "params": [n for n, p in param_optimizer if "bias" in n or "LayerNorm.weight" in n], + "alpha": 0.9, + "beta": 0.999, + "lambda": 0.0, + "epsilon": 1e-6, + }, + { + "params": [n for n, p in param_optimizer if not ("bias" in n or "LayerNorm.weight" in n)], + "alpha": 0.9, + "beta": 0.999, + "lambda": 0.0, + "epsilon": 1e-6, + }, + ] vocab_size = 99 new_model_desc = { - 'inputs': [ - ('input_ids', ['batch', 'max_seq_len_in_batch'],), - ('attention_mask', ['batch', 'max_seq_len_in_batch'],), - ('token_type_ids', ['batch', 'max_seq_len_in_batch'],), - ('masked_lm_labels', ['batch', 'max_seq_len_in_batch'],), - ('next_sentence_label', ['batch',])], - 'outputs': [ - ('loss', [1,], True), - ('prediction_scores', ['batch', 'max_seq_len_in_batch', vocab_size]), - ('seq_relationship_scores', ['batch', 2])]} + "inputs": [ + ( + "input_ids", + ["batch", "max_seq_len_in_batch"], + ), + ( + "attention_mask", + ["batch", "max_seq_len_in_batch"], + ), + ( + "token_type_ids", + ["batch", "max_seq_len_in_batch"], + ), + ( + "masked_lm_labels", + ["batch", "max_seq_len_in_batch"], + ), + ( + "next_sentence_label", + [ + "batch", + ], + ), + ], + "outputs": [ + ( + "loss", + [ + 1, + ], + True, + ), + ("prediction_scores", ["batch", "max_seq_len_in_batch", vocab_size]), + ("seq_relationship_scores", ["batch", 2]), + ], + } optim_config = optim.LambConfig(params=params, lr=2e-5) model = orttrainer.ORTTrainer(model, new_model_desc, optim_config, options=options) - print ("running with new frontend API") + print("running with new frontend API") else: - model = ORTTrainer(model, None, model_desc, "LambOptimizer", + model = ORTTrainer( + model, + None, + model_desc, + "LambOptimizer", map_optimizer_attributes=map_optimizer_attributes, - learning_rate_description=IODescription('Learning_Rate', [1,], torch.float32), + learning_rate_description=IODescription( + "Learning_Rate", + [ + 1, + ], + torch.float32, + ), device=device, _enable_internal_postprocess=True, gradient_accumulation_steps=gradient_accumulation_steps, # BertLAMB default initial settings: b1=0.9, b2=0.999, e=1e-6 - world_rank=args.local_rank, world_size=args.world_size, + world_rank=args.local_rank, + world_size=args.world_size, use_mixed_precision=fp16, allreduce_post_accumulation=allreduce_post_accumulation, get_lr_this_step=get_lr_this_step if use_internal_get_lr_this_step else None, loss_scaler=loss_scaler if use_internal_loss_scaler else None, _opset_version=14, - _use_deterministic_compute=True) - print ("running with old frontend API") + _use_deterministic_compute=True, + ) + print("running with old frontend API") # trainig loop eval_batch = None @@ -145,22 +221,26 @@ def run_test(model, model_desc, device, args, gradient_accumulation_steps, fp16, if batch_args_option == BatchArgsOption.List: if not use_internal_get_lr_this_step: - batch = batch + [learning_rate, ] + batch = batch + [ + learning_rate, + ] if not use_internal_loss_scaler and fp16: - batch = batch + [loss_scale, ] + batch = batch + [ + loss_scale, + ] outputs = model.train_step(*batch) elif batch_args_option == BatchArgsOption.Dict: args, kwargs = split_batch(batch, model_desc.inputs_, 0) if not use_internal_get_lr_this_step: - kwargs['Learning_Rate'] = learning_rate + kwargs["Learning_Rate"] = learning_rate if not use_internal_loss_scaler and fp16: kwargs[model.loss_scale_input_name] = loss_scale outputs = model.train_step(*args, **kwargs) else: - args_count = int(len(model_desc.inputs_) / 2) # approx helf args, half kwargs + args_count = int(len(model_desc.inputs_) / 2) # approx helf args, half kwargs args, kwargs = split_batch(batch, model_desc.inputs_, args_count) if not use_internal_get_lr_this_step: - kwargs['Learning_Rate'] = learning_rate + kwargs["Learning_Rate"] = learning_rate if not use_internal_loss_scaler and fp16: kwargs[model.loss_scale_input_name] = loss_scale outputs = model.train_step(*args, **kwargs) @@ -172,7 +252,7 @@ def run_test(model, model_desc, device, args, gradient_accumulation_steps, fp16, args, kwargs = split_batch(batch, model_desc.inputs_, 0) outputs = model.eval_step(*args, **kwargs) else: - args_count = int(len(model_desc.inputs_) / 2) # approx helf args, half kwargs + args_count = int(len(model_desc.inputs_) / 2) # approx helf args, half kwargs args, kwargs = split_batch(batch, model_desc.inputs_, args_count) outputs = model.eval_step(*args, **kwargs) diff --git a/orttraining/orttraining/test/python/orttraining_transformer_trainer.py b/orttraining/orttraining/test/python/orttraining_transformer_trainer.py index ada4e573d5a96..0185670dac79f 100644 --- a/orttraining/orttraining/test/python/orttraining_transformer_trainer.py +++ b/orttraining/orttraining/test/python/orttraining_transformer_trainer.py @@ -24,9 +24,15 @@ from orttraining_test_bert_postprocess import postprocess_model from onnxruntime.capi.ort_trainer import ORTTrainer, LossScaler, ModelDescription, IODescription -from onnxruntime.training import _utils, amp, optim, orttrainer, TrainStepInfo,\ - model_desc_validation as md_val,\ - orttrainer_options as orttrainer_options +from onnxruntime.training import ( + _utils, + amp, + optim, + orttrainer, + TrainStepInfo, + model_desc_validation as md_val, + orttrainer_options as orttrainer_options, +) from onnxruntime.training.optim import LinearWarmupLRScheduler, _LRScheduler try: @@ -56,6 +62,7 @@ def set_seed(seed: int): torch.cuda.manual_seed_all(seed) onnxruntime.set_seed(seed) + class EvalPrediction(NamedTuple): predictions: np.ndarray label_ids: np.ndarray @@ -71,14 +78,12 @@ class TrainOutput(NamedTuple): global_step: int training_loss: float -def get_linear_schedule_with_warmup(num_warmup_steps, num_training_steps, base_lr): +def get_linear_schedule_with_warmup(num_warmup_steps, num_training_steps, base_lr): def lr_lambda_linear(current_step): if current_step < num_warmup_steps: return float(current_step) / float(max(1, num_warmup_steps)) - return max( - 0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps)) - ) + return max(0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps))) def lambda_lr_get_lr(current_global_step): # LambdaLR increment self.last_epoch at evert sept() @@ -88,8 +93,7 @@ def lambda_lr_get_lr(current_global_step): class ORTTransformerTrainer: - """ - """ + """ """ model: PreTrainedModel args: TrainingArguments @@ -105,10 +109,9 @@ def __init__( train_dataset: Dataset, eval_dataset: Dataset, compute_metrics: Callable[[EvalPrediction], Dict], - world_size: Optional[int] = 1 + world_size: Optional[int] = 1, ): - """ - """ + """ """ self.model = model self.model_desc = model_desc @@ -127,7 +130,9 @@ def get_train_dataloader(self) -> DataLoader: if self.train_dataset is None: raise ValueError("Trainer: training requires a train_dataset.") train_sampler = ( - SequentialSampler(self.train_dataset) if self.args.local_rank == -1 else DistributedSampler(self.train_dataset) + SequentialSampler(self.train_dataset) + if self.args.local_rank == -1 + else DistributedSampler(self.train_dataset) ) return DataLoader( self.train_dataset, @@ -153,7 +158,6 @@ def get_test_dataloader(self, test_dataset: Dataset) -> DataLoader: collate_fn=self.data_collator.collate_batch, ) - def train(self): """ Main training entry point. @@ -169,38 +173,44 @@ def train(self): t_total = int(len(train_dataloader) // self.args.gradient_accumulation_steps * self.args.num_train_epochs) num_train_epochs = self.args.num_train_epochs - lr_scheduler = orttrainer.optim.LinearWarmupLRScheduler(t_total, self.args.warmup_steps/float(t_total)) + lr_scheduler = orttrainer.optim.LinearWarmupLRScheduler(t_total, self.args.warmup_steps / float(t_total)) loss_scaler = amp.DynamicLossScaler() if self.args.fp16 else None device = self.args.device.type - device = f'{device}:{self.args.device.index}' if self.args.device.index else f'{device}:0' - options = orttrainer.ORTTrainerOptions({'batch' : { - 'gradient_accumulation_steps' : self.args.gradient_accumulation_steps}, - 'device': {'id': device}, - 'mixed_precision': { - 'enabled': self.args.fp16, - 'loss_scaler': loss_scaler}, - 'debug': {'deterministic_compute': True, }, - 'utils': { - 'grad_norm_clip': False}, - 'distributed': { - # we are running single node multi gpu test. thus world_rank = local_rank - # and world_size = self.args.n_gpu - 'world_rank': max(0, self.args.local_rank), - 'world_size': int(self.world_size), - 'local_rank': max(0, self.args.local_rank), - 'allreduce_post_accumulation': True}, - 'lr_scheduler': lr_scheduler - }) + device = f"{device}:{self.args.device.index}" if self.args.device.index else f"{device}:0" + options = orttrainer.ORTTrainerOptions( + { + "batch": {"gradient_accumulation_steps": self.args.gradient_accumulation_steps}, + "device": {"id": device}, + "mixed_precision": {"enabled": self.args.fp16, "loss_scaler": loss_scaler}, + "debug": { + "deterministic_compute": True, + }, + "utils": {"grad_norm_clip": False}, + "distributed": { + # we are running single node multi gpu test. thus world_rank = local_rank + # and world_size = self.args.n_gpu + "world_rank": max(0, self.args.local_rank), + "world_size": int(self.world_size), + "local_rank": max(0, self.args.local_rank), + "allreduce_post_accumulation": True, + }, + "lr_scheduler": lr_scheduler, + } + ) param_optimizer = list(self.model.named_parameters()) - params = [{ - 'params': [n for n, p in param_optimizer if "bias" in n or "LayerNorm.weight" in n], - "weight_decay_mode": 1, }, { - 'params': [n for n, p in param_optimizer if not ("bias" in n or "LayerNorm.weight" in n)], - "weight_decay_mode": 1, } - ] + params = [ + { + "params": [n for n, p in param_optimizer if "bias" in n or "LayerNorm.weight" in n], + "weight_decay_mode": 1, + }, + { + "params": [n for n, p in param_optimizer if not ("bias" in n or "LayerNorm.weight" in n)], + "weight_decay_mode": 1, + }, + ] optim_config = optim.AdamConfig(params=params, lr=2e-5, do_bias_correction=True) self.model = orttrainer.ORTTrainer(self.model, self.model_desc, optim_config, options=options) @@ -226,7 +236,10 @@ def train(self): tr_loss = 0.0 logging_loss = 0.0 train_iterator = trange( - epochs_trained, int(num_train_epochs), desc="Epoch", disable=self.args.local_rank not in [-1, 0], + epochs_trained, + int(num_train_epochs), + desc="Epoch", + disable=self.args.local_rank not in [-1, 0], ) for epoch in train_iterator: @@ -241,8 +254,7 @@ def train(self): tr_loss += self._training_step(self.model, inputs) if (step + 1) % self.args.gradient_accumulation_steps == 0 or ( - len(epoch_iterator) <= self.args.gradient_accumulation_steps - and (step + 1) == len(epoch_iterator) + len(epoch_iterator) <= self.args.gradient_accumulation_steps and (step + 1) == len(epoch_iterator) ): global_step += 1 @@ -274,8 +286,7 @@ def train(self): logger.info("\n\nTraining completed. \n\n") return TrainOutput(global_step, tr_loss / global_step) - def _training_step( - self, model, inputs: Dict[str, torch.Tensor]) -> float: + def _training_step(self, model, inputs: Dict[str, torch.Tensor]) -> float: for k, v in inputs.items(): inputs[k] = v.to(self.args.device) @@ -313,9 +324,7 @@ def predict(self, test_dataset: Dataset) -> PredictionOutput: test_dataloader = self.get_test_dataloader(test_dataset) return self._prediction_loop(test_dataloader, description="Prediction") - def _prediction_loop( - self, dataloader: DataLoader, description: str - ) -> PredictionOutput: + def _prediction_loop(self, dataloader: DataLoader, description: str) -> PredictionOutput: """ Prediction/evaluation loop, shared by `evaluate()` and `predict()`. diff --git a/orttraining/orttraining/test/python/perf_log/ort_module_perf_test_tools.py b/orttraining/orttraining/test/python/perf_log/ort_module_perf_test_tools.py index 0bd5e70e3cc3c..b7b619a92e53b 100644 --- a/orttraining/orttraining/test/python/perf_log/ort_module_perf_test_tools.py +++ b/orttraining/orttraining/test/python/perf_log/ort_module_perf_test_tools.py @@ -8,12 +8,14 @@ import argparse from datetime import datetime + def get_repo_commit(repo_path): repo = git.Repo(repo_path, search_parent_directories=True) sha = repo.head.object.hexsha short_sha = repo.git.rev_parse(sha, short=4) return short_sha + create_table_script = "CREATE TABLE perf_test_training_ort_module_data (\ id int(11) NOT NULL AUTO_INCREMENT,\ Model varchar(64) COLLATE utf8_bin DEFAULT NULL,\ @@ -103,10 +105,10 @@ def get_repo_commit(repo_path): # Obtain connection string information from the portal def connect_to_perf_dashboard_db(mysql_server_name, power_bi_user_name, password, database): config = { - 'host': mysql_server_name, - 'user': power_bi_user_name, - 'password': password, - 'database': database, + "host": mysql_server_name, + "user": power_bi_user_name, + "password": password, + "database": database, } try: @@ -121,76 +123,87 @@ def connect_to_perf_dashboard_db(mysql_server_name, power_bi_user_name, password else: print(err) -def log_perf_metrics(perf_metrics, - mysql_server_name, power_bi_user_name, power_bi_password, power_bi_database, perf_repo_path=None): + +def log_perf_metrics( + perf_metrics, mysql_server_name, power_bi_user_name, power_bi_password, power_bi_database, perf_repo_path=None +): if perf_repo_path: - perf_metrics['CommitId'] = get_repo_commit(perf_repo_path) + perf_metrics["CommitId"] = get_repo_commit(perf_repo_path) else: - perf_metrics['CommitId'] = get_repo_commit(os.path.realpath(__file__)) + perf_metrics["CommitId"] = get_repo_commit(os.path.realpath(__file__)) connect_and_insert_perf_metrics( - mysql_server_name, - power_bi_user_name, - power_bi_password, - power_bi_database, - perf_metrics) + mysql_server_name, power_bi_user_name, power_bi_password, power_bi_database, perf_metrics + ) + -required_attributes_for_perf_metrics = ['model_name', 'optimizer', 'batch_size', 'epochs', 'train_steps', - 'sequence_length'] +required_attributes_for_perf_metrics = [ + "model_name", + "optimizer", + "batch_size", + "epochs", + "train_steps", + "sequence_length", +] -def calculate_and_log_perf_metrics(args, start_time, - mysql_server_name, power_bi_user_name, power_bi_password, power_bi_database, ort_repo_path=None): + +def calculate_and_log_perf_metrics( + args, start_time, mysql_server_name, power_bi_user_name, power_bi_password, power_bi_database, ort_repo_path=None +): completion_time = datetime.datetime.now() perf_metrics_duration = completion_time - start_time for attribute in required_attributes_for_perf_metrics: if not hasattr(args, attribute): - raise ValueError('args does not contain all attributes needed to calculate perf metrics. \ - Please prepare perf_metrics and call log_perf_metrics instead') - - perf_metrics = {} - perf_metrics['Model'] = args.model_name - perf_metrics['BatchId'] = 'NA' - perf_metrics['ModelName'] = args.model_name - perf_metrics['DisplayName'] = args.model_name - perf_metrics['UseMixedPrecision'] = args.fp16 if hasattr(args, 'fp16') else False - perf_metrics['UseAutoCast'] = args.use_auto_cast if hasattr(args, 'use_auto_cast') else False - perf_metrics['UseDeepSpeed'] = args.use_deep_speed if hasattr(args, 'use_deep_speed') else False - perf_metrics['Optimizer'] = args.optimizer - perf_metrics['BatchSize'] = args.batch_size - perf_metrics['SeqLen'] = args.sequence_length - perf_metrics['PredictionsPerSeq'] = args.prediction_per_seq if hasattr(args, 'prediction_per_seq') else 0 - perf_metrics['NumOfBatches'] = args.epochs * args.train_steps - perf_metrics['WeightUpdateSteps'] = args.epochs * args.train_steps - perf_metrics['Round'] = 0 # NA - perf_metrics['GradAccSteps'] = args.gradient_accumulation_steps - - perf_metrics['AvgTimePerBatch'] = \ - perf_metrics_duration.microseconds / args.train_steps - - perf_metrics['Throughput'] = \ - args.batch_size * args.train_steps / perf_metrics_duration.seconds - - perf_metrics['StabilizedThroughput'] = 0 # TODO - perf_metrics['EndToEndThroughput'] = 0 # TODO - perf_metrics['TotalTime'] = perf_metrics_duration.seconds - - perf_metrics['AvgCPU'] = 0 # TODO - perf_metrics['Memory'] = 0 # TODO - perf_metrics['RunConfig'] = 'na' - perf_metrics['Time'] = completion_time.strftime("%Y-%m-%d %H:%M:%S") - - log_perf_metrics(perf_metrics, mysql_server_name, power_bi_user_name, power_bi_password, power_bi_database, - ort_repo_path) + raise ValueError( + "args does not contain all attributes needed to calculate perf metrics. \ + Please prepare perf_metrics and call log_perf_metrics instead" + ) + + perf_metrics = {} + perf_metrics["Model"] = args.model_name + perf_metrics["BatchId"] = "NA" + perf_metrics["ModelName"] = args.model_name + perf_metrics["DisplayName"] = args.model_name + perf_metrics["UseMixedPrecision"] = args.fp16 if hasattr(args, "fp16") else False + perf_metrics["UseAutoCast"] = args.use_auto_cast if hasattr(args, "use_auto_cast") else False + perf_metrics["UseDeepSpeed"] = args.use_deep_speed if hasattr(args, "use_deep_speed") else False + perf_metrics["Optimizer"] = args.optimizer + perf_metrics["BatchSize"] = args.batch_size + perf_metrics["SeqLen"] = args.sequence_length + perf_metrics["PredictionsPerSeq"] = args.prediction_per_seq if hasattr(args, "prediction_per_seq") else 0 + perf_metrics["NumOfBatches"] = args.epochs * args.train_steps + perf_metrics["WeightUpdateSteps"] = args.epochs * args.train_steps + perf_metrics["Round"] = 0 # NA + perf_metrics["GradAccSteps"] = args.gradient_accumulation_steps + + perf_metrics["AvgTimePerBatch"] = perf_metrics_duration.microseconds / args.train_steps + + perf_metrics["Throughput"] = args.batch_size * args.train_steps / perf_metrics_duration.seconds + + perf_metrics["StabilizedThroughput"] = 0 # TODO + perf_metrics["EndToEndThroughput"] = 0 # TODO + perf_metrics["TotalTime"] = perf_metrics_duration.seconds + + perf_metrics["AvgCPU"] = 0 # TODO + perf_metrics["Memory"] = 0 # TODO + perf_metrics["RunConfig"] = "na" + perf_metrics["Time"] = completion_time.strftime("%Y-%m-%d %H:%M:%S") + + log_perf_metrics( + perf_metrics, mysql_server_name, power_bi_user_name, power_bi_password, power_bi_database, ort_repo_path + ) + def parse_arguments(): parser = argparse.ArgumentParser() - parser.add_argument('--mysql_server_name', help='Perf dashboard mysql server name') - parser.add_argument('--power_bi_user_name', help='Power BI user name') - parser.add_argument('--password', help='password', default=None) - parser.add_argument('--database', help='The dashboard database') + parser.add_argument("--mysql_server_name", help="Perf dashboard mysql server name") + parser.add_argument("--power_bi_user_name", help="Power BI user name") + parser.add_argument("--password", help="password", default=None) + parser.add_argument("--database", help="The dashboard database") return parser.parse_args() + def connect_and_insert_perf_metrics(mysql_server_name, power_bi_user_name, password, database, perf_metrics): conn = connect_to_perf_dashboard_db(mysql_server_name, power_bi_user_name, password, database) # https://dev.mysql.com/doc/connector-python/en/connector-python-api-mysqlcursor-execute.html @@ -200,7 +213,8 @@ def connect_and_insert_perf_metrics(mysql_server_name, power_bi_user_name, passw conn.close() print("perf_metrics logged into power-bi database.") -if __name__ == '__main__': + +if __name__ == "__main__": args = parse_arguments() conn = connect_to_perf_dashboard_db(args.mysql_server_name, args.power_bi_user_name, args.password, args.database) conn.cursor().execute(create_table_script) diff --git a/orttraining/orttraining/test/python/utils_multiple_choice.py b/orttraining/orttraining/test/python/utils_multiple_choice.py index 381761998c610..562ecbf8c496d 100644 --- a/orttraining/orttraining/test/python/utils_multiple_choice.py +++ b/orttraining/orttraining/test/python/utils_multiple_choice.py @@ -61,6 +61,7 @@ class Split(Enum): dev = "dev" test = "test" + class DataProcessor: """Base class for data converters for multiple choice data sets.""" @@ -80,6 +81,7 @@ def get_labels(self): """Gets the list of labels for this data set.""" raise NotImplementedError() + class MultipleChoiceDataset(Dataset): """ This will be superseded by a framework-agnostic approach @@ -102,7 +104,12 @@ def __init__( cached_features_file = os.path.join( data_dir, - "cached_{}_{}_{}_{}".format(mode.value, tokenizer.__class__.__name__, str(max_seq_length), task,), + "cached_{}_{}_{}_{}".format( + mode.value, + tokenizer.__class__.__name__, + str(max_seq_length), + task, + ), ) # Make sure only the first process in distributed training processes the dataset, @@ -142,6 +149,7 @@ def __len__(self): def __getitem__(self, i) -> InputFeatures: return self.features[i] + class SwagProcessor(DataProcessor): """Processor for the SWAG data set.""" @@ -192,6 +200,7 @@ def _create_examples(self, lines: List[List[str]], type: str): return examples + def convert_examples_to_features( examples: List[InputExample], label_list: List[str], diff --git a/orttraining/pytorch_frontend_examples/mnist_training.py b/orttraining/pytorch_frontend_examples/mnist_training.py index e52b7ceb2842b..bea27011b3118 100644 --- a/orttraining/pytorch_frontend_examples/mnist_training.py +++ b/orttraining/pytorch_frontend_examples/mnist_training.py @@ -16,11 +16,13 @@ from onnxruntime.capi.ort_trainer import IODescription, ModelDescription, ORTTrainer from mpi4py import MPI + try: from onnxruntime.capi._pybind_state import set_cuda_device_id except ImportError: pass + class NeuralNet(nn.Module): def __init__(self, input_size, hidden_size, num_classes): super(NeuralNet, self).__init__() @@ -34,9 +36,11 @@ def forward(self, x): out = self.fc2(out) return out + def my_loss(x, target): return F.nll_loss(F.log_softmax(x, dim=1), target) + def train_with_trainer(args, trainer, device, train_loader, epoch): for batch_idx, (data, target) in enumerate(train_loader): data, target = data.to(device), target.to(device) @@ -47,9 +51,16 @@ def train_with_trainer(args, trainer, device, train_loader, epoch): # Since the output corresponds to [loss_desc, probability_desc], the first value is taken as loss. if batch_idx % args.log_interval == 0: - print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( - epoch, batch_idx * len(data), len(train_loader.dataset), - 100. * batch_idx / len(train_loader), loss[0])) + print( + "Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format( + epoch, + batch_idx * len(data), + len(train_loader.dataset), + 100.0 * batch_idx / len(train_loader), + loss[0], + ) + ) + # TODO: comple this once ORT training can do evaluation. def test_with_trainer(args, trainer, device, test_loader): @@ -59,66 +70,90 @@ def test_with_trainer(args, trainer, device, test_loader): for data, target in test_loader: data, target = data.to(device), target.to(device) data = data.reshape(data.shape[0], -1) - output = F.log_softmax(trainer.eval_step(data, fetches=['probability']), dim=1) - test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss - pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability + output = F.log_softmax(trainer.eval_step(data, fetches=["probability"]), dim=1) + test_loss += F.nll_loss(output, target, reduction="sum").item() # sum up batch loss + pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability correct += pred.eq(target.view_as(pred)).sum().item() test_loss /= len(test_loader.dataset) - print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format( - test_loss, correct, len(test_loader.dataset), - 100. * correct / len(test_loader.dataset))) + print( + "\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n".format( + test_loss, correct, len(test_loader.dataset), 100.0 * correct / len(test_loader.dataset) + ) + ) + def mnist_model_description(): - input_desc = IODescription('input1', ['batch', 784], torch.float32) - label_desc = IODescription('label', ['batch', ], torch.int64, num_classes=10) - loss_desc = IODescription('loss', [], torch.float32) - probability_desc = IODescription('probability', ['batch', 10], torch.float32) + input_desc = IODescription("input1", ["batch", 784], torch.float32) + label_desc = IODescription( + "label", + [ + "batch", + ], + torch.int64, + num_classes=10, + ) + loss_desc = IODescription("loss", [], torch.float32) + probability_desc = IODescription("probability", ["batch", 10], torch.float32) return ModelDescription([input_desc, label_desc], [loss_desc, probability_desc]) -def main(): -#Training settings - parser = argparse.ArgumentParser(description='PyTorch MNIST Example') - parser.add_argument('--batch-size', type=int, default=64, metavar='N', - help='input batch size for training (default: 64)') - parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N', - help='input batch size for testing (default: 1000)') - parser.add_argument('--epochs', type=int, default=10, metavar='N', - help='number of epochs to train (default: 10)') - parser.add_argument('--lr', type=float, default=0.01, metavar='LR', - help='learning rate (default: 0.01)') - parser.add_argument('--no-cuda', action='store_true', default=False, - help='disables CUDA training') - parser.add_argument('--seed', type=int, default=1, metavar='S', - help='random seed (default: 1)') - parser.add_argument('--log-interval', type=int, default=10, metavar='N', - help='how many batches to wait before logging training status') +def main(): + # Training settings + parser = argparse.ArgumentParser(description="PyTorch MNIST Example") + parser.add_argument( + "--batch-size", type=int, default=64, metavar="N", help="input batch size for training (default: 64)" + ) + parser.add_argument( + "--test-batch-size", type=int, default=1000, metavar="N", help="input batch size for testing (default: 1000)" + ) + parser.add_argument("--epochs", type=int, default=10, metavar="N", help="number of epochs to train (default: 10)") + parser.add_argument("--lr", type=float, default=0.01, metavar="LR", help="learning rate (default: 0.01)") + parser.add_argument("--no-cuda", action="store_true", default=False, help="disables CUDA training") + parser.add_argument("--seed", type=int, default=1, metavar="S", help="random seed (default: 1)") + parser.add_argument( + "--log-interval", + type=int, + default=10, + metavar="N", + help="how many batches to wait before logging training status", + ) args = parser.parse_args() use_cuda = not args.no_cuda and torch.cuda.is_available() torch.manual_seed(args.seed) - kwargs = {'num_workers': 0, 'pin_memory': True} + kwargs = {"num_workers": 0, "pin_memory": True} train_loader = torch.utils.data.DataLoader( - datasets.MNIST('../data', train=True, download=True, - transform=transforms.Compose([ - transforms.ToTensor(), - transforms.Normalize((0.1307,), (0.3081,)) - ])), - batch_size=args.batch_size, shuffle=True, **kwargs) + datasets.MNIST( + "../data", + train=True, + download=True, + transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]), + ), + batch_size=args.batch_size, + shuffle=True, + **kwargs + ) test_loader = torch.utils.data.DataLoader( - datasets.MNIST('../data', train=False, transform=transforms.Compose([ - transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])), - batch_size=args.test_batch_size, shuffle=True, **kwargs) - + datasets.MNIST( + "../data", + train=False, + transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]), + ), + batch_size=args.test_batch_size, + shuffle=True, + **kwargs + ) comm = MPI.COMM_WORLD - args.local_rank = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK']) if ('OMPI_COMM_WORLD_LOCAL_RANK' in os.environ) else 0 - args.world_rank = int(os.environ['OMPI_COMM_WORLD_RANK']) if ('OMPI_COMM_WORLD_RANK' in os.environ) else 0 - args.world_size=comm.Get_size() + args.local_rank = ( + int(os.environ["OMPI_COMM_WORLD_LOCAL_RANK"]) if ("OMPI_COMM_WORLD_LOCAL_RANK" in os.environ) else 0 + ) + args.world_rank = int(os.environ["OMPI_COMM_WORLD_RANK"]) if ("OMPI_COMM_WORLD_RANK" in os.environ) else 0 + args.world_size = comm.Get_size() if use_cuda: torch.cuda.set_device(args.local_rank) device = torch.device("cuda", args.local_rank) @@ -134,25 +169,34 @@ def main(): model_desc = mnist_model_description() # use log_interval as gradient accumulate steps - trainer = ORTTrainer(model, - my_loss, - model_desc, - "SGDOptimizer", - None, - IODescription('Learning_Rate', [1,], torch.float32), - device, - 1, - args.world_rank, - args.world_size, - use_mixed_precision=False, - allreduce_post_accumulation=True) - print('\nBuild ort model done.') + trainer = ORTTrainer( + model, + my_loss, + model_desc, + "SGDOptimizer", + None, + IODescription( + "Learning_Rate", + [ + 1, + ], + torch.float32, + ), + device, + 1, + args.world_rank, + args.world_size, + use_mixed_precision=False, + allreduce_post_accumulation=True, + ) + print("\nBuild ort model done.") for epoch in range(1, args.epochs + 1): train_with_trainer(args, trainer, device, train_loader, epoch) import pdb + test_with_trainer(args, trainer, device, test_loader) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/orttraining/tools/amdgpu/script/rocprof.py b/orttraining/tools/amdgpu/script/rocprof.py index 90710effe2087..dc91d13606fb0 100644 --- a/orttraining/tools/amdgpu/script/rocprof.py +++ b/orttraining/tools/amdgpu/script/rocprof.py @@ -4,36 +4,39 @@ import csv parser = argparse.ArgumentParser() -parser.add_argument('--input', type=str) +parser.add_argument("--input", type=str) args = parser.parse_args() + def get_gpu_lines(path): lines = [] - with open(path, newline='') as f: - reader = csv.reader(f, delimiter=',') + with open(path, newline="") as f: + reader = csv.reader(f, delimiter=",") for row in reader: - if row[2].find('TotalDurationNs') < 0 : + if row[2].find("TotalDurationNs") < 0: lines.append(row) return lines + activities = [ - ('nccl', lambda x : x.find('nccl') >= 0), - ('gemm', lambda x : x.find('Cijk_') >= 0), - ('memcpy', lambda x : x.find('CUDA mem') >= 0), - ('adam', lambda x : x.lower().find('adam') >= 0), - ('lamb', lambda x : x.lower().find('lamb') >= 0 or x.lower().find('multi_tensor_apply') >= 0), - ('dropout', lambda x : x.lower().find('dropout') >= 0 or x.find('curand') >= 0), - ('layernorm', lambda x : x.find('LayerNorm') >= 0 or x.find('cuCompute') >= 0), - ('reduce', lambda x : x.find('reduce') >= 0), - ('softmax', lambda x : x.lower().find('softmax') >= 0), - ('transpose', lambda x : x.lower().find('transpose') >= 0), - ('element-wise', lambda x : x.lower().find('elementwise') >= 0 or x.find('DivGrad') >= 0), - ('jit', lambda x : x.startswith('kernel_')), - ('misc', lambda x : True), + ("nccl", lambda x: x.find("nccl") >= 0), + ("gemm", lambda x: x.find("Cijk_") >= 0), + ("memcpy", lambda x: x.find("CUDA mem") >= 0), + ("adam", lambda x: x.lower().find("adam") >= 0), + ("lamb", lambda x: x.lower().find("lamb") >= 0 or x.lower().find("multi_tensor_apply") >= 0), + ("dropout", lambda x: x.lower().find("dropout") >= 0 or x.find("curand") >= 0), + ("layernorm", lambda x: x.find("LayerNorm") >= 0 or x.find("cuCompute") >= 0), + ("reduce", lambda x: x.find("reduce") >= 0), + ("softmax", lambda x: x.lower().find("softmax") >= 0), + ("transpose", lambda x: x.lower().find("transpose") >= 0), + ("element-wise", lambda x: x.lower().find("elementwise") >= 0 or x.find("DivGrad") >= 0), + ("jit", lambda x: x.startswith("kernel_")), + ("misc", lambda x: True), ] + def group_gpu_activity(lines): - groups = { name : [] for name,_ in activities } + groups = {name: [] for name, _ in activities} for line in lines: for name, check in activities: if check(line[0]): @@ -41,26 +44,41 @@ def group_gpu_activity(lines): break return groups + def get_seconds(time): - return float(time.replace('us','')) / (1000.0 * 1000.0 * 1000.0) + return float(time.replace("us", "")) / (1000.0 * 1000.0 * 1000.0) + def gpu_percent_time(activities): - return sum([float(a[4].replace('%','')) for a in activities]) + return sum([float(a[4].replace("%", "")) for a in activities]) + def gpu_absolute_time(activities): return sum([get_seconds(a[2]) for a in activities]) + def gpu_kernel_calls(activities): return sum([int(a[1]) for a in activities]) + lines = get_gpu_lines(args.input) groups = group_gpu_activity(lines) for name in groups: activities = groups[name] - print('{}: N={}, calls={}, absolute={:.3f}s, percent={:.2f}%'.format(name, len(activities), gpu_kernel_calls(activities), gpu_absolute_time(activities), gpu_percent_time(activities))) + print( + "{}: N={}, calls={}, absolute={:.3f}s, percent={:.2f}%".format( + name, + len(activities), + gpu_kernel_calls(activities), + gpu_absolute_time(activities), + gpu_percent_time(activities), + ) + ) total = [item for name in groups for item in groups[name]] -print('Total: N={}, calls={}, absolute={:.3f}s, percent={:.2f}%'.format(len(total), gpu_kernel_calls(total), gpu_absolute_time(total), gpu_percent_time(total))) - - +print( + "Total: N={}, calls={}, absolute={:.3f}s, percent={:.2f}%".format( + len(total), gpu_kernel_calls(total), gpu_absolute_time(total), gpu_percent_time(total) + ) +) diff --git a/orttraining/tools/ci_test/compare_huggingface.py b/orttraining/tools/ci_test/compare_huggingface.py index c70ede9a26752..c484cfb56adcb 100755 --- a/orttraining/tools/ci_test/compare_huggingface.py +++ b/orttraining/tools/ci_test/compare_huggingface.py @@ -6,37 +6,47 @@ expect = sys.argv[2] with open(actual) as file_actual: - json_actual = json.loads(file_actual.read()) + json_actual = json.loads(file_actual.read()) with open(expect) as file_expect: - json_expect = json.loads(file_expect.read()) + json_expect = json.loads(file_expect.read()) + def almost_equal(x, y, threshold=0.05): - return abs(x-y) < threshold + return abs(x - y) < threshold + # loss curve tail match loss_tail_length = 4 loss_tail_matches = collections.deque(maxlen=loss_tail_length) -logged_steps = len(json_actual['steps']) -for i in range(logged_steps): - step_actual = json_actual['steps'][i] - step_expect = json_expect['steps'][i] - - is_match = step_actual['step'] == step_expect['step'] - is_match = is_match if almost_equal(step_actual['loss'], step_expect['loss']) else False - loss_tail_matches.append(is_match) - - print('step {} loss actual {:.6f} expected {:.6f} match {}'.format( - step_actual['step'], step_actual['loss'], step_expect['loss'], - is_match if logged_steps - i <= loss_tail_length else 'n/a')) - -success = all(loss_tail_matches) +logged_steps = len(json_actual["steps"]) +for i in range(logged_steps): + step_actual = json_actual["steps"][i] + step_expect = json_expect["steps"][i] + + is_match = step_actual["step"] == step_expect["step"] + is_match = is_match if almost_equal(step_actual["loss"], step_expect["loss"]) else False + loss_tail_matches.append(is_match) + + print( + "step {} loss actual {:.6f} expected {:.6f} match {}".format( + step_actual["step"], + step_actual["loss"], + step_expect["loss"], + is_match if logged_steps - i <= loss_tail_length else "n/a", + ) + ) + +success = all(loss_tail_matches) # performance match threshold = 0.95 -is_performant = json_actual['samples_per_second'] >= threshold*json_expect['samples_per_second'] +is_performant = json_actual["samples_per_second"] >= threshold * json_expect["samples_per_second"] success = success if is_performant else False -print('samples_per_second actual {:.3f} expected {:.3f} in-range {}'.format( - json_actual['samples_per_second'], json_expect['samples_per_second'], is_performant)) +print( + "samples_per_second actual {:.3f} expected {:.3f} in-range {}".format( + json_actual["samples_per_second"], json_expect["samples_per_second"], is_performant + ) +) -assert(success) +assert success diff --git a/orttraining/tools/ci_test/compare_results.py b/orttraining/tools/ci_test/compare_results.py index 6a5ed77fa45ea..ba76b9eaf414c 100644 --- a/orttraining/tools/ci_test/compare_results.py +++ b/orttraining/tools/ci_test/compare_results.py @@ -9,59 +9,69 @@ Comparison = collections.namedtuple("Comparison", ["name", "fn"]) + class Comparisons: - @staticmethod - def eq(): - return Comparison( - name="equal to", - fn=(lambda actual, expected: actual == expected)) - - @staticmethod - def float_le(tolerance=None): - actual_tolerance = 0.0 if tolerance is None else tolerance - return Comparison( - name="less than or equal to" + - (" (tolerance: {})".format(str(actual_tolerance)) - if tolerance is not None else ""), - fn=(lambda actual, expected: float(actual) <= float(expected) + actual_tolerance)) + @staticmethod + def eq(): + return Comparison(name="equal to", fn=(lambda actual, expected: actual == expected)) + + @staticmethod + def float_le(tolerance=None): + actual_tolerance = 0.0 if tolerance is None else tolerance + return Comparison( + name="less than or equal to" + + (" (tolerance: {})".format(str(actual_tolerance)) if tolerance is not None else ""), + fn=(lambda actual, expected: float(actual) <= float(expected) + actual_tolerance), + ) + def _printf_stderr(fmt, *args): - print(fmt.format(*args), file=sys.stderr) + print(fmt.format(*args), file=sys.stderr) + def _read_results_file(results_path): - with open(results_path) as results_file: - csv_reader = csv.DictReader(results_file) - return [row for row in csv_reader] + with open(results_path) as results_file: + csv_reader = csv.DictReader(results_file) + return [row for row in csv_reader] + def _compare_results(expected_results, actual_results, field_comparisons): - if len(field_comparisons) == 0: - return True - - if len(expected_results) != len(actual_results): - _printf_stderr("Expected and actual result sets have different sizes.") - return False - - mismatch_detected = False - for row_idx, (expected_row, actual_row) in enumerate(zip(expected_results, actual_results)): - for field_name, comparison in field_comparisons.items(): - actual, expected = actual_row[field_name], expected_row[field_name] - if not comparison.fn(actual, expected): - _printf_stderr("Comparison '{}' failed for {} in row {}, actual: {}, expected: {}", - comparison.name, field_name, row_idx, actual, expected) - mismatch_detected = True - return not mismatch_detected + if len(field_comparisons) == 0: + return True + + if len(expected_results) != len(actual_results): + _printf_stderr("Expected and actual result sets have different sizes.") + return False + + mismatch_detected = False + for row_idx, (expected_row, actual_row) in enumerate(zip(expected_results, actual_results)): + for field_name, comparison in field_comparisons.items(): + actual, expected = actual_row[field_name], expected_row[field_name] + if not comparison.fn(actual, expected): + _printf_stderr( + "Comparison '{}' failed for {} in row {}, actual: {}, expected: {}", + comparison.name, + field_name, + row_idx, + actual, + expected, + ) + mismatch_detected = True + return not mismatch_detected + def compare_results_files(expected_results_path: str, actual_results_path: str, field_comparisons: dict): - expected_results = _read_results_file(expected_results_path) - actual_results = _read_results_file(actual_results_path) + expected_results = _read_results_file(expected_results_path) + actual_results = _read_results_file(actual_results_path) - comparison_result = _compare_results( - expected_results, actual_results, field_comparisons) + comparison_result = _compare_results(expected_results, actual_results, field_comparisons) - if not comparison_result: - with open(expected_results_path) as expected_results_file, \ - open(actual_results_path) as actual_results_file: - _printf_stderr("===== Expected results =====\n{}\n===== Actual results =====\n{}", - expected_results_file.read(), actual_results_file.read()) + if not comparison_result: + with open(expected_results_path) as expected_results_file, open(actual_results_path) as actual_results_file: + _printf_stderr( + "===== Expected results =====\n{}\n===== Actual results =====\n{}", + expected_results_file.read(), + actual_results_file.read(), + ) - return comparison_result + return comparison_result diff --git a/orttraining/tools/ci_test/download_azure_blob_archive.py b/orttraining/tools/ci_test/download_azure_blob_archive.py index 564dcc8007ce4..32743f63281df 100755 --- a/orttraining/tools/ci_test/download_azure_blob_archive.py +++ b/orttraining/tools/ci_test/download_azure_blob_archive.py @@ -19,49 +19,51 @@ from util import get_azcopy # noqa: E402 + def _download(azcopy_path, url, local_path): - subprocess.run([azcopy_path, "cp", "--log-level", "NONE", url, local_path], check=True) + subprocess.run([azcopy_path, "cp", "--log-level", "NONE", url, local_path], check=True) + def _get_sha256_digest(file_path): - alg = hashlib.sha256() - read_bytes_length = 8192 + alg = hashlib.sha256() + read_bytes_length = 8192 + + with open(file_path, mode="rb") as archive: + while True: + read_bytes = archive.read(read_bytes_length) + if len(read_bytes) == 0: + break + alg.update(read_bytes) - with open(file_path, mode="rb") as archive: - while True: - read_bytes = archive.read(read_bytes_length) - if len(read_bytes) == 0: break - alg.update(read_bytes) + return alg.hexdigest() - return alg.hexdigest() def _check_file_sha256_digest(path, expected_digest): - actual_digest = _get_sha256_digest(path) - match = actual_digest.lower() == expected_digest.lower() - if not match: - raise RuntimeError( - "SHA256 digest mismatch, expected: {}, actual: {}".format( - expected_digest.lower(), actual_digest.lower())) + actual_digest = _get_sha256_digest(path) + match = actual_digest.lower() == expected_digest.lower() + if not match: + raise RuntimeError( + "SHA256 digest mismatch, expected: {}, actual: {}".format(expected_digest.lower(), actual_digest.lower()) + ) + def main(): - parser = argparse.ArgumentParser( - description="Downloads an Azure blob archive.") - parser.add_argument("--azure_blob_url", required=True, - help="The Azure blob URL.") - parser.add_argument("--target_dir", required=True, - help="The destination directory.") - parser.add_argument("--archive_sha256_digest", - help="The SHA256 digest of the archive. Verified if provided.") - args = parser.parse_args() - - with tempfile.TemporaryDirectory() as temp_dir, get_azcopy() as azcopy_path: - archive_path = os.path.join(temp_dir, "archive.zip") - print("Downloading archive from '{}'...".format(args.azure_blob_url)) - _download(azcopy_path, args.azure_blob_url, archive_path) - if args.archive_sha256_digest: - _check_file_sha256_digest(archive_path, args.archive_sha256_digest) - print("Extracting to '{}'...".format(args.target_dir)) - shutil.unpack_archive(archive_path, args.target_dir) - print("Done.") + parser = argparse.ArgumentParser(description="Downloads an Azure blob archive.") + parser.add_argument("--azure_blob_url", required=True, help="The Azure blob URL.") + parser.add_argument("--target_dir", required=True, help="The destination directory.") + parser.add_argument("--archive_sha256_digest", help="The SHA256 digest of the archive. Verified if provided.") + args = parser.parse_args() + + with tempfile.TemporaryDirectory() as temp_dir, get_azcopy() as azcopy_path: + archive_path = os.path.join(temp_dir, "archive.zip") + print("Downloading archive from '{}'...".format(args.azure_blob_url)) + _download(azcopy_path, args.azure_blob_url, archive_path) + if args.archive_sha256_digest: + _check_file_sha256_digest(archive_path, args.archive_sha256_digest) + print("Extracting to '{}'...".format(args.target_dir)) + shutil.unpack_archive(archive_path, args.target_dir) + print("Done.") + if __name__ == "__main__": - sys.exit(main()) + sys.exit(main()) diff --git a/orttraining/tools/ci_test/run_batch_size_test.py b/orttraining/tools/ci_test/run_batch_size_test.py index b735f71bb777a..4a7ec51062914 100755 --- a/orttraining/tools/ci_test/run_batch_size_test.py +++ b/orttraining/tools/ci_test/run_batch_size_test.py @@ -13,71 +13,96 @@ def parse_args(): parser = argparse.ArgumentParser(description="Runs a BERT batch size test.") parser.add_argument("--binary_dir", required=True, help="Path to the ORT binary directory.") parser.add_argument("--model_root", required=True, help="Path to the model root directory.") - parser.add_argument("--gpu_sku", choices=['V100_16G', 'MI100_32G'], default='V100_16G', required=False, - help="GPU model (e.g. V100_16G, MI100_32G).") + parser.add_argument( + "--gpu_sku", + choices=["V100_16G", "MI100_32G"], + default="V100_16G", + required=False, + help="GPU model (e.g. V100_16G, MI100_32G).", + ) return parser.parse_args() def main(): args = parse_args() - Config = collections.namedtuple("Config", ["enable_mixed_precision", - "sequence_length", - "max_batch_size", - "max_predictions_per_seq", - "additional_options"]) + Config = collections.namedtuple( + "Config", + [ + "enable_mixed_precision", + "sequence_length", + "max_batch_size", + "max_predictions_per_seq", + "additional_options", + ], + ) configs = {} - configs['V100_16G'] = [ + configs["V100_16G"] = [ Config(True, 128, 76, 20, ""), Config(True, 512, 11, 80, ""), Config(False, 128, 39, 20, ""), Config(False, 512, 6, 80, ""), - # BertLarge Phase 1 recompute Config(True, 128, 91, 20, "--gelu_recompute"), Config(True, 128, 83, 20, "--attn_dropout_recompute"), Config(True, 128, 344, 20, "--transformer_layer_recompute"), - # BertLarge Phase 2 recompute Config(True, 512, 12, 80, "--gelu_recompute"), Config(True, 512, 14, 80, "--attn_dropout_recompute"), Config(True, 512, 50, 80, "--transformer_layer_recompute"), ] - configs['MI100_32G'] = [ + configs["MI100_32G"] = [ Config(True, 128, 192, 20, ""), Config(True, 512, 26, 80, ""), Config(False, 128, 108, 20, ""), Config(False, 512, 16, 80, ""), ] - + # run BERT training for config in configs[args.gpu_sku]: - print("##### testing name - {}-{} #####".format("fp16" if config.enable_mixed_precision else "fp32", - config.sequence_length)) + print( + "##### testing name - {}-{} #####".format( + "fp16" if config.enable_mixed_precision else "fp32", config.sequence_length + ) + ) cmds = [ os.path.join(args.binary_dir, "onnxruntime_training_bert"), - "--model_name", os.path.join( + "--model_name", + os.path.join( args.model_root, - "nv/bert-large/bert-large-uncased_L_24_H_1024_A_16_V_30528_S_512_Dp_0.1_optimized_layer_norm_opset12"), - "--train_batch_size", str(config.max_batch_size), - "--mode", "perf", - "--max_seq_length", str(config.sequence_length), - "--num_train_steps", "10", - "--display_loss_steps", "5", - "--optimizer", "adam", - "--learning_rate", "5e-4", - "--warmup_ratio", "0.1", - "--warmup_mode", "Linear", - "--gradient_accumulation_steps", "1", + "nv/bert-large/bert-large-uncased_L_24_H_1024_A_16_V_30528_S_512_Dp_0.1_optimized_layer_norm_opset12", + ), + "--train_batch_size", + str(config.max_batch_size), + "--mode", + "perf", + "--max_seq_length", + str(config.sequence_length), + "--num_train_steps", + "10", + "--display_loss_steps", + "5", + "--optimizer", + "adam", + "--learning_rate", + "5e-4", + "--warmup_ratio", + "0.1", + "--warmup_mode", + "Linear", + "--gradient_accumulation_steps", + "1", "--max_predictions_per_seq=20", "--allreduce_in_fp16", - "--lambda", "0", + "--lambda", + "0", "--use_nccl", - "--seed", "42", + "--seed", + "42", "--enable_grad_norm_clip=false", - config.additional_options + config.additional_options, ] if config.enable_mixed_precision: diff --git a/orttraining/tools/ci_test/run_bert_perf_test.py b/orttraining/tools/ci_test/run_bert_perf_test.py index 5d673a139dc7c..8f6a59c1fd883 100644 --- a/orttraining/tools/ci_test/run_bert_perf_test.py +++ b/orttraining/tools/ci_test/run_bert_perf_test.py @@ -11,76 +11,103 @@ SCRIPT_DIR = os.path.realpath(os.path.dirname(__file__)) + def parse_args(): - parser = argparse.ArgumentParser(description="Runs BERT performance tests.") - parser.add_argument("--binary_dir", required=True, - help="Path to the ORT binary directory.") - parser.add_argument("--training_data_root", required=True, - help="Path to the training data root directory.") - parser.add_argument("--model_root", required=True, - help="Path to the model root directory.") - parser.add_argument("--gpu_sku", choices=['V100_16G', 'MI100_32G'], default='V100_16G', required=False, - help="GPU model (e.g. V100_16G, MI100_32G).") - return parser.parse_args() + parser = argparse.ArgumentParser(description="Runs BERT performance tests.") + parser.add_argument("--binary_dir", required=True, help="Path to the ORT binary directory.") + parser.add_argument("--training_data_root", required=True, help="Path to the training data root directory.") + parser.add_argument("--model_root", required=True, help="Path to the model root directory.") + parser.add_argument( + "--gpu_sku", + choices=["V100_16G", "MI100_32G"], + default="V100_16G", + required=False, + help="GPU model (e.g. V100_16G, MI100_32G).", + ) + return parser.parse_args() + # using the same params from "GitHub Master Merge Schedule" in OneNotes def main(): args = parse_args() - Config = namedtuple('Config', ['use_mixed_precision', 'max_seq_length', 'batch_size', 'max_predictions_per_seq', 'expected_perf']) + Config = namedtuple( + "Config", ["use_mixed_precision", "max_seq_length", "batch_size", "max_predictions_per_seq", "expected_perf"] + ) configs = {} - configs['V100_16G'] = [ + configs["V100_16G"] = [ Config(True, 128, 76, 20, -1.0), Config(True, 512, 11, 80, -1.0), Config(False, 128, 39, 20, -1.0), - Config(False, 512, 6, 80, -1.0) + Config(False, 512, 6, 80, -1.0), ] - configs['MI100_32G'] = [ + configs["MI100_32G"] = [ Config(True, 128, 128, 20, 240), ] # run BERT training for c in configs[args.gpu_sku]: - model = 'bert-large-uncased_L_24_H_1024_A_16_V_30528_S_512_Dp_0.1_optimized_layer_norm_opset12' - precision_prefix = ('fp16' if c.use_mixed_precision else 'fp32') - print("######## testing name - " + ('fp16-' if c.use_mixed_precision else 'fp32-') + str(c.max_seq_length) + " ##############") + model = "bert-large-uncased_L_24_H_1024_A_16_V_30528_S_512_Dp_0.1_optimized_layer_norm_opset12" + precision_prefix = "fp16" if c.use_mixed_precision else "fp32" + print( + "######## testing name - " + + ("fp16-" if c.use_mixed_precision else "fp32-") + + str(c.max_seq_length) + + " ##############" + ) cmds = [ os.path.join(args.binary_dir, "onnxruntime_training_bert"), - "--model_name", os.path.join( - args.model_root, "nv/bert-large/{}".format(model)), - "--train_data_dir", os.path.join( - args.training_data_root, str(c.max_seq_length), "books_wiki_en_corpus/train"), - "--test_data_dir", os.path.join( - args.training_data_root, str(c.max_seq_length), "books_wiki_en_corpus/test"), - "--train_batch_size", str(c.batch_size), - "--mode", "train", - "--max_seq_length", str(c.max_seq_length), - "--num_train_steps", "640", - "--display_loss_steps", "5", - "--optimizer", "Lamb", - "--learning_rate", "3e-3", - "--warmup_ratio", "0.2843", - "--warmup_mode", "Poly", - "--gradient_accumulation_steps", "1", - "--max_predictions_per_seq", str(c.max_predictions_per_seq), - "--lambda", "0", + "--model_name", + os.path.join(args.model_root, "nv/bert-large/{}".format(model)), + "--train_data_dir", + os.path.join(args.training_data_root, str(c.max_seq_length), "books_wiki_en_corpus/train"), + "--test_data_dir", + os.path.join(args.training_data_root, str(c.max_seq_length), "books_wiki_en_corpus/test"), + "--train_batch_size", + str(c.batch_size), + "--mode", + "train", + "--max_seq_length", + str(c.max_seq_length), + "--num_train_steps", + "640", + "--display_loss_steps", + "5", + "--optimizer", + "Lamb", + "--learning_rate", + "3e-3", + "--warmup_ratio", + "0.2843", + "--warmup_mode", + "Poly", + "--gradient_accumulation_steps", + "1", + "--max_predictions_per_seq", + str(c.max_predictions_per_seq), + "--lambda", + "0", "--use_nccl", - "--perf_output_dir", os.path.join(SCRIPT_DIR, "results"), + "--perf_output_dir", + os.path.join(SCRIPT_DIR, "results"), ] - if c.use_mixed_precision: + if c.use_mixed_precision: cmds.append("--use_mixed_precision"), cmds.append("--allreduce_in_fp16"), subprocess.run(cmds).check_returncode() if c.expected_perf > 0.0: - json_filename = 'onnxruntime_perf_metrics_{}.onnx_bert_{}_{}_Lamb.json'.format(model, precision_prefix, c.max_seq_length) - with open(os.path.join(SCRIPT_DIR, 'results', json_filename)) as json_file: + json_filename = "onnxruntime_perf_metrics_{}.onnx_bert_{}_{}_Lamb.json".format( + model, precision_prefix, c.max_seq_length + ) + with open(os.path.join(SCRIPT_DIR, "results", json_filename)) as json_file: results = json.load(json_file) - assert(results['EndToEndThroughput'] > 0.98*c.expected_perf) - + assert results["EndToEndThroughput"] > 0.98 * c.expected_perf + return 0 + if __name__ == "__main__": sys.exit(main()) diff --git a/orttraining/tools/ci_test/run_convergence_test.py b/orttraining/tools/ci_test/run_convergence_test.py index abaacb5734e65..568e3c4cd9c4c 100755 --- a/orttraining/tools/ci_test/run_convergence_test.py +++ b/orttraining/tools/ci_test/run_convergence_test.py @@ -12,75 +12,96 @@ SCRIPT_DIR = os.path.realpath(os.path.dirname(__file__)) + def parse_args(): - parser = argparse.ArgumentParser(description="Runs a BERT convergence test.") - parser.add_argument("--binary_dir", required=True, - help="Path to the ORT binary directory.") - parser.add_argument("--training_data_root", required=True, - help="Path to the training data root directory.") - parser.add_argument("--model_root", required=True, - help="Path to the model root directory.") - parser.add_argument("--gpu_sku", choices=['V100_16G', 'MI100_32G'], default='V100_16G', required=False, - help="GPU model (e.g. V100_16G, MI100_32G).") - return parser.parse_args() + parser = argparse.ArgumentParser(description="Runs a BERT convergence test.") + parser.add_argument("--binary_dir", required=True, help="Path to the ORT binary directory.") + parser.add_argument("--training_data_root", required=True, help="Path to the training data root directory.") + parser.add_argument("--model_root", required=True, help="Path to the model root directory.") + parser.add_argument( + "--gpu_sku", + choices=["V100_16G", "MI100_32G"], + default="V100_16G", + required=False, + help="GPU model (e.g. V100_16G, MI100_32G).", + ) + return parser.parse_args() + def main(): - args = parse_args() + args = parse_args() + + with tempfile.TemporaryDirectory() as output_dir: + convergence_test_output_path = os.path.join(output_dir, "convergence_test_out.csv") - with tempfile.TemporaryDirectory() as output_dir: - convergence_test_output_path = os.path.join( - output_dir, "convergence_test_out.csv") + # run BERT training + subprocess.run( + [ + os.path.join(args.binary_dir, "onnxruntime_training_bert"), + "--model_name", + os.path.join( + args.model_root, + "nv/bert-base/bert-base-uncased_L_12_H_768_A_12_V_30528_S_512_Dp_0.1_optimized_layer_norm_opset12", + ), + "--train_data_dir", + os.path.join(args.training_data_root, "128/books_wiki_en_corpus/train"), + "--test_data_dir", + os.path.join(args.training_data_root, "128/books_wiki_en_corpus/test"), + "--train_batch_size", + "64", + "--mode", + "train", + "--num_train_steps", + "800", + "--display_loss_steps", + "5", + "--optimizer", + "adam", + "--learning_rate", + "5e-4", + "--warmup_ratio", + "0.1", + "--warmup_mode", + "Linear", + "--gradient_accumulation_steps", + "16", + "--max_predictions_per_seq=20", + "--use_mixed_precision", + "--use_deterministic_compute", + "--allreduce_in_fp16", + "--lambda", + "0", + "--use_nccl", + "--convergence_test_output_file", + convergence_test_output_path, + "--seed", + "42", + "--enable_grad_norm_clip=false", + ] + ).check_returncode() - # run BERT training - subprocess.run([ - os.path.join(args.binary_dir, "onnxruntime_training_bert"), - "--model_name", os.path.join( - args.model_root, "nv/bert-base/bert-base-uncased_L_12_H_768_A_12_V_30528_S_512_Dp_0.1_optimized_layer_norm_opset12"), - "--train_data_dir", os.path.join( - args.training_data_root, "128/books_wiki_en_corpus/train"), - "--test_data_dir", os.path.join( - args.training_data_root, "128/books_wiki_en_corpus/test"), - "--train_batch_size", "64", - "--mode", "train", - "--num_train_steps", "800", - "--display_loss_steps", "5", - "--optimizer", "adam", - "--learning_rate", "5e-4", - "--warmup_ratio", "0.1", - "--warmup_mode", "Linear", - "--gradient_accumulation_steps", "16", - "--max_predictions_per_seq=20", - "--use_mixed_precision", - "--use_deterministic_compute", - "--allreduce_in_fp16", - "--lambda", "0", - "--use_nccl", - "--convergence_test_output_file", convergence_test_output_path, - "--seed", "42", - "--enable_grad_norm_clip=false", - ]).check_returncode() + # reference data + if args.gpu_sku == "MI100_32G": + reference_csv = "bert_base.convergence.baseline.mi100.csv" + elif args.gpu_sku == "V100_16G": + reference_csv = "bert_base.convergence.baseline.csv" + else: + raise ValueError("Unrecognized gpu_sku {}".format(args.gpu_sku)) - # reference data - if args.gpu_sku == 'MI100_32G': - reference_csv = "bert_base.convergence.baseline.mi100.csv" - elif args.gpu_sku == 'V100_16G': - reference_csv = "bert_base.convergence.baseline.csv" - else: - raise ValueError('Unrecognized gpu_sku {}'.format(args.gpu_sku)) + # verify output + comparison_result = compare_results_files( + expected_results_path=os.path.join(SCRIPT_DIR, "results", reference_csv), + actual_results_path=convergence_test_output_path, + field_comparisons={ + "step": Comparisons.eq(), + "total_loss": Comparisons.float_le(1e-3), + "mlm_loss": Comparisons.float_le(1e-3), + "nsp_loss": Comparisons.float_le(1e-3), + }, + ) - # verify output - comparison_result = compare_results_files( - expected_results_path=os.path.join( - SCRIPT_DIR, "results", reference_csv), - actual_results_path=convergence_test_output_path, - field_comparisons={ - "step": Comparisons.eq(), - "total_loss": Comparisons.float_le(1e-3), - "mlm_loss": Comparisons.float_le(1e-3), - "nsp_loss": Comparisons.float_le(1e-3), - }) + return 0 if comparison_result else 1 - return 0 if comparison_result else 1 if __name__ == "__main__": - sys.exit(main()) + sys.exit(main()) diff --git a/orttraining/tools/ci_test/run_gpt2_perf_test.py b/orttraining/tools/ci_test/run_gpt2_perf_test.py index 3e39ffd9a6c32..8c0ac1953feed 100644 --- a/orttraining/tools/ci_test/run_gpt2_perf_test.py +++ b/orttraining/tools/ci_test/run_gpt2_perf_test.py @@ -10,51 +10,62 @@ SCRIPT_DIR = os.path.realpath(os.path.dirname(__file__)) + def parse_args(): - parser = argparse.ArgumentParser(description="Runs GPT-2 performance tests.") - parser.add_argument("--binary_dir", required=True, - help="Path to the ORT binary directory.") - parser.add_argument("--training_data_root", required=True, - help="Path to the training data root directory.") - parser.add_argument("--model_root", required=True, - help="Path to the model root directory.") - return parser.parse_args() + parser = argparse.ArgumentParser(description="Runs GPT-2 performance tests.") + parser.add_argument("--binary_dir", required=True, help="Path to the ORT binary directory.") + parser.add_argument("--training_data_root", required=True, help="Path to the training data root directory.") + parser.add_argument("--model_root", required=True, help="Path to the model root directory.") + return parser.parse_args() + # TODO - review to finalize params def main(): args = parse_args() - Config = namedtuple('Config', ['use_mixed_precision', 'max_seq_length', 'batch_size']) - configs = [ - Config(True, 1024, 1), - Config(False, 1024, 1) - ] + Config = namedtuple("Config", ["use_mixed_precision", "max_seq_length", "batch_size"]) + configs = [Config(True, 1024, 1), Config(False, 1024, 1)] # run GPT-2 training for c in configs: - print("######## testing name - " + ('fp16-' if c.use_mixed_precision else 'fp32-') + str(c.max_seq_length) + " ##############") + print( + "######## testing name - " + + ("fp16-" if c.use_mixed_precision else "fp32-") + + str(c.max_seq_length) + + " ##############" + ) cmds = [ os.path.join(args.binary_dir, "onnxruntime_training_gpt2"), - "--model_name", os.path.join( - args.model_root, "megatron-gpt2_hidden-size-1024_num-layers-24_vocab-size-50257_num-attention-heads-16_max-position-embeddings-1024_optimized_opset12"), - "--train_data_dir", os.path.join( - args.training_data_root, "train"), - "--test_data_dir", os.path.join( - args.training_data_root, "test"), - "--train_batch_size", str(c.batch_size), - "--mode", "train", - "--max_seq_length", str(c.max_seq_length), - "--num_train_steps", "640", - "--gradient_accumulation_steps", "1", - "--perf_output_dir", os.path.join(SCRIPT_DIR, "results"), + "--model_name", + os.path.join( + args.model_root, + "megatron-gpt2_hidden-size-1024_num-layers-24_vocab-size-50257_num-attention-heads-16_max-position-embeddings-1024_optimized_opset12", + ), + "--train_data_dir", + os.path.join(args.training_data_root, "train"), + "--test_data_dir", + os.path.join(args.training_data_root, "test"), + "--train_batch_size", + str(c.batch_size), + "--mode", + "train", + "--max_seq_length", + str(c.max_seq_length), + "--num_train_steps", + "640", + "--gradient_accumulation_steps", + "1", + "--perf_output_dir", + os.path.join(SCRIPT_DIR, "results"), ] - if c.use_mixed_precision: + if c.use_mixed_precision: cmds.append("--use_mixed_precision"), subprocess.run(cmds).check_returncode() return 0 + if __name__ == "__main__": sys.exit(main()) diff --git a/orttraining/tools/scripts/experiment.py b/orttraining/tools/scripts/experiment.py index 5292b021ddec5..0e3e2ceead465 100644 --- a/orttraining/tools/scripts/experiment.py +++ b/orttraining/tools/scripts/experiment.py @@ -16,27 +16,80 @@ from azureml.core.runconfig import MpiConfiguration, RunConfiguration parser = argparse.ArgumentParser() -parser.add_argument('--subscription', type=str, default='ea482afa-3a32-437c-aa10-7de928a9e793') # AI Platform GPU - MLPerf -parser.add_argument('--resource_group', type=str, default='onnx_training', help='Azure resource group containing the AzureML Workspace') -parser.add_argument('--workspace', type=str, default='ort_training_dev', help='AzureML Workspace to run the Experiment in') -parser.add_argument('--compute_target', type=str, default='onnx-training', help='AzureML Compute target to run the Experiment on') -parser.add_argument('--experiment', type=str, default='BERT-ONNX', help='Name of the AzureML Experiment') -parser.add_argument('--tags', type=str, default=None, help='Tags to be added to the submitted run (--tag1=value1 --tag2=value2 --tag3=value3)') - -parser.add_argument('--datastore', type=str, default='bert_premium', help='AzureML Datastore to be mounted into the Experiment') -parser.add_argument('--train_dir', type=str, default='book/train', help='Path in the AzureML Datastore containing the train files') -parser.add_argument('--test_dir', type=str, default='book/test', help='Path in the AzureML Datastore containing the test files') -parser.add_argument('--train_dir2', type=str, default=None, help='Path in the AzureML Datastore containing the train files for phase 2') -parser.add_argument('--test_dir2', type=str, default=None, help='Path in the AzureML Datastore containing the test files for phase 2') - -parser.add_argument('--container', type=str, default='onnxtraining.azurecr.io/azureml/bert:latest-openmpi4.0.0-cuda10.1-cudnn7-ubuntu16.04', help='Docker container to use to run the Experiment') -parser.add_argument('--container_registry_resource_group', type=str, default='onnx_training', help='Azure resource group containing the Azure Container Registry (if not public)') - -parser.add_argument('--node_count', type=int, default=1, help='Number of nodes to use for the Experiment. If greater than 1, an MPI distributed job will be run.') -parser.add_argument('--gpu_count', type=int, default=1, help='Number of GPUs to use per node. If greater than 1, an MPI distributed job will be run.') - -parser.add_argument('--model_name', type=str, default='bert_L-24_H-1024_A-16_V_30528_optimized_layer_norm', help='Model to be trained (must exist in the AzureML Datastore)') -parser.add_argument('--script_params', type=str, default='', help='Training script parameters (--param1=value1 --param2=value2 --param3=value3)') +parser.add_argument( + "--subscription", type=str, default="ea482afa-3a32-437c-aa10-7de928a9e793" +) # AI Platform GPU - MLPerf +parser.add_argument( + "--resource_group", type=str, default="onnx_training", help="Azure resource group containing the AzureML Workspace" +) +parser.add_argument( + "--workspace", type=str, default="ort_training_dev", help="AzureML Workspace to run the Experiment in" +) +parser.add_argument( + "--compute_target", type=str, default="onnx-training", help="AzureML Compute target to run the Experiment on" +) +parser.add_argument("--experiment", type=str, default="BERT-ONNX", help="Name of the AzureML Experiment") +parser.add_argument( + "--tags", + type=str, + default=None, + help="Tags to be added to the submitted run (--tag1=value1 --tag2=value2 --tag3=value3)", +) + +parser.add_argument( + "--datastore", type=str, default="bert_premium", help="AzureML Datastore to be mounted into the Experiment" +) +parser.add_argument( + "--train_dir", type=str, default="book/train", help="Path in the AzureML Datastore containing the train files" +) +parser.add_argument( + "--test_dir", type=str, default="book/test", help="Path in the AzureML Datastore containing the test files" +) +parser.add_argument( + "--train_dir2", type=str, default=None, help="Path in the AzureML Datastore containing the train files for phase 2" +) +parser.add_argument( + "--test_dir2", type=str, default=None, help="Path in the AzureML Datastore containing the test files for phase 2" +) + +parser.add_argument( + "--container", + type=str, + default="onnxtraining.azurecr.io/azureml/bert:latest-openmpi4.0.0-cuda10.1-cudnn7-ubuntu16.04", + help="Docker container to use to run the Experiment", +) +parser.add_argument( + "--container_registry_resource_group", + type=str, + default="onnx_training", + help="Azure resource group containing the Azure Container Registry (if not public)", +) + +parser.add_argument( + "--node_count", + type=int, + default=1, + help="Number of nodes to use for the Experiment. If greater than 1, an MPI distributed job will be run.", +) +parser.add_argument( + "--gpu_count", + type=int, + default=1, + help="Number of GPUs to use per node. If greater than 1, an MPI distributed job will be run.", +) + +parser.add_argument( + "--model_name", + type=str, + default="bert_L-24_H-1024_A-16_V_30528_optimized_layer_norm", + help="Model to be trained (must exist in the AzureML Datastore)", +) +parser.add_argument( + "--script_params", + type=str, + default="", + help="Training script parameters (--param1=value1 --param2=value2 --param3=value3)", +) args = parser.parse_args() # Get the AzureML Workspace to run the Experiment in @@ -50,67 +103,70 @@ # Construct common script parameters script_params = { - '--model_name': ds.path(args.model_name).as_download(), - '--train_data_dir': ds.path(args.train_dir).as_mount(), - '--test_data_dir': ds.path(args.test_dir).as_mount(), + "--model_name": ds.path(args.model_name).as_download(), + "--train_data_dir": ds.path(args.train_dir).as_mount(), + "--test_data_dir": ds.path(args.test_dir).as_mount(), } # Optional phase2 script parameters if args.train_dir2: - script_params['--train_data_dir_phase2'] = ds.path(args.train_dir2).as_mount() + script_params["--train_data_dir_phase2"] = ds.path(args.train_dir2).as_mount() if args.test_dir2: - script_params['--test_data_dir_phase2'] = ds.path(args.test_dir2).as_mount() + script_params["--test_data_dir_phase2"] = ds.path(args.test_dir2).as_mount() # Allow additional custom script parameters -for params in args.script_params.split(' '): - key, value = params.split('=') - script_params[key] = value +for params in args.script_params.split(" "): + key, value = params.split("=") + script_params[key] = value # Allow custom tags on the run tags = {} if args.tags: - for tag in args.tags.split(' '): - key, value = tag.split('=') - tags[key] = value + for tag in args.tags.split(" "): + key, value = tag.split("=") + tags[key] = value # Get container registry information (if private) container_image = args.container registry_details = None -acr = re.match('^((\w+).azurecr.io)/(.*)', args.container) +acr = re.match("^((\w+).azurecr.io)/(.*)", args.container) if acr: - # Extract the relevant parts from the container image - # e.g. onnxtraining.azurecr.io/azureml/bert:latest - registry_address = acr.group(1) # onnxtraining.azurecr.io - registry_name = acr.group(2) # onnxtraining - container_image = acr.group(3) # azureml/bert:latest - - registry_client = get_client_from_cli_profile(ContainerRegistryManagementClient, subscription_id=args.subscription) - registry_credentials = registry_client.registries.list_credentials(args.container_registry_resource_group, registry_name) - - registry_details = ContainerRegistry() - registry_details.address = registry_address - registry_details.username = registry_credentials.username - registry_details.password = registry_credentials.passwords[0].value + # Extract the relevant parts from the container image + # e.g. onnxtraining.azurecr.io/azureml/bert:latest + registry_address = acr.group(1) # onnxtraining.azurecr.io + registry_name = acr.group(2) # onnxtraining + container_image = acr.group(3) # azureml/bert:latest + + registry_client = get_client_from_cli_profile(ContainerRegistryManagementClient, subscription_id=args.subscription) + registry_credentials = registry_client.registries.list_credentials( + args.container_registry_resource_group, registry_name + ) + + registry_details = ContainerRegistry() + registry_details.address = registry_address + registry_details.username = registry_credentials.username + registry_details.password = registry_credentials.passwords[0].value # MPI configuration if executing a distributed run mpi = MpiConfiguration() mpi.process_count_per_node = args.gpu_count # AzureML Estimator that describes how to run the Experiment -estimator = Estimator(source_directory='./', - script_params=script_params, - compute_target=compute_target, - node_count=args.node_count, - distributed_training=mpi, - image_registry_details=registry_details, - use_docker=True, - custom_docker_image=container_image, - entry_script='train.py', - inputs=[ds.path('./').as_mount()] - ) +estimator = Estimator( + source_directory="./", + script_params=script_params, + compute_target=compute_target, + node_count=args.node_count, + distributed_training=mpi, + image_registry_details=registry_details, + use_docker=True, + custom_docker_image=container_image, + entry_script="train.py", + inputs=[ds.path("./").as_mount()], +) # Start the AzureML Experiment experiment = Experiment(workspace=ws, name=args.experiment) run = experiment.submit(estimator, tags) -print('Experiment running at: {}'.format(run.get_portal_url())) +print("Experiment running at: {}".format(run.get_portal_url())) diff --git a/orttraining/tools/scripts/gpt2_model_transform.py b/orttraining/tools/scripts/gpt2_model_transform.py index 9c6e3b26dc769..9e018a34069e5 100644 --- a/orttraining/tools/scripts/gpt2_model_transform.py +++ b/orttraining/tools/scripts/gpt2_model_transform.py @@ -12,15 +12,17 @@ exit(1) input_model_name = sys.argv[1] -output_model_name = input_model_name[:-5] + '_optimized.onnx' +output_model_name = input_model_name[:-5] + "_optimized.onnx" model = onnx.load(input_model_name) + def add_name(model): i = 0 for node in model.graph.node: - node.name = '%s_%d' %(node.op_type, i) - i += 1 + node.name = "%s_%d" % (node.op_type, i) + i += 1 + def find_input_node(model, arg): result = [] @@ -28,7 +30,8 @@ def find_input_node(model, arg): for output in node.output: if output == arg: result.append(node) - return result[0] if len(result)== 1 else None + return result[0] if len(result) == 1 else None + def find_output_node(model, arg): result = [] @@ -38,18 +41,21 @@ def find_output_node(model, arg): result.append(node) return result[0] if len(result) == 1 else None + def find_initializer(model, arg): for initializer in model.graph.initializer: if initializer.name == arg: return initializer return None + def find_input(model, arg): for graph_input in model.graph.input: if graph_input.name == arg: return graph_input return None + def find_all_fused_nodes(model, concat_node): result = [] candidate = [concat_node] @@ -57,7 +63,7 @@ def find_all_fused_nodes(model, concat_node): node = candidate[0] candidate.pop(0) result.append(node) - if node.op_type == 'Shape': + if node.op_type == "Shape": continue for input in node.input: input_node = find_input_node(model, input) @@ -65,6 +71,7 @@ def find_all_fused_nodes(model, concat_node): candidate.append(input_node) return result + def get_node_index(model, node): i = 0 while i < len(model.graph.node): @@ -73,13 +80,14 @@ def get_node_index(model, node): i += 1 return i if i < len(model.graph.node) else None -def add_const(model, name, output, t_value = None, f_value = None): + +def add_const(model, name, output, t_value=None, f_value=None): const_node = model.graph.node.add() - const_node.op_type = 'Constant' + const_node.op_type = "Constant" const_node.name = name const_node.output.extend([output]) attr = const_node.attribute.add() - attr.name = 'value' + attr.name = "value" if t_value is not None: attr.type = 4 attr.t.CopyFrom(t_value) @@ -88,61 +96,63 @@ def add_const(model, name, output, t_value = None, f_value = None): attr.f = f_value return const_node + def process_concat(model): new_nodes = {} delete_nodes = [] for node in model.graph.node: - if node.op_type != 'Concat': + if node.op_type != "Concat": continue skip = False input_nodes = [] for input in node.input: concat_input_node = find_input_node(model, input) - if concat_input_node.op_type != 'Unsqueeze': + if concat_input_node.op_type != "Unsqueeze": skip = True input_nodes.append(concat_input_node) if skip == True: continue - #figure out target shape + # figure out target shape shape = [] for input_node in input_nodes: const_input = find_input_node(model, input_node.input[0]) - if const_input.op_type != 'Constant': + if const_input.op_type != "Constant": shape.append(0) else: attr = const_input.attribute assert len(attr) == 1 - assert attr[0].name == 'value' + assert attr[0].name == "value" assert attr[0].type == 4 data = numpy_helper.to_array(attr[0].t) shape.append(np.asscalar(data)) - print('concat node: %s, new_shape is: %s' % (node.name, shape)) + print("concat node: %s, new_shape is: %s" % (node.name, shape)) - #find out the nodes need to be deleted. + # find out the nodes need to be deleted. fuse_nodes = find_all_fused_nodes(model, node) reshape_node = find_output_node(model, node.output[0]) - assert reshape_node.op_type == 'Reshape' + assert reshape_node.op_type == "Reshape" new_nodes[get_node_index(model, reshape_node)] = shape for n in fuse_nodes: delete_nodes.append(get_node_index(model, n)) - #insert new shape to reshape + # insert new shape to reshape index = 0 for reshape_node_index in new_nodes: shape_tensor = numpy_helper.from_array(np.asarray(new_nodes[reshape_node_index], dtype=np.int64)) - const_node = add_const(model, 'concat_shape_node_%d' % index, 'concat_shape_%d' % index, shape_tensor) - index+=1 + const_node = add_const(model, "concat_shape_node_%d" % index, "concat_shape_%d" % index, shape_tensor) + index += 1 reshape_node = model.graph.node[reshape_node_index] reshape_node.input[1] = const_node.output[0] - #delete nodes + # delete nodes delete_nodes.sort(reverse=True) for delete_node in delete_nodes: del model.graph.node[delete_node] + def replace_input_arg(model, arg, new_arg): for node in model.graph.node: i = 0 @@ -151,6 +161,7 @@ def replace_input_arg(model, arg, new_arg): node.input[i] = new_arg i += 1 + def find_weight_index(model, name): index = 0 for w in model.graph.initializer: @@ -159,6 +170,7 @@ def find_weight_index(model, name): index += 1 return None + def find_input_index(model, name): index = 0 for w in model.graph.input: @@ -167,10 +179,11 @@ def find_input_index(model, name): index += 1 return None + def fix_transpose(model): transpose = [] for node in model.graph.node: - if node.op_type == 'Transpose': + if node.op_type == "Transpose": weight = find_initializer(model, node.input[0]) if weight is not None: result = [] @@ -181,7 +194,7 @@ def fix_transpose(model): if len(result) > 1: continue perm = node.attribute[0] - assert perm.name == 'perm' + assert perm.name == "perm" perm = perm.ints assert len(perm) == 2 and perm[0] == 1 and perm[1] == 0 transpose.append((get_node_index(model, node), weight)) @@ -199,7 +212,7 @@ def fix_transpose(model): del model.graph.node[t[0]] old_ws = [] - old_graph_inputs=[] + old_graph_inputs = [] for t in transpose: if find_output_node(model, t[1].name) is None: old_ws.append(find_weight_index(model, t[1].name)) @@ -217,20 +230,23 @@ def fix_transpose(model): print(model.graph.initializer[w_i].name) del model.graph.initializer[w_i] + def process_dropout(model): dropouts = [] index = 0 for node in model.graph.node: - if node.op_type == 'Dropout': + if node.op_type == "Dropout": new_dropout = model.graph.node.add() - new_dropout.op_type = 'TrainableDropout' - new_dropout.name = 'TrainableDropout_%d' % index - #make ratio node + new_dropout.op_type = "TrainableDropout" + new_dropout.name = "TrainableDropout_%d" % index + # make ratio node ratio = np.asarray([node.attribute[0].f], dtype=np.float32) print(ratio.shape) ratio_value = numpy_helper.from_array(ratio) - ratio_node = add_const(model, 'dropout_node_ratio_%d' % index, 'dropout_node_ratio_%d' % index, t_value=ratio_value) - print (ratio_node) + ratio_node = add_const( + model, "dropout_node_ratio_%d" % index, "dropout_node_ratio_%d" % index, t_value=ratio_value + ) + print(ratio_node) new_dropout.input.extend([node.input[0], ratio_node.output[0]]) new_dropout.output.extend(node.output) dropouts.append(get_node_index(model, node)) @@ -239,10 +255,11 @@ def process_dropout(model): for d in dropouts: del model.graph.node[d] + def remove_input_ids_check_subgraph(model): aten_node = None for node in model.graph.node: - if node.op_type == 'ATen': + if node.op_type == "ATen": aten_node = node for i in node.input: input_node = find_input_node(model, i) @@ -260,9 +277,7 @@ def remove_input_ids_check_subgraph(model): removed_nodes.extend(get_nodes_to_remove(or_node.input[0])) removed_nodes.extend(get_nodes_to_remove(or_node.input[1])) - removed_nodes.extend([ - cast_node, cast_node2, or_node, aten_node - ]) + removed_nodes.extend([cast_node, cast_node2, or_node, aten_node]) remove_node_index = [] for n in removed_nodes: @@ -274,6 +289,7 @@ def remove_input_ids_check_subgraph(model): print("Removing useless node ", model.graph.node[d].name) del model.graph.node[d] + def get_nodes_to_remove(input_id): cast_node3 = find_input_node(model, input_id) not_node3 = find_input_node(model, cast_node3.input[0]) @@ -289,14 +305,15 @@ def get_nodes_to_remove(input_id): break return [cast_node3, not_node3, less_node, const_node] + def fix_split(model): # having split attribute, Split op shape inferencing bring 0, so we remove them. for node in model.graph.node: - if node.op_type == 'Split': + if node.op_type == "Split": index = 0 need_remove = False for attr in node.attribute: - if attr.name == 'split': + if attr.name == "split": need_remove = True break index += 1 @@ -304,18 +321,20 @@ def fix_split(model): print("Removing attribute split for ", node.name) del node.attribute[index] + def align_attention_mask_dim(model): for model_input in model.graph.input: if model_input.name == "attention_mask": model_input.type.tensor_type.shape.dim[0].dim_param = "batch" -#add name to nodes + +# add name to nodes add_name(model) -#replace garther&concat to reshape +# replace garther&concat to reshape process_concat(model) -#constant fold transpose +# constant fold transpose fix_transpose(model) -#replace dropout with trainable dropout +# replace dropout with trainable dropout process_dropout(model) remove_input_ids_check_subgraph(model) @@ -324,9 +343,9 @@ def align_attention_mask_dim(model): align_attention_mask_dim(model) -#set opset version to 10 +# set opset version to 10 model.opset_import[0].version = 10 f = open(output_model_name, "wb") f.write(model.SerializeToString()) -f.close() \ No newline at end of file +f.close() diff --git a/orttraining/tools/scripts/layer_norm_transform.py b/orttraining/tools/scripts/layer_norm_transform.py index 6355118709ff8..0ad4ea2559207 100644 --- a/orttraining/tools/scripts/layer_norm_transform.py +++ b/orttraining/tools/scripts/layer_norm_transform.py @@ -4,16 +4,18 @@ import onnx import numpy as np + def find_node(graph_proto, op_type): nodes = [] map_input_node = {} for node in graph_proto.node: if node.op_type == op_type: - node_input = node.input[1] if op_type == 'Div' or op_type == 'Mul' else node.input[0] + node_input = node.input[1] if op_type == "Div" or op_type == "Mul" else node.input[0] nodes.append(node) map_input_node[node_input] = node return nodes, map_input_node + def gen_attribute(key, value): attr = AttributeProto() attr.name = key @@ -21,48 +23,49 @@ def gen_attribute(key, value): attr.type = AttributeProto.INTS return attr + def main(): if len(sys.argv) < 2: print("Please give model path...") return model_file_path = sys.argv[1] - #model_file_path = os.path.dirname(sys.argv[1:]) + # model_file_path = os.path.dirname(sys.argv[1:]) print("model_file_path: " + model_file_path) model_file_name = os.path.basename(model_file_path) - print("model_file_name: "+ model_file_name) + print("model_file_name: " + model_file_name) - new_model_file_path = model_file_path[:-5] + '_layer_norm.onnx' + new_model_file_path = model_file_path[:-5] + "_layer_norm.onnx" print(new_model_file_path) model_proto = onnx.load(model_file_path) - #print(model_proto) + # print(model_proto) graph_proto = model_proto.graph - #print(graph_proto) - #print(graph_proto.input) - - nodes_Div, map_input_Div = find_node(graph_proto, 'Div') - #print(map_input_Div) - nodes_Sqrt, map_input_Sqrt = find_node(graph_proto, 'Sqrt') - #print(map_input_Sqrt) - nodes_Add, map_input_Add = find_node(graph_proto, 'Add') - #print(map_input_Add) - nodes_ReduceMean, map_input_ReduceMean = find_node(graph_proto, 'ReduceMean') - #print(map_input_ReduceMean) - nodes_Pow, map_input_Pow = find_node(graph_proto, 'Pow') - #print(map_input_Pow) - nodes_Mul, map_input_Mul = find_node(graph_proto, 'Mul') + # print(graph_proto) + # print(graph_proto.input) + + nodes_Div, map_input_Div = find_node(graph_proto, "Div") + # print(map_input_Div) + nodes_Sqrt, map_input_Sqrt = find_node(graph_proto, "Sqrt") + # print(map_input_Sqrt) + nodes_Add, map_input_Add = find_node(graph_proto, "Add") + # print(map_input_Add) + nodes_ReduceMean, map_input_ReduceMean = find_node(graph_proto, "ReduceMean") + # print(map_input_ReduceMean) + nodes_Pow, map_input_Pow = find_node(graph_proto, "Pow") + # print(map_input_Pow) + nodes_Mul, map_input_Mul = find_node(graph_proto, "Mul") # find right side Sub nodes_Sub = [] map_input_Sub = {} for node in graph_proto.node: - if node.op_type == 'Sub': + if node.op_type == "Sub": if node.output[0] in map_input_Pow: nodes_Sub.append(node) map_input_Sub[node.input[1]] = node - #print(map_input_Sub) + # print(map_input_Sub) # find first ReduceMean first_ReduceMean = [] @@ -71,16 +74,16 @@ def main(): if node.output[0] in map_input_Sub: first_ReduceMean.append(node) first_ReduceMean_outputs.append(node.output[0]) - #print(first_ReduceMean) + # print(first_ReduceMean) # find constant node nodes_Constant = [] map_output_Constant = {} for node in graph_proto.node: - if node.op_type == 'Constant': + if node.op_type == "Constant": nodes_Constant.append(node) map_output_Constant[node.output[0]] = node - #print(map_input_Sub) + # print(map_input_Sub) id = 0 removed_nodes = [] @@ -110,25 +113,27 @@ def main(): removed_nodes.append(node_Mul) removed_nodes.append(node_Add1) removed_nodes.append(map_output_Constant[node_pow.input[1]]) - #print(map_output_Constant[node_Add.input[1]]) + # print(map_output_Constant[node_Add.input[1]]) removed_nodes.append(map_output_Constant[node_Add.input[1]]) layer_norm_output.append(node_Add1.output[0]) id = id + 1 - layer_norm_output.append('saved_mean_' + str(id)) + layer_norm_output.append("saved_mean_" + str(id)) id = id + 1 - layer_norm_output.append('saved_inv_std_var_' + str(id)) - layer_norm = helper.make_node("LayerNormalization", - layer_norm_input, - layer_norm_output, - "LayerNormalization_" + str(id), - None, - axis = node_reduce.attribute[0].ints[0], - epsilon = 9.999999960041972e-13) + layer_norm_output.append("saved_inv_std_var_" + str(id)) + layer_norm = helper.make_node( + "LayerNormalization", + layer_norm_input, + layer_norm_output, + "LayerNormalization_" + str(id), + None, + axis=node_reduce.attribute[0].ints[0], + epsilon=9.999999960041972e-13, + ) layer_norm_nodes.append(layer_norm) # remove left side Subs for node in graph_proto.node: - if node.op_type == 'Sub': + if node.op_type == "Sub": if node.input[1] in first_ReduceMean_outputs: removed_nodes.append(node) @@ -143,12 +148,13 @@ def main(): graph_proto.ClearField("node") graph_proto.node.extend(all_nodes) - with open(new_model_file_path, 'wb') as f: + with open(new_model_file_path, "wb") as f: f.write(model_proto.SerializeToString()) # Use ORT to verify the converted model. Notice that you must use python package from the # training branch because training requires some extra ops. import onnxruntime as ort + # We convert model to accept variable-length batch size, so it can be any positive integer. batch = 3 # This should match --max_seq_length when calling nv_run_pretraining.py. @@ -157,22 +163,23 @@ def main(): vocab_size = 30528 # Create a fake data point. - vocab_size = 30528 # It shoudl match the value from BERT config file. + vocab_size = 30528 # It shoudl match the value from BERT config file. input_ids = np.random.randint(low=0, high=vocab_size, size=(batch, sq_length), dtype=np.int64) segment_ids = np.random.randint(low=0, high=2, size=(batch, sq_length), dtype=np.int64) input_mask = np.ones((batch, sq_length), dtype=np.int64) # Do forward using the original model. sess = ort.InferenceSession(model_file_path, providers=ort.get_available_providers()) - result = sess.run(None, {'input1': input_ids, 'input2': segment_ids, 'input3': input_mask}) + result = sess.run(None, {"input1": input_ids, "input2": segment_ids, "input3": input_mask}) # Do forward using the new model. new_sess = ort.InferenceSession(new_model_file_path, providers=ort.get_available_providers()) - new_result = new_sess.run(None, {'input1': input_ids, 'input2': segment_ids, 'input3': input_mask}) + new_result = new_sess.run(None, {"input1": input_ids, "input2": segment_ids, "input3": input_mask}) # Compare the outcomes from the two models. - print(np.linalg.norm(result[0]-new_result[0])) - print(np.linalg.norm(result[1]-new_result[1])) + print(np.linalg.norm(result[0] - new_result[0])) + print(np.linalg.norm(result[1] - new_result[1])) + if __name__ == "__main__": main() diff --git a/orttraining/tools/scripts/model_transform.py b/orttraining/tools/scripts/model_transform.py index de23df13a1963..8c0be5b08c04a 100644 --- a/orttraining/tools/scripts/model_transform.py +++ b/orttraining/tools/scripts/model_transform.py @@ -10,15 +10,17 @@ exit(1) input_model_name = sys.argv[1] -output_model_name = input_model_name[:-5] + '_optimized.onnx' +output_model_name = input_model_name[:-5] + "_optimized.onnx" model = onnx.load(input_model_name) + def add_name(model): i = 0 for node in model.graph.node: - node.name = '%s_%d' %(node.op_type, i) - i += 1 + node.name = "%s_%d" % (node.op_type, i) + i += 1 + def find_input_node(model, arg): result = [] @@ -26,7 +28,8 @@ def find_input_node(model, arg): for output in node.output: if output == arg: result.append(node) - return result[0] if len(result)== 1 else None + return result[0] if len(result) == 1 else None + def find_output_node(model, arg): result = [] @@ -36,12 +39,14 @@ def find_output_node(model, arg): result.append(node) return result[0] if len(result) == 1 else None + def find_input(model, arg): for initializer in model.graph.initializer: if initializer.name == arg: return initializer return None + def find_all_fused_nodes(model, concat_node): result = [] candidate = [concat_node] @@ -49,7 +54,7 @@ def find_all_fused_nodes(model, concat_node): node = candidate[0] candidate.pop(0) result.append(node) - if node.op_type == 'Shape': + if node.op_type == "Shape": continue for input in node.input: input_node = find_input_node(model, input) @@ -57,21 +62,23 @@ def find_all_fused_nodes(model, concat_node): candidate.append(input_node) return result + def get_node_index(model, node): i = 0 while i < len(model.graph.node): if model.graph.node[i] == node: - break; + break i += 1 - return i if i < len(model.graph.node) else None; + return i if i < len(model.graph.node) else None -def add_const(model, name, output, t_value = None, f_value = None): + +def add_const(model, name, output, t_value=None, f_value=None): const_node = model.graph.node.add() - const_node.op_type = 'Constant' + const_node.op_type = "Constant" const_node.name = name const_node.output.extend([output]) attr = const_node.attribute.add() - attr.name = 'value' + attr.name = "value" if t_value is not None: attr.type = 4 attr.t.CopyFrom(t_value) @@ -80,103 +87,110 @@ def add_const(model, name, output, t_value = None, f_value = None): attr.f = f_value return const_node + def process_concat(model): new_nodes = {} delete_nodes = [] for node in model.graph.node: - if node.op_type == 'Concat': + if node.op_type == "Concat": input_nodes = [] for input in node.input: input_nodes.append(find_input_node(model, input)) - #figure out target shape + # figure out target shape shape = [] for input_node in input_nodes: - assert input_node.op_type == 'Unsqueeze' + assert input_node.op_type == "Unsqueeze" const_input = find_input_node(model, input_node.input[0]) - if const_input.op_type != 'Constant': + if const_input.op_type != "Constant": shape.append(0) else: attr = const_input.attribute assert len(attr) == 1 - assert attr[0].name == 'value' + assert attr[0].name == "value" assert attr[0].type == 4 data = numpy_helper.to_array(attr[0].t) shape.append(np.asscalar(data)) - print('concat node: %s, new_shape is: %s' % (node.name, shape)) - #find out the nodes need to be deleted. + print("concat node: %s, new_shape is: %s" % (node.name, shape)) + # find out the nodes need to be deleted. fuse_nodes = find_all_fused_nodes(model, node) reshape_node = find_output_node(model, node.output[0]) - assert reshape_node.op_type == 'Reshape' + assert reshape_node.op_type == "Reshape" new_nodes[get_node_index(model, reshape_node)] = shape for n in fuse_nodes: delete_nodes.append(get_node_index(model, n)) - #insert new shape to reshape + # insert new shape to reshape index = 0 for reshape_node_index in new_nodes: shape_tensor = numpy_helper.from_array(np.asarray(new_nodes[reshape_node_index], dtype=np.int64)) - const_node = add_const(model, 'concat_shape_node_%d' % index, 'concat_shape_%d' % index, shape_tensor) - index+=1 + const_node = add_const(model, "concat_shape_node_%d" % index, "concat_shape_%d" % index, shape_tensor) + index += 1 reshape_node = model.graph.node[reshape_node_index] reshape_node.input[1] = const_node.output[0] - #delete nodes + # delete nodes delete_nodes.sort(reverse=True) for delete_node in delete_nodes: del model.graph.node[delete_node] + def add_cast(model, name, input, output, type): cast_node = model.graph.node.add() cast_node.name = name - cast_node.op_type = 'Cast' + cast_node.op_type = "Cast" attr = cast_node.attribute.add() - attr.name = 'to' + attr.name = "to" attr.type = 2 attr.i = type cast_node.input.extend([input]) cast_node.output.extend([output]) return cast_node + def fix_expand(model): - #find expand node + # find expand node expand_node = None for node in model.graph.node: - if node.op_type == 'Expand': + if node.op_type == "Expand": expand_node = node break assert expand_node is not None const_expand_input = find_input_node(model, expand_node.input[0]) - assert const_expand_input.op_type == 'Constant' + assert const_expand_input.op_type == "Constant" shape_node = find_input_node(model, expand_node.input[1]) - assert shape_node.op_type == 'Shape' - #insert cast --> min --> cast - cast_1 = add_cast(model, 'new_cast_01', shape_node.output[0], 'to_min_01', 1) + assert shape_node.op_type == "Shape" + # insert cast --> min --> cast + cast_1 = add_cast(model, "new_cast_01", shape_node.output[0], "to_min_01", 1) min_target = numpy_helper.from_array(np.asarray([1, 9999], dtype=np.float32)) - min_target_node = add_const(model, 'op_min_node_10', 'op_min_ends_expand_10', min_target) + min_target_node = add_const(model, "op_min_node_10", "op_min_ends_expand_10", min_target) min_node = model.graph.node.add() - min_node.name = 'new_min_01' - min_node.op_type = 'Min' + min_node.name = "new_min_01" + min_node.op_type = "Min" min_node.input.extend([cast_1.output[0], min_target_node.output[0]]) - min_node.output.extend(['from_min_01']) - cast_2 = add_cast(model, 'new_cast_02', min_node.output[0], 'to_slice_01', 7) - #insert slice + min_node.output.extend(["from_min_01"]) + cast_2 = add_cast(model, "new_cast_02", min_node.output[0], "to_slice_01", 7) + # insert slice position = numpy_helper.from_array(np.expand_dims(np.arange(512, dtype=np.int64), axis=0)) - position_node = add_const(model, 'position_01_node', 'position_01', position) - start_extend = numpy_helper.from_array(np.asarray([0, 0], dtype=np.int64), 'start_expand_10') - start_extend_node = add_const(model, 'start_expand_10_node', 'start_expand_10', start_extend) - axes = numpy_helper.from_array(np.asarray([0, 1], dtype=np.int64), 'axes_expand_10') - axes_node = add_const(model, 'axes_expand_10_node', 'axes_expand_10', axes) + position_node = add_const(model, "position_01_node", "position_01", position) + start_extend = numpy_helper.from_array(np.asarray([0, 0], dtype=np.int64), "start_expand_10") + start_extend_node = add_const(model, "start_expand_10_node", "start_expand_10", start_extend) + axes = numpy_helper.from_array(np.asarray([0, 1], dtype=np.int64), "axes_expand_10") + axes_node = add_const(model, "axes_expand_10_node", "axes_expand_10", axes) slice_node = model.graph.node.add() - slice_node.name = 'new_slice_01' - slice_node.op_type = 'Slice' - slice_node.input.extend([position_node.output[0], start_extend_node.output[0], cast_2.output[0], axes_node.output[0]]) - slice_node.output.extend(['from_slice_01']) - #connect to expand + slice_node.name = "new_slice_01" + slice_node.op_type = "Slice" + slice_node.input.extend( + [position_node.output[0], start_extend_node.output[0], cast_2.output[0], axes_node.output[0]] + ) + slice_node.output.extend(["from_slice_01"]) + # connect to expand expand_node.input[0] = slice_node.output[0] - #delete the const input + # delete the const input del model.graph.node[get_node_index(model, const_expand_input)] + def fix_dim(model): del model.graph.input[3:] + def replace_input_arg(model, arg, new_arg): for node in model.graph.node: i = 0 @@ -185,6 +199,7 @@ def replace_input_arg(model, arg, new_arg): node.input[i] = new_arg i += 1 + def find_weight_index(model, name): index = 0 for w in model.graph.initializer: @@ -193,10 +208,11 @@ def find_weight_index(model, name): index += 1 return None + def fix_transpose(model): transpose = [] for node in model.graph.node: - if node.op_type == 'Transpose': + if node.op_type == "Transpose": weight = find_input(model, node.input[0]) if weight is not None: result = [] @@ -207,7 +223,7 @@ def fix_transpose(model): if len(result) > 1: continue perm = node.attribute[0] - assert perm.name == 'perm' + assert perm.name == "perm" perm = perm.ints assert len(perm) == 2 and perm[0] == 1 and perm[1] == 0 transpose.append((get_node_index(model, node), weight)) @@ -232,20 +248,23 @@ def fix_transpose(model): for w_i in old_ws: del model.graph.initializer[w_i] + def process_dropout(model): dropouts = [] index = 0 for node in model.graph.node: - if node.op_type == 'Dropout': + if node.op_type == "Dropout": new_dropout = model.graph.node.add() - new_dropout.op_type = 'TrainableDropout' - new_dropout.name = 'TrainableDropout_%d' % index - #make ratio node + new_dropout.op_type = "TrainableDropout" + new_dropout.name = "TrainableDropout_%d" % index + # make ratio node ratio = np.asarray([node.attribute[0].f], dtype=np.float32) print(ratio.shape) ratio_value = numpy_helper.from_array(ratio) - ratio_node = add_const(model, 'dropout_node_ratio_%d' % index, 'dropout_node_ratio_%d' % index, t_value=ratio_value) - print (ratio_node) + ratio_node = add_const( + model, "dropout_node_ratio_%d" % index, "dropout_node_ratio_%d" % index, t_value=ratio_value + ) + print(ratio_node) new_dropout.input.extend([node.input[0], ratio_node.output[0]]) new_dropout.output.extend(node.output) dropouts.append(get_node_index(model, node)) @@ -254,28 +273,30 @@ def process_dropout(model): for d in dropouts: del model.graph.node[d] + # Also need to set following line differently for differnt verison of bert # expand_out.name = '412' def add_expand_shape(model): expand_out = model.graph.value_info.add() - expand_out.name = '74' #'410' # 74 for base model + expand_out.name = "74" #'410' # 74 for base model expand_out.type.CopyFrom(model.graph.input[0].type) -#add name to nodes + +# add name to nodes add_name(model) -#replace garther&concat to reshape +# replace garther&concat to reshape process_concat(model) -#fix the expand with dynamic shape +# fix the expand with dynamic shape fix_expand(model) -#use dynamic batch/sequence +# use dynamic batch/sequence fix_dim(model) -#constant fold transpose +# constant fold transpose fix_transpose(model) -#replace dropout with trainable dropout +# replace dropout with trainable dropout process_dropout(model) -#add output shape of expand +# add output shape of expand add_expand_shape(model) -#set opset version to 10 +# set opset version to 10 model.opset_import[0].version = 10 f = open(output_model_name, "wb") @@ -285,6 +306,7 @@ def add_expand_shape(model): # Use ORT to verify the converted model. Notice that you must use python package from the # training branch because training requires some extra ops. import onnxruntime as ort + # We convert model to accept variable-length batch size, so it can be any positive integer. batch = 3 # This should match --max_seq_length when calling nv_run_pretraining.py. @@ -299,12 +321,12 @@ def add_expand_shape(model): # Do forward using the original model. sess = ort.InferenceSession(input_model_name, providers=ort.get_available_providers()) -result = sess.run(None, {'input1': input_ids, 'input2': segment_ids, 'input3': input_mask}) +result = sess.run(None, {"input1": input_ids, "input2": segment_ids, "input3": input_mask}) # Do forward using the new model. new_sess = ort.InferenceSession(output_model_name, providers=ort.get_available_providers()) -new_result = new_sess.run(None, {'input1': input_ids, 'input2': segment_ids, 'input3': input_mask}) +new_result = new_sess.run(None, {"input1": input_ids, "input2": segment_ids, "input3": input_mask}) # Compare the outcomes from the two models. -print(np.linalg.norm(result[0]-new_result[0])) -print(np.linalg.norm(result[1]-new_result[1])) +print(np.linalg.norm(result[0] - new_result[0])) +print(np.linalg.norm(result[1] - new_result[1])) diff --git a/orttraining/tools/scripts/nv_run_pretraining.py b/orttraining/tools/scripts/nv_run_pretraining.py index c7c03be161c08..3e51a8886ecb6 100644 --- a/orttraining/tools/scripts/nv_run_pretraining.py +++ b/orttraining/tools/scripts/nv_run_pretraining.py @@ -52,9 +52,9 @@ from concurrent.futures import ProcessPoolExecutor -logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', - datefmt='%m/%d/%Y %H:%M:%S', - level=logging.INFO) +logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO +) logger = logging.getLogger(__name__) @@ -62,32 +62,41 @@ def create_pretraining_dataset(input_file, max_pred_length, shared_list, args): train_data = pretraining_dataset(input_file=input_file, max_pred_length=max_pred_length) train_sampler = RandomSampler(train_data) - train_dataloader = DataLoader(train_data, sampler=train_sampler, - batch_size=args.train_batch_size * args.n_gpu, num_workers=4, - pin_memory=True) + train_dataloader = DataLoader( + train_data, sampler=train_sampler, batch_size=args.train_batch_size * args.n_gpu, num_workers=4, pin_memory=True + ) # shared_list["0"] = (train_dataloader, input_file) return train_dataloader, input_file -class pretraining_dataset(Dataset): +class pretraining_dataset(Dataset): def __init__(self, input_file, max_pred_length): self.input_file = input_file self.max_pred_length = max_pred_length f = h5py.File(input_file, "r") - keys = ['input_ids', 'input_mask', 'segment_ids', 'masked_lm_positions', 'masked_lm_ids', - 'next_sentence_labels'] + keys = [ + "input_ids", + "input_mask", + "segment_ids", + "masked_lm_positions", + "masked_lm_ids", + "next_sentence_labels", + ] self.inputs = [np.asarray(f[key][:]) for key in keys] f.close() def __len__(self): - 'Denotes the total number of samples' + "Denotes the total number of samples" return len(self.inputs[0]) def __getitem__(self, index): [input_ids, input_mask, segment_ids, masked_lm_positions, masked_lm_ids, next_sentence_labels] = [ - torch.from_numpy(input[index].astype(np.int64)) if indice < 5 else torch.from_numpy( - np.asarray(input[index].astype(np.int64))) for indice, input in enumerate(self.inputs)] + torch.from_numpy(input[index].astype(np.int64)) + if indice < 5 + else torch.from_numpy(np.asarray(input[index].astype(np.int64))) + for indice, input in enumerate(self.inputs) + ] masked_lm_labels = torch.ones(input_ids.shape, dtype=torch.long) * -1 index = self.max_pred_length @@ -97,136 +106,125 @@ def __getitem__(self, index): index = padded_mask_indices[0].item() masked_lm_labels[masked_lm_positions[:index]] = masked_lm_ids[:index] - return [input_ids, segment_ids, input_mask, - masked_lm_labels, next_sentence_labels] + return [input_ids, segment_ids, input_mask, masked_lm_labels, next_sentence_labels] + def parse_arguments(): parser = argparse.ArgumentParser() ## Required parameters - parser.add_argument("--input_dir", - default=None, - type=str, - required=True, - help="The input data dir. Should contain .hdf5 files for the task.") - - parser.add_argument("--config_file", - default=None, - type=str, - required=True, - help="The BERT model config") - - parser.add_argument("--bert_model", default="bert-large-uncased", type=str, - help="Bert pre-trained model selected in the list: bert-base-uncased, " - "bert-large-uncased, bert-base-cased, bert-base-multilingual, bert-base-chinese.") - - parser.add_argument("--output_dir", - default=None, - type=str, - required=True, - help="The output directory where the model checkpoints will be written.") + parser.add_argument( + "--input_dir", + default=None, + type=str, + required=True, + help="The input data dir. Should contain .hdf5 files for the task.", + ) + + parser.add_argument("--config_file", default=None, type=str, required=True, help="The BERT model config") + + parser.add_argument( + "--bert_model", + default="bert-large-uncased", + type=str, + help="Bert pre-trained model selected in the list: bert-base-uncased, " + "bert-large-uncased, bert-base-cased, bert-base-multilingual, bert-base-chinese.", + ) + + parser.add_argument( + "--output_dir", + default=None, + type=str, + required=True, + help="The output directory where the model checkpoints will be written.", + ) ## Other parameters - parser.add_argument("--max_seq_length", - default=512, - type=int, - help="The maximum total input sequence length after WordPiece tokenization. \n" - "Sequences longer than this will be truncated, and sequences shorter \n" - "than this will be padded.") - parser.add_argument("--max_predictions_per_seq", - default=80, - type=int, - help="The maximum total of masked tokens in input sequence") - parser.add_argument("--train_batch_size", - default=32, - type=int, - help="Total batch size for training.") - parser.add_argument("--learning_rate", - default=5e-5, - type=float, - help="The initial learning rate for Adam.") - parser.add_argument("--num_train_epochs", - default=3.0, - type=float, - help="Total number of training epochs to perform.") - parser.add_argument("--max_steps", - default=1000, - type=float, - help="Total number of training steps to perform.") - parser.add_argument("--warmup_proportion", - default=0.01, - type=float, - help="Proportion of training to perform linear learning rate warmup for. " - "E.g., 0.1 = 10%% of training.") - parser.add_argument("--local_rank", - type=int, - default=-1, - help="local_rank for distributed training on gpus") - parser.add_argument('--seed', - type=int, - default=42, - help="random seed for initialization") - parser.add_argument('--gradient_accumulation_steps', - type=int, - default=1, - help="Number of updates steps to accumualte before performing a backward/update pass.") - parser.add_argument('--fp16', - default=False, - action='store_true', - help="Whether to use 16-bit float precision instead of 32-bit") - parser.add_argument('--loss_scale', - type=float, default=0.0, - help='Loss scaling, positive power of 2 values can improve fp16 convergence.') - parser.add_argument('--log_freq', - type=float, default=50.0, - help='frequency of logging loss.') - parser.add_argument('--checkpoint_activations', - default=False, - action='store_true', - help="Whether to use gradient checkpointing") - parser.add_argument("--resume_from_checkpoint", - default=False, - action='store_true', - help="Whether to resume training from checkpoint.") - parser.add_argument('--resume_step', - type=int, - default=-1, - help="Step to resume training from.") - parser.add_argument('--num_steps_per_checkpoint', - type=int, - default=100, - help="Number of update steps until a model checkpoint is saved to disk.") - parser.add_argument('--phase2', - default=False, - action='store_true', - help="Whether to train with seq len 512") - parser.add_argument('--allreduce_post_accumulation', - default=False, - action='store_true', - help="Whether to do allreduces during gradient accumulation steps.") - parser.add_argument('--allreduce_post_accumulation_fp16', - default=False, - action='store_true', - help="Whether to do fp16 allreduce post accumulation.") - parser.add_argument('--accumulate_into_fp16', - default=False, - action='store_true', - help="Whether to use fp16 gradient accumulators.") - parser.add_argument('--phase1_end_step', - type=int, - default=7038, - help="Number of training steps in Phase1 - seq len 128") - parser.add_argument("--do_train", - default=False, - action='store_true', - help="Whether to run training.") + parser.add_argument( + "--max_seq_length", + default=512, + type=int, + help="The maximum total input sequence length after WordPiece tokenization. \n" + "Sequences longer than this will be truncated, and sequences shorter \n" + "than this will be padded.", + ) + parser.add_argument( + "--max_predictions_per_seq", default=80, type=int, help="The maximum total of masked tokens in input sequence" + ) + parser.add_argument("--train_batch_size", default=32, type=int, help="Total batch size for training.") + parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.") + parser.add_argument( + "--num_train_epochs", default=3.0, type=float, help="Total number of training epochs to perform." + ) + parser.add_argument("--max_steps", default=1000, type=float, help="Total number of training steps to perform.") + parser.add_argument( + "--warmup_proportion", + default=0.01, + type=float, + help="Proportion of training to perform linear learning rate warmup for. " "E.g., 0.1 = 10%% of training.", + ) + parser.add_argument("--local_rank", type=int, default=-1, help="local_rank for distributed training on gpus") + parser.add_argument("--seed", type=int, default=42, help="random seed for initialization") + parser.add_argument( + "--gradient_accumulation_steps", + type=int, + default=1, + help="Number of updates steps to accumualte before performing a backward/update pass.", + ) + parser.add_argument( + "--fp16", default=False, action="store_true", help="Whether to use 16-bit float precision instead of 32-bit" + ) + parser.add_argument( + "--loss_scale", + type=float, + default=0.0, + help="Loss scaling, positive power of 2 values can improve fp16 convergence.", + ) + parser.add_argument("--log_freq", type=float, default=50.0, help="frequency of logging loss.") + parser.add_argument( + "--checkpoint_activations", default=False, action="store_true", help="Whether to use gradient checkpointing" + ) + parser.add_argument( + "--resume_from_checkpoint", + default=False, + action="store_true", + help="Whether to resume training from checkpoint.", + ) + parser.add_argument("--resume_step", type=int, default=-1, help="Step to resume training from.") + parser.add_argument( + "--num_steps_per_checkpoint", + type=int, + default=100, + help="Number of update steps until a model checkpoint is saved to disk.", + ) + parser.add_argument("--phase2", default=False, action="store_true", help="Whether to train with seq len 512") + parser.add_argument( + "--allreduce_post_accumulation", + default=False, + action="store_true", + help="Whether to do allreduces during gradient accumulation steps.", + ) + parser.add_argument( + "--allreduce_post_accumulation_fp16", + default=False, + action="store_true", + help="Whether to do fp16 allreduce post accumulation.", + ) + parser.add_argument( + "--accumulate_into_fp16", default=False, action="store_true", help="Whether to use fp16 gradient accumulators." + ) + parser.add_argument( + "--phase1_end_step", type=int, default=7038, help="Number of training steps in Phase1 - seq len 128" + ) + parser.add_argument("--do_train", default=False, action="store_true", help="Whether to run training.") args = parser.parse_args() return args + def setup_training(args): - assert (torch.cuda.is_available()) + assert torch.cuda.is_available() if args.local_rank == -1: device = torch.device("cuda") @@ -236,24 +234,31 @@ def setup_training(args): device = torch.device("cuda", args.local_rank) args.n_gpu = 1 # Initializes the distributed backend which will take care of sychronizing nodes/GPUs - torch.distributed.init_process_group(backend='nccl', init_method='env://') + torch.distributed.init_process_group(backend="nccl", init_method="env://") logger.info("device %s n_gpu %d distributed training %r", device, args.n_gpu, bool(args.local_rank != -1)) if args.gradient_accumulation_steps < 1: - raise ValueError("Invalid gradient_accumulation_steps parameter: {}, should be >= 1".format( - args.gradient_accumulation_steps)) + raise ValueError( + "Invalid gradient_accumulation_steps parameter: {}, should be >= 1".format(args.gradient_accumulation_steps) + ) if args.train_batch_size % args.gradient_accumulation_steps != 0: - raise ValueError("Invalid gradient_accumulation_steps parameter: {}, batch size {} should be divisible".format( - args.gradient_accumulation_steps, args.train_batch_size)) + raise ValueError( + "Invalid gradient_accumulation_steps parameter: {}, batch size {} should be divisible".format( + args.gradient_accumulation_steps, args.train_batch_size + ) + ) args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps if not args.do_train: raise ValueError(" `do_train` must be True.") - if not args.resume_from_checkpoint and os.path.exists(args.output_dir) and ( - os.listdir(args.output_dir) and os.listdir(args.output_dir) != ['logfile.txt']): + if ( + not args.resume_from_checkpoint + and os.path.exists(args.output_dir) + and (os.listdir(args.output_dir) and os.listdir(args.output_dir) != ["logfile.txt"]) + ): raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir)) if not args.resume_from_checkpoint: @@ -261,6 +266,7 @@ def setup_training(args): return device, args + def prepare_model_and_optimizer(args, device): # Prepare model @@ -277,11 +283,11 @@ def prepare_model_and_optimizer(args, device): else: if args.resume_step == -1: model_names = [f for f in os.listdir(args.output_dir) if f.endswith(".pt")] - args.resume_step = max([int(x.split('.pt')[0].split('_')[1].strip()) for x in model_names]) + args.resume_step = max([int(x.split(".pt")[0].split("_")[1].strip()) for x in model_names]) global_step = args.resume_step checkpoint = torch.load(os.path.join(args.output_dir, "ckpt_{}.pt".format(global_step)), map_location="cpu") - model.load_state_dict(checkpoint['model'], strict=False) + model.load_state_dict(checkpoint["model"], strict=False) if args.phase2: global_step -= args.phase1_end_step if is_main_process(): @@ -289,8 +295,8 @@ def prepare_model_and_optimizer(args, device): model.to(device) param_optimizer = list(model.named_parameters()) - no_decay = ['bias', 'gamma', 'beta', 'LayerNorm'] - + no_decay = ["bias", "gamma", "beta", "LayerNorm"] + optimizer_grouped_parameters = [] names = [] @@ -298,58 +304,68 @@ def prepare_model_and_optimizer(args, device): for n, p in param_optimizer: count += 1 if not any(nd in n for nd in no_decay): - optimizer_grouped_parameters.append({'params': [p], 'weight_decay': 0.01, 'name': n}) - names.append({'params': [n], 'weight_decay': 0.01}) + optimizer_grouped_parameters.append({"params": [p], "weight_decay": 0.01, "name": n}) + names.append({"params": [n], "weight_decay": 0.01}) if any(nd in n for nd in no_decay): - optimizer_grouped_parameters.append({'params': [p], 'weight_decay': 0.00, 'name': n}) - names.append({'params': [n], 'weight_decay': 0.00}) + optimizer_grouped_parameters.append({"params": [p], "weight_decay": 0.00, "name": n}) + names.append({"params": [n], "weight_decay": 0.00}) - optimizer = BertLAMB(optimizer_grouped_parameters, - lr=args.learning_rate, - warmup=args.warmup_proportion, - t_total=args.max_steps) + optimizer = BertLAMB( + optimizer_grouped_parameters, lr=args.learning_rate, warmup=args.warmup_proportion, t_total=args.max_steps + ) if args.fp16: if args.loss_scale == 0: # optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=True) - model, optimizer = amp.initialize(model, optimizer, opt_level="O2", loss_scale="dynamic", - master_weights=False if args.accumulate_into_fp16 else True) + model, optimizer = amp.initialize( + model, + optimizer, + opt_level="O2", + loss_scale="dynamic", + master_weights=False if args.accumulate_into_fp16 else True, + ) else: # optimizer = FP16_Optimizer(optimizer, static_loss_scale=args.loss_scale) - model, optimizer = amp.initialize(model, optimizer, opt_level="O2", loss_scale=args.loss_scale, - master_weights=False if args.accumulate_into_fp16 else True) + model, optimizer = amp.initialize( + model, + optimizer, + opt_level="O2", + loss_scale=args.loss_scale, + master_weights=False if args.accumulate_into_fp16 else True, + ) amp._amp_state.loss_scalers[0]._loss_scale = 2**20 if args.resume_from_checkpoint: if args.phase2: - keys = list(checkpoint['optimizer']['state'].keys()) - #Override hyperparameters from Phase 1 + keys = list(checkpoint["optimizer"]["state"].keys()) + # Override hyperparameters from Phase 1 for key in keys: - checkpoint['optimizer']['state'][key]['step'] = global_step - for iter, item in enumerate(checkpoint['optimizer']['param_groups']): - checkpoint['optimizer']['param_groups'][iter]['t_total'] = args.max_steps - checkpoint['optimizer']['param_groups'][iter]['warmup'] = args.warmup_proportion - checkpoint['optimizer']['param_groups'][iter]['lr'] = args.learning_rate - optimizer.load_state_dict(checkpoint['optimizer']) # , strict=False) - - # Restore AMP master parameters + checkpoint["optimizer"]["state"][key]["step"] = global_step + for iter, item in enumerate(checkpoint["optimizer"]["param_groups"]): + checkpoint["optimizer"]["param_groups"][iter]["t_total"] = args.max_steps + checkpoint["optimizer"]["param_groups"][iter]["warmup"] = args.warmup_proportion + checkpoint["optimizer"]["param_groups"][iter]["lr"] = args.learning_rate + optimizer.load_state_dict(checkpoint["optimizer"]) # , strict=False) + + # Restore AMP master parameters if args.fp16: optimizer._lazy_init_maybe_master_weights() optimizer._amp_stash.lazy_init_called = True - optimizer.load_state_dict(checkpoint['optimizer']) - for param, saved_param in zip(amp.master_params(optimizer), checkpoint['master params']): + optimizer.load_state_dict(checkpoint["optimizer"]) + for param, saved_param in zip(amp.master_params(optimizer), checkpoint["master params"]): param.data.copy_(saved_param.data) if args.local_rank != -1: if not args.allreduce_post_accumulation: model = DDP(model, message_size=250000000, gradient_predivide_factor=torch.distributed.get_world_size()) else: - flat_dist_call([param.data for param in model.parameters()], torch.distributed.broadcast, (0,) ) + flat_dist_call([param.data for param in model.parameters()], torch.distributed.broadcast, (0,)) elif args.n_gpu > 1: model = torch.nn.DataParallel(model) return model, optimizer, checkpoint, global_step + def take_optimizer_step(args, optimizer, model, overflow_buf, global_step): if args.allreduce_post_accumulation: @@ -360,22 +376,21 @@ def take_optimizer_step(args, optimizer, model, overflow_buf, global_step): master_grads = [p.grad for p in amp.master_params(optimizer) if p.grad is not None] flat_grad_size = sum(p.numel() for p in master_grads) allreduce_dtype = torch.float16 if args.allreduce_post_accumulation_fp16 else torch.float32 - flat_raw = torch.empty(flat_grad_size, device='cuda', dtype=allreduce_dtype) + flat_raw = torch.empty(flat_grad_size, device="cuda", dtype=allreduce_dtype) # 2. combine unflattening and predivision of unscaled 'raw' gradient allreduced_views = apex_C.unflatten(flat_raw, master_grads) overflow_buf.zero_() - amp_C.multi_tensor_scale(65536, + amp_C.multi_tensor_scale( + 65536, overflow_buf, [master_grads, allreduced_views], - scaler.loss_scale() / (torch.distributed.get_world_size() * args.gradient_accumulation_steps)) + scaler.loss_scale() / (torch.distributed.get_world_size() * args.gradient_accumulation_steps), + ) # 3. sum gradient across ranks. Because of the predivision, this averages the gradient torch.distributed.all_reduce(flat_raw) # 4. combine unscaling and unflattening of allreduced gradient overflow_buf.zero_() - amp_C.multi_tensor_scale(65536, - overflow_buf, - [allreduced_views, master_grads], - 1./scaler.loss_scale()) + amp_C.multi_tensor_scale(65536, overflow_buf, [allreduced_views, master_grads], 1.0 / scaler.loss_scale()) # 5. update loss scale scaler = _amp_state.loss_scalers[0] old_overflow_buf = scaler._overflow_buf @@ -389,10 +404,11 @@ def take_optimizer_step(args, optimizer, model, overflow_buf, global_step): else: # Overflow detected, print message and clear gradients if is_main_process(): - print(("Rank {} :: Gradient overflow. Skipping step, " + - "reducing loss scale to {}").format( - torch.distributed.get_rank(), - scaler.loss_scale())) + print( + ("Rank {} :: Gradient overflow. Skipping step, " + "reducing loss scale to {}").format( + torch.distributed.get_rank(), scaler.loss_scale() + ) + ) if _amp_state.opt_properties.master_weights: for param in optimizer._amp_stash.all_fp32_from_fp16_params: param.grad = None @@ -400,13 +416,14 @@ def take_optimizer_step(args, optimizer, model, overflow_buf, global_step): param.grad = None else: optimizer.step() - #optimizer.zero_grad() + # optimizer.zero_grad() for param in model.parameters(): param.grad = None global_step += 1 return global_step + def main(): args = parse_arguments() @@ -444,50 +461,66 @@ def main(): while True: thread = None if not args.resume_from_checkpoint or epoch > 0 or args.phase2: - files = [os.path.join(args.input_dir, f) for f in os.listdir(args.input_dir) if - os.path.isfile(os.path.join(args.input_dir, f))] + files = [ + os.path.join(args.input_dir, f) + for f in os.listdir(args.input_dir) + if os.path.isfile(os.path.join(args.input_dir, f)) + ] files.sort() num_files = len(files) random.shuffle(files) f_start_id = 0 else: - f_start_id = checkpoint['files'][0] - files = checkpoint['files'][1:] + f_start_id = checkpoint["files"][0] + files = checkpoint["files"][1:] args.resume_from_checkpoint = False num_files = len(files) - print('File list is [' + ','.join(files) + '].') - + print("File list is [" + ",".join(files) + "].") shared_file_list = {} if torch.distributed.is_initialized() and torch.distributed.get_world_size() > num_files: remainder = torch.distributed.get_world_size() % num_files - data_file = files[(f_start_id*torch.distributed.get_world_size()+torch.distributed.get_rank() + remainder*f_start_id)%num_files] + data_file = files[ + ( + f_start_id * torch.distributed.get_world_size() + + torch.distributed.get_rank() + + remainder * f_start_id + ) + % num_files + ] else: - data_file = files[f_start_id%num_files] + data_file = files[f_start_id % num_files] previous_file = data_file - print('Create pretraining_dataset with file {}...'.format(data_file)) + print("Create pretraining_dataset with file {}...".format(data_file)) train_data = pretraining_dataset(data_file, args.max_predictions_per_seq) train_sampler = RandomSampler(train_data) - train_dataloader = DataLoader(train_data, sampler=train_sampler, - batch_size=args.train_batch_size * args.n_gpu, num_workers=4, - pin_memory=True) + train_dataloader = DataLoader( + train_data, + sampler=train_sampler, + batch_size=args.train_batch_size * args.n_gpu, + num_workers=4, + pin_memory=True, + ) # shared_file_list["0"] = (train_dataloader, data_file) overflow_buf = None if args.allreduce_post_accumulation: overflow_buf = torch.cuda.IntTensor([0]) - for f_id in range(f_start_id + 1 , len(files)): - + for f_id in range(f_start_id + 1, len(files)): + # torch.cuda.synchronize() - # f_start = time.time() + # f_start = time.time() if torch.distributed.is_initialized() and torch.distributed.get_world_size() > num_files: - data_file = files[(f_id*torch.distributed.get_world_size()+torch.distributed.get_rank() + remainder*f_id)%num_files] + data_file = files[ + (f_id * torch.distributed.get_world_size() + torch.distributed.get_rank() + remainder * f_id) + % num_files + ] else: - data_file = files[f_id%num_files] + data_file = files[f_id % num_files] logger.info("file no %s file %s" % (f_id, previous_file)) @@ -501,8 +534,10 @@ def main(): # args=(data_file, args.max_predictions_per_seq, shared_file_list, args, n_gpu) # ) # thread.start() - print('Submit new data file {0} for the next iteration...'.format(data_file)) - dataset_future = pool.submit(create_pretraining_dataset, data_file, args.max_predictions_per_seq, shared_file_list, args) + print("Submit new data file {0} for the next iteration...".format(data_file)) + dataset_future = pool.submit( + create_pretraining_dataset, data_file, args.max_predictions_per_seq, shared_file_list, args + ) # torch.cuda.synchronize() # f_end = time.time() # print('[{}] : shard overhead {}'.format(torch.distributed.get_rank(), f_end - f_start)) @@ -516,37 +551,64 @@ def main(): batch = [t.to(device) for t in batch] input_ids, segment_ids, input_mask, masked_lm_labels, next_sentence_labels = batch if not is_model_exported: - onnx_path = os.path.join(args.output_dir, 'bert_for_pretraining_without_loss_' + config.to_string() + '.onnx') - lm_score, sq_score = model(input_ids=input_ids, token_type_ids=segment_ids, attention_mask=input_mask) - torch.onnx.export(model, (input_ids, segment_ids, input_mask), onnx_path, - verbose = True, - #input_names = ['input_ids', 'token_type_ids', 'input_mask'], - input_names = ['input1', 'input2', 'input3'], - output_names = ['output1', 'output2'], - dynamic_axes={'input1': {0: 'batch'}, 'input2': {0: 'batch'}, 'input3': {0: 'batch'}, 'output1': {0: 'batch'}, 'output2': {0: 'batch'}}, - training=True) - is_model_exported = False - - import onnxruntime as ort - sess = ort.InferenceSession(onnx_path, providers=ort.get_available_providers()) - result = sess.run(None, {'input1': input_ids.cpu().numpy(), 'input2': segment_ids.cpu().numpy(), 'input3': input_mask.cpu().numpy()}) - - print('---ORT result---') - print(result[0]) - print(result[1]) - - print('---Pytorch result---') - print(lm_score) - print(sq_score) - - print('---ORT-Pytorch Diff---') - print(np.linalg.norm(result[0]-lm_score.detach().cpu().numpy())) - print(np.linalg.norm(result[1]-sq_score.detach().cpu().numpy())) - return - - loss = model(input_ids=input_ids, token_type_ids=segment_ids, attention_mask=input_mask, - masked_lm_labels=masked_lm_labels, next_sentence_label=next_sentence_labels, - checkpoint_activations=args.checkpoint_activations) + onnx_path = os.path.join( + args.output_dir, "bert_for_pretraining_without_loss_" + config.to_string() + ".onnx" + ) + lm_score, sq_score = model( + input_ids=input_ids, token_type_ids=segment_ids, attention_mask=input_mask + ) + torch.onnx.export( + model, + (input_ids, segment_ids, input_mask), + onnx_path, + verbose=True, + # input_names = ['input_ids', 'token_type_ids', 'input_mask'], + input_names=["input1", "input2", "input3"], + output_names=["output1", "output2"], + dynamic_axes={ + "input1": {0: "batch"}, + "input2": {0: "batch"}, + "input3": {0: "batch"}, + "output1": {0: "batch"}, + "output2": {0: "batch"}, + }, + training=True, + ) + is_model_exported = False + + import onnxruntime as ort + + sess = ort.InferenceSession(onnx_path, providers=ort.get_available_providers()) + result = sess.run( + None, + { + "input1": input_ids.cpu().numpy(), + "input2": segment_ids.cpu().numpy(), + "input3": input_mask.cpu().numpy(), + }, + ) + + print("---ORT result---") + print(result[0]) + print(result[1]) + + print("---Pytorch result---") + print(lm_score) + print(sq_score) + + print("---ORT-Pytorch Diff---") + print(np.linalg.norm(result[0] - lm_score.detach().cpu().numpy())) + print(np.linalg.norm(result[1] - sq_score.detach().cpu().numpy())) + return + + loss = model( + input_ids=input_ids, + token_type_ids=segment_ids, + attention_mask=input_mask, + masked_lm_labels=masked_lm_labels, + next_sentence_label=next_sentence_labels, + checkpoint_activations=args.checkpoint_activations, + ) if args.n_gpu > 1: loss = loss.mean() # mean() to average on multi-gpu. @@ -557,7 +619,9 @@ def main(): loss = loss / args.gradient_accumulation_steps divisor = 1.0 if args.fp16: - with amp.scale_loss(loss, optimizer, delay_overflow_check=args.allreduce_post_accumulation) as scaled_loss: + with amp.scale_loss( + loss, optimizer, delay_overflow_check=args.allreduce_post_accumulation + ) as scaled_loss: scaled_loss.backward() else: loss.backward() @@ -571,36 +635,49 @@ def main(): last_num_steps = args.log_freq if last_num_steps == 0 else last_num_steps average_loss = torch.tensor(average_loss, dtype=torch.float32).cuda() average_loss = average_loss / (last_num_steps * divisor) - if (torch.distributed.is_initialized()): + if torch.distributed.is_initialized(): average_loss /= torch.distributed.get_world_size() torch.distributed.all_reduce(average_loss) if is_main_process(): logger.info("Total Steps:{} Final Loss = {}".format(training_steps, average_loss.item())) elif training_steps % (args.log_freq * args.gradient_accumulation_steps) == 0: if is_main_process(): - print("Step:{} Average Loss = {} Step Loss = {} LR {}".format(global_step, average_loss / ( - args.log_freq * divisor), - loss.item() * args.gradient_accumulation_steps / divisor, - optimizer.param_groups[0][ - 'lr'])) + print( + "Step:{} Average Loss = {} Step Loss = {} LR {}".format( + global_step, + average_loss / (args.log_freq * divisor), + loss.item() * args.gradient_accumulation_steps / divisor, + optimizer.param_groups[0]["lr"], + ) + ) average_loss = 0 - if global_step >= args.max_steps or training_steps % ( - args.num_steps_per_checkpoint * args.gradient_accumulation_steps) == 0: + if ( + global_step >= args.max_steps + or training_steps % (args.num_steps_per_checkpoint * args.gradient_accumulation_steps) == 0 + ): if is_main_process(): # Save a trained model logger.info("** ** * Saving fine - tuned model ** ** * ") - model_to_save = model.module if hasattr(model, - 'module') else model # Only save the model it-self + model_to_save = ( + model.module if hasattr(model, "module") else model + ) # Only save the model it-self if args.resume_step < 0 or not args.phase2: output_save_file = os.path.join(args.output_dir, "ckpt_{}.pt".format(global_step)) else: - output_save_file = os.path.join(args.output_dir, "ckpt_{}.pt".format(global_step + args.phase1_end_step)) + output_save_file = os.path.join( + args.output_dir, "ckpt_{}.pt".format(global_step + args.phase1_end_step) + ) if args.do_train: - torch.save({'model': model_to_save.state_dict(), - 'optimizer': optimizer.state_dict(), - 'master params': list(amp.master_params(optimizer)), - 'files': [f_id] + files}, output_save_file) + torch.save( + { + "model": model_to_save.state_dict(), + "optimizer": optimizer.state_dict(), + "master params": list(amp.master_params(optimizer)), + "files": [f_id] + files, + }, + output_save_file, + ) most_recent_ckpts_paths.append(output_save_file) if len(most_recent_ckpts_paths) > 3: @@ -612,7 +689,6 @@ def main(): # thread.join() return args - # torch.cuda.synchronize() # iter_end = time.time() diff --git a/orttraining/tools/scripts/opset12_model_transform.py b/orttraining/tools/scripts/opset12_model_transform.py index 3a56f63d43f2c..b3cd82e470d66 100644 --- a/orttraining/tools/scripts/opset12_model_transform.py +++ b/orttraining/tools/scripts/opset12_model_transform.py @@ -1,15 +1,15 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. # -# This converter is an internal util to upgrade existing bert/gpt-2 models, -# which were previously transformed/optimized from orginal model, to Opset 12 -# version as well as replacing deprecated node, i.e., TrainableDropout with -# the "Dropout" node matching the Opset 12 Spec. Typically, a model to -# be run by this scripts would have "_optimized" substring in its model name, -# and the graph should have one or more "TrainableDropout" nodes in its graph. -# Example usage: +# This converter is an internal util to upgrade existing bert/gpt-2 models, +# which were previously transformed/optimized from orginal model, to Opset 12 +# version as well as replacing deprecated node, i.e., TrainableDropout with +# the "Dropout" node matching the Opset 12 Spec. Typically, a model to +# be run by this scripts would have "_optimized" substring in its model name, +# and the graph should have one or more "TrainableDropout" nodes in its graph. +# Example usage: # python opset12_model_transform.py bert-base-uncased_L_12_H_768_A_12_V_30528_S_512_Dp_0.1_optimized_layer_norm.onnx -# Output: +# Output: # bert-base-uncased_L_12_H_768_A_12_V_30528_S_512_Dp_0.1_optimized_layer_norm_opset12.onnx import sys @@ -24,18 +24,19 @@ exit(1) input_model_name = sys.argv[1] -output_model_name = input_model_name[:-5] + '_opset12.onnx' +output_model_name = input_model_name[:-5] + "_opset12.onnx" model = onnx.load(input_model_name) # for a given node input, look thru the graph nodes and find the node -# whose output is matching the input +# whose output is matching the input def find_input_node(model, arg): result = [] for node in model.graph.node: for output in node.output: if output == arg: result.append(node) - return result[0] if len(result)== 1 else None + return result[0] if len(result) == 1 else None + def get_node_index(model, node): for i, graph_node in enumerate(model.graph.node): @@ -43,13 +44,14 @@ def get_node_index(model, node): return i return None -def add_const(model, name, output, t_value = None, f_value = None): + +def add_const(model, name, output, t_value=None, f_value=None): const_node = model.graph.node.add() - const_node.op_type = 'Constant' + const_node.op_type = "Constant" const_node.name = name const_node.output.extend([output]) attr = const_node.attribute.add() - attr.name = 'value' + attr.name = "value" if t_value is not None: attr.type = 4 attr.t.CopyFrom(t_value) @@ -58,34 +60,39 @@ def add_const(model, name, output, t_value = None, f_value = None): attr.f = f_value return const_node + def process_trainabledropout(model): - delete_nodes = [] + delete_nodes = [] index = 0 for node in model.graph.node: - if node.op_type == 'TrainableDropout': + if node.op_type == "TrainableDropout": new_dropout = model.graph.node.add() - new_dropout.op_type = 'Dropout' - new_dropout.name = 'Dropout_%d' % index + new_dropout.op_type = "Dropout" + new_dropout.name = "Dropout_%d" % index # add seed attribute attr = new_dropout.attribute.add() - attr.name = 'seed' + attr.name = "seed" attr.type = 2 # find old ratio node ratio_node = find_input_node(model, node.input[1]) - assert ratio_node.op_type == 'Constant' + assert ratio_node.op_type == "Constant" delete_nodes.append(get_node_index(model, ratio_node)) - # make ratio scalar node + # make ratio scalar node ratio_attr = ratio_node.attribute ratio_data = numpy_helper.to_array(ratio_attr[0].t) ratio_scalar = ratio_data.astype(np.float32).reshape(()) ratio_value = numpy_helper.from_array(ratio_scalar, "ratio") - new_ratio_node = add_const(model, 'dropout_ratio_node_%d' % index, 'dropout_ratio_%d' % index, t_value=ratio_value) - index+=1 + new_ratio_node = add_const( + model, "dropout_ratio_node_%d" % index, "dropout_ratio_%d" % index, t_value=ratio_value + ) + index += 1 # add training_mode output mode_scalar = np.asarray([True]).astype(np.bool).reshape(()) mode_value = numpy_helper.from_array(mode_scalar, "training_mode") - training_mode_node = add_const(model, 'dropout_training_mode_node_%d' % index, 'dropout_training_mode_%d' % index, t_value=mode_value) - index+=1 + training_mode_node = add_const( + model, "dropout_training_mode_node_%d" % index, "dropout_training_mode_%d" % index, t_value=mode_value + ) + index += 1 new_dropout.input.extend([node.input[0], new_ratio_node.output[0], training_mode_node.output[0]]) new_dropout.output.extend(node.output) @@ -96,23 +103,24 @@ def process_trainabledropout(model): for d in delete_nodes: del model.graph.node[d] + def align_attention_mask_dim(model): for model_input in model.graph.input: if model_input.name == "attention_mask": model_input.type.tensor_type.shape.dim[0].dim_param = "batch" -#replace TrainableDropout with Dropout +# replace TrainableDropout with Dropout process_trainabledropout(model) # some gpt-2 models (large ones) still don't have this input corrected align_attention_mask_dim(model) -#set opset version to 12 +# set opset version to 12 model.opset_import[0].version = 12 -with open (output_model_name, "wb") as f: +with open(output_model_name, "wb") as f: f.write(model.SerializeToString()) # -# To verify the converted model in case of bert, refer to the code at the end of model_transform.py +# To verify the converted model in case of bert, refer to the code at the end of model_transform.py # diff --git a/orttraining/tools/scripts/performance_investigation.py b/orttraining/tools/scripts/performance_investigation.py index 295eebe2d58ae..b064b13fa6d34 100644 --- a/orttraining/tools/scripts/performance_investigation.py +++ b/orttraining/tools/scripts/performance_investigation.py @@ -1,12 +1,11 @@ import argparse import onnx -parser = argparse.ArgumentParser(description='ONNX file analyzer for performance investigation.') -parser.add_argument('onnx_file', type=str, help='ONNX file to analyze') +parser = argparse.ArgumentParser(description="ONNX file analyzer for performance investigation.") +parser.add_argument("onnx_file", type=str, help="ONNX file to analyze") args = parser.parse_args() - def process_file(onnx_file): model = onnx.load(onnx_file) @@ -41,38 +40,42 @@ def process_file(onnx_file): if node.op_type == "Dropout" and len(node.input) == 1: prev = output_to_node[node.input[0]] if prev.op_type == "Add": - msgs.append(f"Examine whether {node.name} should be fused with the leading {prev.name} op into BiasDropout node.") + msgs.append( + f"Examine whether {node.name} should be fused with the leading {prev.name} op into BiasDropout node." + ) # Look for stand-alone Softmax node in *_execution_model_.onnx graph. # Examine whether it should be fused with the leading Add ops into BiasSoftmax node. if node.op_type == "Softmax" and len(node.input) == 1: prev = output_to_node[node.input[0]] if prev.op_type == "Add": - msgs.append(f"Examine whether {node.name} should be fused with the leading {prev.name} op into BiasSoftmax node.") + msgs.append( + f"Examine whether {node.name} should be fused with the leading {prev.name} op into BiasSoftmax node." + ) if aten_ops: print("ATen op found:") for line in aten_ops: print(line) - print(10 * '-') + print(10 * "-") if python_ops: print("PythonOp found:") for line in python_ops: print(line) - print(10 * '-') + print(10 * "-") if memcpu_ops: print("Memcpu ops found:") for line in memcpu_ops: print(line) - print(10 * '-') + print(10 * "-") if cast_ops: print("Cast ops found:") for line in cast_ops: print(line) - print(10 * '-') + print(10 * "-") for line in msgs: print(line) @@ -81,5 +84,6 @@ def process_file(onnx_file): def main(): process_file(args.onnx_file) + if __name__ == "__main__": main() diff --git a/orttraining/tools/scripts/pipeline_model_split.py b/orttraining/tools/scripts/pipeline_model_split.py index 008e626e3257d..b95bbe49003ec 100644 --- a/orttraining/tools/scripts/pipeline_model_split.py +++ b/orttraining/tools/scripts/pipeline_model_split.py @@ -4,6 +4,7 @@ from onnx import helper from onnx import TensorProto from onnx import OperatorSetIdProto + # Edge that needs to be cut for the split. # If the edge is feeding into more than one nodes, and not all the nodes belong to the same cut, # specify those consuming nodes that need to be cut @@ -20,11 +21,12 @@ def add_expand_type(model, name, type): expand_edge.name = name expand_edge.type.CopyFrom(type) + # Add wait/record/send/recv nodes and split the graph into disconnected subgraphs def split_graph(model, split_edge_groups): - ms_domain = 'com.microsoft' + ms_domain = "com.microsoft" new_send_nodes = [] new_recv_nodes = [] @@ -50,67 +52,67 @@ def split_graph(model, split_edge_groups): if info.name == id: output_shapes.append(info.type) - send_input_signal_name = 'send_input_signal' + str(cut_index) + send_input_signal_name = "send_input_signal" + str(cut_index) send_signal = model.graph.input.add() - send_signal.CopyFrom(helper.make_tensor_value_info( - send_input_signal_name, onnx.TensorProto.BOOL, None)) - send_signal = helper.make_tensor( - send_input_signal_name, TensorProto.BOOL, (), (True,)) + send_signal.CopyFrom(helper.make_tensor_value_info(send_input_signal_name, onnx.TensorProto.BOOL, None)) + send_signal = helper.make_tensor(send_input_signal_name, TensorProto.BOOL, (), (True,)) model.graph.initializer.extend([send_signal]) - recv_input_signal_name = 'recv_input_signal' + str(cut_index) + recv_input_signal_name = "recv_input_signal" + str(cut_index) recv_signal = model.graph.input.add() - recv_signal.CopyFrom(helper.make_tensor_value_info( - recv_input_signal_name, onnx.TensorProto.BOOL, None)) - recv_signal = helper.make_tensor( - recv_input_signal_name, TensorProto.BOOL, (), (True,)) + recv_signal.CopyFrom(helper.make_tensor_value_info(recv_input_signal_name, onnx.TensorProto.BOOL, None)) + recv_signal = helper.make_tensor(recv_input_signal_name, TensorProto.BOOL, (), (True,)) model.graph.initializer.extend([recv_signal]) - send_dst_rank_name = 'send_dst_rank' + str(cut_index) + send_dst_rank_name = "send_dst_rank" + str(cut_index) send_dst_rank = model.graph.input.add() - send_dst_rank.CopyFrom(helper.make_tensor_value_info( - send_dst_rank_name, onnx.TensorProto.INT64, None)) - send_dst_rank = helper.make_tensor( - send_dst_rank_name, TensorProto.INT64, (), (cut_index + 1,)) + send_dst_rank.CopyFrom(helper.make_tensor_value_info(send_dst_rank_name, onnx.TensorProto.INT64, None)) + send_dst_rank = helper.make_tensor(send_dst_rank_name, TensorProto.INT64, (), (cut_index + 1,)) model.graph.initializer.extend([send_dst_rank]) - recv_src_rank_name = 'recv_src_rank' + str(cut_index) + recv_src_rank_name = "recv_src_rank" + str(cut_index) recv_src_rank = model.graph.input.add() - recv_src_rank.CopyFrom(helper.make_tensor_value_info( - recv_src_rank_name, onnx.TensorProto.INT64, None)) - recv_src_rank = helper.make_tensor( - recv_src_rank_name, TensorProto.INT64, (), (cut_index,)) + recv_src_rank.CopyFrom(helper.make_tensor_value_info(recv_src_rank_name, onnx.TensorProto.INT64, None)) + recv_src_rank = helper.make_tensor(recv_src_rank_name, TensorProto.INT64, (), (cut_index,)) model.graph.initializer.extend([recv_src_rank]) # output signal from send after cut send_output_signal = model.graph.output.add() - send_output_signal.CopyFrom(helper.make_tensor_value_info( - 'send_output_signal' + str(cut_index), onnx.TensorProto.BOOL, None)) + send_output_signal.CopyFrom( + helper.make_tensor_value_info("send_output_signal" + str(cut_index), onnx.TensorProto.BOOL, None) + ) # output signal from receive after cut receive_output_signal = model.graph.output.add() - receive_output_signal.CopyFrom(helper.make_tensor_value_info( - 'receive_output_signal' + str(cut_index), onnx.TensorProto.BOOL, None)) + receive_output_signal.CopyFrom( + helper.make_tensor_value_info("receive_output_signal" + str(cut_index), onnx.TensorProto.BOOL, None) + ) new_send = model.graph.node.add() - new_send.CopyFrom(helper.make_node( - 'Send', - inputs=[send_input_signal_name, send_dst_rank_name], - outputs=['send_output_signal' + str(cut_index)], - tag=0, - domain=ms_domain, - element_types=element_types, - name='send')) + new_send.CopyFrom( + helper.make_node( + "Send", + inputs=[send_input_signal_name, send_dst_rank_name], + outputs=["send_output_signal" + str(cut_index)], + tag=0, + domain=ms_domain, + element_types=element_types, + name="send", + ) + ) new_receive = model.graph.node.add() - new_receive.CopyFrom(helper.make_node( - 'Recv', - inputs=[recv_input_signal_name, recv_src_rank_name], - outputs=['receive_output_signal' + str(cut_index)], - tag=0, - domain=ms_domain, - element_types=element_types, - name='receive')) + new_receive.CopyFrom( + helper.make_node( + "Recv", + inputs=[recv_input_signal_name, recv_src_rank_name], + outputs=["receive_output_signal" + str(cut_index)], + tag=0, + domain=ms_domain, + element_types=element_types, + name="receive", + ) + ) for i in range(len(upstream_nodes)): n = upstream_nodes[i] @@ -118,15 +120,13 @@ def split_graph(model, split_edge_groups): output_type = output_shapes[i] output_edge_name = n.output[idx] - output_nodes = find_all_output_nodes_by_edge( - model, output_edge_name) + output_nodes = find_all_output_nodes_by_edge(model, output_edge_name) # deal with shape inference for newly added edge - new_send_input_name = output_edge_name + '_send' + str(cut_index) + new_send_input_name = output_edge_name + "_send" + str(cut_index) add_expand_type(model, new_send_input_name, output_type) - new_receive_output_name = output_edge_name + \ - '_recv' + str(cut_index) + new_receive_output_name = output_edge_name + "_recv" + str(cut_index) add_expand_type(model, new_receive_output_name, output_type) # the order of data flow is: node-output -> record -> send -> recv -> wait -> node-input @@ -165,8 +165,7 @@ def find_all_output_nodes(model, node): if node: for outputId in node.output: nodes.extend([n for n in model.graph.node if outputId in n.input]) - outputs.extend( - [n for n in model.graph.output if outputId in n.name]) + outputs.extend([n for n in model.graph.output if outputId in n.name]) return nodes, outputs @@ -174,6 +173,7 @@ def find_all_output_nodes_by_edge(model, arg): result = [n for n in model.graph.node if arg in n.input] return result + # Insert identity nodes to separate same output edge which feeds into different sub-graph. @@ -190,7 +190,7 @@ def add_identity(model, cuttingEdge, newEdgeIdName): assert output_nodes, "no output node" new_identity = model.graph.node.add() - new_identity.op_type = 'Identity' + new_identity.op_type = "Identity" new_identity.input.extend([edgeId]) new_identity.output.extend([newEdgeIdName]) @@ -221,9 +221,8 @@ def insert_identity(model, all_cut_inputs): if i.edgeId in updated_edges: i.edgeId = updated_edges[i.edgeId] - new_edge_name = 'identity_output_' + str(count) - new_added_identity.append( - add_identity(model, i, new_edge_name)) + new_edge_name = "identity_output_" + str(count) + new_added_identity.append(add_identity(model, i, new_edge_name)) count += 1 split_edges.append(new_edge_name) updated_edges[i.edgeId] = new_edge_name @@ -233,19 +232,20 @@ def insert_identity(model, all_cut_inputs): split_edge_groups.append(split_edges) return split_edge_groups, new_added_identity, need_shape_inference + # after the graph is split, remove the added identity node because identity op is not registered in gradient builder. def remove_identity(model, new_added_identity): for node in new_added_identity: - assert node.op_type == 'Identity' - output_nodes = [ - n for n in model.graph.node if node.output[0] in n.input] + assert node.op_type == "Identity" + output_nodes = [n for n in model.graph.node if node.output[0] in n.input] for output_node in output_nodes: for i in range(len(output_node.input)): if output_node.input[i] == node.output[0]: output_node.input[i] = node.input[0] + def find_all_connected_nodes(model, node): nodes0, inputs = find_all_input_nodes(model, node) nodes1, outputs = find_all_output_nodes(model, node) @@ -258,14 +258,16 @@ def get_index(node_list, node): found = [i for i, n in enumerate(node_list) if n == node] return found[0] if found else None + def get_identity_index_for_deleting(node_list, node): for i, n in enumerate(node_list): # The node's input name has been changed during send/recv insertion, # but it is sufficient to just compare the type and outputs. - if (n.op_type == 'Identity' and n.output == node.output): + if n.op_type == "Identity" and n.output == node.output: return i return None + # traverse the graph, group connected nodes and generate subgraph @@ -303,8 +305,7 @@ def generate_subgraph(model, start_nodes, identity_node_list): tranversed_node += 1 visited0.append(node) all_visited_nodes.append(node) - connected_nodes, inputs, outputs = find_all_connected_nodes( - main_graph, node) + connected_nodes, inputs, outputs = find_all_connected_nodes(main_graph, node) stack0 = stack0 + connected_nodes inputs0 = inputs0 + inputs @@ -328,8 +329,7 @@ def generate_subgraph(model, start_nodes, identity_node_list): # gather visited outputs visited_outputs = [] for n in outputs0: - visited_outputs.append( - get_index(main_graph.graph.output, n)) + visited_outputs.append(get_index(main_graph.graph.output, n)) visited_outputs.sort(reverse=True) for i in reversed(range(len(main_graph.graph.node))): @@ -374,11 +374,11 @@ def generate_subgraph(model, start_nodes, identity_node_list): def main(): # temporary hard coded the cutting edge structure # TODO: move this info to a file (json?) and load the data from there. - input_model_name = 'bert-tiny-uncased_L_3_H_128_A_2_V_30528_S_512_Dp_0.1.onnx' + input_model_name = "bert-tiny-uncased_L_3_H_128_A_2_V_30528_S_512_Dp_0.1.onnx" stage_count = 3 - cut0_input = {CutEdge('186'), CutEdge('71', {'273', '395'})} - cut1_input = {CutEdge('308'), CutEdge('71', {'395'})} + cut0_input = {CutEdge("186"), CutEdge("71", {"273", "395"})} + cut1_input = {CutEdge("308"), CutEdge("71", {"395"})} all_cut_inputs = [cut0_input, cut1_input] model = onnx.load(input_model_name) @@ -387,12 +387,10 @@ def main(): print("original model length ", len(model.graph.node)) - output_model_names = [os.path.splitext(input_model_name)[0] + '_' + - str(i) + '.onnx' for i in range(stage_count)] + output_model_names = [os.path.splitext(input_model_name)[0] + "_" + str(i) + ".onnx" for i in range(stage_count)] split_edge_groups, new_identity, need_shape_inference = insert_identity(model, all_cut_inputs) - # new edge is being added, need to re-inference shape if need_shape_inference: model = onnx.shape_inference.infer_shapes(model) diff --git a/orttraining/tools/scripts/sqldb_to_tensors.py b/orttraining/tools/scripts/sqldb_to_tensors.py index 0a512a52c3683..cf24e0c294450 100644 --- a/orttraining/tools/scripts/sqldb_to_tensors.py +++ b/orttraining/tools/scripts/sqldb_to_tensors.py @@ -5,15 +5,18 @@ import onnx from onnx import numpy_helper -connection = sqlite3.connect('', detect_types=sqlite3.PARSE_DECLTYPES) +connection = sqlite3.connect("", detect_types=sqlite3.PARSE_DECLTYPES) + def convert_tensor_proto_to_numpy_array(blob): tensor_proto = onnx.TensorProto() tensor_proto.ParseFromString(blob) return numpy_helper.to_array(tensor_proto) + sqlite3.register_converter("TensorProto", convert_tensor_proto_to_numpy_array) for step, name, value, device, producer, consumers in connection.execute( - 'Select Step, Name, Value, DeviceType, TracedProducer, TracedConsumers from Tensors'): + "Select Step, Name, Value, DeviceType, TracedProducer, TracedConsumers from Tensors" +): print(step, name, value.shape, consumers) diff --git a/orttraining/tools/scripts/train.py b/orttraining/tools/scripts/train.py index b0bb6e91c3c6e..0bba0945fa02d 100644 --- a/orttraining/tools/scripts/train.py +++ b/orttraining/tools/scripts/train.py @@ -1,6 +1,6 @@ import os import sys -cmd = '/workspace/onnxruntime_training_bert {}'.format(' '.join(sys.argv[1:])) +cmd = "/workspace/onnxruntime_training_bert {}".format(" ".join(sys.argv[1:])) print(cmd) os.system(cmd) diff --git a/orttraining/tools/scripts/watch_experiment.py b/orttraining/tools/scripts/watch_experiment.py index eab92d8c55c77..33bb73f8dc9b9 100644 --- a/orttraining/tools/scripts/watch_experiment.py +++ b/orttraining/tools/scripts/watch_experiment.py @@ -10,20 +10,26 @@ from azureml._run_impl.run_watcher import RunWatcher parser = argparse.ArgumentParser() -parser.add_argument('--subscription', type=str, default='ea482afa-3a32-437c-aa10-7de928a9e793') # AI Platform GPU - MLPerf -parser.add_argument('--resource_group', type=str, default='onnx_training', help='Azure resource group containing the AzureML Workspace') -parser.add_argument('--workspace', type=str, default='ort_training_dev', help='AzureML Workspace to run the Experiment in') -parser.add_argument('--experiment', type=str, default='BERT-ONNX', help='Name of the AzureML Experiment') -parser.add_argument('--run', type=str, default=None, help='The Experiment run to watch (defaults to the latest run)') +parser.add_argument( + "--subscription", type=str, default="ea482afa-3a32-437c-aa10-7de928a9e793" +) # AI Platform GPU - MLPerf +parser.add_argument( + "--resource_group", type=str, default="onnx_training", help="Azure resource group containing the AzureML Workspace" +) +parser.add_argument( + "--workspace", type=str, default="ort_training_dev", help="AzureML Workspace to run the Experiment in" +) +parser.add_argument("--experiment", type=str, default="BERT-ONNX", help="Name of the AzureML Experiment") +parser.add_argument("--run", type=str, default=None, help="The Experiment run to watch (defaults to the latest run)") -parser.add_argument('--remote_dir', type=str, default=None, help='Specify a remote directory to sync (read) from') -parser.add_argument('--local_dir', type=str, default=None, help='Specify a local directory to sync (write) to') +parser.add_argument("--remote_dir", type=str, default=None, help="Specify a remote directory to sync (read) from") +parser.add_argument("--local_dir", type=str, default=None, help="Specify a local directory to sync (write) to") args = parser.parse_args() # Validate if (args.remote_dir and not args.local_dir) or (not args.remote_dir and args.local_dir): - print('Must specify both remote_dir and local_dir to sync files from Experiment') - sys.exit() + print("Must specify both remote_dir and local_dir to sync files from Experiment") + sys.exit() # Get the AzureML Workspace the Experiment is running in ws = Workspace.get(name=args.workspace, subscription_id=args.subscription, resource_group=args.resource_group) @@ -35,36 +41,46 @@ runs = [r for r in experiment.get_runs()] if len(runs) == 0: - print("No runs found in Experiment '{}'".format(args.experiment)) - sys.exit() + print("No runs found in Experiment '{}'".format(args.experiment)) + sys.exit() run = runs[0] if args.run is not None: - try: - run = next(r for r in runs if r.id == args.run) - except StopIteration: - print("Run id '{}' not found in Experiment '{}'".format(args.run, args.experiment)) - sys.exit() + try: + run = next(r for r in runs if r.id == args.run) + except StopIteration: + print("Run id '{}' not found in Experiment '{}'".format(args.run, args.experiment)) + sys.exit() # Optionally start synchronizing files from Run if args.remote_dir and args.local_dir: - local_root = os.path.normpath(args.local_dir) - remote_root = args.remote_dir + local_root = os.path.normpath(args.local_dir) + remote_root = args.remote_dir - if run.get_status() in ['Completed', 'Failed', 'Canceled']: - print("Downloading Experiment files from remote directory: '{}' to local directory: '{}'".format(remote_root, local_root)) - files = [f for f in run.get_file_names() if f.startswith(remote_root)] - for remote_path in files: - local_path = os.path.join(local_root, os.path.basename(remote_path)) - run.download_file(remote_path, local_path) - else: - executor = ThreadPoolExecutor() - event = Event() - session = Session() + if run.get_status() in ["Completed", "Failed", "Canceled"]: + print( + "Downloading Experiment files from remote directory: '{}' to local directory: '{}'".format( + remote_root, local_root + ) + ) + files = [f for f in run.get_file_names() if f.startswith(remote_root)] + for remote_path in files: + local_path = os.path.join(local_root, os.path.basename(remote_path)) + run.download_file(remote_path, local_path) + else: + executor = ThreadPoolExecutor() + event = Event() + session = Session() - print("Streaming Experiment files from remote directory: '{}' to local directory: '{}'".format(remote_root, local_root)) - watcher = RunWatcher(run, local_root=local_root, remote_root=remote_root, executor=executor, event=event, session=session) - executor.submit(watcher.refresh_requeue) + print( + "Streaming Experiment files from remote directory: '{}' to local directory: '{}'".format( + remote_root, local_root + ) + ) + watcher = RunWatcher( + run, local_root=local_root, remote_root=remote_root, executor=executor, event=event, session=session + ) + executor.submit(watcher.refresh_requeue) # Block until run completes, to keep updating the files (if streaming) run.wait_for_completion(show_output=True) diff --git a/pyproject.toml b/pyproject.toml index ab34de0c9ce5c..9a104f0fbf2e7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,9 +1,16 @@ [tool.black] line-length = 120 +# extend-exclude needs to be a regular expression +extend-exclude = "cmake|onnxruntime/core/flatbuffers/" [tool.isort] profile = "black" line_length = 120 +extend_skip_glob = [ + "cmake/*", + "orttraining/*", + "onnxruntime/core/flatbuffers/*", +] [tool.pydocstyle] convention = "google" diff --git a/samples/python/training/orttrainer/mnist/ort_mnist.py b/samples/python/training/orttrainer/mnist/ort_mnist.py index 14cfc42351dcd..5b28f90a025ff 100644 --- a/samples/python/training/orttrainer/mnist/ort_mnist.py +++ b/samples/python/training/orttrainer/mnist/ort_mnist.py @@ -3,6 +3,7 @@ import argparse import os + import torch import torch.nn as nn import torch.nn.functional as F @@ -29,10 +30,10 @@ def forward(self, input1): # ONNX Runtime training def mnist_model_description(): - return {'inputs': [('input1', ['batch', 784]), - ('label', ['batch'])], - 'outputs': [('loss', [], True), - ('probability', ['batch', 10])]} + return { + "inputs": [("input1", ["batch", 784]), ("label", ["batch"])], + "outputs": [("loss", [], True), ("probability", ["batch", 10])], + } def my_loss(x, target): @@ -54,9 +55,11 @@ def train(log_interval, trainer, device, train_loader, epoch, train_steps): # Stats if batch_idx % log_interval == 0: - print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( - epoch, batch_idx * len(data), len(train_loader.dataset), - 100. * batch_idx / len(train_loader), loss)) + print( + "Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format( + epoch, batch_idx * len(data), len(train_loader.dataset), 100.0 * batch_idx / len(train_loader), loss + ) + ) def test(trainer, device, test_loader): @@ -68,43 +71,52 @@ def test(trainer, device, test_loader): data = data.reshape(data.shape[0], -1) # Using fetches around without eval_step to not pass 'target' as input - trainer._train_step_info.fetches = ['probability'] + trainer._train_step_info.fetches = ["probability"] output = F.log_softmax(trainer.eval_step(data), dim=1) trainer._train_step_info.fetches = [] # Stats - test_loss += F.nll_loss(output, target, reduction='sum').item() + test_loss += F.nll_loss(output, target, reduction="sum").item() pred = output.argmax(dim=1, keepdim=True) correct += pred.eq(target.view_as(pred)).sum().item() test_loss /= len(test_loader.dataset) - print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format( - test_loss, correct, len(test_loader.dataset), - 100. * correct / len(test_loader.dataset))) + print( + "\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n".format( + test_loss, correct, len(test_loader.dataset), 100.0 * correct / len(test_loader.dataset) + ) + ) def main(): # Training settings - parser = argparse.ArgumentParser(description='ONNX Runtime MNIST Example') - parser.add_argument('--train-steps', type=int, default=-1, metavar='N', - help='number of steps to train. Set -1 to run through whole dataset (default: -1)') - parser.add_argument('--batch-size', type=int, default=20, metavar='N', - help='input batch size for training (default: 20)') - parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N', - help='input batch size for testing (default: 1000)') - parser.add_argument('--epochs', type=int, default=1, metavar='N', - help='number of epochs to train (default: 1)') - parser.add_argument('--lr', type=float, default=0.01, metavar='LR', - help='learning rate (default: 0.01)') - parser.add_argument('--no-cuda', action='store_true', default=False, - help='disables CUDA training') - parser.add_argument('--seed', type=int, default=1, metavar='S', - help='random seed (default: 1)') - parser.add_argument('--log-interval', type=int, default=10, metavar='N', - help='how many batches to wait before logging training status') - parser.add_argument('--save-path', type=str, default='', - help='Path for Saving the current Model state') + parser = argparse.ArgumentParser(description="ONNX Runtime MNIST Example") + parser.add_argument( + "--train-steps", + type=int, + default=-1, + metavar="N", + help="number of steps to train. Set -1 to run through whole dataset (default: -1)", + ) + parser.add_argument( + "--batch-size", type=int, default=20, metavar="N", help="input batch size for training (default: 20)" + ) + parser.add_argument( + "--test-batch-size", type=int, default=1000, metavar="N", help="input batch size for testing (default: 1000)" + ) + parser.add_argument("--epochs", type=int, default=1, metavar="N", help="number of epochs to train (default: 1)") + parser.add_argument("--lr", type=float, default=0.01, metavar="LR", help="learning rate (default: 0.01)") + parser.add_argument("--no-cuda", action="store_true", default=False, help="disables CUDA training") + parser.add_argument("--seed", type=int, default=1, metavar="S", help="random seed (default: 1)") + parser.add_argument( + "--log-interval", + type=int, + default=10, + metavar="N", + help="how many batches to wait before logging training status", + ) + parser.add_argument("--save-path", type=str, default="", help="Path for Saving the current Model state") # Basic setup args = parser.parse_args() @@ -117,31 +129,35 @@ def main(): # Data loader train_loader = torch.utils.data.DataLoader( - datasets.MNIST('./data', train=True, download=True, - transform=transforms.Compose([ - transforms.ToTensor(), - transforms.Normalize((0.1307,), (0.3081,)) - ])), - batch_size=args.batch_size, shuffle=True) + datasets.MNIST( + "./data", + train=True, + download=True, + transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]), + ), + batch_size=args.batch_size, + shuffle=True, + ) if args.test_batch_size > 0: test_loader = torch.utils.data.DataLoader( - datasets.MNIST('./data', train=False, transform=transforms.Compose([ - transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])), - batch_size=args.test_batch_size, shuffle=True) + datasets.MNIST( + "./data", + train=False, + transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]), + ), + batch_size=args.test_batch_size, + shuffle=True, + ) # Modeling model = NeuralNet(784, 500, 10) model_desc = mnist_model_description() optim_config = optim.SGDConfig(lr=args.lr) - opts = {'device': {'id': device}} + opts = {"device": {"id": device}} opts = ORTTrainerOptions(opts) - trainer = ORTTrainer(model, - model_desc, - optim_config, - loss_fn=my_loss, - options=opts) + trainer = ORTTrainer(model, model_desc, optim_config, loss_fn=my_loss, options=opts) # Train loop for epoch in range(1, args.epochs + 1): @@ -154,5 +170,5 @@ def main(): torch.save(model.state_dict(), os.path.join(args.save_path, "mnist_cnn.pt")) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/samples/python/training/orttrainer/mnist/pytorch_mnist.py b/samples/python/training/orttrainer/mnist/pytorch_mnist.py index 6e52a80dd2bd0..0f62b56d35221 100644 --- a/samples/python/training/orttrainer/mnist/pytorch_mnist.py +++ b/samples/python/training/orttrainer/mnist/pytorch_mnist.py @@ -1,5 +1,6 @@ import argparse import os + import torch import torch.nn as nn import torch.nn.functional as F @@ -26,7 +27,7 @@ def my_loss(x, target, is_train=True): if is_train: return F.nll_loss(F.log_softmax(x, dim=1), target) else: - return F.nll_loss(F.log_softmax(x, dim=1), target, reduction='sum') + return F.nll_loss(F.log_softmax(x, dim=1), target, reduction="sum") # Helpers @@ -43,9 +44,15 @@ def train(args, model, device, train_loader, optimizer, epoch): loss.backward() optimizer.step() if batch_idx % args.log_interval == 0: - print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( - epoch, batch_idx * len(data), len(train_loader.dataset), - 100. * batch_idx / len(train_loader), loss.item())) + print( + "Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format( + epoch, + batch_idx * len(data), + len(train_loader.dataset), + 100.0 * batch_idx / len(train_loader), + loss.item(), + ) + ) def test(model, device, test_loader): @@ -64,32 +71,41 @@ def test(model, device, test_loader): test_loss /= len(test_loader.dataset) - print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format( - test_loss, correct, len(test_loader.dataset), - 100. * correct / len(test_loader.dataset))) + print( + "\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n".format( + test_loss, correct, len(test_loader.dataset), 100.0 * correct / len(test_loader.dataset) + ) + ) def main(): # Training settings - parser = argparse.ArgumentParser(description='PyTorch MNIST Example') - parser.add_argument('--train-steps', type=int, default=-1, metavar='N', - help='number of steps to train. Set -1 to run through whole dataset (default: -1)') - parser.add_argument('--batch-size', type=int, default=20, metavar='N', - help='input batch size for training (default: 20)') - parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N', - help='input batch size for testing (default: 1000)') - parser.add_argument('--epochs', type=int, default=1, metavar='N', - help='number of epochs to train (default: 1)') - parser.add_argument('--lr', type=float, default=0.01, metavar='LR', - help='learning rate (default: 0.01)') - parser.add_argument('--no-cuda', action='store_true', default=False, - help='disables CUDA training') - parser.add_argument('--seed', type=int, default=1, metavar='S', - help='random seed (default: 1)') - parser.add_argument('--log-interval', type=int, default=10, metavar='N', - help='how many batches to wait before logging training status') - parser.add_argument('--save-path', type=str, default='', - help='Path for Saving the current Model') + parser = argparse.ArgumentParser(description="PyTorch MNIST Example") + parser.add_argument( + "--train-steps", + type=int, + default=-1, + metavar="N", + help="number of steps to train. Set -1 to run through whole dataset (default: -1)", + ) + parser.add_argument( + "--batch-size", type=int, default=20, metavar="N", help="input batch size for training (default: 20)" + ) + parser.add_argument( + "--test-batch-size", type=int, default=1000, metavar="N", help="input batch size for testing (default: 1000)" + ) + parser.add_argument("--epochs", type=int, default=1, metavar="N", help="number of epochs to train (default: 1)") + parser.add_argument("--lr", type=float, default=0.01, metavar="LR", help="learning rate (default: 0.01)") + parser.add_argument("--no-cuda", action="store_true", default=False, help="disables CUDA training") + parser.add_argument("--seed", type=int, default=1, metavar="S", help="random seed (default: 1)") + parser.add_argument( + "--log-interval", + type=int, + default=10, + metavar="N", + help="how many batches to wait before logging training status", + ) + parser.add_argument("--save-path", type=str, default="", help="Path for Saving the current Model") # Basic setup args = parser.parse_args() @@ -101,18 +117,26 @@ def main(): # Data loader train_loader = torch.utils.data.DataLoader( - datasets.MNIST('./data', train=True, download=True, - transform=transforms.Compose([ - transforms.ToTensor(), - transforms.Normalize((0.1307,), (0.3081,)) - ])), - batch_size=args.batch_size, shuffle=True) + datasets.MNIST( + "./data", + train=True, + download=True, + transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]), + ), + batch_size=args.batch_size, + shuffle=True, + ) if args.test_batch_size > 0: test_loader = torch.utils.data.DataLoader( - datasets.MNIST('./data', train=False, transform=transforms.Compose([ - transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])), - batch_size=args.test_batch_size, shuffle=True) + datasets.MNIST( + "./data", + train=False, + transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]), + ), + batch_size=args.test_batch_size, + shuffle=True, + ) # Modeling model = NeuralNet(784, 500, 10).to(device) @@ -129,5 +153,5 @@ def main(): torch.save(model.state_dict(), os.path.join(args.save_path, "mnist_cnn.pt")) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/samples/python/training/orttrainer/pytorch_transformer/ort_train.py b/samples/python/training/orttrainer/pytorch_transformer/ort_train.py index baf2d19b205e4..fdca0f20a9385 100644 --- a/samples/python/training/orttrainer/pytorch_transformer/ort_train.py +++ b/samples/python/training/orttrainer/pytorch_transformer/ort_train.py @@ -1,14 +1,15 @@ import argparse -import torch -import onnxruntime -from utils import prepare_data, get_batch +import torch from ort_utils import my_loss, transformer_model_description_dynamic_axes from pt_model import TransformerModel +from utils import get_batch, prepare_data + +import onnxruntime def train(trainer, data_source, device, epoch, args, bptt=35): - total_loss = 0. + total_loss = 0.0 for batch, i in enumerate(range(0, data_source.size(0) - 1, bptt)): data, targets = get_batch(data_source, i) @@ -16,15 +17,16 @@ def train(trainer, data_source, device, epoch, args, bptt=35): total_loss += loss.item() if batch % args.log_interval == 0 and batch > 0: cur_loss = total_loss / args.log_interval - print('epoch {:3d} | {:5d}/{:5d} batches | loss {:5.2f}'.format(epoch, - batch, - len(data_source) // bptt, - cur_loss)) + print( + "epoch {:3d} | {:5d}/{:5d} batches | loss {:5.2f}".format( + epoch, batch, len(data_source) // bptt, cur_loss + ) + ) total_loss = 0 def evaluate(trainer, data_source, bptt=35): - total_loss = 0. + total_loss = 0.0 with torch.no_grad(): for i in range(0, data_source.size(0) - 1, bptt): data, targets = get_batch(data_source, i) @@ -35,21 +37,24 @@ def evaluate(trainer, data_source, bptt=35): if __name__ == "__main__": # Training settings - parser = argparse.ArgumentParser(description='PyTorch TransformerModel example') - parser.add_argument('--batch-size', type=int, default=20, metavar='N', - help='input batch size for training (default: 20)') - parser.add_argument('--test-batch-size', type=int, default=20, metavar='N', - help='input batch size for testing (default: 20)') - parser.add_argument('--epochs', type=int, default=2, metavar='N', - help='number of epochs to train (default: 2)') - parser.add_argument('--lr', type=float, default=0.001, metavar='LR', - help='learning rate (default: 0.001)') - parser.add_argument('--no-cuda', action='store_true', default=False, - help='disables CUDA training') - parser.add_argument('--seed', type=int, default=1, metavar='S', - help='random seed (default: 1)') - parser.add_argument('--log-interval', type=int, default=200, metavar='N', - help='how many batches to wait before logging training status (default: 200)') + parser = argparse.ArgumentParser(description="PyTorch TransformerModel example") + parser.add_argument( + "--batch-size", type=int, default=20, metavar="N", help="input batch size for training (default: 20)" + ) + parser.add_argument( + "--test-batch-size", type=int, default=20, metavar="N", help="input batch size for testing (default: 20)" + ) + parser.add_argument("--epochs", type=int, default=2, metavar="N", help="number of epochs to train (default: 2)") + parser.add_argument("--lr", type=float, default=0.001, metavar="LR", help="learning rate (default: 0.001)") + parser.add_argument("--no-cuda", action="store_true", default=False, help="disables CUDA training") + parser.add_argument("--seed", type=int, default=1, metavar="S", help="random seed (default: 1)") + parser.add_argument( + "--log-interval", + type=int, + default=200, + metavar="N", + help="how many batches to wait before logging training status (default: 200)", + ) # Basic setup args = parser.parse_args() @@ -73,12 +78,12 @@ def evaluate(trainer, data_source, bptt=35): for epoch in range(1, args.epochs + 1): train(trainer, train_data, device, epoch, args) val_loss = evaluate(trainer, val_data) - print('-' * 89) - print('| end of epoch {:3d} | valid loss {:5.2f} | '.format(epoch, val_loss)) - print('-' * 89) + print("-" * 89) + print("| end of epoch {:3d} | valid loss {:5.2f} | ".format(epoch, val_loss)) + print("-" * 89) # Evaluate test_loss = evaluate(trainer, test_data) - print('=' * 89) - print('| End of training | test loss {:5.2f}'.format(test_loss)) - print('=' * 89) + print("=" * 89) + print("| End of training | test loss {:5.2f}".format(test_loss)) + print("=" * 89) diff --git a/samples/python/training/orttrainer/pytorch_transformer/ort_utils.py b/samples/python/training/orttrainer/pytorch_transformer/ort_utils.py index 61c419964333d..73992f5596f5f 100644 --- a/samples/python/training/orttrainer/pytorch_transformer/ort_utils.py +++ b/samples/python/training/orttrainer/pytorch_transformer/ort_utils.py @@ -1,7 +1,7 @@ import torch -from onnxruntime.capi.ort_trainer import IODescription as Legacy_IODescription,\ - ModelDescription as Legacy_ModelDescription +from onnxruntime.capi.ort_trainer import IODescription as Legacy_IODescription +from onnxruntime.capi.ort_trainer import ModelDescription as Legacy_ModelDescription def my_loss(x, target): @@ -10,34 +10,38 @@ def my_loss(x, target): def transformer_model_description(bptt=35, batch_size=20, ntokens=28785): - model_desc = {'inputs': [('input1', [bptt, batch_size]), - ('label', [bptt * batch_size])], - 'outputs': [('loss', [], True), - ('predictions', [bptt, batch_size, ntokens])]} + model_desc = { + "inputs": [("input1", [bptt, batch_size]), ("label", [bptt * batch_size])], + "outputs": [("loss", [], True), ("predictions", [bptt, batch_size, ntokens])], + } return model_desc def transformer_model_description_dynamic_axes(ntokens=28785): - model_desc = {'inputs': [('input1', ['bptt', 'batch_size']), - ('label', ['bptt_x_batch_size'])], - 'outputs': [('loss', [], True), - ('predictions', ['bptt', 'batch_size', ntokens])]} + model_desc = { + "inputs": [("input1", ["bptt", "batch_size"]), ("label", ["bptt_x_batch_size"])], + "outputs": [("loss", [], True), ("predictions", ["bptt", "batch_size", ntokens])], + } return model_desc def legacy_transformer_model_description(bptt=35, batch_size=20, ntokens=28785): - input_desc = Legacy_IODescription('input1', [bptt, batch_size]) - label_desc = Legacy_IODescription('label', [bptt * batch_size]) - loss_desc = Legacy_IODescription('loss', []) - predictions_desc = Legacy_IODescription('predictions', [bptt, batch_size, ntokens]) - return (Legacy_ModelDescription([input_desc, label_desc], [loss_desc, predictions_desc]), - Legacy_IODescription('__learning_rate', [1])) + input_desc = Legacy_IODescription("input1", [bptt, batch_size]) + label_desc = Legacy_IODescription("label", [bptt * batch_size]) + loss_desc = Legacy_IODescription("loss", []) + predictions_desc = Legacy_IODescription("predictions", [bptt, batch_size, ntokens]) + return ( + Legacy_ModelDescription([input_desc, label_desc], [loss_desc, predictions_desc]), + Legacy_IODescription("__learning_rate", [1]), + ) def legacy_transformer_model_description_dynamic_axes(ntokens=28785): - input_desc = Legacy_IODescription('input1', ['bptt', 'batch_size']) - label_desc = Legacy_IODescription('label', ['bptt_x_batch_size']) - loss_desc = Legacy_IODescription('loss', []) - predictions_desc = Legacy_IODescription('predictions', ['bptt', 'batch_size', ntokens]) - return (Legacy_ModelDescription([input_desc, label_desc], [loss_desc, predictions_desc]), - Legacy_IODescription('__learning_rate', [1])) + input_desc = Legacy_IODescription("input1", ["bptt", "batch_size"]) + label_desc = Legacy_IODescription("label", ["bptt_x_batch_size"]) + loss_desc = Legacy_IODescription("loss", []) + predictions_desc = Legacy_IODescription("predictions", ["bptt", "batch_size", ntokens]) + return ( + Legacy_ModelDescription([input_desc, label_desc], [loss_desc, predictions_desc]), + Legacy_IODescription("__learning_rate", [1]), + ) diff --git a/samples/python/training/orttrainer/pytorch_transformer/pt_model.py b/samples/python/training/orttrainer/pytorch_transformer/pt_model.py index 87a938cd4b757..e125124d14718 100644 --- a/samples/python/training/orttrainer/pytorch_transformer/pt_model.py +++ b/samples/python/training/orttrainer/pytorch_transformer/pt_model.py @@ -1,14 +1,15 @@ import math + import torch import torch.nn as nn class TransformerModel(nn.Module): - def __init__(self, ntoken, ninp, nhead, nhid, nlayers, dropout=0.5): super(TransformerModel, self).__init__() from torch.nn import TransformerEncoder, TransformerEncoderLayer - self.model_type = 'Transformer' + + self.model_type = "Transformer" self.input1_mask = None self.pos_encoder = PositionalEncoding(ninp, dropout) encoder_layers = TransformerEncoderLayer(ninp, nhead, nhid, dropout) @@ -21,7 +22,7 @@ def __init__(self, ntoken, ninp, nhead, nhid, nlayers, dropout=0.5): def _generate_square_subsequent_mask(self, sz): mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1) - mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0)) + mask = mask.float().masked_fill(mask == 0, float("-inf")).masked_fill(mask == 1, float(0.0)) return mask def init_weights(self): @@ -44,7 +45,6 @@ def forward(self, input1): class PositionalEncoding(nn.Module): - def __init__(self, d_model, dropout=0.1, max_len=5000): super(PositionalEncoding, self).__init__() self.dropout = nn.Dropout(p=dropout) @@ -55,8 +55,8 @@ def __init__(self, d_model, dropout=0.1, max_len=5000): pe[:, 0::2] = torch.sin(position * div_term) pe[:, 1::2] = torch.cos(position * div_term) pe = pe.unsqueeze(0).transpose(0, 1) - self.register_buffer('pe', pe) + self.register_buffer("pe", pe) def forward(self, x): - x = x + self.pe[:x.size(0), :] + x = x + self.pe[: x.size(0), :] return self.dropout(x) diff --git a/samples/python/training/orttrainer/pytorch_transformer/pt_train.py b/samples/python/training/orttrainer/pytorch_transformer/pt_train.py index 68c18842d4931..6dc6032a4acef 100644 --- a/samples/python/training/orttrainer/pytorch_transformer/pt_train.py +++ b/samples/python/training/orttrainer/pytorch_transformer/pt_train.py @@ -1,13 +1,13 @@ import argparse + import torch import torch.nn as nn - -from utils import prepare_data, get_batch from pt_model import TransformerModel +from utils import get_batch, prepare_data def train(model, data_source, device, epoch, args, bptt=35): - total_loss = 0. + total_loss = 0.0 model.train() for batch, i in enumerate(range(0, data_source.size(0) - 1, bptt)): data, targets = get_batch(data_source, i) @@ -21,15 +21,16 @@ def train(model, data_source, device, epoch, args, bptt=35): total_loss += loss.item() if batch % args.log_interval == 0 and batch > 0: cur_loss = total_loss / args.log_interval - print('epoch {:3d} | {:5d}/{:5d} batches | loss {:5.2f}'.format(epoch, - batch, - len(data_source) // bptt, - cur_loss)) + print( + "epoch {:3d} | {:5d}/{:5d} batches | loss {:5.2f}".format( + epoch, batch, len(data_source) // bptt, cur_loss + ) + ) total_loss = 0 def evaluate(model, data_source, criterion, bptt=35): - total_loss = 0. + total_loss = 0.0 model.eval() with torch.no_grad(): for i in range(0, data_source.size(0) - 1, bptt): @@ -42,21 +43,24 @@ def evaluate(model, data_source, criterion, bptt=35): if __name__ == "__main__": # Training settings - parser = argparse.ArgumentParser(description='PyTorch TransformerModel example') - parser.add_argument('--batch-size', type=int, default=20, metavar='N', - help='input batch size for training (default: 20)') - parser.add_argument('--test-batch-size', type=int, default=20, metavar='N', - help='input batch size for testing (default: 20)') - parser.add_argument('--epochs', type=int, default=2, metavar='N', - help='number of epochs to train (default: 2)') - parser.add_argument('--lr', type=float, default=0.001, metavar='LR', - help='learning rate (default: 0.001)') - parser.add_argument('--no-cuda', action='store_true', default=False, - help='disables CUDA training') - parser.add_argument('--seed', type=int, default=1, metavar='S', - help='random seed (default: 1)') - parser.add_argument('--log-interval', type=int, default=200, metavar='N', - help='how many batches to wait before logging training status (default: 200)') + parser = argparse.ArgumentParser(description="PyTorch TransformerModel example") + parser.add_argument( + "--batch-size", type=int, default=20, metavar="N", help="input batch size for training (default: 20)" + ) + parser.add_argument( + "--test-batch-size", type=int, default=20, metavar="N", help="input batch size for testing (default: 20)" + ) + parser.add_argument("--epochs", type=int, default=2, metavar="N", help="number of epochs to train (default: 2)") + parser.add_argument("--lr", type=float, default=0.001, metavar="LR", help="learning rate (default: 0.001)") + parser.add_argument("--no-cuda", action="store_true", default=False, help="disables CUDA training") + parser.add_argument("--seed", type=int, default=1, metavar="S", help="random seed (default: 1)") + parser.add_argument( + "--log-interval", + type=int, + default=200, + metavar="N", + help="how many batches to wait before logging training status (default: 200)", + ) # Basic setup args = parser.parse_args() @@ -79,12 +83,12 @@ def evaluate(model, data_source, criterion, bptt=35): for epoch in range(1, args.epochs + 1): train(model, train_data, device, epoch, args) val_loss = evaluate(model, val_data, criterion) - print('-' * 89) - print('| end of epoch {:3d} | valid loss {:5.2f} | '.format(epoch, val_loss)) - print('-' * 89) + print("-" * 89) + print("| end of epoch {:3d} | valid loss {:5.2f} | ".format(epoch, val_loss)) + print("-" * 89) # Evaluate test_loss = evaluate(model, test_data, criterion) - print('=' * 89) - print('| End of training | test loss {:5.2f}'.format(test_loss)) - print('=' * 89) + print("=" * 89) + print("| End of training | test loss {:5.2f}".format(test_loss)) + print("=" * 89) diff --git a/samples/python/training/orttrainer/pytorch_transformer/utils.py b/samples/python/training/orttrainer/pytorch_transformer/utils.py index 0d35fa98035a4..489c0b5acf350 100644 --- a/samples/python/training/orttrainer/pytorch_transformer/utils.py +++ b/samples/python/training/orttrainer/pytorch_transformer/utils.py @@ -1,8 +1,9 @@ import io import os + import torch -from torchtext.utils import download_from_url, extract_archive from torchtext.data.utils import get_tokenizer +from torchtext.utils import download_from_url, extract_archive from torchtext.vocab import build_vocab_from_iterator @@ -18,34 +19,32 @@ def batchify(data, bsz, device): def get_batch(source, i, bptt=35): seq_len = min(bptt, len(source) - 1 - i) - data = source[i:i+seq_len] - target = source[i+1:i+1+seq_len].view(-1) + data = source[i : i + seq_len] + target = source[i + 1 : i + 1 + seq_len].view(-1) return data, target -def prepare_data(device='cpu', train_batch_size=20, eval_batch_size=20, data_dir=None): - url = 'https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-v1.zip' +def prepare_data(device="cpu", train_batch_size=20, eval_batch_size=20, data_dir=None): + url = "https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-v1.zip" - download_path = '.data_wikitext_2_v1' + download_path = ".data_wikitext_2_v1" extract_path = None if data_dir: - download_path = os.path.join(data_dir, 'download') + download_path = os.path.join(data_dir, "download") os.makedirs(download_path, exist_ok=True) - download_path = os.path.join(download_path, 'wikitext-2-v1.zip') + download_path = os.path.join(download_path, "wikitext-2-v1.zip") - extract_path = os.path.join(data_dir, 'extracted') + extract_path = os.path.join(data_dir, "extracted") os.makedirs(extract_path, exist_ok=True) - test_filepath, valid_filepath, train_filepath = extract_archive(download_from_url(url, root=download_path), - to_path=extract_path) - tokenizer = get_tokenizer('basic_english') - vocab = build_vocab_from_iterator(map(tokenizer, - iter(io.open(train_filepath, - encoding="utf8")))) + test_filepath, valid_filepath, train_filepath = extract_archive( + download_from_url(url, root=download_path), to_path=extract_path + ) + tokenizer = get_tokenizer("basic_english") + vocab = build_vocab_from_iterator(map(tokenizer, iter(io.open(train_filepath, encoding="utf8")))) def data_process(raw_text_iter): - data = [torch.tensor([vocab[token] for token in tokenizer(item)], - dtype=torch.long) for item in raw_text_iter] + data = [torch.tensor([vocab[token] for token in tokenizer(item)], dtype=torch.long) for item in raw_text_iter] return torch.cat(tuple(filter(lambda t: t.numel() > 0, data))) train_data = data_process(iter(io.open(train_filepath, encoding="utf8"))) diff --git a/server/test/integration_tests/function_tests.py b/server/test/integration_tests/function_tests.py index 451e323dceb63..23dcd89c6031d 100644 --- a/server/test/integration_tests/function_tests.py +++ b/server/test/integration_tests/function_tests.py @@ -1,443 +1,488 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. -import unittest +import json +import os import subprocess import time -import os -import requests -import json -import numpy +import unittest -import test_util +import grpc +import numpy import onnx_ml_pb2 import predict_pb2 import prediction_service_pb2_grpc -import grpc +import requests +import test_util + class HttpJsonPayloadTests(unittest.TestCase): - server_ip = '127.0.0.1' + server_ip = "127.0.0.1" server_port = 54321 - url_pattern = 'http://{0}:{1}/v1/models/{2}/versions/{3}:predict' - server_app_path = '' - test_data_path = '' - model_path = '' - log_level = 'verbose' + url_pattern = "http://{0}:{1}/v1/models/{2}/versions/{3}:predict" + server_app_path = "" + test_data_path = "" + model_path = "" + log_level = "verbose" server_app_proc = None wait_server_ready_in_seconds = 1 @classmethod def setUpClass(cls): - onnx_model = os.path.join(cls.model_path, 'model.onnx') + onnx_model = os.path.join(cls.model_path, "model.onnx") test_util.prepare_mnist_model(onnx_model) - cmd = [cls.server_app_path, '--http_port', str(cls.server_port), '--model_path', onnx_model, '--log_level', cls.log_level] - test_util.test_log('Launching server app: [{0}]'.format(' '.join(cmd))) + cmd = [ + cls.server_app_path, + "--http_port", + str(cls.server_port), + "--model_path", + onnx_model, + "--log_level", + cls.log_level, + ] + test_util.test_log("Launching server app: [{0}]".format(" ".join(cmd))) cls.server_app_proc = subprocess.Popen(cmd) - test_util.test_log('Server app PID: {0}'.format(cls.server_app_proc.pid)) - test_util.test_log('Sleep {0} second(s) to wait for server initialization'.format(cls.wait_server_ready_in_seconds)) + test_util.test_log("Server app PID: {0}".format(cls.server_app_proc.pid)) + test_util.test_log( + "Sleep {0} second(s) to wait for server initialization".format(cls.wait_server_ready_in_seconds) + ) time.sleep(cls.wait_server_ready_in_seconds) - @classmethod def tearDownClass(cls): - test_util.test_log('Shutdown server app') + test_util.test_log("Shutdown server app") cls.server_app_proc.kill() - test_util.test_log('PID {0} has been killed: {1}'.format(cls.server_app_proc.pid, test_util.is_process_killed(cls.server_app_proc.pid))) - + test_util.test_log( + "PID {0} has been killed: {1}".format( + cls.server_app_proc.pid, test_util.is_process_killed(cls.server_app_proc.pid) + ) + ) def test_mnist_happy_path(self): - input_data_file = os.path.join(self.test_data_path, 'mnist_test_data_set_0_input.json') - output_data_file = os.path.join(self.test_data_path, 'mnist_test_data_set_0_output.json') + input_data_file = os.path.join(self.test_data_path, "mnist_test_data_set_0_input.json") + output_data_file = os.path.join(self.test_data_path, "mnist_test_data_set_0_output.json") - with open(input_data_file, 'r') as f: + with open(input_data_file, "r") as f: request_payload = f.read() - with open(output_data_file, 'r') as f: + with open(output_data_file, "r") as f: expected_response_json = f.read() expected_response = json.loads(expected_response_json) request_headers = { - 'Content-Type': 'application/json', - 'Accept': 'application/json', - 'x-ms-client-request-id': 'This~is~my~id' + "Content-Type": "application/json", + "Accept": "application/json", + "x-ms-client-request-id": "This~is~my~id", } - url = self.url_pattern.format(self.server_ip, self.server_port, 'default', 1) + url = self.url_pattern.format(self.server_ip, self.server_port, "default", 1) test_util.test_log(url) r = requests.post(url, headers=request_headers, data=request_payload) self.assertEqual(r.status_code, 200) - self.assertEqual(r.headers.get('Content-Type'), 'application/json') - self.assertTrue(r.headers.get('x-ms-request-id')) - self.assertEqual(r.headers.get('x-ms-client-request-id'), 'This~is~my~id') + self.assertEqual(r.headers.get("Content-Type"), "application/json") + self.assertTrue(r.headers.get("x-ms-request-id")) + self.assertEqual(r.headers.get("x-ms-client-request-id"), "This~is~my~id") - actual_response = json.loads(r.content.decode('utf-8')) + actual_response = json.loads(r.content.decode("utf-8")) # Note: # The 'dims' field is defined as "repeated int64" in protobuf. # When it is serialized to JSON, all int64/fixed64/uint64 numbers are converted to string # Reference: https://developers.google.com/protocol-buffers/docs/proto3#json - self.assertTrue(actual_response['outputs']) - self.assertTrue(actual_response['outputs']['Plus214_Output_0']) - self.assertTrue(actual_response['outputs']['Plus214_Output_0']['dims']) - self.assertEqual(actual_response['outputs']['Plus214_Output_0']['dims'], ['1', '10']) - self.assertTrue(actual_response['outputs']['Plus214_Output_0']['dataType']) - self.assertEqual(actual_response['outputs']['Plus214_Output_0']['dataType'], 1) - self.assertTrue(actual_response['outputs']['Plus214_Output_0']['rawData']) - actual_data = test_util.decode_base64_string(actual_response['outputs']['Plus214_Output_0']['rawData'], '10f') - expected_data = test_util.decode_base64_string(expected_response['outputs']['Plus214_Output_0']['rawData'], '10f') + self.assertTrue(actual_response["outputs"]) + self.assertTrue(actual_response["outputs"]["Plus214_Output_0"]) + self.assertTrue(actual_response["outputs"]["Plus214_Output_0"]["dims"]) + self.assertEqual(actual_response["outputs"]["Plus214_Output_0"]["dims"], ["1", "10"]) + self.assertTrue(actual_response["outputs"]["Plus214_Output_0"]["dataType"]) + self.assertEqual(actual_response["outputs"]["Plus214_Output_0"]["dataType"], 1) + self.assertTrue(actual_response["outputs"]["Plus214_Output_0"]["rawData"]) + actual_data = test_util.decode_base64_string(actual_response["outputs"]["Plus214_Output_0"]["rawData"], "10f") + expected_data = test_util.decode_base64_string( + expected_response["outputs"]["Plus214_Output_0"]["rawData"], "10f" + ) for i in range(0, 10): self.assertTrue(test_util.compare_floats(actual_data[i], expected_data[i])) - def test_mnist_invalid_url(self): - url = self.url_pattern.format(self.server_ip, self.server_port, 'default', -1) + url = self.url_pattern.format(self.server_ip, self.server_port, "default", -1) test_util.test_log(url) - request_headers = { - 'Content-Type': 'application/json', - 'Accept': 'application/json' - } + request_headers = {"Content-Type": "application/json", "Accept": "application/json"} - r = requests.post(url, headers=request_headers, data={'foo': 'bar'}) + r = requests.post(url, headers=request_headers, data={"foo": "bar"}) self.assertEqual(r.status_code, 404) - self.assertEqual(r.headers.get('Content-Type'), 'application/json') - self.assertTrue(r.headers.get('x-ms-request-id')) - + self.assertEqual(r.headers.get("Content-Type"), "application/json") + self.assertTrue(r.headers.get("x-ms-request-id")) def test_mnist_invalid_content_type(self): - input_data_file = os.path.join(self.test_data_path, 'mnist_test_data_set_0_input.json') - url = self.url_pattern.format(self.server_ip, self.server_port, 'default', 1) + input_data_file = os.path.join(self.test_data_path, "mnist_test_data_set_0_input.json") + url = self.url_pattern.format(self.server_ip, self.server_port, "default", 1) test_util.test_log(url) request_headers = { - 'Content-Type': 'application/abc', - 'Accept': 'application/json', - 'x-ms-client-request-id': 'This~is~my~id' + "Content-Type": "application/abc", + "Accept": "application/json", + "x-ms-client-request-id": "This~is~my~id", } - with open(input_data_file, 'r') as f: + with open(input_data_file, "r") as f: request_payload = f.read() r = requests.post(url, headers=request_headers, data=request_payload) self.assertEqual(r.status_code, 400) - self.assertEqual(r.headers.get('Content-Type'), 'application/json') - self.assertTrue(r.headers.get('x-ms-request-id')) - self.assertEqual(r.headers.get('x-ms-client-request-id'), 'This~is~my~id') - self.assertEqual(r.content.decode('utf-8'), '{"error_code": 400, "error_message": "Missing or unknown \'Content-Type\' header field in the request"}\n') - + self.assertEqual(r.headers.get("Content-Type"), "application/json") + self.assertTrue(r.headers.get("x-ms-request-id")) + self.assertEqual(r.headers.get("x-ms-client-request-id"), "This~is~my~id") + self.assertEqual( + r.content.decode("utf-8"), + '{"error_code": 400, "error_message": "Missing or unknown \'Content-Type\' header field in the request"}\n', + ) def test_mnist_missing_content_type(self): - input_data_file = os.path.join(self.test_data_path, 'mnist_test_data_set_0_input.json') - url = self.url_pattern.format(self.server_ip, self.server_port, 'default', 1) + input_data_file = os.path.join(self.test_data_path, "mnist_test_data_set_0_input.json") + url = self.url_pattern.format(self.server_ip, self.server_port, "default", 1) test_util.test_log(url) - request_headers = { - 'Accept': 'application/json' - } + request_headers = {"Accept": "application/json"} - with open(input_data_file, 'r') as f: + with open(input_data_file, "r") as f: request_payload = f.read() r = requests.post(url, headers=request_headers, data=request_payload) self.assertEqual(r.status_code, 400) - self.assertEqual(r.headers.get('Content-Type'), 'application/json') - self.assertTrue(r.headers.get('x-ms-request-id')) - self.assertEqual(r.content.decode('utf-8'), '{"error_code": 400, "error_message": "Missing or unknown \'Content-Type\' header field in the request"}\n') - + self.assertEqual(r.headers.get("Content-Type"), "application/json") + self.assertTrue(r.headers.get("x-ms-request-id")) + self.assertEqual( + r.content.decode("utf-8"), + '{"error_code": 400, "error_message": "Missing or unknown \'Content-Type\' header field in the request"}\n', + ) def test_single_model_shortcut(self): - input_data_file = os.path.join(self.test_data_path, 'mnist_test_data_set_0_input.json') - output_data_file = os.path.join(self.test_data_path, 'mnist_test_data_set_0_output.json') + input_data_file = os.path.join(self.test_data_path, "mnist_test_data_set_0_input.json") + output_data_file = os.path.join(self.test_data_path, "mnist_test_data_set_0_output.json") - with open(input_data_file, 'r') as f: + with open(input_data_file, "r") as f: request_payload = f.read() - with open(output_data_file, 'r') as f: + with open(output_data_file, "r") as f: expected_response_json = f.read() expected_response = json.loads(expected_response_json) request_headers = { - 'Content-Type': 'application/json', - 'Accept': 'application/json', - 'x-ms-client-request-id': 'This~is~my~id' + "Content-Type": "application/json", + "Accept": "application/json", + "x-ms-client-request-id": "This~is~my~id", } url = "http://{0}:{1}/score".format(self.server_ip, self.server_port) test_util.test_log(url) r = requests.post(url, headers=request_headers, data=request_payload) self.assertEqual(r.status_code, 200) - self.assertEqual(r.headers.get('Content-Type'), 'application/json') - self.assertTrue(r.headers.get('x-ms-request-id')) - self.assertEqual(r.headers.get('x-ms-client-request-id'), 'This~is~my~id') + self.assertEqual(r.headers.get("Content-Type"), "application/json") + self.assertTrue(r.headers.get("x-ms-request-id")) + self.assertEqual(r.headers.get("x-ms-client-request-id"), "This~is~my~id") - actual_response = json.loads(r.content.decode('utf-8')) + actual_response = json.loads(r.content.decode("utf-8")) # Note: # The 'dims' field is defined as "repeated int64" in protobuf. # When it is serialized to JSON, all int64/fixed64/uint64 numbers are converted to string # Reference: https://developers.google.com/protocol-buffers/docs/proto3#json - self.assertTrue(actual_response['outputs']) - self.assertTrue(actual_response['outputs']['Plus214_Output_0']) - self.assertTrue(actual_response['outputs']['Plus214_Output_0']['dims']) - self.assertEqual(actual_response['outputs']['Plus214_Output_0']['dims'], ['1', '10']) - self.assertTrue(actual_response['outputs']['Plus214_Output_0']['dataType']) - self.assertEqual(actual_response['outputs']['Plus214_Output_0']['dataType'], 1) - self.assertTrue(actual_response['outputs']['Plus214_Output_0']['rawData']) - actual_data = test_util.decode_base64_string(actual_response['outputs']['Plus214_Output_0']['rawData'], '10f') - expected_data = test_util.decode_base64_string(expected_response['outputs']['Plus214_Output_0']['rawData'], '10f') + self.assertTrue(actual_response["outputs"]) + self.assertTrue(actual_response["outputs"]["Plus214_Output_0"]) + self.assertTrue(actual_response["outputs"]["Plus214_Output_0"]["dims"]) + self.assertEqual(actual_response["outputs"]["Plus214_Output_0"]["dims"], ["1", "10"]) + self.assertTrue(actual_response["outputs"]["Plus214_Output_0"]["dataType"]) + self.assertEqual(actual_response["outputs"]["Plus214_Output_0"]["dataType"], 1) + self.assertTrue(actual_response["outputs"]["Plus214_Output_0"]["rawData"]) + actual_data = test_util.decode_base64_string(actual_response["outputs"]["Plus214_Output_0"]["rawData"], "10f") + expected_data = test_util.decode_base64_string( + expected_response["outputs"]["Plus214_Output_0"]["rawData"], "10f" + ) for i in range(0, 10): self.assertTrue(test_util.compare_floats(actual_data[i], expected_data[i])) def test_single_version_shortcut(self): - input_data_file = os.path.join(self.test_data_path, 'mnist_test_data_set_0_input.json') - output_data_file = os.path.join(self.test_data_path, 'mnist_test_data_set_0_output.json') + input_data_file = os.path.join(self.test_data_path, "mnist_test_data_set_0_input.json") + output_data_file = os.path.join(self.test_data_path, "mnist_test_data_set_0_output.json") - with open(input_data_file, 'r') as f: + with open(input_data_file, "r") as f: request_payload = f.read() - with open(output_data_file, 'r') as f: + with open(output_data_file, "r") as f: expected_response_json = f.read() expected_response = json.loads(expected_response_json) request_headers = { - 'Content-Type': 'application/json', - 'Accept': 'application/json', - 'x-ms-client-request-id': 'This~is~my~id' + "Content-Type": "application/json", + "Accept": "application/json", + "x-ms-client-request-id": "This~is~my~id", } - url = "http://{0}:{1}/v1/models/{2}:predict".format(self.server_ip, self.server_port, 'default') + url = "http://{0}:{1}/v1/models/{2}:predict".format(self.server_ip, self.server_port, "default") test_util.test_log(url) r = requests.post(url, headers=request_headers, data=request_payload) self.assertEqual(r.status_code, 200) - self.assertEqual(r.headers.get('Content-Type'), 'application/json') - self.assertTrue(r.headers.get('x-ms-request-id')) - self.assertEqual(r.headers.get('x-ms-client-request-id'), 'This~is~my~id') + self.assertEqual(r.headers.get("Content-Type"), "application/json") + self.assertTrue(r.headers.get("x-ms-request-id")) + self.assertEqual(r.headers.get("x-ms-client-request-id"), "This~is~my~id") - actual_response = json.loads(r.content.decode('utf-8')) + actual_response = json.loads(r.content.decode("utf-8")) # Note: # The 'dims' field is defined as "repeated int64" in protobuf. # When it is serialized to JSON, all int64/fixed64/uint64 numbers are converted to string # Reference: https://developers.google.com/protocol-buffers/docs/proto3#json - self.assertTrue(actual_response['outputs']) - self.assertTrue(actual_response['outputs']['Plus214_Output_0']) - self.assertTrue(actual_response['outputs']['Plus214_Output_0']['dims']) - self.assertEqual(actual_response['outputs']['Plus214_Output_0']['dims'], ['1', '10']) - self.assertTrue(actual_response['outputs']['Plus214_Output_0']['dataType']) - self.assertEqual(actual_response['outputs']['Plus214_Output_0']['dataType'], 1) - self.assertTrue(actual_response['outputs']['Plus214_Output_0']['rawData']) - actual_data = test_util.decode_base64_string(actual_response['outputs']['Plus214_Output_0']['rawData'], '10f') - expected_data = test_util.decode_base64_string(expected_response['outputs']['Plus214_Output_0']['rawData'], '10f') + self.assertTrue(actual_response["outputs"]) + self.assertTrue(actual_response["outputs"]["Plus214_Output_0"]) + self.assertTrue(actual_response["outputs"]["Plus214_Output_0"]["dims"]) + self.assertEqual(actual_response["outputs"]["Plus214_Output_0"]["dims"], ["1", "10"]) + self.assertTrue(actual_response["outputs"]["Plus214_Output_0"]["dataType"]) + self.assertEqual(actual_response["outputs"]["Plus214_Output_0"]["dataType"], 1) + self.assertTrue(actual_response["outputs"]["Plus214_Output_0"]["rawData"]) + actual_data = test_util.decode_base64_string(actual_response["outputs"]["Plus214_Output_0"]["rawData"], "10f") + expected_data = test_util.decode_base64_string( + expected_response["outputs"]["Plus214_Output_0"]["rawData"], "10f" + ) for i in range(0, 10): self.assertTrue(test_util.compare_floats(actual_data[i], expected_data[i])) + class HttpProtobufPayloadTests(unittest.TestCase): - server_ip = '127.0.0.1' + server_ip = "127.0.0.1" server_port = 54321 - url_pattern = 'http://{0}:{1}/v1/models/{2}/versions/{3}:predict' - server_app_path = '' - test_data_path = '' - model_path = '' - log_level = 'verbose' + url_pattern = "http://{0}:{1}/v1/models/{2}/versions/{3}:predict" + server_app_path = "" + test_data_path = "" + model_path = "" + log_level = "verbose" server_app_proc = None wait_server_ready_in_seconds = 1 @classmethod def setUpClass(cls): - onnx_model = os.path.join(cls.model_path, 'model.onnx') + onnx_model = os.path.join(cls.model_path, "model.onnx") test_util.prepare_mnist_model(onnx_model) - cmd = [cls.server_app_path, '--http_port', str(cls.server_port), '--model_path', onnx_model, '--log_level', cls.log_level] - test_util.test_log('Launching server app: [{0}]'.format(' '.join(cmd))) + cmd = [ + cls.server_app_path, + "--http_port", + str(cls.server_port), + "--model_path", + onnx_model, + "--log_level", + cls.log_level, + ] + test_util.test_log("Launching server app: [{0}]".format(" ".join(cmd))) cls.server_app_proc = subprocess.Popen(cmd) - test_util.test_log('Server app PID: {0}'.format(cls.server_app_proc.pid)) - test_util.test_log('Sleep {0} second(s) to wait for server initialization'.format(cls.wait_server_ready_in_seconds)) + test_util.test_log("Server app PID: {0}".format(cls.server_app_proc.pid)) + test_util.test_log( + "Sleep {0} second(s) to wait for server initialization".format(cls.wait_server_ready_in_seconds) + ) time.sleep(cls.wait_server_ready_in_seconds) - @classmethod def tearDownClass(cls): - test_util.test_log('Shutdown server app') + test_util.test_log("Shutdown server app") cls.server_app_proc.kill() - test_util.test_log('PID {0} has been killed: {1}'.format(cls.server_app_proc.pid, test_util.is_process_killed(cls.server_app_proc.pid))) - + test_util.test_log( + "PID {0} has been killed: {1}".format( + cls.server_app_proc.pid, test_util.is_process_killed(cls.server_app_proc.pid) + ) + ) def test_mnist_happy_path(self): - input_data_file = os.path.join(self.test_data_path, 'mnist_test_data_set_0_input.pb') - output_data_file = os.path.join(self.test_data_path, 'mnist_test_data_set_0_output.pb') + input_data_file = os.path.join(self.test_data_path, "mnist_test_data_set_0_input.pb") + output_data_file = os.path.join(self.test_data_path, "mnist_test_data_set_0_output.pb") - with open(input_data_file, 'rb') as f: + with open(input_data_file, "rb") as f: request_payload = f.read() - content_type_headers = ['application/x-protobuf', 'application/octet-stream', 'application/vnd.google.protobuf'] + content_type_headers = ["application/x-protobuf", "application/octet-stream", "application/vnd.google.protobuf"] for h in content_type_headers: - request_headers = { - 'Content-Type': h, - 'Accept': 'application/x-protobuf' - } + request_headers = {"Content-Type": h, "Accept": "application/x-protobuf"} - url = self.url_pattern.format(self.server_ip, self.server_port, 'default', 1) + url = self.url_pattern.format(self.server_ip, self.server_port, "default", 1) test_util.test_log(url) r = requests.post(url, headers=request_headers, data=request_payload) self.assertEqual(r.status_code, 200) - self.assertEqual(r.headers.get('Content-Type'), 'application/x-protobuf') - self.assertTrue(r.headers.get('x-ms-request-id')) + self.assertEqual(r.headers.get("Content-Type"), "application/x-protobuf") + self.assertTrue(r.headers.get("x-ms-request-id")) actual_result = predict_pb2.PredictResponse() actual_result.ParseFromString(r.content) expected_result = predict_pb2.PredictResponse() - with open(output_data_file, 'rb') as f: + with open(output_data_file, "rb") as f: expected_result.ParseFromString(f.read()) for k in expected_result.outputs.keys(): self.assertEqual(actual_result.outputs[k].data_type, expected_result.outputs[k].data_type) count = 1 - for i in range(0, len(expected_result.outputs['Plus214_Output_0'].dims)): - self.assertEqual(actual_result.outputs['Plus214_Output_0'].dims[i], expected_result.outputs['Plus214_Output_0'].dims[i]) - count = count * int(actual_result.outputs['Plus214_Output_0'].dims[i]) - - actual_array = numpy.frombuffer(actual_result.outputs['Plus214_Output_0'].raw_data, dtype=numpy.float32) - expected_array = numpy.frombuffer(expected_result.outputs['Plus214_Output_0'].raw_data, dtype=numpy.float32) + for i in range(0, len(expected_result.outputs["Plus214_Output_0"].dims)): + self.assertEqual( + actual_result.outputs["Plus214_Output_0"].dims[i], + expected_result.outputs["Plus214_Output_0"].dims[i], + ) + count = count * int(actual_result.outputs["Plus214_Output_0"].dims[i]) + + actual_array = numpy.frombuffer(actual_result.outputs["Plus214_Output_0"].raw_data, dtype=numpy.float32) + expected_array = numpy.frombuffer(expected_result.outputs["Plus214_Output_0"].raw_data, dtype=numpy.float32) self.assertEqual(len(actual_array), len(expected_array)) self.assertEqual(len(actual_array), count) for i in range(0, count): self.assertTrue(test_util.compare_floats(actual_array[i], expected_array[i], rel_tol=0.001)) - def test_respect_accept_header(self): - input_data_file = os.path.join(self.test_data_path, 'mnist_test_data_set_0_input.pb') + input_data_file = os.path.join(self.test_data_path, "mnist_test_data_set_0_input.pb") - with open(input_data_file, 'rb') as f: + with open(input_data_file, "rb") as f: request_payload = f.read() - accept_headers = ['application/x-protobuf', 'application/octet-stream', 'application/vnd.google.protobuf'] + accept_headers = ["application/x-protobuf", "application/octet-stream", "application/vnd.google.protobuf"] for h in accept_headers: - request_headers = { - 'Content-Type': 'application/x-protobuf', - 'Accept': h - } + request_headers = {"Content-Type": "application/x-protobuf", "Accept": h} - url = self.url_pattern.format(self.server_ip, self.server_port, 'default', 1) + url = self.url_pattern.format(self.server_ip, self.server_port, "default", 1) test_util.test_log(url) r = requests.post(url, headers=request_headers, data=request_payload) self.assertEqual(r.status_code, 200) - self.assertEqual(r.headers.get('Content-Type'), h) - + self.assertEqual(r.headers.get("Content-Type"), h) def test_missing_accept_header(self): - input_data_file = os.path.join(self.test_data_path, 'mnist_test_data_set_0_input.pb') + input_data_file = os.path.join(self.test_data_path, "mnist_test_data_set_0_input.pb") - with open(input_data_file, 'rb') as f: + with open(input_data_file, "rb") as f: request_payload = f.read() request_headers = { - 'Content-Type': 'application/x-protobuf', + "Content-Type": "application/x-protobuf", } - url = self.url_pattern.format(self.server_ip, self.server_port, 'default', 1) + url = self.url_pattern.format(self.server_ip, self.server_port, "default", 1) test_util.test_log(url) r = requests.post(url, headers=request_headers, data=request_payload) self.assertEqual(r.status_code, 200) - self.assertEqual(r.headers.get('Content-Type'), 'application/octet-stream') - + self.assertEqual(r.headers.get("Content-Type"), "application/octet-stream") def test_any_accept_header(self): - input_data_file = os.path.join(self.test_data_path, 'mnist_test_data_set_0_input.pb') + input_data_file = os.path.join(self.test_data_path, "mnist_test_data_set_0_input.pb") - with open(input_data_file, 'rb') as f: + with open(input_data_file, "rb") as f: request_payload = f.read() - request_headers = { - 'Content-Type': 'application/x-protobuf', - 'Accept': '*/*' - } + request_headers = {"Content-Type": "application/x-protobuf", "Accept": "*/*"} - url = self.url_pattern.format(self.server_ip, self.server_port, 'default', 1) + url = self.url_pattern.format(self.server_ip, self.server_port, "default", 1) test_util.test_log(url) r = requests.post(url, headers=request_headers, data=request_payload) self.assertEqual(r.status_code, 200) - self.assertEqual(r.headers.get('Content-Type'), 'application/octet-stream') + self.assertEqual(r.headers.get("Content-Type"), "application/octet-stream") class HttpEndpointTests(unittest.TestCase): - server_ip = '127.0.0.1' + server_ip = "127.0.0.1" server_port = 54321 - server_app_path = '' - test_data_path = '' - model_path = '' - log_level = 'verbose' + server_app_path = "" + test_data_path = "" + model_path = "" + log_level = "verbose" server_app_proc = None wait_server_ready_in_seconds = 1 @classmethod def setUpClass(cls): - onnx_model = os.path.join(cls.model_path, 'model.onnx') + onnx_model = os.path.join(cls.model_path, "model.onnx") test_util.prepare_mnist_model(onnx_model) - cmd = [cls.server_app_path, '--http_port', str(cls.server_port), '--model_path', onnx_model, '--log_level', cls.log_level] - test_util.test_log('Launching server app: [{0}]'.format(' '.join(cmd))) + cmd = [ + cls.server_app_path, + "--http_port", + str(cls.server_port), + "--model_path", + onnx_model, + "--log_level", + cls.log_level, + ] + test_util.test_log("Launching server app: [{0}]".format(" ".join(cmd))) cls.server_app_proc = subprocess.Popen(cmd) - test_util.test_log('Server app PID: {0}'.format(cls.server_app_proc.pid)) - test_util.test_log('Sleep {0} second(s) to wait for server initialization'.format(cls.wait_server_ready_in_seconds)) + test_util.test_log("Server app PID: {0}".format(cls.server_app_proc.pid)) + test_util.test_log( + "Sleep {0} second(s) to wait for server initialization".format(cls.wait_server_ready_in_seconds) + ) time.sleep(cls.wait_server_ready_in_seconds) - @classmethod def tearDownClass(cls): - test_util.test_log('Shutdown server app') + test_util.test_log("Shutdown server app") cls.server_app_proc.kill() - test_util.test_log('PID {0} has been killed: {1}'.format(cls.server_app_proc.pid, test_util.is_process_killed(cls.server_app_proc.pid))) - + test_util.test_log( + "PID {0} has been killed: {1}".format( + cls.server_app_proc.pid, test_util.is_process_killed(cls.server_app_proc.pid) + ) + ) def test_health_endpoint(self): url = url = "http://{0}:{1}/".format(self.server_ip, self.server_port) test_util.test_log(url) r = requests.get(url) self.assertEqual(r.status_code, 200) - self.assertEqual(r.content.decode('utf-8'), 'Healthy') + self.assertEqual(r.content.decode("utf-8"), "Healthy") + class GRPCTests(unittest.TestCase): - server_ip = '127.0.0.1' + server_ip = "127.0.0.1" server_port = 54321 - server_app_path = '' - test_data_path = '' - model_path = '' - log_level = 'verbose' + server_app_path = "" + test_data_path = "" + model_path = "" + log_level = "verbose" server_app_proc = None wait_server_ready_in_seconds = 1 @classmethod def setUpClass(cls): - onnx_model = os.path.join(cls.model_path, 'model.onnx') + onnx_model = os.path.join(cls.model_path, "model.onnx") test_util.prepare_mnist_model(onnx_model) - cmd = [cls.server_app_path, '--grpc_port', str(cls.server_port), '--model_path', onnx_model, '--log_level', cls.log_level] - test_util.test_log('Launching server app: [{0}]'.format(' '.join(cmd))) + cmd = [ + cls.server_app_path, + "--grpc_port", + str(cls.server_port), + "--model_path", + onnx_model, + "--log_level", + cls.log_level, + ] + test_util.test_log("Launching server app: [{0}]".format(" ".join(cmd))) cls.server_app_proc = subprocess.Popen(cmd) - test_util.test_log('Server app PID: {0}'.format(cls.server_app_proc.pid)) - test_util.test_log('Sleep {0} second(s) to wait for server initialization'.format(cls.wait_server_ready_in_seconds)) + test_util.test_log("Server app PID: {0}".format(cls.server_app_proc.pid)) + test_util.test_log( + "Sleep {0} second(s) to wait for server initialization".format(cls.wait_server_ready_in_seconds) + ) time.sleep(cls.wait_server_ready_in_seconds) - @classmethod def tearDownClass(cls): - test_util.test_log('Shutdown server app') + test_util.test_log("Shutdown server app") cls.server_app_proc.kill() - test_util.test_log('PID {0} has been killed: {1}'.format(cls.server_app_proc.pid, test_util.is_process_killed(cls.server_app_proc.pid))) - + test_util.test_log( + "PID {0} has been killed: {1}".format( + cls.server_app_proc.pid, test_util.is_process_killed(cls.server_app_proc.pid) + ) + ) def test_mnist_happy_path(self): - input_data_file = os.path.join(self.test_data_path, 'mnist_test_data_set_0_input.pb') - output_data_file = os.path.join(self.test_data_path, 'mnist_test_data_set_0_output.pb') + input_data_file = os.path.join(self.test_data_path, "mnist_test_data_set_0_input.pb") + output_data_file = os.path.join(self.test_data_path, "mnist_test_data_set_0_output.pb") - with open(input_data_file, 'rb') as f: + with open(input_data_file, "rb") as f: request_payload = f.read() request = predict_pb2.PredictRequest() @@ -449,24 +494,26 @@ def test_mnist_happy_path(self): actual_result = stub.Predict(request) expected_result = predict_pb2.PredictResponse() - with open(output_data_file, 'rb') as f: + with open(output_data_file, "rb") as f: expected_result.ParseFromString(f.read()) for k in expected_result.outputs.keys(): self.assertEqual(actual_result.outputs[k].data_type, expected_result.outputs[k].data_type) count = 1 - for i in range(0, len(expected_result.outputs['Plus214_Output_0'].dims)): - self.assertEqual(actual_result.outputs['Plus214_Output_0'].dims[i], expected_result.outputs['Plus214_Output_0'].dims[i]) - count = count * int(actual_result.outputs['Plus214_Output_0'].dims[i]) - - actual_array = numpy.frombuffer(actual_result.outputs['Plus214_Output_0'].raw_data, dtype=numpy.float32) - expected_array = numpy.frombuffer(expected_result.outputs['Plus214_Output_0'].raw_data, dtype=numpy.float32) + for i in range(0, len(expected_result.outputs["Plus214_Output_0"].dims)): + self.assertEqual( + actual_result.outputs["Plus214_Output_0"].dims[i], expected_result.outputs["Plus214_Output_0"].dims[i] + ) + count = count * int(actual_result.outputs["Plus214_Output_0"].dims[i]) + + actual_array = numpy.frombuffer(actual_result.outputs["Plus214_Output_0"].raw_data, dtype=numpy.float32) + expected_array = numpy.frombuffer(expected_result.outputs["Plus214_Output_0"].raw_data, dtype=numpy.float32) self.assertEqual(len(actual_array), len(expected_array)) self.assertEqual(len(actual_array), count) for i in range(0, count): self.assertTrue(test_util.compare_floats(actual_array[i], expected_array[i], rel_tol=0.001)) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/server/test/integration_tests/model_zoo_data_prep.py b/server/test/integration_tests/model_zoo_data_prep.py index f5767d9748a8f..cad0498b24621 100644 --- a/server/test/integration_tests/model_zoo_data_prep.py +++ b/server/test/integration_tests/model_zoo_data_prep.py @@ -1,146 +1,151 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. +import json import os -import sys import shutil -import json +import sys from google.protobuf.json_format import MessageToJson + # Current models only have one input and one output def get_io_name(model_file_name): - sess = onnxruntime.InferenceSession(model_file_name) - return sess.get_inputs()[0].name, sess.get_outputs()[0].name + sess = onnxruntime.InferenceSession(model_file_name) + return sess.get_inputs()[0].name, sess.get_outputs()[0].name def gen_input_pb(pb_full_path, input_name, output_name, request_file_path): - t = onnx_ml_pb2.TensorProto() - with open(pb_full_path, 'rb') as fin: - t.ParseFromString(fin.read()) - predict_request = predict_pb2.PredictRequest() - predict_request.inputs[input_name].CopyFrom(t) - predict_request.output_filter.append(output_name) + t = onnx_ml_pb2.TensorProto() + with open(pb_full_path, "rb") as fin: + t.ParseFromString(fin.read()) + predict_request = predict_pb2.PredictRequest() + predict_request.inputs[input_name].CopyFrom(t) + predict_request.output_filter.append(output_name) - with open(request_file_path, "wb") as fout: - fout.write(predict_request.SerializeToString()) + with open(request_file_path, "wb") as fout: + fout.write(predict_request.SerializeToString()) def gen_output_pb(pb_full_path, output_name, response_file_path): - t = onnx_ml_pb2.TensorProto() - with open(pb_full_path, 'rb') as fin: - t.ParseFromString(fin.read()) - predict_response = predict_pb2.PredictResponse() - predict_response.outputs[output_name].CopyFrom(t) + t = onnx_ml_pb2.TensorProto() + with open(pb_full_path, "rb") as fin: + t.ParseFromString(fin.read()) + predict_response = predict_pb2.PredictResponse() + predict_response.outputs[output_name].CopyFrom(t) - with open(response_file_path, "wb") as fout: - fout.write(predict_response.SerializeToString()) + with open(response_file_path, "wb") as fout: + fout.write(predict_response.SerializeToString()) def tensor2dict(full_path): - t = onnx_ml_pb2.TensorProto() - with open(full_path, 'rb') as f: - t.ParseFromString(f.read()) + t = onnx_ml_pb2.TensorProto() + with open(full_path, "rb") as f: + t.ParseFromString(f.read()) - jsonStr = MessageToJson(t, use_integers_for_enums=True) - data = json.loads(jsonStr) + jsonStr = MessageToJson(t, use_integers_for_enums=True) + data = json.loads(jsonStr) - return data + return data def gen_input_json(pb_full_path, input_name, output_name, json_file_path): - data = tensor2dict(pb_full_path) + data = tensor2dict(pb_full_path) - inputs = {} - inputs[input_name] = data - output_filters = [ output_name ] + inputs = {} + inputs[input_name] = data + output_filters = [output_name] - req = {} - req["inputs"] = inputs - req["outputFilter"] = output_filters + req = {} + req["inputs"] = inputs + req["outputFilter"] = output_filters - with open(json_file_path, 'w') as outfile: - json.dump(req, outfile) + with open(json_file_path, "w") as outfile: + json.dump(req, outfile) def gen_output_json(pb_full_path, output_name, json_file_path): - data = tensor2dict(pb_full_path) + data = tensor2dict(pb_full_path) - output = {} - output[output_name] = data + output = {} + output[output_name] = data - resp = {} - resp["outputs"] = output + resp = {} + resp["outputs"] = output - with open(json_file_path, 'w') as outfile: - json.dump(resp, outfile) + with open(json_file_path, "w") as outfile: + json.dump(resp, outfile) def gen_req_resp(model_zoo, test_data, copy_model=False): - skip_list = [ - ('opset8', 'mxnet_arcface') # REASON: Known issue - ] - - opsets = [name for name in os.listdir(model_zoo) if os.path.isdir(os.path.join(model_zoo, name))] - for opset in opsets: - os.makedirs(os.path.join(test_data, opset), exist_ok=True) - - current_model_folder = os.path.join(model_zoo, opset) - current_data_folder = os.path.join(test_data, opset) - - models = [name for name in os.listdir(current_model_folder) if os.path.isdir(os.path.join(current_model_folder, name))] - for model in models: - print("Working on Opset: {0}, Model: {1}".format(opset, model)) - if (opset, model) in skip_list: - print(" SKIP!!") - continue - - os.makedirs(os.path.join(current_data_folder, model), exist_ok=True) - - src_folder = os.path.join(current_model_folder, model) - dst_folder = os.path.join(current_data_folder, model) - - onnx_file_path = '' - for fname in os.listdir(src_folder): - if not fname.startswith(".") and fname.endswith(".onnx") and os.path.isfile(os.path.join(src_folder, fname)): - onnx_file_path = os.path.join(src_folder, fname) - break - - if onnx_file_path == '': - raise FileNotFoundError('Could not find any *.onnx file in {0}'.format(src_folder)) - - if copy_model: - # Copy model file - target_file_path = os.path.join(dst_folder, "model.onnx") - shutil.copy2(onnx_file_path, target_file_path) - - for fname in os.listdir(src_folder): - if not fname.endswith(".onnx") and os.path.isfile(os.path.join(src_folder, fname)): - shutil.copy2(os.path.join(src_folder, fname), dst_folder) - - iname, oname = get_io_name(onnx_file_path) - model_test_data = [name for name in os.listdir(src_folder) if os.path.isdir(os.path.join(src_folder, name))] - for test in model_test_data: - src = os.path.join(src_folder, test) - dst = os.path.join(dst_folder, test) - os.makedirs(dst, exist_ok=True) - gen_input_json(os.path.join(src, 'input_0.pb'), iname, oname, os.path.join(dst, 'request.json')) - gen_output_json(os.path.join(src, 'output_0.pb'), oname, os.path.join(dst, 'response.json')) - gen_input_pb(os.path.join(src, 'input_0.pb'), iname, oname, os.path.join(dst, 'request.pb')) - gen_output_pb(os.path.join(src, 'output_0.pb'), oname, os.path.join(dst, 'response.pb')) - - -if __name__ == '__main__': - model_zoo = os.path.realpath(sys.argv[1]) - test_data = os.path.realpath(sys.argv[2]) - - sys.path.append(os.path.realpath(sys.argv[3])) - sys.path.append(os.path.realpath(sys.argv[4])) - - import onnxruntime - import predict_pb2 - import onnx_ml_pb2 - - os.makedirs(test_data, exist_ok=True) - gen_req_resp(model_zoo, test_data) - \ No newline at end of file + skip_list = [("opset8", "mxnet_arcface")] # REASON: Known issue + + opsets = [name for name in os.listdir(model_zoo) if os.path.isdir(os.path.join(model_zoo, name))] + for opset in opsets: + os.makedirs(os.path.join(test_data, opset), exist_ok=True) + + current_model_folder = os.path.join(model_zoo, opset) + current_data_folder = os.path.join(test_data, opset) + + models = [ + name for name in os.listdir(current_model_folder) if os.path.isdir(os.path.join(current_model_folder, name)) + ] + for model in models: + print("Working on Opset: {0}, Model: {1}".format(opset, model)) + if (opset, model) in skip_list: + print(" SKIP!!") + continue + + os.makedirs(os.path.join(current_data_folder, model), exist_ok=True) + + src_folder = os.path.join(current_model_folder, model) + dst_folder = os.path.join(current_data_folder, model) + + onnx_file_path = "" + for fname in os.listdir(src_folder): + if ( + not fname.startswith(".") + and fname.endswith(".onnx") + and os.path.isfile(os.path.join(src_folder, fname)) + ): + onnx_file_path = os.path.join(src_folder, fname) + break + + if onnx_file_path == "": + raise FileNotFoundError("Could not find any *.onnx file in {0}".format(src_folder)) + + if copy_model: + # Copy model file + target_file_path = os.path.join(dst_folder, "model.onnx") + shutil.copy2(onnx_file_path, target_file_path) + + for fname in os.listdir(src_folder): + if not fname.endswith(".onnx") and os.path.isfile(os.path.join(src_folder, fname)): + shutil.copy2(os.path.join(src_folder, fname), dst_folder) + + iname, oname = get_io_name(onnx_file_path) + model_test_data = [name for name in os.listdir(src_folder) if os.path.isdir(os.path.join(src_folder, name))] + for test in model_test_data: + src = os.path.join(src_folder, test) + dst = os.path.join(dst_folder, test) + os.makedirs(dst, exist_ok=True) + gen_input_json(os.path.join(src, "input_0.pb"), iname, oname, os.path.join(dst, "request.json")) + gen_output_json(os.path.join(src, "output_0.pb"), oname, os.path.join(dst, "response.json")) + gen_input_pb(os.path.join(src, "input_0.pb"), iname, oname, os.path.join(dst, "request.pb")) + gen_output_pb(os.path.join(src, "output_0.pb"), oname, os.path.join(dst, "response.pb")) + + +if __name__ == "__main__": + model_zoo = os.path.realpath(sys.argv[1]) + test_data = os.path.realpath(sys.argv[2]) + + sys.path.append(os.path.realpath(sys.argv[3])) + sys.path.append(os.path.realpath(sys.argv[4])) + + import onnx_ml_pb2 + import predict_pb2 + + import onnxruntime + + os.makedirs(test_data, exist_ok=True) + gen_req_resp(model_zoo, test_data) diff --git a/server/test/integration_tests/model_zoo_tests.py b/server/test/integration_tests/model_zoo_tests.py index 8f1e79d5aba72..128625a2aab64 100644 --- a/server/test/integration_tests/model_zoo_tests.py +++ b/server/test/integration_tests/model_zoo_tests.py @@ -1,86 +1,94 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. -import unittest -import random import os +import random import sys +import unittest class ModelZooTests(unittest.TestCase): - server_ip = '127.0.0.1' + server_ip = "127.0.0.1" server_port = 54321 grpc_port = 56789 - url_pattern = 'http://{0}:{1}/v1/models/{2}/versions/{3}:predict' - server_app_path = '' # Required - log_level = 'verbose' + url_pattern = "http://{0}:{1}/v1/models/{2}/versions/{3}:predict" + server_app_path = "" # Required + log_level = "verbose" server_ready_in_seconds = 10 server_off_in_seconds = 100 need_data_preparation = False need_data_cleanup = False - model_zoo_model_path = '' # Required - model_zoo_test_data_path = '' # Required - supported_opsets = ['opset7', 'opset8', 'opset9', 'opset_7', 'opset_8', 'opset_9'] + model_zoo_model_path = "" # Required + model_zoo_test_data_path = "" # Required + supported_opsets = ["opset7", "opset8", "opset9", "opset_7", "opset_8", "opset_9"] skipped_models = [ - ('opset7', 'tf_inception_v2'), # Known issue + ("opset7", "tf_inception_v2"), # Known issue ] def __test_model(self, model_path, data_paths): - json_request_headers = { - 'Content-Type': 'application/json', - 'Accept': 'application/json' - } - pb_request_headers = { - 'Content-Type': 'application/octet-stream', - 'Accept': 'application/octet-stream' - } + json_request_headers = {"Content-Type": "application/json", "Accept": "application/json"} + pb_request_headers = {"Content-Type": "application/octet-stream", "Accept": "application/octet-stream"} server_app_proc = None try: - onnx_file_path = '' + onnx_file_path = "" for fname in os.listdir(model_path): - if not fname.startswith(".") and fname.endswith(".onnx") and os.path.isfile(os.path.join(model_path, fname)): + if ( + not fname.startswith(".") + and fname.endswith(".onnx") + and os.path.isfile(os.path.join(model_path, fname)) + ): onnx_file_path = os.path.join(model_path, fname) break - if onnx_file_path == '': - raise FileNotFoundError('Could not find any *.onnx file in {0}'.format(model_path)) - - cmd = [self.server_app_path, '--http_port', str(self.server_port), '--model_path', onnx_file_path, '--log_level', self.log_level, '--grpc_port', str(self.grpc_port)] + if onnx_file_path == "": + raise FileNotFoundError("Could not find any *.onnx file in {0}".format(model_path)) + + cmd = [ + self.server_app_path, + "--http_port", + str(self.server_port), + "--model_path", + onnx_file_path, + "--log_level", + self.log_level, + "--grpc_port", + str(self.grpc_port), + ] test_util.test_log(cmd) - server_app_proc = test_util.launch_server_app(cmd, self.server_ip, self.server_port, - self.server_ready_in_seconds) + server_app_proc = test_util.launch_server_app( + cmd, self.server_ip, self.server_port, self.server_ready_in_seconds + ) - test_util.test_log('[{0}] Run tests...'.format(model_path)) + test_util.test_log("[{0}] Run tests...".format(model_path)) for test in data_paths: - test_util.test_log('[{0}] Current: {0}'.format(model_path, test)) + test_util.test_log("[{0}] Current: {0}".format(model_path, test)) - test_util.test_log('[{0}] JSON payload testing ....'.format(model_path)) - url = self.url_pattern.format(self.server_ip, self.server_port, 'default_model', 12345) - with open(os.path.join(test, 'request.json')) as f: + test_util.test_log("[{0}] JSON payload testing ....".format(model_path)) + url = self.url_pattern.format(self.server_ip, self.server_port, "default_model", 12345) + with open(os.path.join(test, "request.json")) as f: request_payload = f.read() resp = test_util.make_http_request(url, json_request_headers, request_payload) - test_util.json_response_validation(self, resp, os.path.join(test, 'response.json')) + test_util.json_response_validation(self, resp, os.path.join(test, "response.json")) - test_util.test_log('[{0}] Protobuf payload testing ....'.format(model_path)) - url = self.url_pattern.format(self.server_ip, self.server_port, 'default_model', 54321) - with open(os.path.join(test, 'request.pb'), 'rb') as f: + test_util.test_log("[{0}] Protobuf payload testing ....".format(model_path)) + url = self.url_pattern.format(self.server_ip, self.server_port, "default_model", 54321) + with open(os.path.join(test, "request.pb"), "rb") as f: request_payload = f.read() resp = test_util.make_http_request(url, pb_request_headers, request_payload) - test_util.pb_response_validation(self, resp, os.path.join(test, 'response.pb')) + test_util.pb_response_validation(self, resp, os.path.join(test, "response.pb")) - test_util.test_log('[{0}] GRPC testing ....'.format(model_path)) - uri = ("{}:{}".format(self.server_ip, self.grpc_port)) - with open(os.path.join(test, 'request.pb'), 'rb') as f: + test_util.test_log("[{0}] GRPC testing ....".format(model_path)) + uri = "{}:{}".format(self.server_ip, self.grpc_port) + with open(os.path.join(test, "request.pb"), "rb") as f: request_payload = f.read() with grpc.insecure_channel(uri) as channel: stub = prediction_service_pb2_grpc.PredictionServiceStub(channel) resp = stub.Predict(request_payload) - test_util.pb_response_validation(self, resp, os.path.join(test, 'response.pb')) + test_util.pb_response_validation(self, resp, os.path.join(test, "response.pb")) finally: test_util.shutdown_server_app(server_app_proc, self.server_off_in_seconds) - def test_models_from_model_zoo(self): model_data_map = {} for opset in self.supported_opsets: @@ -95,14 +103,18 @@ def test_models_from_model_zoo(self): if os.path.isdir(os.path.join(test_data_folder, name)): current_dir = os.path.join(test_data_folder, name) - model_data_map[os.path.join(model_file_folder, name)] = [os.path.join(current_dir, name) for name in os.listdir(current_dir) if os.path.isdir(os.path.join(current_dir, name))] + model_data_map[os.path.join(model_file_folder, name)] = [ + os.path.join(current_dir, name) + for name in os.listdir(current_dir) + if os.path.isdir(os.path.join(current_dir, name)) + ] - test_util.test_log('Planned models and test data:') + test_util.test_log("Planned models and test data:") for model_data, data_paths in model_data_map.items(): test_util.test_log(model_data) for data in data_paths: - test_util.test_log('\t\t{0}'.format(data)) - test_util.test_log('-----------------------') + test_util.test_log("\t\t{0}".format(data)) + test_util.test_log("-----------------------") self.server_port = random.randint(30000, 40000) self.grpc_port = self.server_port + 1 @@ -110,7 +122,7 @@ def test_models_from_model_zoo(self): self.__test_model(model_path, data_paths) -if __name__ == '__main__': +if __name__ == "__main__": sys.path.append(sys.argv[4]) sys.path.append(sys.argv[5]) diff --git a/server/test/integration_tests/test_main.py b/server/test/integration_tests/test_main.py index c6f3e82c12c83..edb8121b65b92 100644 --- a/server/test/integration_tests/test_main.py +++ b/server/test/integration_tests/test_main.py @@ -1,11 +1,11 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. -import sys import random +import sys import unittest -if __name__ == '__main__': +if __name__ == "__main__": sys.path.append(sys.argv[4]) sys.path.append(sys.argv[5]) @@ -13,7 +13,12 @@ loader = unittest.TestLoader() - test_classes = [function_tests.HttpJsonPayloadTests, function_tests.HttpProtobufPayloadTests, function_tests.HttpEndpointTests, function_tests.GRPCTests] + test_classes = [ + function_tests.HttpJsonPayloadTests, + function_tests.HttpProtobufPayloadTests, + function_tests.HttpEndpointTests, + function_tests.GRPCTests, + ] test_suites = [] for tests in test_classes: @@ -28,9 +33,8 @@ runner = unittest.TextTestRunner(verbosity=2) results = runner.run(suites) - + if results.wasSuccessful(): exit(0) else: exit(1) - \ No newline at end of file diff --git a/server/test/integration_tests/test_util.py b/server/test/integration_tests/test_util.py index 71d104c30cb76..62276e2679dc4 100644 --- a/server/test/integration_tests/test_util.py +++ b/server/test/integration_tests/test_util.py @@ -1,35 +1,36 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. -import os import base64 -import struct -import math -import subprocess -import time -import requests -import json import datetime -import socket import errno +import json +import math +import os +import socket +import struct +import subprocess import sys +import time import urllib.request -import predict_pb2 -import onnx_ml_pb2 import numpy +import onnx_ml_pb2 +import predict_pb2 +import requests + def test_log(str): - print('[Test Log][{0}] {1}'.format(datetime.datetime.now(), str)) + print("[Test Log][{0}] {1}".format(datetime.datetime.now(), str)) def is_process_killed(pid): if sys.platform.startswith("win"): - process_name = 'onnxruntime_host.exe' - call = 'TASKLIST', '/FI', 'imagename eq {0}'.format(process_name) - output = subprocess.check_output(call).decode('utf-8') + process_name = "onnxruntime_host.exe" + call = "TASKLIST", "/FI", "imagename eq {0}".format(process_name) + output = subprocess.check_output(call).decode("utf-8") print(output) - last_line = output.strip().split('\r\n')[-1] + last_line = output.strip().split("\r\n")[-1] return not last_line.lower().startswith(process_name) else: try: @@ -39,13 +40,18 @@ def is_process_killed(pid): else: return True + def prepare_mnist_model(target_path): if not os.path.isfile(target_path): # Backup path: in case the mnist model is missing, we need to download it from Internet. - test_log('Downloading model from blob storage: https://ortsrvdev.blob.core.windows.net/test-data/model.onnx to {0}'.format(target_path)) - urllib.request.urlretrieve('https://ortsrvdev.blob.core.windows.net/test-data/model.onnx', target_path) + test_log( + "Downloading model from blob storage: https://ortsrvdev.blob.core.windows.net/test-data/model.onnx to {0}".format( + target_path + ) + ) + urllib.request.urlretrieve("https://ortsrvdev.blob.core.windows.net/test-data/model.onnx", target_path) else: - test_log('Found mnist model at {0}'.format(target_path)) + test_log("Found mnist model at {0}".format(target_path)) def decode_base64_string(s, count_and_type): @@ -57,7 +63,11 @@ def decode_base64_string(s, count_and_type): def compare_floats(a, b, rel_tol=0.0001, abs_tol=0.0001): if not math.isclose(a, b, rel_tol=rel_tol, abs_tol=abs_tol): - test_log('Not match with relative tolerance {0} and absolute tolerance {1}: {2} and {3}'.format(rel_tol, abs_tol, a, b)) + test_log( + "Not match with relative tolerance {0} and absolute tolerance {1}: {2} and {3}".format( + rel_tol, abs_tol, a, b + ) + ) return False return True @@ -75,8 +85,8 @@ def wait_service_up(server, port, timeout=1): if next_timeout < 0: return False else: - s.settimeout(next_timeout) - + s.settimeout(next_timeout) + s.connect((server, port)) except socket.timeout as err: if timeout: @@ -89,10 +99,10 @@ def wait_service_up(server, port, timeout=1): def launch_server_app(cmd, server_ip, server_port, wait_server_ready_in_seconds): - test_log('Launching server app: [{0}]'.format(' '.join(cmd))) + test_log("Launching server app: [{0}]".format(" ".join(cmd))) server_app_proc = subprocess.Popen(cmd) - test_log('Server app PID: {0}'.format(server_app_proc.pid)) - test_log('Wait up to {0} second(s) for server initialization'.format(wait_server_ready_in_seconds)) + test_log("Server app PID: {0}".format(server_app_proc.pid)) + test_log("Wait up to {0} second(s) for server initialization".format(wait_server_ready_in_seconds)) wait_service_up(server_ip, server_port, wait_server_ready_in_seconds) # Additional sleep to make sure the server is ready. @@ -103,12 +113,12 @@ def launch_server_app(cmd, server_ip, server_port, wait_server_ready_in_seconds) def shutdown_server_app(server_app_proc, wait_for_server_off_in_seconds): if server_app_proc is not None: - test_log('Shutdown server app') + test_log("Shutdown server app") server_app_proc.kill() while not is_process_killed(server_app_proc.pid): server_app_proc.wait(timeout=wait_for_server_off_in_seconds) - test_log('PID {0} has been killed: {1}'.format(server_app_proc.pid, is_process_killed(server_app_proc.pid))) + test_log("PID {0} has been killed: {1}".format(server_app_proc.pid, is_process_killed(server_app_proc.pid))) # Additional sleep to make sure the resource has been freed. time.sleep(1) @@ -117,62 +127,65 @@ def shutdown_server_app(server_app_proc, wait_for_server_off_in_seconds): def make_http_request(url, request_headers, payload): - test_log('POST Request Started') + test_log("POST Request Started") resp = requests.post(url, headers=request_headers, data=payload) - test_log('POST Request Done') + test_log("POST Request Done") return resp def json_response_validation(cls, resp, expected_resp_json_file): cls.assertEqual(resp.status_code, 200) - cls.assertTrue(resp.headers.get('x-ms-request-id')) - cls.assertEqual(resp.headers.get('Content-Type'), 'application/json') + cls.assertTrue(resp.headers.get("x-ms-request-id")) + cls.assertEqual(resp.headers.get("Content-Type"), "application/json") with open(expected_resp_json_file) as f: expected_result = json.loads(f.read()) - actual_response = json.loads(resp.content.decode('utf-8')) - cls.assertTrue(actual_response['outputs']) + actual_response = json.loads(resp.content.decode("utf-8")) + cls.assertTrue(actual_response["outputs"]) - for output in expected_result['outputs'].keys(): - cls.assertTrue(actual_response['outputs'][output]) - cls.assertTrue(actual_response['outputs'][output]['dataType']) - cls.assertEqual(actual_response['outputs'][output]['dataType'], expected_result['outputs'][output]['dataType']) - cls.assertTrue(actual_response['outputs'][output]['dims']) - cls.assertEqual(actual_response['outputs'][output]['dims'], expected_result['outputs'][output]['dims']) - cls.assertTrue(actual_response['outputs'][output]['rawData']) + for output in expected_result["outputs"].keys(): + cls.assertTrue(actual_response["outputs"][output]) + cls.assertTrue(actual_response["outputs"][output]["dataType"]) + cls.assertEqual(actual_response["outputs"][output]["dataType"], expected_result["outputs"][output]["dataType"]) + cls.assertTrue(actual_response["outputs"][output]["dims"]) + cls.assertEqual(actual_response["outputs"][output]["dims"], expected_result["outputs"][output]["dims"]) + cls.assertTrue(actual_response["outputs"][output]["rawData"]) count = 1 - for x in actual_response['outputs'][output]['dims']: + for x in actual_response["outputs"][output]["dims"]: count = count * int(x) - if actual_response['outputs'][output]['dataType'] == 10 or actual_response['outputs'][output]['dataType'] == 16: - actual_array = numpy.frombuffer(base64.b64decode(actual_response['outputs'][output]['rawData']), dtype=numpy.float16) - expected_array = numpy.frombuffer(base64.b64decode(expected_result['outputs'][output]['rawData']), dtype=numpy.float16) + if actual_response["outputs"][output]["dataType"] == 10 or actual_response["outputs"][output]["dataType"] == 16: + actual_array = numpy.frombuffer( + base64.b64decode(actual_response["outputs"][output]["rawData"]), dtype=numpy.float16 + ) + expected_array = numpy.frombuffer( + base64.b64decode(expected_result["outputs"][output]["rawData"]), dtype=numpy.float16 + ) cls.assertEqual(len(actual_array), len(expected_array)) cls.assertEqual(len(actual_array), count) for i in range(0, count): cls.assertTrue(compare_floats(actual_array[i], expected_array[i], rel_tol=0.05, abs_tol=0.05)) - elif actual_response['outputs'][output]['dataType'] == 1: - actual_array = decode_base64_string(actual_response['outputs'][output]['rawData'], '{0}f'.format(count)) - expected_array = decode_base64_string(expected_result['outputs'][output]['rawData'], '{0}f'.format(count)) + elif actual_response["outputs"][output]["dataType"] == 1: + actual_array = decode_base64_string(actual_response["outputs"][output]["rawData"], "{0}f".format(count)) + expected_array = decode_base64_string(expected_result["outputs"][output]["rawData"], "{0}f".format(count)) cls.assertEqual(len(actual_array), len(expected_array)) cls.assertEqual(len(actual_array), count) for i in range(0, count): cls.assertTrue(compare_floats(actual_array[i], expected_array[i], rel_tol=0.001)) - def pb_response_validation(cls, resp, expected_resp_pb_file): cls.assertEqual(resp.status_code, 200) - cls.assertTrue(resp.headers.get('x-ms-request-id')) - cls.assertEqual(resp.headers.get('Content-Type'), 'application/octet-stream') + cls.assertTrue(resp.headers.get("x-ms-request-id")) + cls.assertEqual(resp.headers.get("Content-Type"), "application/octet-stream") actual_result = predict_pb2.PredictResponse() actual_result.ParseFromString(resp.content) expected_result = predict_pb2.PredictResponse() - with open(expected_resp_pb_file, 'rb') as f: + with open(expected_resp_pb_file, "rb") as f: expected_result.ParseFromString(f.read()) for k in expected_result.outputs.keys(): @@ -196,4 +209,4 @@ def pb_response_validation(cls, resp, expected_resp_pb_file): cls.assertEqual(len(actual_array), len(expected_array)) cls.assertEqual(len(actual_array), count) for i in range(0, count): - cls.assertTrue(compare_floats(actual_array[i], expected_array[i], rel_tol=0.001)) \ No newline at end of file + cls.assertTrue(compare_floats(actual_array[i], expected_array[i], rel_tol=0.001)) diff --git a/setup.py b/setup.py index 262b7e27ef0fb..3e25d08d100c0 100644 --- a/setup.py +++ b/setup.py @@ -3,20 +3,21 @@ # Licensed under the MIT License. # ------------------------------------------------------------------------ -from setuptools import setup, Extension +import datetime +import platform +import subprocess +import sys from distutils import log as logger from distutils.command.build_ext import build_ext as _build_ext from glob import glob, iglob -from os import path, getcwd, environ, remove +from os import environ, getcwd, path, remove +from pathlib import Path from shutil import copyfile -import platform -import subprocess -import sys -import datetime -from pathlib import Path +from setuptools import Extension, setup + nightly_build = False -package_name = 'onnxruntime' +package_name = "onnxruntime" wheel_name_suffix = None @@ -33,7 +34,7 @@ def parse_arg_remove_string(argv, arg_name_equal): arg_value = None for arg in sys.argv[1:]: if arg.startswith(arg_name_equal): - arg_value = arg[len(arg_name_equal):] + arg_value = arg[len(arg_name_equal) :] sys.argv.remove(arg) break @@ -42,37 +43,37 @@ def parse_arg_remove_string(argv, arg_name_equal): # Any combination of the following arguments can be applied -if parse_arg_remove_boolean(sys.argv, '--nightly_build'): - package_name = 'ort-nightly' +if parse_arg_remove_boolean(sys.argv, "--nightly_build"): + package_name = "ort-nightly" nightly_build = True -wheel_name_suffix = parse_arg_remove_string(sys.argv, '--wheel_name_suffix=') +wheel_name_suffix = parse_arg_remove_string(sys.argv, "--wheel_name_suffix=") cuda_version = None rocm_version = None is_rocm = False # The following arguments are mutually exclusive -if wheel_name_suffix == 'gpu': +if wheel_name_suffix == "gpu": # TODO: how to support multiple CUDA versions? - cuda_version = parse_arg_remove_string(sys.argv, '--cuda_version=') -elif parse_arg_remove_boolean(sys.argv, '--use_rocm'): + cuda_version = parse_arg_remove_string(sys.argv, "--cuda_version=") +elif parse_arg_remove_boolean(sys.argv, "--use_rocm"): is_rocm = True - package_name = 'onnxruntime-rocm' if not nightly_build else 'ort-rocm-nightly' - rocm_version = parse_arg_remove_string(sys.argv, '--rocm_version=') -elif parse_arg_remove_boolean(sys.argv, '--use_openvino'): - package_name = 'onnxruntime-openvino' -elif parse_arg_remove_boolean(sys.argv, '--use_dnnl'): - package_name = 'onnxruntime-dnnl' -elif parse_arg_remove_boolean(sys.argv, '--use_nuphar'): - package_name = 'onnxruntime-nuphar' -elif parse_arg_remove_boolean(sys.argv, '--use_tvm'): - package_name = 'onnxruntime-tvm' -elif parse_arg_remove_boolean(sys.argv, '--use_vitisai'): - package_name = 'onnxruntime-vitisai' -elif parse_arg_remove_boolean(sys.argv, '--use_acl'): - package_name = 'onnxruntime-acl' -elif parse_arg_remove_boolean(sys.argv, '--use_armnn'): - package_name = 'onnxruntime-armnn' + package_name = "onnxruntime-rocm" if not nightly_build else "ort-rocm-nightly" + rocm_version = parse_arg_remove_string(sys.argv, "--rocm_version=") +elif parse_arg_remove_boolean(sys.argv, "--use_openvino"): + package_name = "onnxruntime-openvino" +elif parse_arg_remove_boolean(sys.argv, "--use_dnnl"): + package_name = "onnxruntime-dnnl" +elif parse_arg_remove_boolean(sys.argv, "--use_nuphar"): + package_name = "onnxruntime-nuphar" +elif parse_arg_remove_boolean(sys.argv, "--use_tvm"): + package_name = "onnxruntime-tvm" +elif parse_arg_remove_boolean(sys.argv, "--use_vitisai"): + package_name = "onnxruntime-vitisai" +elif parse_arg_remove_boolean(sys.argv, "--use_acl"): + package_name = "onnxruntime-acl" +elif parse_arg_remove_boolean(sys.argv, "--use_armnn"): + package_name = "onnxruntime-armnn" # PEP 513 defined manylinux1_x86_64 and manylinux1_i686 # PEP 571 defined manylinux2010_x86_64 and manylinux2010_i686 @@ -85,25 +86,25 @@ def parse_arg_remove_string(argv, arg_name_equal): # manylinux2014_ppc64le # manylinux2014_s390x manylinux_tags = [ - 'manylinux1_x86_64', - 'manylinux1_i686', - 'manylinux2010_x86_64', - 'manylinux2010_i686', - 'manylinux2014_x86_64', - 'manylinux2014_i686', - 'manylinux2014_aarch64', - 'manylinux2014_armv7l', - 'manylinux2014_ppc64', - 'manylinux2014_ppc64le', - 'manylinux2014_s390x', + "manylinux1_x86_64", + "manylinux1_i686", + "manylinux2010_x86_64", + "manylinux2010_i686", + "manylinux2014_x86_64", + "manylinux2014_i686", + "manylinux2014_aarch64", + "manylinux2014_armv7l", + "manylinux2014_ppc64", + "manylinux2014_ppc64le", + "manylinux2014_s390x", ] -is_manylinux = environ.get('AUDITWHEEL_PLAT', None) in manylinux_tags +is_manylinux = environ.get("AUDITWHEEL_PLAT", None) in manylinux_tags class build_ext(_build_ext): def build_extension(self, ext): dest_file = self.get_ext_fullpath(ext.name) - logger.info('copying %s -> %s', ext.sources[0], dest_file) + logger.info("copying %s -> %s", ext.sources[0], dest_file) copyfile(ext.sources[0], dest_file) @@ -117,89 +118,118 @@ def finalize_options(self): self.root_is_pure = False def _rewrite_ld_preload(self, to_preload): - with open('onnxruntime/capi/_ld_preload.py', 'a') as f: + with open("onnxruntime/capi/_ld_preload.py", "a") as f: if len(to_preload) > 0: - f.write('from ctypes import CDLL, RTLD_GLOBAL\n') + f.write("from ctypes import CDLL, RTLD_GLOBAL\n") for library in to_preload: - f.write('_{} = CDLL("{}", mode=RTLD_GLOBAL)\n'.format(library.split('.')[0], library)) + f.write('_{} = CDLL("{}", mode=RTLD_GLOBAL)\n'.format(library.split(".")[0], library)) def _rewrite_ld_preload_cuda(self, to_preload): - with open('onnxruntime/capi/_ld_preload.py', 'a') as f: + with open("onnxruntime/capi/_ld_preload.py", "a") as f: if len(to_preload) > 0: - f.write('from ctypes import CDLL, RTLD_GLOBAL\n') - f.write('try:\n') + f.write("from ctypes import CDLL, RTLD_GLOBAL\n") + f.write("try:\n") for library in to_preload: - f.write(' _{} = CDLL("{}", mode=RTLD_GLOBAL)\n'.format(library.split('.')[0], library)) - f.write('except OSError:\n') - f.write(' import os\n') + f.write(' _{} = CDLL("{}", mode=RTLD_GLOBAL)\n'.format(library.split(".")[0], library)) + f.write("except OSError:\n") + f.write(" import os\n") f.write(' os.environ["ORT_CUDA_UNAVAILABLE"] = "1"\n') def _rewrite_ld_preload_tensorrt(self, to_preload): - with open('onnxruntime/capi/_ld_preload.py', 'a') as f: + with open("onnxruntime/capi/_ld_preload.py", "a") as f: if len(to_preload) > 0: - f.write('from ctypes import CDLL, RTLD_GLOBAL\n') - f.write('try:\n') + f.write("from ctypes import CDLL, RTLD_GLOBAL\n") + f.write("try:\n") for library in to_preload: - f.write(' _{} = CDLL("{}", mode=RTLD_GLOBAL)\n'.format(library.split('.')[0], library)) - f.write('except OSError:\n') - f.write(' import os\n') + f.write(' _{} = CDLL("{}", mode=RTLD_GLOBAL)\n'.format(library.split(".")[0], library)) + f.write("except OSError:\n") + f.write(" import os\n") f.write(' os.environ["ORT_TENSORRT_UNAVAILABLE"] = "1"\n') def run(self): if is_manylinux: - source = 'onnxruntime/capi/onnxruntime_pybind11_state.so' - dest = 'onnxruntime/capi/onnxruntime_pybind11_state_manylinux1.so' - logger.info('copying %s -> %s', source, dest) + source = "onnxruntime/capi/onnxruntime_pybind11_state.so" + dest = "onnxruntime/capi/onnxruntime_pybind11_state_manylinux1.so" + logger.info("copying %s -> %s", source, dest) copyfile(source, dest) - result = subprocess.run(['patchelf', '--print-needed', dest], - check=True, stdout=subprocess.PIPE, universal_newlines=True) - dependencies = ['librccl.so', 'libamdhip64.so', 'librocblas.so', 'libMIOpen.so', - 'libhsa-runtime64.so', 'libhsakmt.so'] + result = subprocess.run( + ["patchelf", "--print-needed", dest], check=True, stdout=subprocess.PIPE, universal_newlines=True + ) + dependencies = [ + "librccl.so", + "libamdhip64.so", + "librocblas.so", + "libMIOpen.so", + "libhsa-runtime64.so", + "libhsakmt.so", + ] to_preload = [] to_preload_cuda = [] to_preload_tensorrt = [] cuda_dependencies = [] - args = ['patchelf', '--debug'] - for line in result.stdout.split('\n'): + args = ["patchelf", "--debug"] + for line in result.stdout.split("\n"): for dependency in dependencies: if dependency in line: to_preload.append(line) - args.extend(['--remove-needed', line]) + args.extend(["--remove-needed", line]) args.append(dest) if len(args) > 3: subprocess.run(args, check=True, stdout=subprocess.PIPE) - dest = 'onnxruntime/capi/libonnxruntime_providers_' + ('rocm.so' if is_rocm else 'cuda.so') + dest = "onnxruntime/capi/libonnxruntime_providers_" + ("rocm.so" if is_rocm else "cuda.so") if path.isfile(dest): - result = subprocess.run(['patchelf', '--print-needed', dest], - check=True, stdout=subprocess.PIPE, universal_newlines=True) - cuda_dependencies = ['libcublas.so', 'libcublasLt.so', 'libcudnn.so', 'libcudart.so', - 'libcurand.so', 'libcufft.so', 'libnvToolsExt.so', 'libcupti.so'] - rocm_dependencies = ['librccl.so', 'libamdhip64.so', 'librocblas.so', 'libMIOpen.so', - 'libhsa-runtime64.so', 'libhsakmt.so'] - args = ['patchelf', '--debug'] - for line in result.stdout.split('\n'): - for dependency in (cuda_dependencies + rocm_dependencies): + result = subprocess.run( + ["patchelf", "--print-needed", dest], + check=True, + stdout=subprocess.PIPE, + universal_newlines=True, + ) + cuda_dependencies = [ + "libcublas.so", + "libcublasLt.so", + "libcudnn.so", + "libcudart.so", + "libcurand.so", + "libcufft.so", + "libnvToolsExt.so", + "libcupti.so", + ] + rocm_dependencies = [ + "librccl.so", + "libamdhip64.so", + "librocblas.so", + "libMIOpen.so", + "libhsa-runtime64.so", + "libhsakmt.so", + ] + args = ["patchelf", "--debug"] + for line in result.stdout.split("\n"): + for dependency in cuda_dependencies + rocm_dependencies: if dependency in line: if dependency not in to_preload: to_preload_cuda.append(line) - args.extend(['--remove-needed', line]) + args.extend(["--remove-needed", line]) args.append(dest) if len(args) > 3: subprocess.run(args, check=True, stdout=subprocess.PIPE) - dest = 'onnxruntime/capi/libonnxruntime_providers_' + ('migraphx.so' if is_rocm else 'tensorrt.so') + dest = "onnxruntime/capi/libonnxruntime_providers_" + ("migraphx.so" if is_rocm else "tensorrt.so") if path.isfile(dest): - result = subprocess.run(['patchelf', '--print-needed', dest], - check=True, stdout=subprocess.PIPE, universal_newlines=True) - tensorrt_dependencies = ['libnvinfer.so', 'libnvinfer_plugin.so', 'libnvonnxparser.so'] - args = ['patchelf', '--debug'] - for line in result.stdout.split('\n'): - for dependency in (cuda_dependencies + tensorrt_dependencies): + result = subprocess.run( + ["patchelf", "--print-needed", dest], + check=True, + stdout=subprocess.PIPE, + universal_newlines=True, + ) + tensorrt_dependencies = ["libnvinfer.so", "libnvinfer_plugin.so", "libnvonnxparser.so"] + args = ["patchelf", "--debug"] + for line in result.stdout.split("\n"): + for dependency in cuda_dependencies + tensorrt_dependencies: if dependency in line: if dependency not in (to_preload + to_preload_cuda): to_preload_tensorrt.append(line) - args.extend(['--remove-needed', line]) + args.extend(["--remove-needed", line]) args.append(dest) if len(args) > 3: subprocess.run(args, check=True, stdout=subprocess.PIPE) @@ -209,13 +239,14 @@ def run(self): self._rewrite_ld_preload_tensorrt(to_preload_tensorrt) _bdist_wheel.run(self) if is_manylinux and not disable_auditwheel_repair: - file = glob(path.join(self.dist_dir, '*linux*.whl'))[0] - logger.info('repairing %s for manylinux1', file) + file = glob(path.join(self.dist_dir, "*linux*.whl"))[0] + logger.info("repairing %s for manylinux1", file) try: - subprocess.run(['auditwheel', 'repair', '-w', self.dist_dir, file], - check=True, stdout=subprocess.PIPE) + subprocess.run( + ["auditwheel", "repair", "-w", self.dist_dir, file], check=True, stdout=subprocess.PIPE + ) finally: - logger.info('removing %s', file) + logger.info("removing %s", file) remove(file) except ImportError as error: @@ -223,65 +254,71 @@ def run(self): print(error) bdist_wheel = None -providers_cuda_or_rocm = 'libonnxruntime_providers_' + ('rocm.so' if is_rocm else 'cuda.so') -providers_tensorrt_or_migraphx = 'libonnxruntime_providers_' + ('migraphx.so' if is_rocm else 'tensorrt.so') +providers_cuda_or_rocm = "libonnxruntime_providers_" + ("rocm.so" if is_rocm else "cuda.so") +providers_tensorrt_or_migraphx = "libonnxruntime_providers_" + ("migraphx.so" if is_rocm else "tensorrt.so") # Additional binaries -if platform.system() == 'Linux': - libs = ['onnxruntime_pybind11_state.so', 'libdnnl.so.2', 'libmklml_intel.so', 'libmklml_gnu.so', 'libiomp5.so', - 'mimalloc.so'] - dl_libs = ['libonnxruntime_providers_shared.so'] +if platform.system() == "Linux": + libs = [ + "onnxruntime_pybind11_state.so", + "libdnnl.so.2", + "libmklml_intel.so", + "libmklml_gnu.so", + "libiomp5.so", + "mimalloc.so", + ] + dl_libs = ["libonnxruntime_providers_shared.so"] dl_libs.append(providers_cuda_or_rocm) dl_libs.append(providers_tensorrt_or_migraphx) # DNNL, TensorRT & OpenVINO EPs are built as shared libs - libs.extend(['libonnxruntime_providers_shared.so']) - libs.extend(['libonnxruntime_providers_dnnl.so']) - libs.extend(['libonnxruntime_providers_openvino.so']) + libs.extend(["libonnxruntime_providers_shared.so"]) + libs.extend(["libonnxruntime_providers_dnnl.so"]) + libs.extend(["libonnxruntime_providers_openvino.so"]) libs.append(providers_cuda_or_rocm) libs.append(providers_tensorrt_or_migraphx) # Nuphar Libs - libs.extend(['libtvm.so.0.5.1']) + libs.extend(["libtvm.so.0.5.1"]) if nightly_build: - libs.extend(['libonnxruntime_pywrapper.so']) + libs.extend(["libonnxruntime_pywrapper.so"]) elif platform.system() == "Darwin": - libs = ['onnxruntime_pybind11_state.so', 'libdnnl.2.dylib', 'mimalloc.so'] # TODO add libmklml and libiomp5 later. + libs = ["onnxruntime_pybind11_state.so", "libdnnl.2.dylib", "mimalloc.so"] # TODO add libmklml and libiomp5 later. # DNNL & TensorRT EPs are built as shared libs - libs.extend(['libonnxruntime_providers_shared.dylib']) - libs.extend(['libonnxruntime_providers_dnnl.dylib']) - libs.extend(['libonnxruntime_providers_tensorrt.dylib']) - libs.extend(['libonnxruntime_providers_cuda.dylib']) + libs.extend(["libonnxruntime_providers_shared.dylib"]) + libs.extend(["libonnxruntime_providers_dnnl.dylib"]) + libs.extend(["libonnxruntime_providers_tensorrt.dylib"]) + libs.extend(["libonnxruntime_providers_cuda.dylib"]) if nightly_build: - libs.extend(['libonnxruntime_pywrapper.dylib']) + libs.extend(["libonnxruntime_pywrapper.dylib"]) else: - libs = ['onnxruntime_pybind11_state.pyd', 'dnnl.dll', 'mklml.dll', 'libiomp5md.dll'] + libs = ["onnxruntime_pybind11_state.pyd", "dnnl.dll", "mklml.dll", "libiomp5md.dll"] # DNNL, TensorRT & OpenVINO EPs are built as shared libs - libs.extend(['onnxruntime_providers_shared.dll']) - libs.extend(['onnxruntime_providers_dnnl.dll']) - libs.extend(['onnxruntime_providers_tensorrt.dll']) - libs.extend(['onnxruntime_providers_openvino.dll']) - libs.extend(['onnxruntime_providers_cuda.dll']) + libs.extend(["onnxruntime_providers_shared.dll"]) + libs.extend(["onnxruntime_providers_dnnl.dll"]) + libs.extend(["onnxruntime_providers_tensorrt.dll"]) + libs.extend(["onnxruntime_providers_openvino.dll"]) + libs.extend(["onnxruntime_providers_cuda.dll"]) # DirectML Libs - libs.extend(['DirectML.dll']) + libs.extend(["DirectML.dll"]) # Nuphar Libs - libs.extend(['tvm.dll']) + libs.extend(["tvm.dll"]) if nightly_build: - libs.extend(['onnxruntime_pywrapper.dll']) + libs.extend(["onnxruntime_pywrapper.dll"]) if is_manylinux: - data = ['capi/libonnxruntime_pywrapper.so'] if nightly_build else [] - data += [path.join('capi', x) for x in dl_libs if path.isfile(path.join('onnxruntime', 'capi', x))] + data = ["capi/libonnxruntime_pywrapper.so"] if nightly_build else [] + data += [path.join("capi", x) for x in dl_libs if path.isfile(path.join("onnxruntime", "capi", x))] ext_modules = [ Extension( - 'onnxruntime.capi.onnxruntime_pybind11_state', - ['onnxruntime/capi/onnxruntime_pybind11_state_manylinux1.so'], + "onnxruntime.capi.onnxruntime_pybind11_state", + ["onnxruntime/capi/onnxruntime_pybind11_state_manylinux1.so"], ), ] else: - data = [path.join('capi', x) for x in libs if path.isfile(path.join('onnxruntime', 'capi', x))] + data = [path.join("capi", x) for x in libs if path.isfile(path.join("onnxruntime", "capi", x))] ext_modules = [] # Additional examples examples_names = ["mul_1.onnx", "logreg_iris.onnx", "sigmoid.onnx"] -examples = [path.join('datasets', x) for x in examples_names] +examples = [path.join("datasets", x) for x in examples_names] # Extra files such as EULA and ThirdPartyNotices extra = ["LICENSE", "ThirdPartyNotices.txt", "Privacy.md"] @@ -300,86 +337,97 @@ def run(self): # line option is specified. # If the options is not specified this following condition fails as onnxruntime/external folder is not created in the # build flow under the build binary directory. -if (path.isdir(path.join("onnxruntime", "external"))): +if path.isdir(path.join("onnxruntime", "external")): # Gather all files under onnxruntime/external directory. - extra.extend(list(str(Path(*Path(x).parts[1:])) for x in list(iglob( - path.join(path.join("onnxruntime", "external"), '**/*.*'), recursive=True)))) + extra.extend( + list( + str(Path(*Path(x).parts[1:])) + for x in list(iglob(path.join(path.join("onnxruntime", "external"), "**/*.*"), recursive=True)) + ) + ) packages = [ - 'onnxruntime', - 'onnxruntime.backend', - 'onnxruntime.capi', - 'onnxruntime.capi.training', - 'onnxruntime.datasets', - 'onnxruntime.tools', - 'onnxruntime.tools.mobile_helpers', - 'onnxruntime.tools.ort_format_model', - 'onnxruntime.tools.ort_format_model.ort_flatbuffers_py', - 'onnxruntime.tools.ort_format_model.ort_flatbuffers_py.fbs', - 'onnxruntime.tools.qdq_helpers', - 'onnxruntime.quantization', - 'onnxruntime.quantization.operators', - 'onnxruntime.quantization.CalTableFlatBuffers', - 'onnxruntime.transformers', - 'onnxruntime.transformers.models.gpt2', - 'onnxruntime.transformers.models.longformer', - 'onnxruntime.transformers.models.t5', + "onnxruntime", + "onnxruntime.backend", + "onnxruntime.capi", + "onnxruntime.capi.training", + "onnxruntime.datasets", + "onnxruntime.tools", + "onnxruntime.tools.mobile_helpers", + "onnxruntime.tools.ort_format_model", + "onnxruntime.tools.ort_format_model.ort_flatbuffers_py", + "onnxruntime.tools.ort_format_model.ort_flatbuffers_py.fbs", + "onnxruntime.tools.qdq_helpers", + "onnxruntime.quantization", + "onnxruntime.quantization.operators", + "onnxruntime.quantization.CalTableFlatBuffers", + "onnxruntime.transformers", + "onnxruntime.transformers.models.gpt2", + "onnxruntime.transformers.models.longformer", + "onnxruntime.transformers.models.t5", ] requirements_file = "requirements.txt" local_version = None -enable_training = parse_arg_remove_boolean(sys.argv, '--enable_training') -disable_auditwheel_repair = parse_arg_remove_boolean(sys.argv, '--disable_auditwheel_repair') -default_training_package_device = parse_arg_remove_boolean(sys.argv, '--default_training_package_device') +enable_training = parse_arg_remove_boolean(sys.argv, "--enable_training") +disable_auditwheel_repair = parse_arg_remove_boolean(sys.argv, "--disable_auditwheel_repair") +default_training_package_device = parse_arg_remove_boolean(sys.argv, "--default_training_package_device") package_data = {} data_files = [] classifiers = [ - 'Development Status :: 5 - Production/Stable', - 'Intended Audience :: Developers', - 'License :: OSI Approved :: MIT License', - 'Operating System :: POSIX :: Linux', - 'Topic :: Scientific/Engineering', - 'Topic :: Scientific/Engineering :: Mathematics', - 'Topic :: Scientific/Engineering :: Artificial Intelligence', - 'Topic :: Software Development', - 'Topic :: Software Development :: Libraries', - 'Topic :: Software Development :: Libraries :: Python Modules', - 'Programming Language :: Python', - 'Programming Language :: Python :: 3 :: Only', - 'Programming Language :: Python :: 3.6', - 'Programming Language :: Python :: 3.7', - 'Programming Language :: Python :: 3.8', - 'Programming Language :: Python :: 3.9'] + "Development Status :: 5 - Production/Stable", + "Intended Audience :: Developers", + "License :: OSI Approved :: MIT License", + "Operating System :: POSIX :: Linux", + "Topic :: Scientific/Engineering", + "Topic :: Scientific/Engineering :: Mathematics", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + "Topic :: Software Development", + "Topic :: Software Development :: Libraries", + "Topic :: Software Development :: Libraries :: Python Modules", + "Programming Language :: Python", + "Programming Language :: Python :: 3 :: Only", + "Programming Language :: Python :: 3.6", + "Programming Language :: Python :: 3.7", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", +] if not enable_training: - classifiers.extend([ - 'Operating System :: Microsoft :: Windows', - 'Operating System :: MacOS']) + classifiers.extend(["Operating System :: Microsoft :: Windows", "Operating System :: MacOS"]) if enable_training: - packages.extend(['onnxruntime.training', - 'onnxruntime.training.amp', - 'onnxruntime.training.experimental', - 'onnxruntime.training.experimental.gradient_graph', - 'onnxruntime.training.optim', - 'onnxruntime.training.ortmodule', - 'onnxruntime.training.ortmodule.experimental', - 'onnxruntime.training.ortmodule.experimental.json_config', - 'onnxruntime.training.ortmodule.experimental.hierarchical_ortmodule', - 'onnxruntime.training.ortmodule.torch_cpp_extensions', - 'onnxruntime.training.ortmodule.torch_cpp_extensions.cpu.aten_op_executor', - 'onnxruntime.training.ortmodule.torch_cpp_extensions.cpu.torch_interop_utils', - 'onnxruntime.training.ortmodule.torch_cpp_extensions.cuda.torch_gpu_allocator', - 'onnxruntime.training.ortmodule.torch_cpp_extensions.cuda.fused_ops', - 'onnxruntime.training.utils.data']) - package_data['onnxruntime.training.ortmodule.torch_cpp_extensions.cpu.aten_op_executor'] = ['*.cc'] - package_data['onnxruntime.training.ortmodule.torch_cpp_extensions.cpu.torch_interop_utils'] = ['*.cc'] - package_data['onnxruntime.training.ortmodule.torch_cpp_extensions.cuda.torch_gpu_allocator'] = ['*.cc'] - package_data['onnxruntime.training.ortmodule.torch_cpp_extensions.cuda.fused_ops'] = \ - ['*.cpp', '*.cu', '*.cuh', '*.h'] + packages.extend( + [ + "onnxruntime.training", + "onnxruntime.training.amp", + "onnxruntime.training.experimental", + "onnxruntime.training.experimental.gradient_graph", + "onnxruntime.training.optim", + "onnxruntime.training.ortmodule", + "onnxruntime.training.ortmodule.experimental", + "onnxruntime.training.ortmodule.experimental.json_config", + "onnxruntime.training.ortmodule.experimental.hierarchical_ortmodule", + "onnxruntime.training.ortmodule.torch_cpp_extensions", + "onnxruntime.training.ortmodule.torch_cpp_extensions.cpu.aten_op_executor", + "onnxruntime.training.ortmodule.torch_cpp_extensions.cpu.torch_interop_utils", + "onnxruntime.training.ortmodule.torch_cpp_extensions.cuda.torch_gpu_allocator", + "onnxruntime.training.ortmodule.torch_cpp_extensions.cuda.fused_ops", + "onnxruntime.training.utils.data", + ] + ) + package_data["onnxruntime.training.ortmodule.torch_cpp_extensions.cpu.aten_op_executor"] = ["*.cc"] + package_data["onnxruntime.training.ortmodule.torch_cpp_extensions.cpu.torch_interop_utils"] = ["*.cc"] + package_data["onnxruntime.training.ortmodule.torch_cpp_extensions.cuda.torch_gpu_allocator"] = ["*.cc"] + package_data["onnxruntime.training.ortmodule.torch_cpp_extensions.cuda.fused_ops"] = [ + "*.cpp", + "*.cu", + "*.cuh", + "*.h", + ] requirements_file = "requirements-training.txt" # with training, we want to follow this naming convention: # stable: @@ -388,40 +436,40 @@ def run(self): # onnxruntime-training-1.7.0.dev20210408+cu111-cp36-cp36m-linux_x86_64.whl # this is needed immediately by pytorch/ort so that the user is able to # install an onnxruntime training package with matching torch cuda version. - package_name = 'onnxruntime-training' + package_name = "onnxruntime-training" # we want put default training packages to pypi. pypi does not accept package with a local version. if not default_training_package_device or nightly_build: if cuda_version: # removing '.' to make Cuda version number in the same form as Pytorch. - local_version = '+cu' + cuda_version.replace('.', '') + local_version = "+cu" + cuda_version.replace(".", "") elif rocm_version: # removing '.' to make Rocm version number in the same form as Pytorch. - local_version = '+rocm' + rocm_version.replace('.', '') + local_version = "+rocm" + rocm_version.replace(".", "") else: # cpu version for documentation - local_version = '+cpu' + local_version = "+cpu" -if package_name == 'onnxruntime-nuphar': +if package_name == "onnxruntime-nuphar": packages += ["onnxruntime.nuphar"] - extra += [path.join('nuphar', 'NUPHAR_CACHE_VERSION')] + extra += [path.join("nuphar", "NUPHAR_CACHE_VERSION")] -if package_name == 'onnxruntime-tvm': - packages += ['onnxruntime.providers.tvm'] +if package_name == "onnxruntime-tvm": + packages += ["onnxruntime.providers.tvm"] package_data["onnxruntime"] = data + examples + extra -version_number = '' -with open('VERSION_NUMBER') as f: +version_number = "" +with open("VERSION_NUMBER") as f: version_number = f.readline().strip() if nightly_build: # https://docs.microsoft.com/en-us/azure/devops/pipelines/build/variables - build_suffix = environ.get('BUILD_BUILDNUMBER') + build_suffix = environ.get("BUILD_BUILDNUMBER") if build_suffix is None: # The following line is only for local testing build_suffix = str(datetime.datetime.now().date().strftime("%Y%m%d")) else: - build_suffix = build_suffix.replace('.', '') + build_suffix = build_suffix.replace(".", "") if len(build_suffix) > 8 and len(build_suffix) < 12: # we want to format the build_suffix to avoid (the 12th run on 20210630 vs the first run on 20210701): @@ -438,9 +486,9 @@ def run(self): # --cuda_version 11.1 def check_date_format(date_str): try: - datetime.datetime.strptime(date_str, '%Y%m%d') + datetime.datetime.strptime(date_str, "%Y%m%d") return True - except: # noqa + except: # noqa return False def reformat_run_count(count_str): @@ -449,9 +497,9 @@ def reformat_run_count(count_str): if count >= 0 and count < 1000: return "{:03}".format(count) elif count >= 1000: - raise RuntimeError(f'Too many builds for the same day: {count}') + raise RuntimeError(f"Too many builds for the same day: {count}") return "" - except: # noqa + except: # noqa return "" build_suffix_is_date_format = check_date_format(build_suffix[:8]) @@ -464,6 +512,7 @@ def reformat_run_count(count_str): if enable_training: from packaging import version from packaging.version import Version + # with training package, we need to bump up version minor number so that # nightly releases take precedence over the latest release when --pre is used during pip install. # eventually this shall be the behavior of all onnxruntime releases. @@ -473,10 +522,9 @@ def reformat_run_count(count_str): # TODO: this is the last time we have to do this!!! # We shall bump up release number right after release cut. if ort_version.major == 1 and ort_version.minor == 8 and ort_version.micro == 0: - version_number = '{major}.{minor}.{macro}'.format( - major=ort_version.major, - minor=ort_version.minor + 1, - macro=ort_version.micro) + version_number = "{major}.{minor}.{macro}".format( + major=ort_version.major, minor=ort_version.minor + 1, macro=ort_version.micro + ) version_number = version_number + ".dev" + build_suffix @@ -484,14 +532,14 @@ def reformat_run_count(count_str): version_number = version_number + local_version if wheel_name_suffix: - if not (enable_training and wheel_name_suffix == 'gpu'): + if not (enable_training and wheel_name_suffix == "gpu"): # for training packages, local version is used to indicate device types package_name = "{}-{}".format(package_name, wheel_name_suffix) cmd_classes = {} if bdist_wheel is not None: - cmd_classes['bdist_wheel'] = bdist_wheel -cmd_classes['build_ext'] = build_ext + cmd_classes["bdist_wheel"] = bdist_wheel +cmd_classes["build_ext"] = build_ext requirements_path = path.join(getcwd(), requirements_file) if not path.exists(requirements_path): @@ -504,12 +552,13 @@ def reformat_run_count(count_str): if enable_training: + def save_build_and_package_info(package_name, version_number, cuda_version, rocm_version): - sys.path.append(path.join(path.dirname(__file__), 'onnxruntime', 'python')) + sys.path.append(path.join(path.dirname(__file__), "onnxruntime", "python")) from onnxruntime_collect_build_info import find_cudart_versions - version_path = path.join('onnxruntime', 'capi', 'build_and_package_info.py') - with open(version_path, 'w') as f: + version_path = path.join("onnxruntime", "capi", "build_and_package_info.py") + with open(version_path, "w") as f: f.write("package_name = '{}'\n".format(package_name)) f.write("__version__ = '{}'\n".format(version_number)) @@ -525,7 +574,8 @@ def save_build_and_package_info(package_name, version_number, cuda_version, rocm "Error getting cudart version. ", "did not find any cudart library" if not cudart_versions or len(cudart_versions) == 0 - else "found multiple cudart libraries") + else "found multiple cudart libraries", + ) elif rocm_version: f.write("rocm_version = '{}'\n".format(rocm_version)) @@ -535,24 +585,24 @@ def save_build_and_package_info(package_name, version_number, cuda_version, rocm setup( name=package_name, version=version_number, - description='ONNX Runtime is a runtime accelerator for Machine Learning models', + description="ONNX Runtime is a runtime accelerator for Machine Learning models", long_description=long_description, - author='Microsoft Corporation', - author_email='onnxruntime@microsoft.com', + author="Microsoft Corporation", + author_email="onnxruntime@microsoft.com", cmdclass=cmd_classes, license="MIT License", packages=packages, ext_modules=ext_modules, package_data=package_data, url="https://onnxruntime.ai", - download_url='https://github.com/microsoft/onnxruntime/tags', + download_url="https://github.com/microsoft/onnxruntime/tags", data_files=data_files, install_requires=install_requires, - keywords='onnx machine learning', + keywords="onnx machine learning", entry_points={ - 'console_scripts': [ - 'onnxruntime_test = onnxruntime.tools.onnxruntime_test:main', + "console_scripts": [ + "onnxruntime_test = onnxruntime.tools.onnxruntime_test:main", ] }, classifiers=classifiers, - ) +) diff --git a/tools/android_custom_build/build_custom_android_package.py b/tools/android_custom_build/build_custom_android_package.py index abe913e1fcf99..cdbf0df5bdbca 100755 --- a/tools/android_custom_build/build_custom_android_package.py +++ b/tools/android_custom_build/build_custom_android_package.py @@ -19,41 +19,59 @@ def parse_args(): dependencies. Then, from a Docker container with that image, it calls the ONNX Runtime build scripts to build a custom Android package. The resulting package will be under /output/aar_out. See https://onnxruntime.ai/docs/build/custom.html for more - information about custom builds.""") - - parser.add_argument("working_dir", type=pathlib.Path, - help="The directory used to store intermediate and output files.") - - parser.add_argument("--onnxruntime_branch_or_tag", - help="The ONNX Runtime branch or tag to build. " - "Supports branches and tags starting from 1.11 (branch rel-1.11.0 or tag v1.11.0). " - "If unspecified, builds the latest.") - - parser.add_argument("--onnxruntime_repo_url", - help="The ONNX Runtime repo URL. If unspecified, uses the official repo.") - - parser.add_argument("--include_ops_by_config", type=pathlib.Path, - help="The configuration file specifying which ops to include. " - "Such a configuration file is generated during ONNX to ORT format model conversion. " - f"The default is {DEFAULT_OPS_CONFIG_RELATIVE_PATH} in the ONNX Runtime repo.") - - parser.add_argument("--build_settings", type=pathlib.Path, - help="The configuration file specifying the build.py options. " - f"The default is {DEFAULT_BUILD_SETTINGS_RELATIVE_PATH} in the ONNX Runtime repo.") + information about custom builds.""" + ) + + parser.add_argument( + "working_dir", type=pathlib.Path, help="The directory used to store intermediate and output files." + ) + + parser.add_argument( + "--onnxruntime_branch_or_tag", + help="The ONNX Runtime branch or tag to build. " + "Supports branches and tags starting from 1.11 (branch rel-1.11.0 or tag v1.11.0). " + "If unspecified, builds the latest.", + ) + + parser.add_argument( + "--onnxruntime_repo_url", help="The ONNX Runtime repo URL. If unspecified, uses the official repo." + ) + + parser.add_argument( + "--include_ops_by_config", + type=pathlib.Path, + help="The configuration file specifying which ops to include. " + "Such a configuration file is generated during ONNX to ORT format model conversion. " + f"The default is {DEFAULT_OPS_CONFIG_RELATIVE_PATH} in the ONNX Runtime repo.", + ) + + parser.add_argument( + "--build_settings", + type=pathlib.Path, + help="The configuration file specifying the build.py options. " + f"The default is {DEFAULT_BUILD_SETTINGS_RELATIVE_PATH} in the ONNX Runtime repo.", + ) default_config = "Release" - parser.add_argument("--config", choices=["Debug", "MinSizeRel", "Release", "RelWithDebInfo"], - default=default_config, - help="The build configuration. " - f"The default is {default_config}.") + parser.add_argument( + "--config", + choices=["Debug", "MinSizeRel", "Release", "RelWithDebInfo"], + default=default_config, + help="The build configuration. " f"The default is {default_config}.", + ) default_docker_image_tag = "onnxruntime-android-custom-build:latest" - parser.add_argument("--docker_image_tag", default=default_docker_image_tag, - help="The tag for the Docker image. " - f"The default is {default_docker_image_tag}.") - - parser.add_argument("--docker_path", default=shutil.which("docker"), - help="The path to docker. If unspecified, docker should be in PATH.") + parser.add_argument( + "--docker_image_tag", + default=default_docker_image_tag, + help="The tag for the Docker image. " f"The default is {default_docker_image_tag}.", + ) + + parser.add_argument( + "--docker_path", + default=shutil.which("docker"), + help="The path to docker. If unspecified, docker should be in PATH.", + ) args = parser.parse_args() @@ -72,10 +90,18 @@ def main(): if args.onnxruntime_repo_url: docker_build_args += ["--build-arg", f"ONNXRUNTIME_REPO={args.onnxruntime_repo_url}"] - docker_build_cmd = [args.docker_path, "build", - "--tag", args.docker_image_tag, - "--file", str(SCRIPT_DIR / "Dockerfile"), - ] + docker_build_args + [str(SCRIPT_DIR)] + docker_build_cmd = ( + [ + args.docker_path, + "build", + "--tag", + args.docker_image_tag, + "--file", + str(SCRIPT_DIR / "Dockerfile"), + ] + + docker_build_args + + [str(SCRIPT_DIR)] + ) subprocess.run(docker_build_cmd, check=True) @@ -94,25 +120,33 @@ def main(): output_dir = working_dir / "output" output_dir.mkdir(exist_ok=True) - container_ops_config_file = \ - f"/workspace/shared/input/{args.include_ops_by_config.name}" if args.include_ops_by_config \ + container_ops_config_file = ( + f"/workspace/shared/input/{args.include_ops_by_config.name}" + if args.include_ops_by_config else f"/workspace/onnxruntime/{DEFAULT_OPS_CONFIG_RELATIVE_PATH}" + ) - container_build_settings_file =\ - f"/workspace/shared/input/{args.build_settings.name}" if args.build_settings \ + container_build_settings_file = ( + f"/workspace/shared/input/{args.build_settings.name}" + if args.build_settings else f"/workspace/onnxruntime/{DEFAULT_BUILD_SETTINGS_RELATIVE_PATH}" - - docker_run_cmd = [args.docker_path, "run", - "--rm", "-it", - f"--volume={str(working_dir)}:/workspace/shared", - args.docker_image_tag, - "/usr/bin/env", "python3", - "/workspace/onnxruntime/tools/ci_build/github/android/build_aar_package.py", - "--build_dir=/workspace/shared/output", - f"--config={args.config}", - f"--include_ops_by_config={container_ops_config_file}", - container_build_settings_file, - ] + ) + + docker_run_cmd = [ + args.docker_path, + "run", + "--rm", + "-it", + f"--volume={str(working_dir)}:/workspace/shared", + args.docker_image_tag, + "/usr/bin/env", + "python3", + "/workspace/onnxruntime/tools/ci_build/github/android/build_aar_package.py", + "--build_dir=/workspace/shared/output", + f"--config={args.config}", + f"--include_ops_by_config={container_ops_config_file}", + container_build_settings_file, + ] subprocess.run(docker_run_cmd, check=True) diff --git a/tools/ci_build/amd_hipify.py b/tools/ci_build/amd_hipify.py index 7e6c6e0f83d5e..6ede643f58b09 100644 --- a/tools/ci_build/amd_hipify.py +++ b/tools/ci_build/amd_hipify.py @@ -6,203 +6,204 @@ import os import shutil import subprocess + from logger import get_logger log = get_logger("amd_hipify") -contrib_ops_path = 'onnxruntime/contrib_ops' -providers_path = 'onnxruntime/core/providers' -training_ops_path = 'orttraining/orttraining/training_ops' +contrib_ops_path = "onnxruntime/contrib_ops" +providers_path = "onnxruntime/core/providers" +training_ops_path = "orttraining/orttraining/training_ops" contrib_ops_excluded_files = [ - 'bert/attention.cc', - 'bert/attention_impl.cu', - 'bert/attention_softmax.h', - 'bert/decoder_attention.h', - 'bert/decoder_attention.cc', - 'bert/embed_layer_norm.cc', - 'bert/embed_layer_norm.h', - 'bert/embed_layer_norm_impl.cu', - 'bert/embed_layer_norm_impl.h', - 'bert/fast_gelu_impl.cu', - # 'bert/layer_norm.cuh', - 'bert/longformer_attention.cc', - 'bert/longformer_attention.h', - 'bert/longformer_attention_softmax.cu', - 'bert/longformer_attention_softmax.h', - 'bert/longformer_attention_impl.cu', - 'bert/longformer_attention_impl.h', - 'bert/longformer_global_impl.cu', - 'bert/longformer_global_impl.h', - 'bert/transformer_cuda_common.h', - 'math/bias_softmax.cc', - 'math/bias_softmax.h', - 'math/bias_softmax_impl.cu', - 'math/complex_mul.cc', - 'math/complex_mul.h', - 'math/complex_mul_impl.cu', - 'math/complex_mul_impl.h', - 'math/cufft_plan_cache.h', - 'math/fft_ops.cc', - 'math/fft_ops.h', - 'math/fft_ops_impl.cu', - 'math/fft_ops_impl.h', - 'quantization/attention_quantization.cc', - 'quantization/attention_quantization.h', - 'quantization/attention_quantization_impl.cu', - 'quantization/attention_quantization_impl.cuh', - 'quantization/quantize_dequantize_linear.cc', - 'tensor/crop.cc', - 'tensor/crop.h', - 'tensor/crop_impl.cu', - 'tensor/crop_impl.h', - 'tensor/dynamicslice.cc', - 'tensor/image_scaler.cc', - 'tensor/image_scaler.h', - 'tensor/image_scaler_impl.cu', - 'tensor/image_scaler_impl.h', - 'transformers/beam_search.cc', - 'transformers/beam_search.h', - 'transformers/beam_search_device_helper.cc', - 'transformers/beam_search_device_helper.h', - 'transformers/beam_search_impl.cu', - 'transformers/beam_search_impl.h', - 'transformers/dump_cuda_tensor.cc', - 'transformers/dump_cuda_tensor.h', - 'conv_transpose_with_dynamic_pads.cc', - 'conv_transpose_with_dynamic_pads.h', - 'cuda_contrib_kernels.cc', - 'cuda_contrib_kernels.h', - 'inverse.cc', - 'fused_conv.cc' + "bert/attention.cc", + "bert/attention_impl.cu", + "bert/attention_softmax.h", + "bert/decoder_attention.h", + "bert/decoder_attention.cc", + "bert/embed_layer_norm.cc", + "bert/embed_layer_norm.h", + "bert/embed_layer_norm_impl.cu", + "bert/embed_layer_norm_impl.h", + "bert/fast_gelu_impl.cu", + # 'bert/layer_norm.cuh', + "bert/longformer_attention.cc", + "bert/longformer_attention.h", + "bert/longformer_attention_softmax.cu", + "bert/longformer_attention_softmax.h", + "bert/longformer_attention_impl.cu", + "bert/longformer_attention_impl.h", + "bert/longformer_global_impl.cu", + "bert/longformer_global_impl.h", + "bert/transformer_cuda_common.h", + "math/bias_softmax.cc", + "math/bias_softmax.h", + "math/bias_softmax_impl.cu", + "math/complex_mul.cc", + "math/complex_mul.h", + "math/complex_mul_impl.cu", + "math/complex_mul_impl.h", + "math/cufft_plan_cache.h", + "math/fft_ops.cc", + "math/fft_ops.h", + "math/fft_ops_impl.cu", + "math/fft_ops_impl.h", + "quantization/attention_quantization.cc", + "quantization/attention_quantization.h", + "quantization/attention_quantization_impl.cu", + "quantization/attention_quantization_impl.cuh", + "quantization/quantize_dequantize_linear.cc", + "tensor/crop.cc", + "tensor/crop.h", + "tensor/crop_impl.cu", + "tensor/crop_impl.h", + "tensor/dynamicslice.cc", + "tensor/image_scaler.cc", + "tensor/image_scaler.h", + "tensor/image_scaler_impl.cu", + "tensor/image_scaler_impl.h", + "transformers/beam_search.cc", + "transformers/beam_search.h", + "transformers/beam_search_device_helper.cc", + "transformers/beam_search_device_helper.h", + "transformers/beam_search_impl.cu", + "transformers/beam_search_impl.h", + "transformers/dump_cuda_tensor.cc", + "transformers/dump_cuda_tensor.h", + "conv_transpose_with_dynamic_pads.cc", + "conv_transpose_with_dynamic_pads.h", + "cuda_contrib_kernels.cc", + "cuda_contrib_kernels.h", + "inverse.cc", + "fused_conv.cc", ] provider_excluded_files = [ - 'atomic/common.cuh', - 'controlflow/if.cc', - 'controlflow/if.h', - 'controlflow/loop.cc', - 'controlflow/loop.h', - 'controlflow/scan.cc', - 'controlflow/scan.h', - 'cu_inc/common.cuh', - 'math/einsum_utils/einsum_auxiliary_ops.cc', - 'math/einsum_utils/einsum_auxiliary_ops.h', - 'math/einsum_utils/einsum_auxiliary_ops_diagonal.cu', - 'math/einsum_utils/einsum_auxiliary_ops_diagonal.h', - 'math/einsum.cc', - 'math/einsum.h', - 'math/gemm.cc', - 'math/matmul.cc', - 'math/matmul_integer.cc', - 'math/matmul_integer.cu', - 'math/matmul_integer.cuh', - 'math/matmul_integer.h', - 'math/softmax_impl.cu', - 'math/softmax_warpwise_impl.cuh', - 'math/softmax.cc', - 'nn/batch_norm.cc', - 'nn/batch_norm.h', - 'nn/conv.cc', - 'nn/conv.h', - 'nn/conv_transpose.cc', - 'nn/conv_transpose.h', - 'nn/instance_norm.cc', - 'nn/instance_norm.h', - 'nn/instance_norm_impl.cu', - 'nn/instance_norm_impl.h', - 'nn/lrn.cc', - 'nn/lrn.h', - 'nn/max_pool_with_index.cu', - 'nn/max_pool_with_index.h', - 'nn/pool.cc', - 'nn/pool.h', - 'reduction/reduction_ops.cc', - 'reduction/reduction_ops.h', - 'rnn/cudnn_rnn_base.cc', - 'rnn/cudnn_rnn_base.h', - 'rnn/gru.cc', - 'rnn/gru.h', - 'rnn/lstm.cc', - 'rnn/lstm.h', - 'rnn/rnn.cc', - 'rnn/rnn.h', - 'rnn/rnn_impl.cu', - 'rnn/rnn_impl.h', - 'shared_inc/cuda_call.h', - 'shared_inc/fpgeneric.h', - 'shared_inc/integer_gemm.h', - 'cuda_allocator.cc', - 'cuda_allocator.h', - 'cuda_call.cc', - 'cuda_common.cc', - 'cuda_common.h', - 'cuda_execution_provider_info.cc', - 'cuda_execution_provider_info.h', - 'cuda_execution_provider.cc', - 'cuda_execution_provider.h', - 'cuda_memory_check.cc', - 'cuda_memory_check.h', - 'cuda_fence.cc', - 'cuda_fence.h', - 'cuda_fwd.h', - 'cuda_kernel.h', - 'cuda_pch.cc', - 'cuda_pch.h', - 'cuda_provider_factory.cc', - 'cuda_provider_factory.h', - 'cuda_utils.cu', - 'cudnn_common.cc', - 'cudnn_common.h', - 'fpgeneric.cu', - 'gpu_data_transfer.cc', - 'gpu_data_transfer.h', - 'integer_gemm.cc', - 'symbols.txt', + "atomic/common.cuh", + "controlflow/if.cc", + "controlflow/if.h", + "controlflow/loop.cc", + "controlflow/loop.h", + "controlflow/scan.cc", + "controlflow/scan.h", + "cu_inc/common.cuh", + "math/einsum_utils/einsum_auxiliary_ops.cc", + "math/einsum_utils/einsum_auxiliary_ops.h", + "math/einsum_utils/einsum_auxiliary_ops_diagonal.cu", + "math/einsum_utils/einsum_auxiliary_ops_diagonal.h", + "math/einsum.cc", + "math/einsum.h", + "math/gemm.cc", + "math/matmul.cc", + "math/matmul_integer.cc", + "math/matmul_integer.cu", + "math/matmul_integer.cuh", + "math/matmul_integer.h", + "math/softmax_impl.cu", + "math/softmax_warpwise_impl.cuh", + "math/softmax.cc", + "nn/batch_norm.cc", + "nn/batch_norm.h", + "nn/conv.cc", + "nn/conv.h", + "nn/conv_transpose.cc", + "nn/conv_transpose.h", + "nn/instance_norm.cc", + "nn/instance_norm.h", + "nn/instance_norm_impl.cu", + "nn/instance_norm_impl.h", + "nn/lrn.cc", + "nn/lrn.h", + "nn/max_pool_with_index.cu", + "nn/max_pool_with_index.h", + "nn/pool.cc", + "nn/pool.h", + "reduction/reduction_ops.cc", + "reduction/reduction_ops.h", + "rnn/cudnn_rnn_base.cc", + "rnn/cudnn_rnn_base.h", + "rnn/gru.cc", + "rnn/gru.h", + "rnn/lstm.cc", + "rnn/lstm.h", + "rnn/rnn.cc", + "rnn/rnn.h", + "rnn/rnn_impl.cu", + "rnn/rnn_impl.h", + "shared_inc/cuda_call.h", + "shared_inc/fpgeneric.h", + "shared_inc/integer_gemm.h", + "cuda_allocator.cc", + "cuda_allocator.h", + "cuda_call.cc", + "cuda_common.cc", + "cuda_common.h", + "cuda_execution_provider_info.cc", + "cuda_execution_provider_info.h", + "cuda_execution_provider.cc", + "cuda_execution_provider.h", + "cuda_memory_check.cc", + "cuda_memory_check.h", + "cuda_fence.cc", + "cuda_fence.h", + "cuda_fwd.h", + "cuda_kernel.h", + "cuda_pch.cc", + "cuda_pch.h", + "cuda_provider_factory.cc", + "cuda_provider_factory.h", + "cuda_utils.cu", + "cudnn_common.cc", + "cudnn_common.h", + "fpgeneric.cu", + "gpu_data_transfer.cc", + "gpu_data_transfer.h", + "integer_gemm.cc", + "symbols.txt", ] training_ops_excluded_files = [ - 'activation/gelu_grad_impl_common.cuh', # uses custom tanh - 'collective/adasum_kernels.cc', - 'collective/adasum_kernels.h', - 'math/div_grad.cc', # miopen API differs from cudnn, no double type support - 'math/softmax_grad_impl.cu', # warp size differences - 'math/softmax_grad.cc', # miopen API differs from cudnn, no double type support - 'nn/batch_norm_grad.cc', # no double type support - 'nn/batch_norm_grad.h', # miopen API differs from cudnn - 'nn/batch_norm_internal.cc', # miopen API differs from cudnn, no double type support - 'nn/batch_norm_internal.h', # miopen API differs from cudnn, no double type support - 'nn/conv_grad.cc', - 'nn/conv_grad.h', - 'reduction/reduction_all.cc', # deterministic = true, ignore ctx setting - 'reduction/reduction_ops.cc', # no double type support - 'cuda_training_kernels.cc', - 'cuda_training_kernels.h', + "activation/gelu_grad_impl_common.cuh", # uses custom tanh + "collective/adasum_kernels.cc", + "collective/adasum_kernels.h", + "math/div_grad.cc", # miopen API differs from cudnn, no double type support + "math/softmax_grad_impl.cu", # warp size differences + "math/softmax_grad.cc", # miopen API differs from cudnn, no double type support + "nn/batch_norm_grad.cc", # no double type support + "nn/batch_norm_grad.h", # miopen API differs from cudnn + "nn/batch_norm_internal.cc", # miopen API differs from cudnn, no double type support + "nn/batch_norm_internal.h", # miopen API differs from cudnn, no double type support + "nn/conv_grad.cc", + "nn/conv_grad.h", + "reduction/reduction_all.cc", # deterministic = true, ignore ctx setting + "reduction/reduction_ops.cc", # no double type support + "cuda_training_kernels.cc", + "cuda_training_kernels.h", ] @functools.lru_cache(maxsize=1) def get_hipify_path(): # prefer the hipify-perl in PATH - HIPIFY_PERL = shutil.which('hipify-perl') + HIPIFY_PERL = shutil.which("hipify-perl") # if not found, attempt hard-coded location 1 if HIPIFY_PERL is None: - print('hipify-perl not found, trying default location 1') - hipify_path = '/opt/rocm/hip/bin/hipify-perl' + print("hipify-perl not found, trying default location 1") + hipify_path = "/opt/rocm/hip/bin/hipify-perl" HIPIFY_PERL = hipify_path if os.access(hipify_path, os.X_OK) else None # if not found, attempt hard-coded location 2 if HIPIFY_PERL is None: - print('hipify-perl not found, trying default location 2') - hipify_path = '/opt/rocm/bin/hipify-perl' + print("hipify-perl not found, trying default location 2") + hipify_path = "/opt/rocm/bin/hipify-perl" HIPIFY_PERL = hipify_path if os.access(hipify_path, os.X_OK) else None # fail if HIPIFY_PERL is None: - raise RuntimeError('Could not locate hipify-perl script') + raise RuntimeError("Could not locate hipify-perl script") return HIPIFY_PERL def hipify(src_file_path, dst_file_path): - dst_file_path = dst_file_path.replace('cuda', 'rocm') + dst_file_path = dst_file_path.replace("cuda", "rocm") dir_name = os.path.dirname(dst_file_path) if not os.path.exists(dir_name): os.makedirs(dir_name, exist_ok=True) @@ -211,69 +212,75 @@ def hipify(src_file_path, dst_file_path): # Additional exact-match replacements. # Order matters for all of the following replacements, reglardless of appearing in logical sections. - s = s.replace('kCudaExecutionProvider', 'kRocmExecutionProvider') - s = s.replace('CUDAStreamType', 'HIPStreamType') - s = s.replace('kCudaStreamDefault', 'kHipStreamDefault') - s = s.replace('kCudaStreamCopyIn', 'kHipStreamCopyIn') - s = s.replace('kCudaStreamCopyOut', 'kHipStreamCopyOut') - s = s.replace('kTotalCudaStreams', 'kTotalHipStreams') + s = s.replace("kCudaExecutionProvider", "kRocmExecutionProvider") + s = s.replace("CUDAStreamType", "HIPStreamType") + s = s.replace("kCudaStreamDefault", "kHipStreamDefault") + s = s.replace("kCudaStreamCopyIn", "kHipStreamCopyIn") + s = s.replace("kCudaStreamCopyOut", "kHipStreamCopyOut") + s = s.replace("kTotalCudaStreams", "kTotalHipStreams") # We want rocblas interfaces, not hipblas. Also force some hipify replacements back to rocblas from hipblas. - s = s.replace('CublasHandle', 'RocblasHandle') - s = s.replace('cublas_handle', 'rocblas_handle') - s = s.replace('hipblasHandle_t', 'rocblas_handle') - s = s.replace('hipblasDatatype_t', 'rocblas_datatype') - s = s.replace('HIPBLAS_STATUS_SUCCESS', 'rocblas_status_success') - s = s.replace('hipblasStatus_t', 'rocblas_status') - s = s.replace('hipblasCreate', 'rocblas_create_handle') - s = s.replace('hipblasDestroy', 'rocblas_destroy_handle') - s = s.replace('hipblasSetStream', 'rocblas_set_stream') - s = s.replace('HIPBLAS_OP_T', 'rocblas_operation_transpose') - - s = s.replace('RegisterCudaContribKernels', 'RegisterRocmContribKernels') - s = s.replace('cudaEvent', 'hipEvent') - s = s.replace('CreateCudaAllocator', 'CreateRocmAllocator') - s = s.replace('CudaErrString', 'RocmErrString') - s = s.replace('CudaAsyncBuffer', 'RocmAsyncBuffer') - s = s.replace('CudaKernel', 'RocmKernel') - s = s.replace('ToCudaType', 'ToHipType') - s = s.replace('CudaT', 'HipT') - s = s.replace('CUDA_LONG', 'HIP_LONG') - s = s.replace('CUDA_RETURN_IF_ERROR', 'HIP_RETURN_IF_ERROR') - s = s.replace('CUDA_KERNEL_ASSERT', 'HIP_KERNEL_ASSERT') - s = s.replace('CUDA_CALL', 'HIP_CALL') - s = s.replace('SliceCuda', 'SliceRocm') - s = s.replace('thrust::cuda', 'thrust::hip') - s = s.replace('CudaCall', 'RocmCall') - s = s.replace('cuda', 'rocm') + s = s.replace("CublasHandle", "RocblasHandle") + s = s.replace("cublas_handle", "rocblas_handle") + s = s.replace("hipblasHandle_t", "rocblas_handle") + s = s.replace("hipblasDatatype_t", "rocblas_datatype") + s = s.replace("HIPBLAS_STATUS_SUCCESS", "rocblas_status_success") + s = s.replace("hipblasStatus_t", "rocblas_status") + s = s.replace("hipblasCreate", "rocblas_create_handle") + s = s.replace("hipblasDestroy", "rocblas_destroy_handle") + s = s.replace("hipblasSetStream", "rocblas_set_stream") + s = s.replace("HIPBLAS_OP_T", "rocblas_operation_transpose") + + s = s.replace("RegisterCudaContribKernels", "RegisterRocmContribKernels") + s = s.replace("cudaEvent", "hipEvent") + s = s.replace("CreateCudaAllocator", "CreateRocmAllocator") + s = s.replace("CudaErrString", "RocmErrString") + s = s.replace("CudaAsyncBuffer", "RocmAsyncBuffer") + s = s.replace("CudaKernel", "RocmKernel") + s = s.replace("ToCudaType", "ToHipType") + s = s.replace("CudaT", "HipT") + s = s.replace("CUDA_LONG", "HIP_LONG") + s = s.replace("CUDA_RETURN_IF_ERROR", "HIP_RETURN_IF_ERROR") + s = s.replace("CUDA_KERNEL_ASSERT", "HIP_KERNEL_ASSERT") + s = s.replace("CUDA_CALL", "HIP_CALL") + s = s.replace("SliceCuda", "SliceRocm") + s = s.replace("thrust::cuda", "thrust::hip") + s = s.replace("CudaCall", "RocmCall") + s = s.replace("cuda", "rocm") # s = s.replace('Cuda', 'Rocm') - s = s.replace('CUDA', 'ROCM') - s = s.replace('GPU_WARP_SIZE = 32', 'GPU_WARP_SIZE = 64') - s = s.replace('std::exp', 'expf') - s = s.replace('std::log', 'logf') - s = s.replace('#include ', - '#include \n#include ') - s = s.replace('#include "cub/device/device_radix_sort.cuh"', - '#include \n#include ') - s = s.replace('#include ', - '#include ') - s = s.replace('#include ', - '#include ') - s = s.replace('#include ', - '#include ') - s = s.replace('#include ', - '#include ') - s = s.replace('#include ', - '#include ') - s = s.replace('#include ', - '#include ') - s = s.replace('#include "cub/util_allocator.cuh"', - '#include ') - s = s.replace('#include ', - '#include ') - s = s.replace('#include "cub/util_type.cuh"', - '#include ') - s = s.replace('typedef half MappedType', 'typedef __half MappedType') + s = s.replace("CUDA", "ROCM") + s = s.replace("GPU_WARP_SIZE = 32", "GPU_WARP_SIZE = 64") + s = s.replace("std::exp", "expf") + s = s.replace("std::log", "logf") + s = s.replace( + "#include ", + "#include \n#include ", + ) + s = s.replace( + '#include "cub/device/device_radix_sort.cuh"', + "#include \n#include ", + ) + s = s.replace( + "#include ", "#include " + ) + s = s.replace( + "#include ", + "#include ", + ) + s = s.replace("#include ", "#include ") + s = s.replace( + "#include ", + "#include ", + ) + s = s.replace( + "#include ", + "#include ", + ) + s = s.replace("#include ", "#include ") + s = s.replace('#include "cub/util_allocator.cuh"', "#include ") + s = s.replace("#include ", "#include ") + s = s.replace('#include "cub/util_type.cuh"', "#include ") + s = s.replace("typedef half MappedType", "typedef __half MappedType") # CUBLAS -> HIPBLAS # Note: We do not use the hipblas marshalling interfaces; use rocblas instead. @@ -282,55 +289,55 @@ def hipify(src_file_path, dst_file_path): # s = s.replace('cublas', 'hipblas') # CUBLAS -> ROCBLAS - s = s.replace('CUBLAS', 'ROCBLAS') - s = s.replace('Cublas', 'Rocblas') - s = s.replace('cublas', 'rocblas') + s = s.replace("CUBLAS", "ROCBLAS") + s = s.replace("Cublas", "Rocblas") + s = s.replace("cublas", "rocblas") # CURAND -> HIPRAND - s = s.replace('CURAND', 'HIPRAND') - s = s.replace('Curand', 'Hiprand') - s = s.replace('curand', 'hiprand') + s = s.replace("CURAND", "HIPRAND") + s = s.replace("Curand", "Hiprand") + s = s.replace("curand", "hiprand") # NCCL -> RCCL # s = s.replace('NCCL_CALL', 'RCCL_CALL') - s = s.replace('#include ', '#include ') + s = s.replace("#include ", "#include ") # CUDNN -> MIOpen - s = s.replace('CUDNN', 'MIOPEN') - s = s.replace('Cudnn', 'Miopen') - s = s.replace('cudnn', 'miopen') + s = s.replace("CUDNN", "MIOPEN") + s = s.replace("Cudnn", "Miopen") + s = s.replace("cudnn", "miopen") # hipify seems to have a bug for MIOpen, cudnn.h -> hipDNN.h, cudnn -> hipdnn - s = s.replace('#include ', '#include ') - s = s.replace('hipdnn', 'miopen') - s = s.replace('HIPDNN_STATUS_SUCCESS', 'miopenStatusSuccess') - s = s.replace('HIPDNN', 'MIOPEN') + s = s.replace("#include ", "#include ") + s = s.replace("hipdnn", "miopen") + s = s.replace("HIPDNN_STATUS_SUCCESS", "miopenStatusSuccess") + s = s.replace("HIPDNN", "MIOPEN") # CUSPARSE -> HIPSPARSE - s = s.replace('CUSPARSE', 'HIPSPARSE') + s = s.replace("CUSPARSE", "HIPSPARSE") # CUFFT -> HIPFFT - s = s.replace('CUFFT', 'HIPFFT') + s = s.replace("CUFFT", "HIPFFT") # Undo where above hipify steps went too far. - s = s.replace('id, ROCM', 'id, CUDA') # cuda_execution_provider.cc - s = s.replace('ROCM error executing', 'HIP error executing') - s = s.replace('ROCM_PINNED', 'CUDA_PINNED') - s = s.replace('rocm_err', 'hip_err') - s = s.replace('RegisterHipTrainingKernels', 'RegisterRocmTrainingKernels') - s = s.replace('ROCM_VERSION', 'CUDA_VERSION') # semantically different meanings, cannot hipify - s = s.replace('__ROCM_ARCH__', '__CUDA_ARCH__') # semantically different meanings, cannot hipify + s = s.replace("id, ROCM", "id, CUDA") # cuda_execution_provider.cc + s = s.replace("ROCM error executing", "HIP error executing") + s = s.replace("ROCM_PINNED", "CUDA_PINNED") + s = s.replace("rocm_err", "hip_err") + s = s.replace("RegisterHipTrainingKernels", "RegisterRocmTrainingKernels") + s = s.replace("ROCM_VERSION", "CUDA_VERSION") # semantically different meanings, cannot hipify + s = s.replace("__ROCM_ARCH__", "__CUDA_ARCH__") # semantically different meanings, cannot hipify # "std::log" above incorrectly changed "std::logic_error" to "logfic_error" - s = s.replace('logfic_error', 'std::logic_error') + s = s.replace("logfic_error", "std::logic_error") # Deletions - s = s.replace('#include "device_atomic_functions.h"', '') # HIP atomics in main hip header already + s = s.replace('#include "device_atomic_functions.h"', "") # HIP atomics in main hip header already do_write = True if os.path.exists(dst_file_path): - with open(dst_file_path, 'r', encoding='utf-8') as fout_old: + with open(dst_file_path, "r", encoding="utf-8") as fout_old: do_write = fout_old.read() != s if do_write: - with open(dst_file_path, 'w') as f: + with open(dst_file_path, "w") as f: f.write(s) return 'Hipified: "{}" -> "{}"'.format(src_file_path, dst_file_path) else: @@ -349,25 +356,34 @@ def list_files(prefix, path): def amd_hipify(config_build_dir): # determine hipify script path now to avoid doing so concurrently in the thread pool - print('Using %s' % get_hipify_path()) + print("Using %s" % get_hipify_path()) with concurrent.futures.ThreadPoolExecutor() as executor: - cuda_path = os.path.join(contrib_ops_path, 'cuda') - rocm_path = os.path.join(config_build_dir, 'amdgpu', contrib_ops_path, 'rocm') - contrib_files = list_files(cuda_path, '') - contrib_results = [executor.submit(hipify, os.path.join(cuda_path, f), os.path.join(rocm_path, f)) - for f in contrib_files if f not in contrib_ops_excluded_files] - - cuda_path = os.path.join(providers_path, 'cuda') - rocm_path = os.path.join(config_build_dir, 'amdgpu', providers_path, 'rocm') - provider_files = list_files(cuda_path, '') - provider_results = [executor.submit(hipify, os.path.join(cuda_path, f), os.path.join(rocm_path, f)) - for f in provider_files if f not in provider_excluded_files] - - cuda_path = os.path.join(training_ops_path, 'cuda') - rocm_path = os.path.join(config_build_dir, 'amdgpu', training_ops_path, 'rocm') - training_files = list_files(cuda_path, '') - training_results = [executor.submit(hipify, os.path.join(cuda_path, f), os.path.join(rocm_path, f)) - for f in training_files if f not in training_ops_excluded_files] + cuda_path = os.path.join(contrib_ops_path, "cuda") + rocm_path = os.path.join(config_build_dir, "amdgpu", contrib_ops_path, "rocm") + contrib_files = list_files(cuda_path, "") + contrib_results = [ + executor.submit(hipify, os.path.join(cuda_path, f), os.path.join(rocm_path, f)) + for f in contrib_files + if f not in contrib_ops_excluded_files + ] + + cuda_path = os.path.join(providers_path, "cuda") + rocm_path = os.path.join(config_build_dir, "amdgpu", providers_path, "rocm") + provider_files = list_files(cuda_path, "") + provider_results = [ + executor.submit(hipify, os.path.join(cuda_path, f), os.path.join(rocm_path, f)) + for f in provider_files + if f not in provider_excluded_files + ] + + cuda_path = os.path.join(training_ops_path, "cuda") + rocm_path = os.path.join(config_build_dir, "amdgpu", training_ops_path, "rocm") + training_files = list_files(cuda_path, "") + training_results = [ + executor.submit(hipify, os.path.join(cuda_path, f), os.path.join(rocm_path, f)) + for f in training_files + if f not in training_ops_excluded_files + ] # explicitly wait so that hipify warnings finish printing before logging the hipify statements concurrent.futures.wait(contrib_results) concurrent.futures.wait(provider_results) @@ -380,6 +396,7 @@ def amd_hipify(config_build_dir): log.debug(result.result()) -if __name__ == '__main__': +if __name__ == "__main__": import sys + amd_hipify(sys.argv[1]) diff --git a/tools/ci_build/build.py b/tools/ci_build/build.py index b1a4feceb8951..d8e8514c52a35 100644 --- a/tools/ci_build/build.py +++ b/tools/ci_build/build.py @@ -5,34 +5,31 @@ import argparse import contextlib import os +import platform import re import shlex import shutil import subprocess import sys -import platform -from amd_hipify import amd_hipify from distutils.version import LooseVersion +from amd_hipify import amd_hipify + SCRIPT_DIR = os.path.dirname(os.path.realpath(__file__)) REPO_DIR = os.path.normpath(os.path.join(SCRIPT_DIR, "..", "..")) sys.path.insert(0, os.path.join(REPO_DIR, "tools", "python")) -from util import ( # noqa: E402 - run, - is_windows, is_macOS, is_linux, - get_logger) - import util.android as android # noqa: E402 - +from util import get_logger, is_linux, is_macOS, is_windows, run # noqa: E402 log = get_logger("build") class BaseError(Exception): """Base class for errors originating from build.py.""" + pass @@ -55,20 +52,16 @@ def _check_python_version(): # Python 2 is definitely not supported and it should be safer to consider # it won't run with python 4: if sys.version_info[0] != 3: - raise BuildError( - "Bad python major version: expecting python 3, found version " - "'{}'".format(sys.version)) + raise BuildError("Bad python major version: expecting python 3, found version " "'{}'".format(sys.version)) if sys.version_info[1] < 6: - raise BuildError( - "Bad python minor version: expecting python 3.6+, found version " - "'{}'".format(sys.version)) + raise BuildError("Bad python minor version: expecting python 3.6+, found version " "'{}'".format(sys.version)) def _str_to_bool(s): """Convert string to bool (in argparse context).""" - if s.lower() not in ['true', 'false']: - raise ValueError('Need bool; got %r' % s) - return {'true': True, 'false': False}[s.lower()] + if s.lower() not in ["true", "false"]: + raise ValueError("Need bool; got %r" % s) + return {"true": True, "false": False}[s.lower()] _check_python_version() @@ -77,29 +70,35 @@ def _str_to_bool(s): def _openvino_verify_device_type(device_read): choices = ["CPU_FP32", "GPU_FP32", "GPU_FP16", "VAD-M_FP16", "MYRIAD_FP16", "VAD-F_FP32"] - choices1 = ["CPU_FP32_NO_PARTITION", "GPU_FP32_NO_PARTITION", "GPU_FP16_NO_PARTITION", - "VAD-M_FP16_NO_PARTITION", "MYRIAD_FP16_NO_PARTITION", "VAD-F_FP32_NO_PARTITION"] + choices1 = [ + "CPU_FP32_NO_PARTITION", + "GPU_FP32_NO_PARTITION", + "GPU_FP16_NO_PARTITION", + "VAD-M_FP16_NO_PARTITION", + "MYRIAD_FP16_NO_PARTITION", + "VAD-F_FP32_NO_PARTITION", + ] status_hetero = True res = False - if (device_read in choices): + if device_read in choices: res = True - elif (device_read in choices1): + elif device_read in choices1: res = True - elif (device_read.startswith("HETERO:") or device_read.startswith("MULTI:") or device_read.startswith("AUTO:")): + elif device_read.startswith("HETERO:") or device_read.startswith("MULTI:") or device_read.startswith("AUTO:"): res = True comma_separated_devices = device_read.split(":") - comma_separated_devices = comma_separated_devices[1].split(',') - if (len(comma_separated_devices) < 2): + comma_separated_devices = comma_separated_devices[1].split(",") + if len(comma_separated_devices) < 2: print("At least two devices required in Hetero/Multi/Auto Mode") status_hetero = False dev_options = ["CPU", "GPU", "MYRIAD", "FPGA", "HDDL"] for dev in comma_separated_devices: - if (dev not in dev_options): + if dev not in dev_options: status_hetero = False break def invalid_hetero_build(): - print("\n" + "If trying to build Hetero/Multi/Auto, specifiy the supported devices along with it." + + "\n") + print("\n" + "If trying to build Hetero/Multi/Auto, specifiy the supported devices along with it." + +"\n") print("specify the keyword HETERO or MULTI or AUTO followed by the devices ") print("in the order of priority you want to build" + "\n") print("The different hardware devices that can be added in HETERO or MULTI or AUTO") @@ -109,7 +108,7 @@ def invalid_hetero_build(): print("An example of how to specify the AUTO build type. Ex: AUTO:GPU,CPU" + "\n") sys.exit("Wrong Build Type selected") - if (res is False): + if res is False: print("\n" + "You have selcted wrong configuration for the build.") print("pick the build type for specific Hardware Device from following options: ", choices) print("(or) from the following options with graph partitioning disabled: ", choices1) @@ -118,7 +117,7 @@ def invalid_hetero_build(): invalid_hetero_build() sys.exit("Wrong Build Type selected") - if (status_hetero is False): + if status_hetero is False: invalid_hetero_build() return device_read @@ -144,460 +143,490 @@ def convert_arg_line_to_args(self, arg_line): """, # files containing arguments can be specified on the command line with "@" and the arguments within # will be included at that point - fromfile_prefix_chars="@") + fromfile_prefix_chars="@", + ) # Main arguments + parser.add_argument("--build_dir", required=True, help="Path to the build directory.") parser.add_argument( - "--build_dir", required=True, help="Path to the build directory.") - parser.add_argument( - "--config", nargs="+", default=["Debug"], + "--config", + nargs="+", + default=["Debug"], choices=["Debug", "MinSizeRel", "Release", "RelWithDebInfo"], - help="Configuration(s) to build.") - parser.add_argument( - "--update", action='store_true', help="Update makefiles.") - parser.add_argument("--build", action='store_true', help="Build.") + help="Configuration(s) to build.", + ) + parser.add_argument("--update", action="store_true", help="Update makefiles.") + parser.add_argument("--build", action="store_true", help="Build.") parser.add_argument( - "--clean", action='store_true', - help="Run 'cmake --build --target clean' for the selected config/s.") + "--clean", action="store_true", help="Run 'cmake --build --target clean' for the selected config/s." + ) parser.add_argument( - "--parallel", nargs='?', const='0', default='1', type=int, + "--parallel", + nargs="?", + const="0", + default="1", + type=int, help="Use parallel build. The optional value specifies the maximum number of parallel jobs. " - "If the optional value is 0 or unspecified, it is interpreted as the number of CPUs.") - parser.add_argument("--test", action='store_true', help="Run unit tests.") - parser.add_argument("--skip_tests", action='store_true', help="Skip all tests.") + "If the optional value is 0 or unspecified, it is interpreted as the number of CPUs.", + ) + parser.add_argument("--test", action="store_true", help="Run unit tests.") + parser.add_argument("--skip_tests", action="store_true", help="Skip all tests.") # Training options + parser.add_argument("--enable_nvtx_profile", action="store_true", help="Enable NVTX profile in ORT.") + parser.add_argument("--enable_memory_profile", action="store_true", help="Enable memory profile in ORT.") + parser.add_argument("--enable_training", action="store_true", help="Enable training in ORT.") + parser.add_argument("--enable_training_ops", action="store_true", help="Enable training ops in inference graph.") parser.add_argument( - "--enable_nvtx_profile", action='store_true', help="Enable NVTX profile in ORT.") - parser.add_argument( - "--enable_memory_profile", action='store_true', help="Enable memory profile in ORT.") - parser.add_argument( - "--enable_training", action='store_true', help="Enable training in ORT.") - parser.add_argument( - "--enable_training_ops", action='store_true', help="Enable training ops in inference graph.") - parser.add_argument( - "--enable_training_torch_interop", action='store_true', help="Enable training kernels interop with torch.") - parser.add_argument( - "--disable_nccl", action='store_true', help="Disable Nccl.") - parser.add_argument( - "--mpi_home", help="Path to MPI installation dir") - parser.add_argument( - "--nccl_home", help="Path to NCCL installation dir") - parser.add_argument( - "--use_mpi", nargs='?', default=True, const=True, type=_str_to_bool) + "--enable_training_torch_interop", action="store_true", help="Enable training kernels interop with torch." + ) + parser.add_argument("--disable_nccl", action="store_true", help="Disable Nccl.") + parser.add_argument("--mpi_home", help="Path to MPI installation dir") + parser.add_argument("--nccl_home", help="Path to NCCL installation dir") + parser.add_argument("--use_mpi", nargs="?", default=True, const=True, type=_str_to_bool) # enable ONNX tests parser.add_argument( - "--enable_onnx_tests", action='store_true', + "--enable_onnx_tests", + action="store_true", help="""When running the Test phase, run onnx_test_running against - available test data directories.""") + available test data directories.""", + ) parser.add_argument("--path_to_protoc_exe", help="Path to protoc exe.") + parser.add_argument("--fuzz_testing", action="store_true", help="Enable Fuzz testing of the onnxruntime.") parser.add_argument( - "--fuzz_testing", action='store_true', help="Enable Fuzz testing of the onnxruntime.") - parser.add_argument( - "--enable_symbolic_shape_infer_tests", action='store_true', + "--enable_symbolic_shape_infer_tests", + action="store_true", help="""When running the Test phase, run symbolic shape inference against - available test data directories.""") + available test data directories.""", + ) # generate documentation - parser.add_argument("--gen_doc", nargs='?', const='yes', type=str, - help="Generate documentation listing standard ONNX operators and types implemented by " - "various execution providers and contrib operator schemas. " - "Use `--gen_doc validate` to validate these match the current contents in /docs.") - parser.add_argument( - "--gen-api-doc", action='store_true', - help="Generate API documentation for PyTorch frontend") + "--gen_doc", + nargs="?", + const="yes", + type=str, + help="Generate documentation listing standard ONNX operators and types implemented by " + "various execution providers and contrib operator schemas. " + "Use `--gen_doc validate` to validate these match the current contents in /docs.", + ) + + parser.add_argument("--gen-api-doc", action="store_true", help="Generate API documentation for PyTorch frontend") # CUDA related - parser.add_argument("--use_cuda", action='store_true', help="Enable CUDA.") + parser.add_argument("--use_cuda", action="store_true", help="Enable CUDA.") parser.add_argument( - "--cuda_version", help="The version of CUDA toolkit to use. " - "Auto-detect if not specified. e.g. 9.0") + "--cuda_version", help="The version of CUDA toolkit to use. " "Auto-detect if not specified. e.g. 9.0" + ) parser.add_argument( - "--cuda_home", help="Path to CUDA home." + "--cuda_home", + help="Path to CUDA home." "Read from CUDA_HOME environment variable if --use_cuda is true and " - "--cuda_home is not specified.") + "--cuda_home is not specified.", + ) parser.add_argument( - "--cudnn_home", help="Path to CUDNN home. " + "--cudnn_home", + help="Path to CUDNN home. " "Read from CUDNN_HOME environment variable if --use_cuda is true and " - "--cudnn_home is not specified.") - parser.add_argument( - "--enable_cuda_line_info", action='store_true', help="Enable CUDA line info.") + "--cudnn_home is not specified.", + ) + parser.add_argument("--enable_cuda_line_info", action="store_true", help="Enable CUDA line info.") # Python bindings + parser.add_argument("--enable_pybind", action="store_true", help="Enable Python Bindings.") + parser.add_argument("--build_wheel", action="store_true", help="Build Python Wheel.") parser.add_argument( - "--enable_pybind", action='store_true', help="Enable Python Bindings.") - parser.add_argument( - "--build_wheel", action='store_true', help="Build Python Wheel.") - parser.add_argument( - "--wheel_name_suffix", help="Suffix to append to created wheel names. " - "This value is currently only used for nightly builds.") - parser.add_argument( - "--numpy_version", help="Installs a specific version of numpy " - "before building the python binding.") + "--wheel_name_suffix", + help="Suffix to append to created wheel names. " "This value is currently only used for nightly builds.", + ) parser.add_argument( - "--skip-keras-test", action='store_true', - help="Skip tests with Keras if keras is installed") + "--numpy_version", help="Installs a specific version of numpy " "before building the python binding." + ) + parser.add_argument("--skip-keras-test", action="store_true", help="Skip tests with Keras if keras is installed") # C-Sharp bindings parser.add_argument( - "--build_csharp", action='store_true', + "--build_csharp", + action="store_true", help="Build C#.Net DLL and NuGet package. This should be only used in CI pipelines. " - "For building C# bindings and packaging them into nuget package use --build_nuget arg.") + "For building C# bindings and packaging them into nuget package use --build_nuget arg.", + ) parser.add_argument( - "--build_nuget", action='store_true', + "--build_nuget", + action="store_true", help="Build C#.Net DLL and NuGet package on the local machine. " - "Currently only Windows and Linux platforms are supported.") + "Currently only Windows and Linux platforms are supported.", + ) # Java bindings - parser.add_argument( - "--build_java", action='store_true', help="Build Java bindings.") + parser.add_argument("--build_java", action="store_true", help="Build Java bindings.") # Node.js binding - parser.add_argument( - "--build_nodejs", action='store_true', - help="Build Node.js binding and NPM package.") + parser.add_argument("--build_nodejs", action="store_true", help="Build Node.js binding and NPM package.") # Objective-C binding - parser.add_argument( - "--build_objc", action='store_true', - help="Build Objective-C binding.") + parser.add_argument("--build_objc", action="store_true", help="Build Objective-C binding.") # Build a shared lib - parser.add_argument( - "--build_shared_lib", action='store_true', - help="Build a shared library for the ONNXRuntime.") + parser.add_argument("--build_shared_lib", action="store_true", help="Build a shared library for the ONNXRuntime.") # Build a shared lib parser.add_argument( - "--build_apple_framework", action='store_true', - help="Build a macOS/iOS framework for the ONNXRuntime.") + "--build_apple_framework", action="store_true", help="Build a macOS/iOS framework for the ONNXRuntime." + ) # Build options parser.add_argument( - "--cmake_extra_defines", nargs="+", action='append', + "--cmake_extra_defines", + nargs="+", + action="append", help="Extra definitions to pass to CMake during build system " - "generation. These are just CMake -D options without the leading -D.") - parser.add_argument( - "--target", - help="Build a specific target, e.g. winml_dll") + "generation. These are just CMake -D options without the leading -D.", + ) + parser.add_argument("--target", help="Build a specific target, e.g. winml_dll") # This flag is needed when : # 1. The OS is 64 bits Windows # 2. And the target binary is for 32 bits Windows # 3. And the python used for running this script is 64 bits. # But if you can get a 32 bits python, the build will run better and you won't need this flag. parser.add_argument( - "--x86", action='store_true', + "--x86", + action="store_true", help="[cross-compiling] Create Windows x86 makefiles. Requires --update and no existing cache " - "CMake setup. Delete CMakeCache.txt if needed") + "CMake setup. Delete CMakeCache.txt if needed", + ) parser.add_argument( - "--arm", action='store_true', + "--arm", + action="store_true", help="[cross-compiling] Create ARM makefiles. Requires --update and no existing cache " - "CMake setup. Delete CMakeCache.txt if needed") + "CMake setup. Delete CMakeCache.txt if needed", + ) parser.add_argument( - "--arm64", action='store_true', + "--arm64", + action="store_true", help="[cross-compiling] Create ARM64 makefiles. Requires --update and no existing cache " - "CMake setup. Delete CMakeCache.txt if needed") + "CMake setup. Delete CMakeCache.txt if needed", + ) parser.add_argument( - "--arm64ec", action='store_true', + "--arm64ec", + action="store_true", help="[cross-compiling] Create ARM64EC makefiles. Requires --update and no existing cache " - "CMake setup. Delete CMakeCache.txt if needed") - parser.add_argument( - "--msvc_toolset", help="MSVC toolset to use. e.g. 14.11") - parser.add_argument("--android", action='store_true', help='Build for Android') + "CMake setup. Delete CMakeCache.txt if needed", + ) + parser.add_argument("--msvc_toolset", help="MSVC toolset to use. e.g. 14.11") + parser.add_argument("--android", action="store_true", help="Build for Android") parser.add_argument( - "--android_abi", default="arm64-v8a", + "--android_abi", + default="arm64-v8a", choices=["armeabi-v7a", "arm64-v8a", "x86", "x86_64"], - help="Specify the target Android Application Binary Interface (ABI)") - parser.add_argument("--android_api", type=int, default=27, help='Android API Level, e.g. 21') - parser.add_argument( - "--android_sdk_path", type=str, default=os.environ.get("ANDROID_HOME", ""), - help="Path to the Android SDK") - parser.add_argument( - "--android_ndk_path", type=str, default=os.environ.get("ANDROID_NDK_HOME", ""), - help="Path to the Android NDK") - parser.add_argument("--android_cpp_shared", action="store_true", - help="Build with shared libc++ instead of the default static libc++.") - parser.add_argument("--android_run_emulator", action="store_true", - help="Start up an Android emulator if needed.") - - parser.add_argument("--use_gdk", action='store_true', help="Build with the GDK toolchain.") - parser.add_argument("--gdk_edition", default=os.path.normpath(os.environ.get("GameDKLatest", "")).split(os.sep)[-1], - help="Build with a specific GDK edition. Defaults to the latest installed.") + help="Specify the target Android Application Binary Interface (ABI)", + ) + parser.add_argument("--android_api", type=int, default=27, help="Android API Level, e.g. 21") + parser.add_argument( + "--android_sdk_path", type=str, default=os.environ.get("ANDROID_HOME", ""), help="Path to the Android SDK" + ) + parser.add_argument( + "--android_ndk_path", type=str, default=os.environ.get("ANDROID_NDK_HOME", ""), help="Path to the Android NDK" + ) + parser.add_argument( + "--android_cpp_shared", + action="store_true", + help="Build with shared libc++ instead of the default static libc++.", + ) + parser.add_argument("--android_run_emulator", action="store_true", help="Start up an Android emulator if needed.") + + parser.add_argument("--use_gdk", action="store_true", help="Build with the GDK toolchain.") + parser.add_argument( + "--gdk_edition", + default=os.path.normpath(os.environ.get("GameDKLatest", "")).split(os.sep)[-1], + help="Build with a specific GDK edition. Defaults to the latest installed.", + ) parser.add_argument("--gdk_platform", default="Scarlett", help="Sets the GDK target platform.") - parser.add_argument("--ios", action='store_true', help="build for ios") + parser.add_argument("--ios", action="store_true", help="build for ios") parser.add_argument( - "--ios_sysroot", default="", - help="Specify the location name of the macOS platform SDK to be used") + "--ios_sysroot", default="", help="Specify the location name of the macOS platform SDK to be used" + ) parser.add_argument( - "--ios_toolchain_file", default="", - help="Path to ios toolchain file, " - "or cmake/onnxruntime_ios.toolchain.cmake will be used") + "--ios_toolchain_file", + default="", + help="Path to ios toolchain file, " "or cmake/onnxruntime_ios.toolchain.cmake will be used", + ) parser.add_argument( - "--xcode_code_signing_team_id", default="", - help="The development team ID used for code signing in Xcode") + "--xcode_code_signing_team_id", default="", help="The development team ID used for code signing in Xcode" + ) parser.add_argument( - "--xcode_code_signing_identity", default="", - help="The development identity used for code signing in Xcode") + "--xcode_code_signing_identity", default="", help="The development identity used for code signing in Xcode" + ) parser.add_argument( - "--use_xcode", action='store_true', - help="Use Xcode as cmake generator, this is only supported on MacOS.") + "--use_xcode", action="store_true", help="Use Xcode as cmake generator, this is only supported on MacOS." + ) parser.add_argument( "--osx_arch", default="arm64" if platform.machine() == "arm64" else "x86_64", choices=["arm64", "arm64e", "x86_64"], - help="Specify the Target specific architectures for macOS and iOS, This is only supported on MacOS") + help="Specify the Target specific architectures for macOS and iOS, This is only supported on MacOS", + ) parser.add_argument( - "--apple_deploy_target", type=str, + "--apple_deploy_target", + type=str, help="Specify the minimum version of the target platform " "(e.g. macOS or iOS)" - "This is only supported on MacOS") + "This is only supported on MacOS", + ) parser.add_argument( - "--disable_memleak_checker", action='store_true', - help="Disable memory leak checker from Windows build") + "--disable_memleak_checker", action="store_true", help="Disable memory leak checker from Windows build" + ) # WebAssembly build - parser.add_argument("--build_wasm", action='store_true', help="Build for WebAssembly") - parser.add_argument("--build_wasm_static_lib", action='store_true', help="Build for WebAssembly static library") - parser.add_argument( - "--emsdk_version", default="3.1.3", help="Specify version of emsdk") + parser.add_argument("--build_wasm", action="store_true", help="Build for WebAssembly") + parser.add_argument("--build_wasm_static_lib", action="store_true", help="Build for WebAssembly static library") + parser.add_argument("--emsdk_version", default="3.1.3", help="Specify version of emsdk") - parser.add_argument("--enable_wasm_simd", action='store_true', help="Enable WebAssembly SIMD") - parser.add_argument( - "--enable_wasm_threads", action='store_true', - help="Enable WebAssembly multi-threads support") + parser.add_argument("--enable_wasm_simd", action="store_true", help="Enable WebAssembly SIMD") + parser.add_argument("--enable_wasm_threads", action="store_true", help="Enable WebAssembly multi-threads support") parser.add_argument( - "--disable_wasm_exception_catching", action='store_true', - help="Disable exception catching in WebAssembly.") + "--disable_wasm_exception_catching", action="store_true", help="Disable exception catching in WebAssembly." + ) parser.add_argument( - "--enable_wasm_exception_throwing_override", action='store_true', + "--enable_wasm_exception_throwing_override", + action="store_true", help="Enable exception throwing in WebAssembly, this will override default disabling exception throwing " - "behavior when disable exceptions.") + "behavior when disable exceptions.", + ) parser.add_argument( - "--enable_wasm_profiling", action='store_true', - help="Enable WebAsselby profiling and preserve function names") + "--enable_wasm_profiling", action="store_true", help="Enable WebAsselby profiling and preserve function names" + ) parser.add_argument( - "--enable_wasm_debug_info", action='store_true', - help="Build WebAssembly with DWARF format debug info") + "--enable_wasm_debug_info", action="store_true", help="Build WebAssembly with DWARF format debug info" + ) parser.add_argument("--wasm_malloc", help="Specify memory allocator for WebAssembly") parser.add_argument( - "--emscripten_settings", nargs="+", action='append', - help="Extra emscripten settings to pass to emcc using '-s =' during build.") + "--emscripten_settings", + nargs="+", + action="append", + help="Extra emscripten settings to pass to emcc using '-s =' during build.", + ) # Enable onnxruntime-extensions parser.add_argument( - "--use_extensions", action='store_true', + "--use_extensions", + action="store_true", help="Enable custom operators in onnxruntime-extensions, use git submodule onnxruntime-extensions " - "in path cmake/external/onnxruntime-extensions by default.") + "in path cmake/external/onnxruntime-extensions by default.", + ) parser.add_argument( - "--extensions_overridden_path", type=str, - help="Path to pre-pulled onnxruntime-extensions, will override default onnxruntime-extensions path.") + "--extensions_overridden_path", + type=str, + help="Path to pre-pulled onnxruntime-extensions, will override default onnxruntime-extensions path.", + ) # Arguments needed by CI + parser.add_argument("--cmake_path", default="cmake", help="Path to the CMake program.") parser.add_argument( - "--cmake_path", default="cmake", help="Path to the CMake program.") - parser.add_argument( - "--ctest_path", default="ctest", help="Path to the CTest program. It can be an empty string. If it is empty, " - "we will use this script driving the test programs directly.") - parser.add_argument( - "--skip_submodule_sync", action='store_true', help="Don't do a " - "'git submodule update'. Makes the Update phase faster.") - parser.add_argument( - "--use_vstest", action='store_true', - help="Use use_vstest for running unitests.") - parser.add_argument( - "--use_mimalloc", action='store_true', help="Use mimalloc allocator") + "--ctest_path", + default="ctest", + help="Path to the CTest program. It can be an empty string. If it is empty, " + "we will use this script driving the test programs directly.", + ) parser.add_argument( - "--use_dnnl", action='store_true', help="Build with DNNL.") + "--skip_submodule_sync", + action="store_true", + help="Don't do a " "'git submodule update'. Makes the Update phase faster.", + ) + parser.add_argument("--use_vstest", action="store_true", help="Use use_vstest for running unitests.") + parser.add_argument("--use_mimalloc", action="store_true", help="Use mimalloc allocator") + parser.add_argument("--use_dnnl", action="store_true", help="Build with DNNL.") parser.add_argument( - "--dnnl_gpu_runtime", action='store', default='', type=str.lower, - help="e.g. --dnnl_gpu_runtime ocl") + "--dnnl_gpu_runtime", action="store", default="", type=str.lower, help="e.g. --dnnl_gpu_runtime ocl" + ) parser.add_argument( - "--dnnl_opencl_root", action='store', default='', + "--dnnl_opencl_root", + action="store", + default="", help="Path to OpenCL SDK. " - "e.g. --dnnl_opencl_root \"C:/Program Files (x86)/IntelSWTools/sw_dev_tools/OpenCL/sdk\"") + 'e.g. --dnnl_opencl_root "C:/Program Files (x86)/IntelSWTools/sw_dev_tools/OpenCL/sdk"', + ) parser.add_argument( - "--use_openvino", nargs="?", const="CPU_FP32", + "--use_openvino", + nargs="?", + const="CPU_FP32", type=_openvino_verify_device_type, - help="Build with OpenVINO for specific hardware.") - parser.add_argument( - "--use_coreml", action='store_true', help="Build with CoreML support.") - parser.add_argument( - "--use_nnapi", action='store_true', help="Build with NNAPI support.") - parser.add_argument( - "--nnapi_min_api", type=int, - help="Minimum Android API level to enable NNAPI, should be no less than 27") - parser.add_argument( - "--use_rknpu", action='store_true', help="Build with RKNPU.") + help="Build with OpenVINO for specific hardware.", + ) + parser.add_argument("--use_coreml", action="store_true", help="Build with CoreML support.") + parser.add_argument("--use_nnapi", action="store_true", help="Build with NNAPI support.") parser.add_argument( - "--use_preinstalled_eigen", action='store_true', - help="Use pre-installed Eigen.") + "--nnapi_min_api", type=int, help="Minimum Android API level to enable NNAPI, should be no less than 27" + ) + parser.add_argument("--use_rknpu", action="store_true", help="Build with RKNPU.") + parser.add_argument("--use_preinstalled_eigen", action="store_true", help="Use pre-installed Eigen.") parser.add_argument("--eigen_path", help="Path to pre-installed Eigen.") - parser.add_argument( - "--use_openmp", action='store_true', help="Build with OpenMP") - parser.add_argument( - "--enable_msinternal", action="store_true", - help="Enable for Microsoft internal builds only.") + parser.add_argument("--use_openmp", action="store_true", help="Build with OpenMP") + parser.add_argument("--enable_msinternal", action="store_true", help="Enable for Microsoft internal builds only.") parser.add_argument("--llvm_path", help="Path to llvm dir") - parser.add_argument( - "--use_vitisai", action='store_true', help="Build with Vitis-AI") - parser.add_argument( - "--use_nuphar", action='store_true', help="Build with nuphar") - parser.add_argument( - "--use_tvm", action='store_true', help="Build with TVM") - parser.add_argument( - "--tvm_cuda_runtime", action='store_true', default=False, - help="Build TVM with CUDA support") - parser.add_argument( - "--use_tensorrt", action='store_true', help="Build with TensorRT") - parser.add_argument( - "--tensorrt_home", help="Path to TensorRT installation dir") - parser.add_argument( - "--use_migraphx", action='store_true', help="Build with MIGraphX") - parser.add_argument( - "--migraphx_home", help="Path to MIGraphX installation dir") - parser.add_argument( - "--use_full_protobuf", action='store_true', - help="Use the full protobuf library") - - parser.add_argument("--skip_onnx_tests", action='store_true', - help="Explicitly disable all onnx related tests. Note: Use --skip_tests to skip all tests.") - parser.add_argument("--skip_winml_tests", action='store_true', help="Explicitly disable all WinML related tests") - parser.add_argument("--skip_nodejs_tests", action='store_true', help="Explicitly disable all Node.js binding tests") + parser.add_argument("--use_vitisai", action="store_true", help="Build with Vitis-AI") + parser.add_argument("--use_nuphar", action="store_true", help="Build with nuphar") + parser.add_argument("--use_tvm", action="store_true", help="Build with TVM") + parser.add_argument("--tvm_cuda_runtime", action="store_true", default=False, help="Build TVM with CUDA support") + parser.add_argument("--use_tensorrt", action="store_true", help="Build with TensorRT") + parser.add_argument("--tensorrt_home", help="Path to TensorRT installation dir") + parser.add_argument("--use_migraphx", action="store_true", help="Build with MIGraphX") + parser.add_argument("--migraphx_home", help="Path to MIGraphX installation dir") + parser.add_argument("--use_full_protobuf", action="store_true", help="Use the full protobuf library") + + parser.add_argument( + "--skip_onnx_tests", + action="store_true", + help="Explicitly disable all onnx related tests. Note: Use --skip_tests to skip all tests.", + ) + parser.add_argument("--skip_winml_tests", action="store_true", help="Explicitly disable all WinML related tests") + parser.add_argument("--skip_nodejs_tests", action="store_true", help="Explicitly disable all Node.js binding tests") parser.add_argument( - "--enable_msvc_static_runtime", action='store_true', - help="Enable static linking of MSVC runtimes.") + "--enable_msvc_static_runtime", action="store_true", help="Enable static linking of MSVC runtimes." + ) parser.add_argument( - "--enable_language_interop_ops", action='store_true', - help="Enable operator implemented in language other than cpp") + "--enable_language_interop_ops", + action="store_true", + help="Enable operator implemented in language other than cpp", + ) parser.add_argument( "--cmake_generator", - choices=['Visual Studio 15 2017', 'Visual Studio 16 2019', 'Visual Studio 17 2022', 'Ninja'], - default='Visual Studio 16 2019' if is_windows() else None, - help="Specify the generator that CMake invokes. " - "This is only supported on Windows") - parser.add_argument( - "--enable_multi_device_test", action='store_true', - help="Test with multi-device. Mostly used for multi-device GPU") - parser.add_argument( - "--use_dml", action='store_true', help="Build with DirectML.") - parser.add_argument( - "--dml_path", type=str, default="", - help="Path to a custom DirectML installation (must have bin/, lib/, and include/ subdirectories).") - parser.add_argument( - "--use_winml", action='store_true', help="Build with WinML.") - parser.add_argument( - "--winml_root_namespace_override", type=str, - help="Specify the namespace that WinML builds into.") + choices=["Visual Studio 15 2017", "Visual Studio 16 2019", "Visual Studio 17 2022", "Ninja"], + default="Visual Studio 16 2019" if is_windows() else None, + help="Specify the generator that CMake invokes. " "This is only supported on Windows", + ) parser.add_argument( - "--use_telemetry", action='store_true', - help="Only official builds can set this flag to enable telemetry.") + "--enable_multi_device_test", + action="store_true", + help="Test with multi-device. Mostly used for multi-device GPU", + ) + parser.add_argument("--use_dml", action="store_true", help="Build with DirectML.") parser.add_argument( - "--enable_wcos", action='store_true', - help="Build for Windows Core OS.") + "--dml_path", + type=str, + default="", + help="Path to a custom DirectML installation (must have bin/, lib/, and include/ subdirectories).", + ) + parser.add_argument("--use_winml", action="store_true", help="Build with WinML.") parser.add_argument( - "--enable_lto", action='store_true', - help="Enable Link Time Optimization") + "--winml_root_namespace_override", type=str, help="Specify the namespace that WinML builds into." + ) parser.add_argument( - "--enable_transformers_tool_test", action='store_true', - help="Enable transformers tool test") + "--use_telemetry", action="store_true", help="Only official builds can set this flag to enable telemetry." + ) + parser.add_argument("--enable_wcos", action="store_true", help="Build for Windows Core OS.") + parser.add_argument("--enable_lto", action="store_true", help="Enable Link Time Optimization") + parser.add_argument("--enable_transformers_tool_test", action="store_true", help="Enable transformers tool test") parser.add_argument( - "--use_acl", nargs="?", const="ACL_1905", + "--use_acl", + nargs="?", + const="ACL_1905", choices=["ACL_1902", "ACL_1905", "ACL_1908", "ACL_2002"], - help="Build with ACL for ARM architectures.") - parser.add_argument( - "--acl_home", help="Path to ACL home dir") - parser.add_argument( - "--acl_libs", help="Path to ACL libraries") - parser.add_argument( - "--use_armnn", action='store_true', - help="Enable ArmNN Execution Provider.") + help="Build with ACL for ARM architectures.", + ) + parser.add_argument("--acl_home", help="Path to ACL home dir") + parser.add_argument("--acl_libs", help="Path to ACL libraries") + parser.add_argument("--use_armnn", action="store_true", help="Enable ArmNN Execution Provider.") parser.add_argument( - "--armnn_relu", action='store_true', - help="Use the Relu operator implementation from the ArmNN EP.") + "--armnn_relu", action="store_true", help="Use the Relu operator implementation from the ArmNN EP." + ) parser.add_argument( - "--armnn_bn", action='store_true', - help="Use the Batch Normalization operator implementation from the ArmNN EP.") + "--armnn_bn", action="store_true", help="Use the Batch Normalization operator implementation from the ArmNN EP." + ) + parser.add_argument("--armnn_home", help="Path to ArmNN home dir") + parser.add_argument("--armnn_libs", help="Path to ArmNN libraries") + parser.add_argument("--build_micro_benchmarks", action="store_true", help="Build ONNXRuntime micro-benchmarks.") + + # options to reduce binary size parser.add_argument( - "--armnn_home", help="Path to ArmNN home dir") + "--minimal_build", + default=None, + nargs="*", + type=str.lower, + help="Create a build that only supports ORT format models. " + "See https://onnxruntime.ai/docs/tutorials/mobile/ for more information. " + "RTTI is automatically disabled in a minimal build. " + "To enable execution providers that compile kernels at runtime (e.g. NNAPI) pass 'extended' " + "as a parameter. e.g. '--minimal_build extended'. " + "To enable support for custom operators pass 'custom_ops' as a parameter. " + "e.g. '--minimal_build custom_ops'. This can be combined with an 'extended' build by passing " + "'--minimal_build extended custom_ops'", + ) + parser.add_argument( - "--armnn_libs", help="Path to ArmNN libraries") + "--include_ops_by_config", + type=str, + help="Include ops from config file. " "See /docs/Reduced_Operator_Kernel_build.md for more information.", + ) parser.add_argument( - "--build_micro_benchmarks", action='store_true', - help="Build ONNXRuntime micro-benchmarks.") + "--enable_reduced_operator_type_support", + action="store_true", + help="If --include_ops_by_config is specified, and the configuration file has type reduction " + "information, limit the types individual operators support where possible to further " + "reduce the build size. " + "See /docs/Reduced_Operator_Kernel_build.md for more information.", + ) - # options to reduce binary size - parser.add_argument("--minimal_build", default=None, nargs='*', type=str.lower, - help="Create a build that only supports ORT format models. " - "See https://onnxruntime.ai/docs/tutorials/mobile/ for more information. " - "RTTI is automatically disabled in a minimal build. " - "To enable execution providers that compile kernels at runtime (e.g. NNAPI) pass 'extended' " - "as a parameter. e.g. '--minimal_build extended'. " - "To enable support for custom operators pass 'custom_ops' as a parameter. " - "e.g. '--minimal_build custom_ops'. This can be combined with an 'extended' build by passing " - "'--minimal_build extended custom_ops'") - - parser.add_argument("--include_ops_by_config", type=str, - help="Include ops from config file. " - "See /docs/Reduced_Operator_Kernel_build.md for more information.") - parser.add_argument("--enable_reduced_operator_type_support", action='store_true', - help='If --include_ops_by_config is specified, and the configuration file has type reduction ' - 'information, limit the types individual operators support where possible to further ' - 'reduce the build size. ' - 'See /docs/Reduced_Operator_Kernel_build.md for more information.') - - parser.add_argument("--disable_contrib_ops", action='store_true', - help="Disable contrib ops (reduces binary size)") - parser.add_argument("--disable_ml_ops", action='store_true', - help="Disable traditional ML ops (reduces binary size)") + parser.add_argument("--disable_contrib_ops", action="store_true", help="Disable contrib ops (reduces binary size)") + parser.add_argument( + "--disable_ml_ops", action="store_true", help="Disable traditional ML ops (reduces binary size)" + ) # Please note in our CMakeLists.txt this is already default on. But in this file we reverse it to default OFF. - parser.add_argument("--disable_rtti", action='store_true', help="Disable RTTI (reduces binary size)") - parser.add_argument("--disable_exceptions", action='store_true', - help="Disable exceptions to reduce binary size. Requires --minimal_build.") - + parser.add_argument("--disable_rtti", action="store_true", help="Disable RTTI (reduces binary size)") parser.add_argument( - "--rocm_version", help="The version of ROCM stack to use. ") - parser.add_argument("--use_rocm", action='store_true', help="Build with ROCm") + "--disable_exceptions", + action="store_true", + help="Disable exceptions to reduce binary size. Requires --minimal_build.", + ) + + parser.add_argument("--rocm_version", help="The version of ROCM stack to use. ") + parser.add_argument("--use_rocm", action="store_true", help="Build with ROCm") parser.add_argument("--rocm_home", help="Path to ROCm installation dir") # Code coverage - parser.add_argument("--code_coverage", action='store_true', - help="Generate code coverage when targetting Android (only).") parser.add_argument( - "--ms_experimental", action='store_true', help="Build microsoft experimental operators.")\ - + "--code_coverage", action="store_true", help="Generate code coverage when targetting Android (only)." + ) + parser.add_argument("--ms_experimental", action="store_true", help="Build microsoft experimental operators.") # eager mode + parser.add_argument("--build_eager_mode", action="store_true", help="Build ONNXRuntime micro-benchmarks.") + parser.add_argument( + "--eager_customop_module", default=None, help="Module containing custom op mappings for eager mode." + ) parser.add_argument( - "--build_eager_mode", action='store_true', - help="Build ONNXRuntime micro-benchmarks.") - parser.add_argument('--eager_customop_module', default=None, - help='Module containing custom op mappings for eager mode.') - parser.add_argument('--eager_customop_header', default=None, - help='Header containing custom op definitions for eager mode.') + "--eager_customop_header", default=None, help="Header containing custom op definitions for eager mode." + ) parser.add_argument( - "--enable_external_custom_op_schemas", action='store_true', + "--enable_external_custom_op_schemas", + action="store_true", help="Enable registering user defined custom operation schemas at shared library load time.\ - This feature is only supported/available on Ubuntu.") + This feature is only supported/available on Ubuntu.", + ) parser.add_argument( - "--external_graph_transformer_path", type=str, - help="path to the external graph transformer dir.") + "--external_graph_transformer_path", type=str, help="path to the external graph transformer dir." + ) parser.add_argument( - "--test_external_transformer_example", action='store_true', - help="run the example external transformer test, mainly used in CI pipeline.") + "--test_external_transformer_example", + action="store_true", + help="run the example external transformer test, mainly used in CI pipeline.", + ) parser.add_argument( - "--enable_cuda_profiling", action='store_true', help="enable cuda kernel profiling, \ - cupti library must be added to PATH beforehand.") + "--enable_cuda_profiling", + action="store_true", + help="enable cuda kernel profiling, \ + cupti library must be added to PATH beforehand.", + ) args = parser.parse_args() if args.android_sdk_path: @@ -617,8 +646,7 @@ def resolve_executable_path(command_or_path): if command_or_path and command_or_path.strip(): executable_path = shutil.which(command_or_path) if executable_path is None: - raise BuildError("Failed to resolve executable path for " - "'{}'.".format(command_or_path)) + raise BuildError("Failed to resolve executable path for " "'{}'.".format(command_or_path)) return os.path.abspath(executable_path) else: return None @@ -626,18 +654,16 @@ def resolve_executable_path(command_or_path): def get_linux_distro(): try: - with open('/etc/os-release', 'r') as f: - dist_info = dict( - line.strip().split('=', 1) for line in f.readlines()) - return dist_info.get('NAME', '').strip('"'), dist_info.get( - 'VERSION', '').strip('"') + with open("/etc/os-release", "r") as f: + dist_info = dict(line.strip().split("=", 1) for line in f.readlines()) + return dist_info.get("NAME", "").strip('"'), dist_info.get("VERSION", "").strip('"') except (IOError, ValueError): - return '', '' + return "", "" def is_ubuntu_1604(): dist, ver = get_linux_distro() - return dist == 'Ubuntu' and ver.startswith('16.04') + return dist == "Ubuntu" and ver.startswith("16.04") def get_config_build_dir(build_dir, config): @@ -645,8 +671,7 @@ def get_config_build_dir(build_dir, config): return os.path.join(build_dir, config) -def run_subprocess(args, cwd=None, capture_stdout=False, dll_path=None, - shell=False, env={}, python_path=None): +def run_subprocess(args, cwd=None, capture_stdout=False, dll_path=None, shell=False, env={}, python_path=None): if isinstance(args, str): raise ValueError("args should be a sequence of strings, not a string") @@ -675,68 +700,54 @@ def run_subprocess(args, cwd=None, capture_stdout=False, dll_path=None, def update_submodules(source_dir): run_subprocess(["git", "submodule", "sync", "--recursive"], cwd=source_dir) - run_subprocess(["git", "submodule", "update", "--init", "--recursive"], - cwd=source_dir) + run_subprocess(["git", "submodule", "update", "--init", "--recursive"], cwd=source_dir) def is_docker(): - path = '/proc/self/cgroup' - return ( - os.path.exists('/.dockerenv') or - os.path.isfile(path) and any('docker' in line for line in open(path)) - ) + path = "/proc/self/cgroup" + return os.path.exists("/.dockerenv") or os.path.isfile(path) and any("docker" in line for line in open(path)) def install_python_deps(numpy_version=""): - dep_packages = ['setuptools', 'wheel', 'pytest'] - dep_packages.append('numpy=={}'.format(numpy_version) if numpy_version - else 'numpy>=1.16.6') - dep_packages.append('sympy>=1.1') - dep_packages.append('packaging') - dep_packages.append('cerberus') - run_subprocess([sys.executable, '-m', 'pip', 'install'] + dep_packages) + dep_packages = ["setuptools", "wheel", "pytest"] + dep_packages.append("numpy=={}".format(numpy_version) if numpy_version else "numpy>=1.16.6") + dep_packages.append("sympy>=1.1") + dep_packages.append("packaging") + dep_packages.append("cerberus") + run_subprocess([sys.executable, "-m", "pip", "install"] + dep_packages) def setup_test_data(build_dir, configs): # create a shortcut for test models if there is a 'models' # folder in build_dir if is_windows(): - src_model_dir = os.path.join(build_dir, 'models') - if os.path.exists('C:\\local\\models') and not os.path.exists( - src_model_dir): - log.debug("creating shortcut %s -> %s" % ( - 'C:\\local\\models', src_model_dir)) - run_subprocess(['mklink', '/D', '/J', src_model_dir, - 'C:\\local\\models'], shell=True) + src_model_dir = os.path.join(build_dir, "models") + if os.path.exists("C:\\local\\models") and not os.path.exists(src_model_dir): + log.debug("creating shortcut %s -> %s" % ("C:\\local\\models", src_model_dir)) + run_subprocess(["mklink", "/D", "/J", src_model_dir, "C:\\local\\models"], shell=True) for config in configs: config_build_dir = get_config_build_dir(build_dir, config) os.makedirs(config_build_dir, exist_ok=True) - dest_model_dir = os.path.join(config_build_dir, 'models') - if os.path.exists('C:\\local\\models') and not os.path.exists( - dest_model_dir): - log.debug("creating shortcut %s -> %s" % ( - 'C:\\local\\models', dest_model_dir)) - run_subprocess(['mklink', '/D', '/J', dest_model_dir, - 'C:\\local\\models'], shell=True) - elif os.path.exists(src_model_dir) and not os.path.exists( - dest_model_dir): - log.debug("creating shortcut %s -> %s" % ( - src_model_dir, dest_model_dir)) - run_subprocess(['mklink', '/D', '/J', dest_model_dir, - src_model_dir], shell=True) + dest_model_dir = os.path.join(config_build_dir, "models") + if os.path.exists("C:\\local\\models") and not os.path.exists(dest_model_dir): + log.debug("creating shortcut %s -> %s" % ("C:\\local\\models", dest_model_dir)) + run_subprocess(["mklink", "/D", "/J", dest_model_dir, "C:\\local\\models"], shell=True) + elif os.path.exists(src_model_dir) and not os.path.exists(dest_model_dir): + log.debug("creating shortcut %s -> %s" % (src_model_dir, dest_model_dir)) + run_subprocess(["mklink", "/D", "/J", dest_model_dir, src_model_dir], shell=True) def use_dev_mode(args): if args.use_acl: - return 'OFF' + return "OFF" if args.use_armnn: - return 'OFF' + return "OFF" if args.ios and is_macOS(): - return 'OFF' - SYSTEM_COLLECTIONURI = os.getenv('SYSTEM_COLLECTIONURI') - if SYSTEM_COLLECTIONURI and not SYSTEM_COLLECTIONURI == 'https://dev.azure.com/onnxruntime/': - return 'OFF' - return 'ON' + return "OFF" + SYSTEM_COLLECTIONURI = os.getenv("SYSTEM_COLLECTIONURI") + if SYSTEM_COLLECTIONURI and not SYSTEM_COLLECTIONURI == "https://dev.azure.com/onnxruntime/": + return "OFF" + return "ON" def add_default_definition(definition_list, key, default_value): @@ -747,17 +758,35 @@ def add_default_definition(definition_list, key, default_value): def normalize_arg_list(nested_list): - return ([i for j in nested_list for i in j] - if nested_list else []) - - -def generate_build_tree(cmake_path, source_dir, build_dir, cuda_home, cudnn_home, rocm_home, - mpi_home, nccl_home, tensorrt_home, migraphx_home, acl_home, acl_libs, armnn_home, armnn_libs, - path_to_protoc_exe, configs, cmake_extra_defines, args, cmake_extra_args): + return [i for j in nested_list for i in j] if nested_list else [] + + +def generate_build_tree( + cmake_path, + source_dir, + build_dir, + cuda_home, + cudnn_home, + rocm_home, + mpi_home, + nccl_home, + tensorrt_home, + migraphx_home, + acl_home, + acl_libs, + armnn_home, + armnn_libs, + path_to_protoc_exe, + configs, + cmake_extra_defines, + args, + cmake_extra_args, +): log.info("Generating CMake build tree") cmake_dir = os.path.join(source_dir, "cmake") cmake_args = [ - cmake_path, cmake_dir, + cmake_path, + cmake_dir, "-Donnxruntime_RUN_ONNX_TESTS=" + ("ON" if args.enable_onnx_tests else "OFF"), "-Donnxruntime_BUILD_WINML_TESTS=" + ("OFF" if args.skip_winml_tests else "ON"), "-Donnxruntime_GENERATE_TEST_REPORTS=ON", @@ -781,12 +810,12 @@ def generate_build_tree(cmake_path, source_dir, build_dir, cuda_home, cudnn_home "-Donnxruntime_DNNL_OPENCL_ROOT=" + (args.dnnl_opencl_root if args.use_dnnl else ""), "-Donnxruntime_USE_NNAPI_BUILTIN=" + ("ON" if args.use_nnapi else "OFF"), "-Donnxruntime_USE_RKNPU=" + ("ON" if args.use_rknpu else "OFF"), - "-Donnxruntime_USE_OPENMP=" + ( - "ON" if args.use_openmp and not ( - args.use_nnapi or - args.android or (args.ios and is_macOS()) - or args.use_rknpu) - else "OFF"), + "-Donnxruntime_USE_OPENMP=" + + ( + "ON" + if args.use_openmp and not (args.use_nnapi or args.android or (args.ios and is_macOS()) or args.use_rknpu) + else "OFF" + ), "-Donnxruntime_USE_NUPHAR_TVM=" + ("ON" if args.use_nuphar else "OFF"), "-Donnxruntime_USE_LLVM=" + ("ON" if args.use_nuphar or args.use_tvm else "OFF"), "-Donnxruntime_ENABLE_MICROSOFT_INTERNAL=" + ("ON" if args.enable_msinternal else "OFF"), @@ -805,16 +834,19 @@ def generate_build_tree(cmake_path, source_dir, build_dir, cuda_home, cudnn_home "-Donnxruntime_CROSS_COMPILING=" + ("ON" if args.arm64 or args.arm64ec or args.arm else "OFF"), "-Donnxruntime_DISABLE_CONTRIB_OPS=" + ("ON" if args.disable_contrib_ops else "OFF"), "-Donnxruntime_DISABLE_ML_OPS=" + ("ON" if args.disable_ml_ops else "OFF"), - "-Donnxruntime_DISABLE_RTTI=" + ("ON" if args.disable_rtti or (args.minimal_build is not None - and not args.enable_pybind) else "OFF"), + "-Donnxruntime_DISABLE_RTTI=" + + ("ON" if args.disable_rtti or (args.minimal_build is not None and not args.enable_pybind) else "OFF"), "-Donnxruntime_DISABLE_EXCEPTIONS=" + ("ON" if args.disable_exceptions else "OFF"), # Need to use 'is not None' with minimal_build check as it could be an empty list. "-Donnxruntime_MINIMAL_BUILD=" + ("ON" if args.minimal_build is not None else "OFF"), - "-Donnxruntime_EXTENDED_MINIMAL_BUILD=" + ("ON" if args.minimal_build and 'extended' in args.minimal_build - else "OFF"), - "-Donnxruntime_MINIMAL_BUILD_CUSTOM_OPS=" + ("ON" if (args.minimal_build is not None and ('custom_ops' in - args.minimal_build or args.use_extensions)) - else "OFF"), + "-Donnxruntime_EXTENDED_MINIMAL_BUILD=" + + ("ON" if args.minimal_build and "extended" in args.minimal_build else "OFF"), + "-Donnxruntime_MINIMAL_BUILD_CUSTOM_OPS=" + + ( + "ON" + if (args.minimal_build is not None and ("custom_ops" in args.minimal_build or args.use_extensions)) + else "OFF" + ), "-Donnxruntime_REDUCED_OPS_BUILD=" + ("ON" if is_reduced_ops_build(args) else "OFF"), "-Donnxruntime_ENABLE_LANGUAGE_INTEROP_OPS=" + ("ON" if args.enable_language_interop_ops else "OFF"), "-Donnxruntime_USE_DML=" + ("ON" if args.use_dml else "OFF"), @@ -849,16 +881,16 @@ def generate_build_tree(cmake_path, source_dir, build_dir, cuda_home, cudnn_home "-Donnxruntime_BUILD_WEBASSEMBLY=" + ("ON" if args.build_wasm else "OFF"), "-Donnxruntime_BUILD_WEBASSEMBLY_STATIC_LIB=" + ("ON" if args.build_wasm_static_lib else "OFF"), "-Donnxruntime_ENABLE_WEBASSEMBLY_SIMD=" + ("ON" if args.enable_wasm_simd else "OFF"), - "-Donnxruntime_ENABLE_WEBASSEMBLY_EXCEPTION_CATCHING=" + ("OFF" if args.disable_wasm_exception_catching - else "ON"), - "-Donnxruntime_ENABLE_WEBASSEMBLY_EXCEPTION_THROWING=" + ("ON" if args.enable_wasm_exception_throwing_override - else "OFF"), + "-Donnxruntime_ENABLE_WEBASSEMBLY_EXCEPTION_CATCHING=" + + ("OFF" if args.disable_wasm_exception_catching else "ON"), + "-Donnxruntime_ENABLE_WEBASSEMBLY_EXCEPTION_THROWING=" + + ("ON" if args.enable_wasm_exception_throwing_override else "OFF"), "-Donnxruntime_ENABLE_WEBASSEMBLY_THREADS=" + ("ON" if args.enable_wasm_threads else "OFF"), "-Donnxruntime_ENABLE_WEBASSEMBLY_DEBUG_INFO=" + ("ON" if args.enable_wasm_debug_info else "OFF"), "-Donnxruntime_ENABLE_WEBASSEMBLY_PROFILING=" + ("ON" if args.enable_wasm_profiling else "OFF"), "-Donnxruntime_ENABLE_EAGER_MODE=" + ("ON" if args.build_eager_mode else "OFF"), - "-Donnxruntime_ENABLE_EXTERNAL_CUSTOM_OP_SCHEMAS=" + ("ON" if args.enable_external_custom_op_schemas - else "OFF"), + "-Donnxruntime_ENABLE_EXTERNAL_CUSTOM_OP_SCHEMAS=" + + ("ON" if args.enable_external_custom_op_schemas else "OFF"), "-Donnxruntime_NVCC_THREADS=" + str(args.parallel), "-Donnxruntime_ENABLE_CUDA_PROFILING=" + ("ON" if args.enable_cuda_profiling else "OFF"), ] @@ -876,8 +908,9 @@ def generate_build_tree(cmake_path, source_dir, build_dir, cuda_home, cudnn_home if is_windows(): if args.enable_msvc_static_runtime: - add_default_definition(cmake_extra_defines, "CMAKE_MSVC_RUNTIME_LIBRARY", - "MultiThreaded$<$:Debug>") + add_default_definition( + cmake_extra_defines, "CMAKE_MSVC_RUNTIME_LIBRARY", "MultiThreaded$<$:Debug>" + ) add_default_definition(cmake_extra_defines, "ONNX_USE_MSVC_STATIC_RUNTIME", "ON") add_default_definition(cmake_extra_defines, "protobuf_MSVC_STATIC_RUNTIME", "ON") add_default_definition(cmake_extra_defines, "gtest_force_shared_crt", "OFF") @@ -903,57 +936,47 @@ def generate_build_tree(cmake_path, source_dir, build_dir, cuda_home, cudnn_home if args.use_mpi: cmake_args += ["-Donnxruntime_MPI_HOME=" + mpi_home] else: - log.warning("mpi_home is supplied but use_mpi is set to false." - " Build will continue without linking MPI libraries.") + log.warning( + "mpi_home is supplied but use_mpi is set to false." + " Build will continue without linking MPI libraries." + ) if nccl_home and os.path.exists(nccl_home): cmake_args += ["-Donnxruntime_NCCL_HOME=" + nccl_home] if args.winml_root_namespace_override: - cmake_args += ["-Donnxruntime_WINML_NAMESPACE_OVERRIDE=" + - args.winml_root_namespace_override] + cmake_args += ["-Donnxruntime_WINML_NAMESPACE_OVERRIDE=" + args.winml_root_namespace_override] if args.use_openvino: - cmake_args += ["-Donnxruntime_USE_OPENVINO=ON", - "-Donnxruntime_USE_OPENVINO_MYRIAD=" + ( - "ON" if args.use_openvino == "MYRIAD_FP16" else "OFF"), - "-Donnxruntime_USE_OPENVINO_GPU_FP32=" + ( - "ON" if args.use_openvino == "GPU_FP32" else "OFF"), - "-Donnxruntime_USE_OPENVINO_GPU_FP16=" + ( - "ON" if args.use_openvino == "GPU_FP16" else "OFF"), - "-Donnxruntime_USE_OPENVINO_CPU_FP32=" + ( - "ON" if args.use_openvino == "CPU_FP32" else "OFF"), - "-Donnxruntime_USE_OPENVINO_VAD_M=" + ( - "ON" if args.use_openvino == "VAD-M_FP16" else "OFF"), - "-Donnxruntime_USE_OPENVINO_VAD_F=" + ( - "ON" if args.use_openvino == "VAD-F_FP32" else "OFF"), - "-Donnxruntime_USE_OPENVINO_MYRIAD_NP=" + ( - "ON" if args.use_openvino == "MYRIAD_FP16_NO_PARTITION" else "OFF"), - "-Donnxruntime_USE_OPENVINO_GPU_FP32_NP=" + ( - "ON" if args.use_openvino == "GPU_FP32_NO_PARTITION" else "OFF"), - "-Donnxruntime_USE_OPENVINO_GPU_FP16_NP=" + ( - "ON" if args.use_openvino == "GPU_FP16_NO_PARTITION" else "OFF"), - "-Donnxruntime_USE_OPENVINO_CPU_FP32_NP=" + ( - "ON" if args.use_openvino == "CPU_FP32_NO_PARTITION" else "OFF"), - "-Donnxruntime_USE_OPENVINO_VAD_M_NP=" + ( - "ON" if args.use_openvino == "VAD-M_FP16_NO_PARTITION" else "OFF"), - "-Donnxruntime_USE_OPENVINO_VAD_F_NP=" + ( - "ON" if args.use_openvino == "VAD-F_FP32_NO_PARTITION" else "OFF"), - "-Donnxruntime_USE_OPENVINO_HETERO=" + ( - "ON" if args.use_openvino.startswith("HETERO") else "OFF"), - "-Donnxruntime_USE_OPENVINO_DEVICE=" + (args.use_openvino), - "-Donnxruntime_USE_OPENVINO_MULTI=" + ( - "ON" if args.use_openvino.startswith("MULTI") else "OFF"), - "-Donnxruntime_USE_OPENVINO_AUTO=" + ( - "ON" if args.use_openvino.startswith("AUTO") else "OFF")] + cmake_args += [ + "-Donnxruntime_USE_OPENVINO=ON", + "-Donnxruntime_USE_OPENVINO_MYRIAD=" + ("ON" if args.use_openvino == "MYRIAD_FP16" else "OFF"), + "-Donnxruntime_USE_OPENVINO_GPU_FP32=" + ("ON" if args.use_openvino == "GPU_FP32" else "OFF"), + "-Donnxruntime_USE_OPENVINO_GPU_FP16=" + ("ON" if args.use_openvino == "GPU_FP16" else "OFF"), + "-Donnxruntime_USE_OPENVINO_CPU_FP32=" + ("ON" if args.use_openvino == "CPU_FP32" else "OFF"), + "-Donnxruntime_USE_OPENVINO_VAD_M=" + ("ON" if args.use_openvino == "VAD-M_FP16" else "OFF"), + "-Donnxruntime_USE_OPENVINO_VAD_F=" + ("ON" if args.use_openvino == "VAD-F_FP32" else "OFF"), + "-Donnxruntime_USE_OPENVINO_MYRIAD_NP=" + + ("ON" if args.use_openvino == "MYRIAD_FP16_NO_PARTITION" else "OFF"), + "-Donnxruntime_USE_OPENVINO_GPU_FP32_NP=" + + ("ON" if args.use_openvino == "GPU_FP32_NO_PARTITION" else "OFF"), + "-Donnxruntime_USE_OPENVINO_GPU_FP16_NP=" + + ("ON" if args.use_openvino == "GPU_FP16_NO_PARTITION" else "OFF"), + "-Donnxruntime_USE_OPENVINO_CPU_FP32_NP=" + + ("ON" if args.use_openvino == "CPU_FP32_NO_PARTITION" else "OFF"), + "-Donnxruntime_USE_OPENVINO_VAD_M_NP=" + + ("ON" if args.use_openvino == "VAD-M_FP16_NO_PARTITION" else "OFF"), + "-Donnxruntime_USE_OPENVINO_VAD_F_NP=" + + ("ON" if args.use_openvino == "VAD-F_FP32_NO_PARTITION" else "OFF"), + "-Donnxruntime_USE_OPENVINO_HETERO=" + ("ON" if args.use_openvino.startswith("HETERO") else "OFF"), + "-Donnxruntime_USE_OPENVINO_DEVICE=" + (args.use_openvino), + "-Donnxruntime_USE_OPENVINO_MULTI=" + ("ON" if args.use_openvino.startswith("MULTI") else "OFF"), + "-Donnxruntime_USE_OPENVINO_AUTO=" + ("ON" if args.use_openvino.startswith("AUTO") else "OFF"), + ] # TensorRT and OpenVINO providers currently only support # full_protobuf option. - if (args.use_full_protobuf or args.use_tensorrt or - args.use_openvino or args.use_vitisai or args.gen_doc): - cmake_args += [ - "-Donnxruntime_USE_FULL_PROTOBUF=ON", - "-DProtobuf_USE_STATIC_LIBS=ON" - ] + if args.use_full_protobuf or args.use_tensorrt or args.use_openvino or args.use_vitisai or args.gen_doc: + cmake_args += ["-Donnxruntime_USE_FULL_PROTOBUF=ON", "-DProtobuf_USE_STATIC_LIBS=ON"] if (args.use_nuphar or args.use_tvm) and args.llvm_path is not None: cmake_args += ["-DLLVM_DIR=%s" % args.llvm_path] @@ -963,8 +986,7 @@ def generate_build_tree(cmake_path, source_dir, build_dir, cuda_home, cudnn_home cmake_args += ["-DCUDA_CUDA_LIBRARY=" + nvml_stub_path] if args.use_preinstalled_eigen: - cmake_args += ["-Donnxruntime_USE_PREINSTALLED_EIGEN=ON", - "-Deigen_SOURCE_PATH=" + args.eigen_path] + cmake_args += ["-Donnxruntime_USE_PREINSTALLED_EIGEN=ON", "-Deigen_SOURCE_PATH=" + args.eigen_path] if args.nnapi_min_api: cmake_args += ["-Donnxruntime_NNAPI_MIN_API=" + str(args.nnapi_min_api)] @@ -975,8 +997,8 @@ def generate_build_tree(cmake_path, source_dir, build_dir, cuda_home, cudnn_home if not args.android_sdk_path: raise BuildError("android_sdk_path required to build for Android") cmake_args += [ - "-DCMAKE_TOOLCHAIN_FILE=" + os.path.join( - args.android_ndk_path, 'build', 'cmake', 'android.toolchain.cmake'), + "-DCMAKE_TOOLCHAIN_FILE=" + + os.path.join(args.android_ndk_path, "build", "cmake", "android.toolchain.cmake"), "-DANDROID_PLATFORM=android-" + str(args.android_api), "-DANDROID_ABI=" + str(args.android_abi), "-DANDROID_MIN_SDK=" + str(args.android_api), @@ -994,10 +1016,10 @@ def generate_build_tree(cmake_path, source_dir, build_dir, cuda_home, cudnn_home if args.use_gdk: cmake_args += [ - "-DCMAKE_TOOLCHAIN_FILE=" + os.path.join(source_dir, 'cmake', 'gdk_toolchain.cmake'), + "-DCMAKE_TOOLCHAIN_FILE=" + os.path.join(source_dir, "cmake", "gdk_toolchain.cmake"), "-DGDK_EDITION=" + args.gdk_edition, "-DGDK_PLATFORM=" + args.gdk_platform, - "-Donnxruntime_BUILD_UNIT_TESTS=OFF" # gtest doesn't build for GDK + "-Donnxruntime_BUILD_UNIT_TESTS=OFF", # gtest doesn't build for GDK ] if args.use_dml and not args.dml_path: raise BuildError("You must set dml_path when building with the GDK.") @@ -1005,17 +1027,17 @@ def generate_build_tree(cmake_path, source_dir, build_dir, cuda_home, cudnn_home if is_macOS() and not args.android: cmake_args += ["-DCMAKE_OSX_ARCHITECTURES=" + args.osx_arch] if args.use_xcode: - cmake_ver = LooseVersion( - subprocess.check_output(['cmake', '--version']).decode('utf-8').split()[2]) + cmake_ver = LooseVersion(subprocess.check_output(["cmake", "--version"]).decode("utf-8").split()[2]) xcode_ver = LooseVersion( - subprocess.check_output(['xcrun', 'xcodebuild', '-version']).decode('utf-8').split()[1]) + subprocess.check_output(["xcrun", "xcodebuild", "-version"]).decode("utf-8").split()[1] + ) # Requires Cmake 3.21.1+ for XCode 13+ # The legacy build system is not longer supported on XCode 13+ - if xcode_ver >= LooseVersion('13') and cmake_ver < LooseVersion('3.21.1'): + if xcode_ver >= LooseVersion("13") and cmake_ver < LooseVersion("3.21.1"): raise BuildError("CMake 3.21.1+ required to use XCode 13+") # Use legacy build system for old CMake [3.19, 3.21.1) which uses new build system by default # CMake 3.18- use the legacy build system by default - if cmake_ver >= LooseVersion('3.19.0') and cmake_ver < LooseVersion('3.21.1'): + if cmake_ver >= LooseVersion("3.19.0") and cmake_ver < LooseVersion("3.21.1"): cmake_args += ["-T", "buildsystem=1"] if args.apple_deploy_target: cmake_args += ["-DCMAKE_OSX_DEPLOYMENT_TARGET=" + args.apple_deploy_target] @@ -1035,19 +1057,15 @@ def generate_build_tree(cmake_path, source_dir, build_dir, cuda_home, cudnn_home args.apple_deploy_target, ] arg_names = [ - "--use_xcode " + - "", - "--ios_sysroot " + - "", - "--apple_deploy_target " + - "", + "--use_xcode " + "", + "--ios_sysroot " + "", + "--apple_deploy_target " + "", ] if not all(needed_args): raise BuildError( - "iOS build on MacOS canceled due to missing arguments: " + - ', '.join( - val for val, cond in zip(arg_names, needed_args) - if not cond)) + "iOS build on MacOS canceled due to missing arguments: " + + ", ".join(val for val, cond in zip(arg_names, needed_args) if not cond) + ) cmake_args += [ "-DCMAKE_SYSTEM_NAME=iOS", "-Donnxruntime_BUILD_SHARED_LIB=ON", @@ -1055,18 +1073,16 @@ def generate_build_tree(cmake_path, source_dir, build_dir, cuda_home, cudnn_home "-DCMAKE_OSX_DEPLOYMENT_TARGET=" + args.apple_deploy_target, # we do not need protoc binary for ios cross build "-Dprotobuf_BUILD_PROTOC_BINARIES=OFF", - "-DCMAKE_TOOLCHAIN_FILE=" + ( - args.ios_toolchain_file if args.ios_toolchain_file - else "../cmake/onnxruntime_ios.toolchain.cmake") + "-DCMAKE_TOOLCHAIN_FILE=" + + (args.ios_toolchain_file if args.ios_toolchain_file else "../cmake/onnxruntime_ios.toolchain.cmake"), ] if args.build_wasm: emsdk_dir = os.path.join(cmake_dir, "external", "emsdk") - emscripten_cmake_toolchain_file = os.path.join(emsdk_dir, "upstream", "emscripten", "cmake", "Modules", - "Platform", "Emscripten.cmake") - cmake_args += [ - "-DCMAKE_TOOLCHAIN_FILE=" + emscripten_cmake_toolchain_file - ] + emscripten_cmake_toolchain_file = os.path.join( + emsdk_dir, "upstream", "emscripten", "cmake", "Modules", "Platform", "Emscripten.cmake" + ) + cmake_args += ["-DCMAKE_TOOLCHAIN_FILE=" + emscripten_cmake_toolchain_file] if args.disable_wasm_exception_catching: # WebAssembly unittest requires exception catching to work. If this feature is disabled, we do not build # unit test. @@ -1079,10 +1095,10 @@ def generate_build_tree(cmake_path, source_dir, build_dir, cuda_home, cudnn_home # set -s MALLOC if args.wasm_malloc is not None: - add_default_definition(emscripten_settings, 'MALLOC', args.wasm_malloc) - add_default_definition(emscripten_settings, 'MALLOC', 'dlmalloc') + add_default_definition(emscripten_settings, "MALLOC", args.wasm_malloc) + add_default_definition(emscripten_settings, "MALLOC", "dlmalloc") - if (emscripten_settings): + if emscripten_settings: cmake_args += [f"-Donnxruntime_EMSCRIPTEN_SETTINGS={';'.join(emscripten_settings)}"] # Append onnxruntime-extensions cmake options @@ -1096,32 +1112,32 @@ def generate_build_tree(cmake_path, source_dir, build_dir, cuda_home, cudnn_home # use absolute path here because onnxruntime-extensions is outside onnxruntime onnxruntime_extensions_path = os.path.abspath(args.extensions_overridden_path) - cmake_args += [ - "-Donnxruntime_EXTENSIONS_PATH=" + onnxruntime_extensions_path] - print('[onnxruntime-extensions] onnxruntime_extensions_path: ', onnxruntime_extensions_path) + cmake_args += ["-Donnxruntime_EXTENSIONS_PATH=" + onnxruntime_extensions_path] + print("[onnxruntime-extensions] onnxruntime_extensions_path: ", onnxruntime_extensions_path) if is_reduced_ops_build(args): operators_config_file = os.path.abspath(args.include_ops_by_config) - cmake_tool_dir = os.path.join(onnxruntime_extensions_path, 'tools') + cmake_tool_dir = os.path.join(onnxruntime_extensions_path, "tools") # generate _selectedoplist.cmake by operators config file - run_subprocess([sys.executable, 'gen_selectedops.py', operators_config_file], cwd=cmake_tool_dir) + run_subprocess([sys.executable, "gen_selectedops.py", operators_config_file], cwd=cmake_tool_dir) if path_to_protoc_exe: - cmake_args += [ - "-DONNX_CUSTOM_PROTOC_EXECUTABLE=%s" % path_to_protoc_exe] + cmake_args += ["-DONNX_CUSTOM_PROTOC_EXECUTABLE=%s" % path_to_protoc_exe] if args.fuzz_testing: - if not (args.build_shared_lib and - is_windows() and - args.cmake_generator == 'Visual Studio 16 2019' and - args.use_full_protobuf): - raise BuildError( - "Fuzz test has only be tested with build shared libs option using MSVC on windows") + if not ( + args.build_shared_lib + and is_windows() + and args.cmake_generator == "Visual Studio 16 2019" + and args.use_full_protobuf + ): + raise BuildError("Fuzz test has only be tested with build shared libs option using MSVC on windows") cmake_args += [ "-Donnxruntime_BUILD_UNIT_TESTS=ON", "-Donnxruntime_FUZZ_TEST=ON", - "-Donnxruntime_USE_FULL_PROTOBUF=ON"] + "-Donnxruntime_USE_FULL_PROTOBUF=ON", + ] if args.gen_doc: add_default_definition(cmake_extra_defines, "onnxruntime_PYBIND_EXPORT_OPSCHEMA", "ON") @@ -1130,8 +1146,9 @@ def generate_build_tree(cmake_path, source_dir, build_dir, cuda_home, cudnn_home if args.build_eager_mode: import torch + cmake_args += ["-Donnxruntime_PREBUILT_PYTORCH_PATH=%s" % os.path.dirname(torch.__file__)] - cmake_args += ['-D_GLIBCXX_USE_CXX11_ABI=' + str(int(torch._C._GLIBCXX_USE_CXX11_ABI))] + cmake_args += ["-D_GLIBCXX_USE_CXX11_ABI=" + str(int(torch._C._GLIBCXX_USE_CXX11_ABI))] cmake_args += ["-D{}".format(define) for define in cmake_extra_defines] @@ -1141,18 +1158,17 @@ def generate_build_tree(cmake_path, source_dir, build_dir, cuda_home, cudnn_home # (e.g. 191101-2300.1.master) and source version in environment # variables. If present, use these values to define the # WinML/ORT DLL versions. - build_number = os.getenv('Build_BuildNumber') - source_version = os.getenv('Build_SourceVersion') + build_number = os.getenv("Build_BuildNumber") + source_version = os.getenv("Build_SourceVersion") if build_number and source_version: - build_matches = re.fullmatch( - r"(\d\d)(\d\d)(\d\d)(\d\d)\.(\d+)", build_number) + build_matches = re.fullmatch(r"(\d\d)(\d\d)(\d\d)(\d\d)\.(\d+)", build_number) if build_matches: YY = build_matches.group(2) MM = build_matches.group(3) DD = build_matches.group(4) # Get ORT major and minor number - with open(os.path.join(source_dir, 'VERSION_NUMBER')) as f: + with open(os.path.join(source_dir, "VERSION_NUMBER")) as f: first_line = f.readline() ort_version_matches = re.match(r"(\d+).(\d+)", first_line) if not ort_version_matches: @@ -1171,9 +1187,7 @@ def generate_build_tree(cmake_path, source_dir, build_dir, cuda_home, cudnn_home "-DVERSION_MINOR_PART={}".format(ort_minor), "-DVERSION_BUILD_PART={}".format(YY), "-DVERSION_PRIVATE_PART={}{}".format(MM, DD), - "-DVERSION_STRING={}.{}.{}.{}".format( - ort_major, ort_minor, build_number, - source_version[0:7]) + "-DVERSION_STRING={}.{}.{}.{}".format(ort_major, ort_minor, build_number, source_version[0:7]), ] for config in configs: @@ -1181,30 +1195,40 @@ def generate_build_tree(cmake_path, source_dir, build_dir, cuda_home, cudnn_home os.makedirs(config_build_dir, exist_ok=True) if args.use_nuphar or args.use_tvm: os.environ["PATH"] = ( - os.path.join(config_build_dir, "_deps", "tvm-build") + os.pathsep + - os.path.join(config_build_dir, "_deps", "tvm-src") + os.pathsep + - os.path.dirname(sys.executable) + os.pathsep + os.environ["PATH"]) + os.path.join(config_build_dir, "_deps", "tvm-build") + + os.pathsep + + os.path.join(config_build_dir, "_deps", "tvm-src") + + os.pathsep + + os.path.dirname(sys.executable) + + os.pathsep + + os.environ["PATH"] + ) run_subprocess( - cmake_args + [ - "-Donnxruntime_ENABLE_MEMLEAK_CHECKER=" + - ("ON" if config.lower() == 'debug' and not (args.use_nuphar or args.use_tvm) and not - args.use_openvino and not - args.use_gdk and not - args.enable_msvc_static_runtime and not - args.disable_memleak_checker - else "OFF"), "-DCMAKE_BUILD_TYPE={}".format(config)], - cwd=config_build_dir) + cmake_args + + [ + "-Donnxruntime_ENABLE_MEMLEAK_CHECKER=" + + ( + "ON" + if config.lower() == "debug" + and not (args.use_nuphar or args.use_tvm) + and not args.use_openvino + and not args.use_gdk + and not args.enable_msvc_static_runtime + and not args.disable_memleak_checker + else "OFF" + ), + "-DCMAKE_BUILD_TYPE={}".format(config), + ], + cwd=config_build_dir, + ) def clean_targets(cmake_path, build_dir, configs): for config in configs: log.info("Cleaning targets for %s configuration", config) build_dir2 = get_config_build_dir(build_dir, config) - cmd_args = [cmake_path, - "--build", build_dir2, - "--config", config, - "--target", "clean"] + cmd_args = [cmake_path, "--build", build_dir2, "--config", config, "--target", "clean"] run_subprocess(cmd_args) @@ -1213,21 +1237,19 @@ def build_targets(args, cmake_path, build_dir, configs, num_parallel_jobs, targe for config in configs: log.info("Building targets for %s configuration", config) build_dir2 = get_config_build_dir(build_dir, config) - cmd_args = [cmake_path, - "--build", build_dir2, - "--config", config] + cmd_args = [cmake_path, "--build", build_dir2, "--config", config] if target: - cmd_args.extend(['--target', target]) + cmd_args.extend(["--target", target]) build_tool_args = [] if num_parallel_jobs != 1: - if is_windows() and args.cmake_generator != 'Ninja' and not args.build_wasm: + if is_windows() and args.cmake_generator != "Ninja" and not args.build_wasm: build_tool_args += [ "/maxcpucount:{}".format(num_parallel_jobs), # if nodeReuse is true, msbuild processes will stay around for a bit after the build completes "/nodeReuse:False", ] - elif (is_macOS() and args.use_xcode): + elif is_macOS() and args.use_xcode: # CMake will generate correct build tool args for Xcode cmd_args += ["--parallel", str(num_parallel_jobs)] else: @@ -1239,8 +1261,8 @@ def build_targets(args, cmake_path, build_dir, configs, num_parallel_jobs, targe env = {} if args.android: - env['ANDROID_SDK_ROOT'] = args.android_sdk_path - env['ANDROID_NDK_HOME'] = args.android_ndk_path + env["ANDROID_SDK_ROOT"] = args.android_sdk_path + env["ANDROID_NDK_HOME"] = args.android_ndk_path run_subprocess(cmd_args, env=env) @@ -1255,21 +1277,19 @@ def setup_cuda_vars(args): cudnn_home = "" if args.use_cuda: - cuda_home = args.cuda_home if args.cuda_home else os.getenv( - "CUDA_HOME") - cudnn_home = args.cudnn_home if args.cudnn_home else os.getenv( - "CUDNN_HOME") + cuda_home = args.cuda_home if args.cuda_home else os.getenv("CUDA_HOME") + cudnn_home = args.cudnn_home if args.cudnn_home else os.getenv("CUDNN_HOME") - cuda_home_valid = (cuda_home is not None and os.path.exists(cuda_home)) - cudnn_home_valid = (cudnn_home is not None and os.path.exists( - cudnn_home)) + cuda_home_valid = cuda_home is not None and os.path.exists(cuda_home) + cudnn_home_valid = cudnn_home is not None and os.path.exists(cudnn_home) if not cuda_home_valid or not cudnn_home_valid: raise BuildError( "cuda_home and cudnn_home paths must be specified and valid.", - "cuda_home='{}' valid={}. cudnn_home='{}' valid={}" - .format( - cuda_home, cuda_home_valid, cudnn_home, cudnn_home_valid)) + "cuda_home='{}' valid={}. cudnn_home='{}' valid={}".format( + cuda_home, cuda_home_valid, cudnn_home, cudnn_home_valid + ), + ) return cuda_home, cudnn_home @@ -1277,15 +1297,13 @@ def setup_cuda_vars(args): def setup_tensorrt_vars(args): tensorrt_home = "" if args.use_tensorrt: - tensorrt_home = (args.tensorrt_home if args.tensorrt_home - else os.getenv("TENSORRT_HOME")) - tensorrt_home_valid = (tensorrt_home is not None and - os.path.exists(tensorrt_home)) + tensorrt_home = args.tensorrt_home if args.tensorrt_home else os.getenv("TENSORRT_HOME") + tensorrt_home_valid = tensorrt_home is not None and os.path.exists(tensorrt_home) if not tensorrt_home_valid: raise BuildError( "tensorrt_home paths must be specified and valid.", - "tensorrt_home='{}' valid={}." - .format(tensorrt_home, tensorrt_home_valid)) + "tensorrt_home='{}' valid={}.".format(tensorrt_home, tensorrt_home_valid), + ) # Set maximum workspace size in byte for # TensorRT (1GB = 1073741824 bytes). @@ -1309,17 +1327,18 @@ def setup_migraphx_vars(args): migraphx_home = None - if (args.use_migraphx): + if args.use_migraphx: print("migraphx_home = {}".format(args.migraphx_home)) migraphx_home = args.migraphx_home or os.getenv("MIGRAPHX_HOME") or None - migraphx_home_not_valid = (migraphx_home and not os.path.exists(migraphx_home)) + migraphx_home_not_valid = migraphx_home and not os.path.exists(migraphx_home) - if (migraphx_home_not_valid): - raise BuildError("migraphx_home paths must be specified and valid.", - "migraphx_home='{}' valid={}." - .format(migraphx_home, migraphx_home_not_valid)) - return migraphx_home or '' + if migraphx_home_not_valid: + raise BuildError( + "migraphx_home paths must be specified and valid.", + "migraphx_home='{}' valid={}.".format(migraphx_home, migraphx_home_not_valid), + ) + return migraphx_home or "" def setup_dml_build(args, cmake_path, build_dir, configs): @@ -1330,17 +1349,22 @@ def setup_dml_build(args, cmake_path, build_dir, configs): for expected_file in ["bin/DirectML.dll", "lib/DirectML.lib", "include/DirectML.h"]: file_path = os.path.join(args.dml_path, expected_file) if not os.path.exists(file_path): - raise BuildError("dml_path is invalid.", - "dml_path='{}' expected_file='{}'." - .format(args.dml_path, file_path)) + raise BuildError( + "dml_path is invalid.", "dml_path='{}' expected_file='{}'.".format(args.dml_path, file_path) + ) else: for config in configs: # Run the RESTORE_PACKAGES target to perform the initial # NuGet setup. - cmd_args = [cmake_path, - "--build", get_config_build_dir(build_dir, config), - "--config", config, - "--target", "RESTORE_PACKAGES"] + cmd_args = [ + cmake_path, + "--build", + get_config_build_dir(build_dir, config), + "--config", + config, + "--target", + "RESTORE_PACKAGES", + ] run_subprocess(cmd_args) @@ -1348,34 +1372,35 @@ def setup_rocm_build(args, configs): rocm_home = None - if (args.use_rocm): + if args.use_rocm: print("rocm_home = {}".format(args.rocm_home)) rocm_home = args.rocm_home or None - rocm_home_not_valid = (rocm_home and not os.path.exists(rocm_home)) + rocm_home_not_valid = rocm_home and not os.path.exists(rocm_home) - if (rocm_home_not_valid): - raise BuildError("rocm_home paths must be specified and valid.", - "rocm_home='{}' valid={}." - .format(rocm_home, rocm_home_not_valid)) + if rocm_home_not_valid: + raise BuildError( + "rocm_home paths must be specified and valid.", + "rocm_home='{}' valid={}.".format(rocm_home, rocm_home_not_valid), + ) for config in configs: amd_hipify(get_config_build_dir(args.build_dir, config)) - return rocm_home or '' + return rocm_home or "" def run_android_tests(args, source_dir, build_dir, config, cwd): sdk_tool_paths = android.get_sdk_tool_paths(args.android_sdk_path) - device_dir = '/data/local/tmp' + device_dir = "/data/local/tmp" def adb_push(src, dest, **kwargs): - return run_subprocess([sdk_tool_paths.adb, 'push', src, dest], **kwargs) + return run_subprocess([sdk_tool_paths.adb, "push", src, dest], **kwargs) def adb_shell(*args, **kwargs): - return run_subprocess([sdk_tool_paths.adb, 'shell', *args], **kwargs) + return run_subprocess([sdk_tool_paths.adb, "shell", *args], **kwargs) def adb_install(*args, **kwargs): - return run_subprocess([sdk_tool_paths.adb, 'install', *args], **kwargs) + return run_subprocess([sdk_tool_paths.adb, "install", *args], **kwargs) def run_adb_shell(cmd): # GCOV_PREFIX_STRIP specifies the depth of the directory hierarchy to strip and @@ -1383,89 +1408,129 @@ def run_adb_shell(cmd): # for creating the runtime code coverage files. if args.code_coverage: adb_shell( - 'cd {0} && GCOV_PREFIX={0} GCOV_PREFIX_STRIP={1} {2}'.format( - device_dir, cwd.count(os.sep) + 1, cmd)) + "cd {0} && GCOV_PREFIX={0} GCOV_PREFIX_STRIP={1} {2}".format(device_dir, cwd.count(os.sep) + 1, cmd) + ) else: - adb_shell('cd {} && {}'.format(device_dir, cmd)) + adb_shell("cd {} && {}".format(device_dir, cmd)) - if args.android_abi == 'x86_64': + if args.android_abi == "x86_64": with contextlib.ExitStack() as context_stack: if args.android_run_emulator: avd_name = "ort_android" - system_image = "system-images;android-{};google_apis;{}".format( - args.android_api, args.android_abi) + system_image = "system-images;android-{};google_apis;{}".format(args.android_api, args.android_abi) android.create_virtual_device(sdk_tool_paths, system_image, avd_name) emulator_proc = context_stack.enter_context( android.start_emulator( sdk_tool_paths=sdk_tool_paths, avd_name=avd_name, - extra_args=[ - "-partition-size", "2047", - "-wipe-data"])) + extra_args=["-partition-size", "2047", "-wipe-data"], + ) + ) context_stack.callback(android.stop_emulator, emulator_proc) - adb_push('testdata', device_dir, cwd=cwd) + adb_push("testdata", device_dir, cwd=cwd) adb_push( - os.path.join(source_dir, 'cmake', 'external', 'onnx', 'onnx', 'backend', 'test'), - device_dir, cwd=cwd) - adb_push('onnxruntime_test_all', device_dir, cwd=cwd) - adb_shell('chmod +x {}/onnxruntime_test_all'.format(device_dir)) - adb_push('onnx_test_runner', device_dir, cwd=cwd) - adb_shell('chmod +x {}/onnx_test_runner'.format(device_dir)) - run_adb_shell('{0}/onnxruntime_test_all'.format(device_dir)) + os.path.join(source_dir, "cmake", "external", "onnx", "onnx", "backend", "test"), device_dir, cwd=cwd + ) + adb_push("onnxruntime_test_all", device_dir, cwd=cwd) + adb_shell("chmod +x {}/onnxruntime_test_all".format(device_dir)) + adb_push("onnx_test_runner", device_dir, cwd=cwd) + adb_shell("chmod +x {}/onnx_test_runner".format(device_dir)) + run_adb_shell("{0}/onnxruntime_test_all".format(device_dir)) if args.build_java: - gradle_executable = 'gradle' + gradle_executable = "gradle" # use the gradle wrapper if it exists, the gradlew should be setup under /java - gradlew_path = os.path.join(source_dir, 'java', - 'gradlew.bat' if is_windows() else 'gradlew') + gradlew_path = os.path.join(source_dir, "java", "gradlew.bat" if is_windows() else "gradlew") if os.path.exists(gradlew_path): gradle_executable = gradlew_path android_test_path = os.path.join(cwd, "java", "androidtest", "android") - run_subprocess([gradle_executable, '--no-daemon', - '-DminSdkVer={}'.format(args.android_api), - 'clean', 'connectedDebugAndroidTest'], - cwd=android_test_path) + run_subprocess( + [ + gradle_executable, + "--no-daemon", + "-DminSdkVer={}".format(args.android_api), + "clean", + "connectedDebugAndroidTest", + ], + cwd=android_test_path, + ) if args.use_nnapi: - adb_shell('cd {0} && {0}/onnx_test_runner -e nnapi {0}/test'.format(device_dir)) + adb_shell("cd {0} && {0}/onnx_test_runner -e nnapi {0}/test".format(device_dir)) else: - adb_shell('cd {0} && {0}/onnx_test_runner {0}/test'.format(device_dir)) + adb_shell("cd {0} && {0}/onnx_test_runner {0}/test".format(device_dir)) # run shared_lib_test if necessary if args.build_shared_lib: - adb_push('libonnxruntime.so', device_dir, cwd=cwd) - adb_push('onnxruntime_shared_lib_test', device_dir, cwd=cwd) - adb_shell('chmod +x {}/onnxruntime_shared_lib_test'.format(device_dir)) - run_adb_shell( - 'LD_LIBRARY_PATH=$LD_LIBRARY_PATH:{0} {0}/onnxruntime_shared_lib_test'.format( - device_dir)) + adb_push("libonnxruntime.so", device_dir, cwd=cwd) + adb_push("onnxruntime_shared_lib_test", device_dir, cwd=cwd) + adb_shell("chmod +x {}/onnxruntime_shared_lib_test".format(device_dir)) + run_adb_shell("LD_LIBRARY_PATH=$LD_LIBRARY_PATH:{0} {0}/onnxruntime_shared_lib_test".format(device_dir)) def run_ios_tests(args, source_dir, config, cwd): - run_subprocess(["xcodebuild", "test-without-building", "-project", "./onnxruntime.xcodeproj", - "-configuration", config, - "-scheme", "onnxruntime_test_all_xc", "-destination", - "platform=iOS Simulator,OS=latest,name=iPhone SE (2nd generation)"], cwd=cwd) + run_subprocess( + [ + "xcodebuild", + "test-without-building", + "-project", + "./onnxruntime.xcodeproj", + "-configuration", + config, + "-scheme", + "onnxruntime_test_all_xc", + "-destination", + "platform=iOS Simulator,OS=latest,name=iPhone SE (2nd generation)", + ], + cwd=cwd, + ) - run_subprocess(["xcodebuild", "test-without-building", "-project", "./onnxruntime.xcodeproj", - "-configuration", config, - "-scheme", "onnxruntime_shared_lib_test_xc", "-destination", - "platform=iOS Simulator,OS=latest,name=iPhone SE (2nd generation)"], cwd=cwd) + run_subprocess( + [ + "xcodebuild", + "test-without-building", + "-project", + "./onnxruntime.xcodeproj", + "-configuration", + config, + "-scheme", + "onnxruntime_shared_lib_test_xc", + "-destination", + "platform=iOS Simulator,OS=latest,name=iPhone SE (2nd generation)", + ], + cwd=cwd, + ) if args.build_apple_framework: - package_test_py = os.path.join(source_dir, 'tools', 'ci_build', 'github', 'apple', 'test_ios_packages.py') - framework_info_file = os.path.join(cwd, 'framework_info.json') - dynamic_framework_dir = os.path.join(cwd, config + '-' + args.ios_sysroot) - static_framework_dir = os.path.join(cwd, config + '-' + args.ios_sysroot, 'static_framework') + package_test_py = os.path.join(source_dir, "tools", "ci_build", "github", "apple", "test_ios_packages.py") + framework_info_file = os.path.join(cwd, "framework_info.json") + dynamic_framework_dir = os.path.join(cwd, config + "-" + args.ios_sysroot) + static_framework_dir = os.path.join(cwd, config + "-" + args.ios_sysroot, "static_framework") # test dynamic framework - run_subprocess([sys.executable, package_test_py, - '--c_framework_dir', dynamic_framework_dir, - '--framework_info_file', framework_info_file], cwd=cwd) + run_subprocess( + [ + sys.executable, + package_test_py, + "--c_framework_dir", + dynamic_framework_dir, + "--framework_info_file", + framework_info_file, + ], + cwd=cwd, + ) # test static framework - run_subprocess([sys.executable, package_test_py, - '--c_framework_dir', static_framework_dir, - '--framework_info_file', framework_info_file], cwd=cwd) + run_subprocess( + [ + sys.executable, + package_test_py, + "--c_framework_dir", + static_framework_dir, + "--framework_info_file", + framework_info_file, + ], + cwd=cwd, + ) def run_orttraining_test_orttrainer_frontend_separately(cwd): @@ -1475,9 +1540,9 @@ def __init__(self): def pytest_collection_modifyitems(self, items): for item in items: - print('item.name: ', item.name) + print("item.name: ", item.name) test_name = item.name - start = test_name.find('[') + start = test_name.find("[") if start > 0: test_name = test_name[:start] self.collected.add(test_name) @@ -1486,12 +1551,12 @@ def pytest_collection_modifyitems(self, items): plugin = TestNameCollecterPlugin() test_script_filename = os.path.join(cwd, "orttraining_test_orttrainer_frontend.py") - pytest.main(['--collect-only', test_script_filename], plugins=[plugin]) + pytest.main(["--collect-only", test_script_filename], plugins=[plugin]) for test_name in plugin.collected: - run_subprocess([ - sys.executable, '-m', 'pytest', - 'orttraining_test_orttrainer_frontend.py', '-v', '-k', test_name], cwd=cwd) + run_subprocess( + [sys.executable, "-m", "pytest", "orttraining_test_orttrainer_frontend.py", "-v", "-k", test_name], cwd=cwd + ) def run_training_python_frontend_tests(cwd): @@ -1502,31 +1567,46 @@ def run_training_python_frontend_tests(cwd): # HTTP Error 404: Not Found # run_subprocess([sys.executable, 'onnxruntime_test_ort_trainer.py'], cwd=cwd) - run_subprocess([sys.executable, 'onnxruntime_test_training_unit_tests.py'], cwd=cwd) - run_subprocess([ - sys.executable, 'orttraining_test_transformers.py', - 'BertModelTest.test_for_pretraining_full_precision_list_input'], cwd=cwd) - run_subprocess([ - sys.executable, 'orttraining_test_transformers.py', - 'BertModelTest.test_for_pretraining_full_precision_dict_input'], cwd=cwd) - run_subprocess([ - sys.executable, 'orttraining_test_transformers.py', - 'BertModelTest.test_for_pretraining_full_precision_list_and_dict_input'], cwd=cwd) + run_subprocess([sys.executable, "onnxruntime_test_training_unit_tests.py"], cwd=cwd) + run_subprocess( + [ + sys.executable, + "orttraining_test_transformers.py", + "BertModelTest.test_for_pretraining_full_precision_list_input", + ], + cwd=cwd, + ) + run_subprocess( + [ + sys.executable, + "orttraining_test_transformers.py", + "BertModelTest.test_for_pretraining_full_precision_dict_input", + ], + cwd=cwd, + ) + run_subprocess( + [ + sys.executable, + "orttraining_test_transformers.py", + "BertModelTest.test_for_pretraining_full_precision_list_and_dict_input", + ], + cwd=cwd, + ) # TODO: use run_orttraining_test_orttrainer_frontend_separately to work around a sporadic segfault. # shall revert to run_subprocess call once the segfault issue is resolved. run_orttraining_test_orttrainer_frontend_separately(cwd) # run_subprocess([sys.executable, '-m', 'pytest', '-sv', 'orttraining_test_orttrainer_frontend.py'], cwd=cwd) - run_subprocess([sys.executable, '-m', 'pytest', '-sv', 'orttraining_test_orttrainer_bert_toy_onnx.py'], cwd=cwd) + run_subprocess([sys.executable, "-m", "pytest", "-sv", "orttraining_test_orttrainer_bert_toy_onnx.py"], cwd=cwd) - run_subprocess([sys.executable, '-m', 'pytest', '-sv', 'orttraining_test_checkpoint_storage.py'], cwd=cwd) + run_subprocess([sys.executable, "-m", "pytest", "-sv", "orttraining_test_checkpoint_storage.py"], cwd=cwd) - run_subprocess([ - sys.executable, '-m', 'pytest', '-sv', 'orttraining_test_orttrainer_checkpoint_functions.py'], cwd=cwd) + run_subprocess( + [sys.executable, "-m", "pytest", "-sv", "orttraining_test_orttrainer_checkpoint_functions.py"], cwd=cwd + ) # Not technically training related, but it needs torch to be installed. - run_subprocess([ - sys.executable, '-m', 'pytest', '-sv', 'test_pytorch_export_contrib_ops.py'], cwd=cwd) + run_subprocess([sys.executable, "-m", "pytest", "-sv", "test_pytorch_export_contrib_ops.py"], cwd=cwd) def run_training_python_frontend_e2e_tests(cwd): @@ -1534,23 +1614,43 @@ def run_training_python_frontend_e2e_tests(cwd): log.info("Running python frontend e2e tests.") run_subprocess( - [sys.executable, 'orttraining_run_frontend_batch_size_test.py', '-v'], - cwd=cwd, env={'CUDA_VISIBLE_DEVICES': '0'}) + [sys.executable, "orttraining_run_frontend_batch_size_test.py", "-v"], + cwd=cwd, + env={"CUDA_VISIBLE_DEVICES": "0"}, + ) import torch + ngpus = torch.cuda.device_count() if ngpus > 1: - bert_pretrain_script = 'orttraining_run_bert_pretrain.py' + bert_pretrain_script = "orttraining_run_bert_pretrain.py" # TODO: this test will be replaced with convergence test ported from backend - log.debug('RUN: mpirun -n {} ''-x' 'NCCL_DEBUG=INFO'' {} {} {}'.format( - ngpus, sys.executable, bert_pretrain_script, 'ORTBertPretrainTest.test_pretrain_convergence')) - run_subprocess([ - 'mpirun', '-n', str(ngpus), '-x', 'NCCL_DEBUG=INFO', sys.executable, - bert_pretrain_script, 'ORTBertPretrainTest.test_pretrain_convergence'], cwd=cwd) - - log.debug('RUN: mpirun -n {} {} orttraining_run_glue.py'.format(ngpus, sys.executable)) - run_subprocess([ - 'mpirun', '-n', str(ngpus), '-x', 'NCCL_DEBUG=INFO', sys.executable, 'orttraining_run_glue.py'], cwd=cwd) + log.debug( + "RUN: mpirun -n {} " + "-x" + "NCCL_DEBUG=INFO" + " {} {} {}".format( + ngpus, sys.executable, bert_pretrain_script, "ORTBertPretrainTest.test_pretrain_convergence" + ) + ) + run_subprocess( + [ + "mpirun", + "-n", + str(ngpus), + "-x", + "NCCL_DEBUG=INFO", + sys.executable, + bert_pretrain_script, + "ORTBertPretrainTest.test_pretrain_convergence", + ], + cwd=cwd, + ) + + log.debug("RUN: mpirun -n {} {} orttraining_run_glue.py".format(ngpus, sys.executable)) + run_subprocess( + ["mpirun", "-n", str(ngpus), "-x", "NCCL_DEBUG=INFO", sys.executable, "orttraining_run_glue.py"], cwd=cwd + ) # with orttraining_run_glue.py. # 1. we like to force to use single GPU (with CUDA_VISIBLE_DEVICES) @@ -1558,30 +1658,41 @@ def run_training_python_frontend_e2e_tests(cwd): # 2. need to run test separately (not to mix between fp16 # and full precision runs. this need to be investigated). run_subprocess( - [sys.executable, 'orttraining_run_glue.py', 'ORTGlueTest.test_bert_with_mrpc', '-v'], - cwd=cwd, env={'CUDA_VISIBLE_DEVICES': '0'}) + [sys.executable, "orttraining_run_glue.py", "ORTGlueTest.test_bert_with_mrpc", "-v"], + cwd=cwd, + env={"CUDA_VISIBLE_DEVICES": "0"}, + ) run_subprocess( - [sys.executable, 'orttraining_run_glue.py', 'ORTGlueTest.test_bert_fp16_with_mrpc', '-v'], - cwd=cwd, env={'CUDA_VISIBLE_DEVICES': '0'}) + [sys.executable, "orttraining_run_glue.py", "ORTGlueTest.test_bert_fp16_with_mrpc", "-v"], + cwd=cwd, + env={"CUDA_VISIBLE_DEVICES": "0"}, + ) run_subprocess( - [sys.executable, 'orttraining_run_glue.py', 'ORTGlueTest.test_roberta_with_mrpc', '-v'], - cwd=cwd, env={'CUDA_VISIBLE_DEVICES': '0'}) + [sys.executable, "orttraining_run_glue.py", "ORTGlueTest.test_roberta_with_mrpc", "-v"], + cwd=cwd, + env={"CUDA_VISIBLE_DEVICES": "0"}, + ) run_subprocess( - [sys.executable, 'orttraining_run_glue.py', 'ORTGlueTest.test_roberta_fp16_with_mrpc', '-v'], - cwd=cwd, env={'CUDA_VISIBLE_DEVICES': '0'}) + [sys.executable, "orttraining_run_glue.py", "ORTGlueTest.test_roberta_fp16_with_mrpc", "-v"], + cwd=cwd, + env={"CUDA_VISIBLE_DEVICES": "0"}, + ) run_subprocess( - [sys.executable, 'orttraining_run_multiple_choice.py', 'ORTMultipleChoiceTest.test_bert_fp16_with_swag', '-v'], - cwd=cwd, env={'CUDA_VISIBLE_DEVICES': '0'}) + [sys.executable, "orttraining_run_multiple_choice.py", "ORTMultipleChoiceTest.test_bert_fp16_with_swag", "-v"], + cwd=cwd, + env={"CUDA_VISIBLE_DEVICES": "0"}, + ) - run_subprocess([sys.executable, 'onnxruntime_test_ort_trainer_with_mixed_precision.py'], cwd=cwd) + run_subprocess([sys.executable, "onnxruntime_test_ort_trainer_with_mixed_precision.py"], cwd=cwd) - run_subprocess([ - sys.executable, 'orttraining_test_transformers.py', - 'BertModelTest.test_for_pretraining_mixed_precision'], cwd=cwd) + run_subprocess( + [sys.executable, "orttraining_test_transformers.py", "BertModelTest.test_for_pretraining_mixed_precision"], + cwd=cwd, + ) def run_onnxruntime_tests(args, source_dir, ctest_path, build_dir, configs): @@ -1600,13 +1711,14 @@ def run_onnxruntime_tests(args, source_dir, ctest_path, build_dir, configs): if args.use_nuphar: dll_path_list.append(os.path.join(build_dir, "_deps", "tvm-build")) if args.use_tensorrt: - dll_path_list.append(os.path.join(args.tensorrt_home, 'lib')) + dll_path_list.append(os.path.join(args.tensorrt_home, "lib")) # Adding the torch lib path for loading DLLs for onnxruntime in eager mode # This works for Python 3.7 and below, and doesn't work for Python 3.8+ # User will need to import torch before onnxruntime and it will work for all versions if args.build_eager_mode and is_windows(): import torch - dll_path_list.append(os.path.join(os.path.dirname(torch.__file__), 'lib')) + + dll_path_list.append(os.path.join(os.path.dirname(torch.__file__), "lib")) dll_path = None if len(dll_path_list) > 0: @@ -1615,32 +1727,44 @@ def run_onnxruntime_tests(args, source_dir, ctest_path, build_dir, configs): if not ctest_path: if is_windows(): # Get the "Google Test Adapter" for vstest. - if not os.path.exists(os.path.join(cwd, - 'googletestadapter.0.17.1')): + if not os.path.exists(os.path.join(cwd, "googletestadapter.0.17.1")): run_subprocess( - ['nuget.exe', 'restore', - os.path.join(source_dir, 'packages.config'), - '-ConfigFile', os.path.join(source_dir, 'NuGet.config'), - '-PackagesDirectory', cwd]) + [ + "nuget.exe", + "restore", + os.path.join(source_dir, "packages.config"), + "-ConfigFile", + os.path.join(source_dir, "NuGet.config"), + "-PackagesDirectory", + cwd, + ] + ) cwd2 = os.path.join(cwd, config) - executables = ['onnxruntime_test_all.exe', 'onnxruntime_mlas_test.exe'] + executables = ["onnxruntime_test_all.exe", "onnxruntime_mlas_test.exe"] if args.build_shared_lib: - executables.append('onnxruntime_shared_lib_test.exe') - executables.append('onnxruntime_global_thread_pools_test.exe') - executables.append('onnxruntime_api_tests_without_env.exe') + executables.append("onnxruntime_shared_lib_test.exe") + executables.append("onnxruntime_global_thread_pools_test.exe") + executables.append("onnxruntime_api_tests_without_env.exe") run_subprocess( - ['vstest.console.exe', '--parallel', - '--TestAdapterPath:..\\googletestadapter.0.17.1\\build\\_common', # noqa - '/Logger:trx', '/Enablecodecoverage', '/Platform:x64', - "/Settings:%s" % os.path.join( - source_dir, 'cmake\\codeconv.runsettings')] + executables, - cwd=cwd2, dll_path=dll_path) + [ + "vstest.console.exe", + "--parallel", + "--TestAdapterPath:..\\googletestadapter.0.17.1\\build\\_common", # noqa + "/Logger:trx", + "/Enablecodecoverage", + "/Platform:x64", + "/Settings:%s" % os.path.join(source_dir, "cmake\\codeconv.runsettings"), + ] + + executables, + cwd=cwd2, + dll_path=dll_path, + ) else: - executables = ['onnxruntime_test_all', 'onnxruntime_mlas_test'] + executables = ["onnxruntime_test_all", "onnxruntime_mlas_test"] if args.build_shared_lib: - executables.append('onnxruntime_shared_lib_test') - executables.append('onnxruntime_global_thread_pools_test') - executables.append('onnxruntime_api_tests_without_env') + executables.append("onnxruntime_shared_lib_test") + executables.append("onnxruntime_global_thread_pools_test") + executables.append("onnxruntime_api_tests_without_env") for exe in executables: run_subprocess([os.path.join(cwd, exe)], cwd=cwd, dll_path=dll_path) @@ -1666,29 +1790,30 @@ def run_onnxruntime_tests(args, source_dir, ctest_path, build_dir, configs): if is_windows(): cwd = os.path.join(cwd, config) - run_subprocess([sys.executable, 'onnxruntime_test_python.py'], - cwd=cwd, dll_path=dll_path, python_path=python_path) + run_subprocess( + [sys.executable, "onnxruntime_test_python.py"], cwd=cwd, dll_path=dll_path, python_path=python_path + ) if not args.disable_contrib_ops: - run_subprocess([sys.executable, 'onnxruntime_test_python_sparse_matmul.py'], - cwd=cwd, dll_path=dll_path) + run_subprocess([sys.executable, "onnxruntime_test_python_sparse_matmul.py"], cwd=cwd, dll_path=dll_path) if args.enable_symbolic_shape_infer_tests: - run_subprocess([sys.executable, 'onnxruntime_test_python_symbolic_shape_infer.py'], - cwd=cwd, dll_path=dll_path) + run_subprocess( + [sys.executable, "onnxruntime_test_python_symbolic_shape_infer.py"], cwd=cwd, dll_path=dll_path + ) # For CUDA enabled builds test IOBinding feature if args.use_cuda: # We need to have Torch installed to test the IOBinding feature # which currently uses Torch's allocator to allocate GPU memory for testing log.info("Testing IOBinding feature") - run_subprocess([sys.executable, 'onnxruntime_test_python_iobinding.py'], cwd=cwd, dll_path=dll_path) + run_subprocess([sys.executable, "onnxruntime_test_python_iobinding.py"], cwd=cwd, dll_path=dll_path) log.info("Testing CUDA Graph feature") - run_subprocess([sys.executable, 'onnxruntime_test_python_cudagraph.py'], cwd=cwd, dll_path=dll_path) + run_subprocess([sys.executable, "onnxruntime_test_python_cudagraph.py"], cwd=cwd, dll_path=dll_path) if not args.disable_ml_ops: - run_subprocess([sys.executable, 'onnxruntime_test_python_mlops.py'], cwd=cwd, dll_path=dll_path) + run_subprocess([sys.executable, "onnxruntime_test_python_mlops.py"], cwd=cwd, dll_path=dll_path) # The following test has multiple failures on Windows if args.enable_training and args.use_cuda and not is_windows(): @@ -1697,20 +1822,29 @@ def run_onnxruntime_tests(args, source_dir, ctest_path, build_dir, configs): if args.build_eager_mode: # run eager mode test - args_list = [sys.executable, os.path.join(cwd, 'eager_test')] + args_list = [sys.executable, os.path.join(cwd, "eager_test")] run_subprocess(args_list, cwd=cwd, dll_path=dll_path, python_path=cwd) if args.test_external_transformer_example: - run_subprocess([sys.executable, - os.path.join(source_dir, - 'orttraining', - 'orttraining', - 'test', - 'external_transformer', - 'test', - 'external_transformers_test.py')], cwd=cwd, dll_path=dll_path) + run_subprocess( + [ + sys.executable, + os.path.join( + source_dir, + "orttraining", + "orttraining", + "test", + "external_transformer", + "test", + "external_transformers_test.py", + ), + ], + cwd=cwd, + dll_path=dll_path, + ) try: import onnx # noqa + onnx_test = True except ImportError as error: log.exception(error) @@ -1718,105 +1852,130 @@ def run_onnxruntime_tests(args, source_dir, ctest_path, build_dir, configs): onnx_test = False if onnx_test: - run_subprocess([sys.executable, 'onnxruntime_test_python_backend.py'], cwd=cwd, dll_path=dll_path, - python_path=python_path) + run_subprocess( + [sys.executable, "onnxruntime_test_python_backend.py"], + cwd=cwd, + dll_path=dll_path, + python_path=python_path, + ) if not args.disable_contrib_ops: - run_subprocess([sys.executable, '-m', 'unittest', 'discover', '-s', 'quantization'], - cwd=cwd, dll_path=dll_path) + run_subprocess( + [sys.executable, "-m", "unittest", "discover", "-s", "quantization"], cwd=cwd, dll_path=dll_path + ) if args.enable_transformers_tool_test: - import numpy import google.protobuf + import numpy + numpy_init_version = numpy.__version__ pb_init_version = google.protobuf.__version__ - run_subprocess([sys.executable, '-m', 'pip', 'install', '-r', 'requirements.txt'], - cwd=SCRIPT_DIR) - run_subprocess([sys.executable, '-m', 'pytest', 'transformers'], cwd=cwd) + run_subprocess( + [sys.executable, "-m", "pip", "install", "-r", "requirements.txt"], cwd=SCRIPT_DIR + ) + run_subprocess([sys.executable, "-m", "pytest", "transformers"], cwd=cwd) # Restore initial numpy/protobuf version in case other tests use it - run_subprocess([sys.executable, '-m', 'pip', 'install', 'numpy==' + numpy_init_version]) - run_subprocess([sys.executable, '-m', 'pip', 'install', 'protobuf==' + pb_init_version]) + run_subprocess([sys.executable, "-m", "pip", "install", "numpy==" + numpy_init_version]) + run_subprocess([sys.executable, "-m", "pip", "install", "protobuf==" + pb_init_version]) if not args.disable_ml_ops: - run_subprocess([sys.executable, 'onnxruntime_test_python_backend_mlops.py'], - cwd=cwd, dll_path=dll_path) + run_subprocess( + [sys.executable, "onnxruntime_test_python_backend_mlops.py"], cwd=cwd, dll_path=dll_path + ) - run_subprocess([sys.executable, - os.path.join(source_dir, 'onnxruntime', 'test', 'onnx', 'gen_test_models.py'), - '--output_dir', 'test_models'], cwd=cwd) + run_subprocess( + [ + sys.executable, + os.path.join(source_dir, "onnxruntime", "test", "onnx", "gen_test_models.py"), + "--output_dir", + "test_models", + ], + cwd=cwd, + ) if not args.skip_onnx_tests: - run_subprocess([os.path.join(cwd, 'onnx_test_runner'), 'test_models'], cwd=cwd) - if config != 'Debug': - run_subprocess([sys.executable, 'onnx_backend_test_series.py'], cwd=cwd, dll_path=dll_path) + run_subprocess([os.path.join(cwd, "onnx_test_runner"), "test_models"], cwd=cwd) + if config != "Debug": + run_subprocess([sys.executable, "onnx_backend_test_series.py"], cwd=cwd, dll_path=dll_path) if not args.skip_keras_test: try: - import onnxmltools # noqa import keras # noqa + import onnxmltools # noqa + onnxml_test = True except ImportError: - log.warning( - "onnxmltools and keras are not installed. " - "The keras tests will be skipped.") + log.warning("onnxmltools and keras are not installed. " "The keras tests will be skipped.") onnxml_test = False if onnxml_test: - run_subprocess( - [sys.executable, 'onnxruntime_test_python_keras.py'], - cwd=cwd, dll_path=dll_path) + run_subprocess([sys.executable, "onnxruntime_test_python_keras.py"], cwd=cwd, dll_path=dll_path) def nuphar_run_python_tests(build_dir, configs): for config in configs: - if config == 'Debug': + if config == "Debug": continue cwd = get_config_build_dir(build_dir, config) if is_windows(): cwd = os.path.join(cwd, config) dll_path = os.path.join(build_dir, config, "_deps", "tvm-build", config) - run_subprocess( - [sys.executable, 'onnxruntime_test_python_nuphar.py'], - cwd=cwd, dll_path=dll_path) + run_subprocess([sys.executable, "onnxruntime_test_python_nuphar.py"], cwd=cwd, dll_path=dll_path) def tvm_run_python_tests(build_dir, configs): for config in configs: - if config == 'Debug': + if config == "Debug": continue cwd = get_config_build_dir(build_dir, config) if is_windows(): cwd = os.path.join(cwd, config) dll_path = os.path.join(build_dir, config, "_deps", "tvm-build", config) - run_subprocess( - [sys.executable, 'onnxruntime_test_python_tvm.py'], - cwd=cwd, dll_path=dll_path) + run_subprocess([sys.executable, "onnxruntime_test_python_tvm.py"], cwd=cwd, dll_path=dll_path) def run_nodejs_tests(nodejs_binding_dir): - args = ['npm', 'test', '--', '--timeout=30000'] + args = ["npm", "test", "--", "--timeout=30000"] if is_windows(): - args = ['cmd', '/c'] + args + args = ["cmd", "/c"] + args run_subprocess(args, cwd=nodejs_binding_dir) def build_python_wheel( - source_dir, build_dir, configs, use_cuda, cuda_version, use_rocm, rocm_version, use_dnnl, - use_tensorrt, use_openvino, use_nuphar, use_tvm, use_vitisai, use_acl, use_armnn, use_dml, - wheel_name_suffix, enable_training, nightly_build=False, default_training_package_device=False, - use_ninja=False, build_eager_mode=False): + source_dir, + build_dir, + configs, + use_cuda, + cuda_version, + use_rocm, + rocm_version, + use_dnnl, + use_tensorrt, + use_openvino, + use_nuphar, + use_tvm, + use_vitisai, + use_acl, + use_armnn, + use_dml, + wheel_name_suffix, + enable_training, + nightly_build=False, + default_training_package_device=False, + use_ninja=False, + build_eager_mode=False, +): for config in configs: cwd = get_config_build_dir(build_dir, config) if is_windows() and not use_ninja: cwd = os.path.join(cwd, config) - args = [sys.executable, os.path.join(source_dir, 'setup.py'), - 'bdist_wheel'] + args = [sys.executable, os.path.join(source_dir, "setup.py"), "bdist_wheel"] # Any combination of the following arguments can be applied if nightly_build: - args.append('--nightly_build') + args.append("--nightly_build") if default_training_package_device: - args.append('--default_training_package_device') + args.append("--default_training_package_device") if wheel_name_suffix: - args.append('--wheel_name_suffix={}'.format(wheel_name_suffix)) + args.append("--wheel_name_suffix={}".format(wheel_name_suffix)) if enable_training: args.append("--enable_training") if build_eager_mode: @@ -1825,48 +1984,49 @@ def build_python_wheel( # The following arguments are mutually exclusive if use_cuda: # The following line assumes no other EP is enabled - args.append('--wheel_name_suffix=gpu') + args.append("--wheel_name_suffix=gpu") if cuda_version: - args.append('--cuda_version={}'.format(cuda_version)) + args.append("--cuda_version={}".format(cuda_version)) elif use_rocm: - args.append('--use_rocm') + args.append("--use_rocm") if rocm_version: - args.append('--rocm_version={}'.format(rocm_version)) + args.append("--rocm_version={}".format(rocm_version)) elif use_openvino: - args.append('--use_openvino') + args.append("--use_openvino") elif use_dnnl: - args.append('--use_dnnl') + args.append("--use_dnnl") elif use_nuphar: - args.append('--use_nuphar') + args.append("--use_nuphar") elif use_tvm: - args.append('--use_tvm') + args.append("--use_tvm") elif use_vitisai: - args.append('--use_vitisai') + args.append("--use_vitisai") elif use_acl: - args.append('--use_acl') + args.append("--use_acl") elif use_armnn: - args.append('--use_armnn') + args.append("--use_armnn") elif use_dml: - args.append('--wheel_name_suffix=directml') + args.append("--wheel_name_suffix=directml") run_subprocess(args, cwd=cwd) def derive_linux_build_property(): if is_windows(): - return "/p:IsLinuxBuild=\"false\"" + return '/p:IsLinuxBuild="false"' else: - return "/p:IsLinuxBuild=\"true\"" + return '/p:IsLinuxBuild="true"' -def build_nuget_package(source_dir, build_dir, configs, use_cuda, use_openvino, use_tensorrt, use_dnnl, use_nuphar, - use_tvm, use_winml): +def build_nuget_package( + source_dir, build_dir, configs, use_cuda, use_openvino, use_tensorrt, use_dnnl, use_nuphar, use_tvm, use_winml +): if not (is_windows() or is_linux()): raise BuildError( - 'Currently csharp builds and nuget package creation is only supportted ' - 'on Windows and Linux platforms.') + "Currently csharp builds and nuget package creation is only supportted " "on Windows and Linux platforms." + ) - csharp_build_dir = os.path.join(source_dir, 'csharp') + csharp_build_dir = os.path.join(source_dir, "csharp") is_linux_build = derive_linux_build_property() # in most cases we don't want/need to include the Xamarin mobile targets, as doing so means the Xamarin @@ -1876,33 +2036,33 @@ def build_nuget_package(source_dir, build_dir, configs, use_cuda, use_openvino, # derive package name and execution provider based on the build args target_name = "/t:CreatePackage" - execution_provider = "/p:ExecutionProvider=\"None\"" - package_name = "/p:OrtPackageId=\"Microsoft.ML.OnnxRuntime\"" + execution_provider = '/p:ExecutionProvider="None"' + package_name = '/p:OrtPackageId="Microsoft.ML.OnnxRuntime"' if use_winml: - package_name = "/p:OrtPackageId=\"Microsoft.AI.MachineLearning\"" + package_name = '/p:OrtPackageId="Microsoft.AI.MachineLearning"' target_name = "/t:CreateWindowsAIPackage" elif use_openvino: - execution_provider = "/p:ExecutionProvider=\"openvino\"" - package_name = "/p:OrtPackageId=\"Microsoft.ML.OnnxRuntime.OpenVino\"" + execution_provider = '/p:ExecutionProvider="openvino"' + package_name = '/p:OrtPackageId="Microsoft.ML.OnnxRuntime.OpenVino"' elif use_tensorrt: - execution_provider = "/p:ExecutionProvider=\"tensorrt\"" - package_name = "/p:OrtPackageId=\"Microsoft.ML.OnnxRuntime.TensorRT\"" + execution_provider = '/p:ExecutionProvider="tensorrt"' + package_name = '/p:OrtPackageId="Microsoft.ML.OnnxRuntime.TensorRT"' elif use_dnnl: - execution_provider = "/p:ExecutionProvider=\"dnnl\"" - package_name = "/p:OrtPackageId=\"Microsoft.ML.OnnxRuntime.DNNL\"" + execution_provider = '/p:ExecutionProvider="dnnl"' + package_name = '/p:OrtPackageId="Microsoft.ML.OnnxRuntime.DNNL"' elif use_cuda: - package_name = "/p:OrtPackageId=\"Microsoft.ML.OnnxRuntime.Gpu\"" + package_name = '/p:OrtPackageId="Microsoft.ML.OnnxRuntime.Gpu"' elif use_nuphar: - package_name = "/p:OrtPackageId=\"Microsoft.ML.OnnxRuntime.Nuphar\"" + package_name = '/p:OrtPackageId="Microsoft.ML.OnnxRuntime.Nuphar"' elif use_tvm: - package_name = "/p:OrtPackageId=\"Microsoft.ML.OnnxRuntime.Tvm\"" + package_name = '/p:OrtPackageId="Microsoft.ML.OnnxRuntime.Tvm"' else: # use the solution file that includes Xamarin mobile targets sln = "OnnxRuntime.CSharp.sln" # set build directory based on build_dir arg native_dir = os.path.normpath(os.path.join(source_dir, build_dir)) - ort_build_dir = "/p:OnnxRuntimeBuildDirectory=\"" + native_dir + "\"" + ort_build_dir = '/p:OnnxRuntimeBuildDirectory="' + native_dir + '"' # dotnet restore cmd_args = ["dotnet", "restore", sln, "--configfile", "Nuget.CSharp.config"] @@ -1915,7 +2075,7 @@ def build_nuget_package(source_dir, build_dir, configs, use_cuda, use_openvino, cmd_args = ["make", "install", "DESTDIR=.//nuget-staging"] run_subprocess(cmd_args, cwd=native_build_dir) - configuration = "/p:Configuration=\"" + config + "\"" + configuration = '/p:Configuration="' + config + '"' if not use_winml: cmd_args = ["dotnet", "msbuild", sln, configuration, package_name, is_linux_build, ort_build_dir] @@ -1924,8 +2084,15 @@ def build_nuget_package(source_dir, build_dir, configs, use_cuda, use_openvino, winml_interop_dir = os.path.join(source_dir, "csharp", "src", "Microsoft.AI.MachineLearning.Interop") winml_interop_project = os.path.join(winml_interop_dir, "Microsoft.AI.MachineLearning.Interop.csproj") winml_interop_project = os.path.normpath(winml_interop_project) - cmd_args = ["dotnet", "msbuild", winml_interop_project, configuration, "/p:Platform=\"Any CPU\"", - ort_build_dir, "-restore"] + cmd_args = [ + "dotnet", + "msbuild", + winml_interop_project, + configuration, + '/p:Platform="Any CPU"', + ort_build_dir, + "-restore", + ] run_subprocess(cmd_args, cwd=csharp_build_dir) if is_windows(): @@ -1939,11 +2106,20 @@ def build_nuget_package(source_dir, build_dir, configs, use_cuda, use_openvino, # user needs to make sure nuget is installed and can be found nuget_exe = "nuget" - nuget_exe_arg = "/p:NugetExe=\"" + nuget_exe + "\"" + nuget_exe_arg = '/p:NugetExe="' + nuget_exe + '"' cmd_args = [ - "dotnet", "msbuild", "OnnxRuntime.CSharp.proj", target_name, - package_name, configuration, execution_provider, is_linux_build, ort_build_dir, nuget_exe_arg] + "dotnet", + "msbuild", + "OnnxRuntime.CSharp.proj", + target_name, + package_name, + configuration, + execution_provider, + is_linux_build, + ort_build_dir, + nuget_exe_arg, + ] run_subprocess(cmd_args, cwd=csharp_build_dir) @@ -1951,7 +2127,7 @@ def run_csharp_tests(source_dir, build_dir, use_cuda, use_openvino, use_tensorrt # Currently only running tests on windows. if not is_windows(): return - csharp_source_dir = os.path.join(source_dir, 'csharp') + csharp_source_dir = os.path.join(source_dir, "csharp") is_linux_build = derive_linux_build_property() # define macros based on build args @@ -1967,17 +2143,24 @@ def run_csharp_tests(source_dir, build_dir, use_cuda, use_openvino, use_tensorrt define_constants = "" if macros != "": - define_constants = "/p:DefineConstants=\"" + macros + "\"" + define_constants = '/p:DefineConstants="' + macros + '"' # set build directory based on build_dir arg native_build_dir = os.path.normpath(os.path.join(source_dir, build_dir)) - ort_build_dir = "/p:OnnxRuntimeBuildDirectory=\"" + native_build_dir + "\"" + ort_build_dir = '/p:OnnxRuntimeBuildDirectory="' + native_build_dir + '"' # Skip pretrained models test. Only run unit tests as part of the build # add "--verbosity", "detailed" to this command if required - cmd_args = ["dotnet", "test", "test\\Microsoft.ML.OnnxRuntime.Tests\\Microsoft.ML.OnnxRuntime.Tests.csproj", - "--filter", "FullyQualifiedName!=Microsoft.ML.OnnxRuntime.Tests.InferenceTest.TestPreTrainedModels", - is_linux_build, define_constants, ort_build_dir] + cmd_args = [ + "dotnet", + "test", + "test\\Microsoft.ML.OnnxRuntime.Tests\\Microsoft.ML.OnnxRuntime.Tests.csproj", + "--filter", + "FullyQualifiedName!=Microsoft.ML.OnnxRuntime.Tests.InferenceTest.TestPreTrainedModels", + is_linux_build, + define_constants, + ort_build_dir, + ] run_subprocess(cmd_args, cwd=csharp_source_dir) @@ -1992,59 +2175,54 @@ def is_cross_compiling_on_apple(args): def build_protoc_for_host(cmake_path, source_dir, build_dir, args): - if (args.arm or args.arm64 or args.arm64ec) and \ - not (is_windows() or is_cross_compiling_on_apple(args)): + if (args.arm or args.arm64 or args.arm64ec) and not (is_windows() or is_cross_compiling_on_apple(args)): raise BuildError( - 'Currently only support building protoc for Windows host while ' - 'cross-compiling for ARM/ARM64/Store and linux cross-compiling iOS') + "Currently only support building protoc for Windows host while " + "cross-compiling for ARM/ARM64/Store and linux cross-compiling iOS" + ) - log.info( - "Building protoc for host to be used in cross-compiled build process") - protoc_build_dir = os.path.join(os.getcwd(), build_dir, 'host_protoc') + log.info("Building protoc for host to be used in cross-compiled build process") + protoc_build_dir = os.path.join(os.getcwd(), build_dir, "host_protoc") os.makedirs(protoc_build_dir, exist_ok=True) # Generate step cmd_args = [ cmake_path, - os.path.join(source_dir, 'cmake', 'external', 'protobuf', 'cmake'), - '-Dprotobuf_BUILD_TESTS=OFF', - '-Dprotobuf_WITH_ZLIB_DEFAULT=OFF', - '-Dprotobuf_BUILD_SHARED_LIBS=OFF' + os.path.join(source_dir, "cmake", "external", "protobuf", "cmake"), + "-Dprotobuf_BUILD_TESTS=OFF", + "-Dprotobuf_WITH_ZLIB_DEFAULT=OFF", + "-Dprotobuf_BUILD_SHARED_LIBS=OFF", ] - is_ninja = args.cmake_generator == 'Ninja' + is_ninja = args.cmake_generator == "Ninja" if args.cmake_generator is not None and not (is_macOS() and args.use_xcode): - cmd_args += ['-G', args.cmake_generator] + cmd_args += ["-G", args.cmake_generator] if is_windows(): if not is_ninja: - cmd_args += ['-T', 'host=x64'] + cmd_args += ["-T", "host=x64"] elif is_macOS(): if args.use_xcode: - cmd_args += ['-G', 'Xcode'] + cmd_args += ["-G", "Xcode"] # CMake < 3.18 has a bug setting system arch to arm64 (if not specified) for Xcode 12, # protoc for host should be built using host architecture # Explicitly specify the CMAKE_OSX_ARCHITECTURES for x86_64 Mac. - cmd_args += ["-DCMAKE_OSX_ARCHITECTURES={}".format( - 'arm64' if platform.machine() == 'arm64' else 'x86_64')] + cmd_args += ["-DCMAKE_OSX_ARCHITECTURES={}".format("arm64" if platform.machine() == "arm64" else "x86_64")] run_subprocess(cmd_args, cwd=protoc_build_dir) # Build step - cmd_args = [cmake_path, - "--build", protoc_build_dir, - "--config", "Release", - "--target", "protoc"] + cmd_args = [cmake_path, "--build", protoc_build_dir, "--config", "Release", "--target", "protoc"] run_subprocess(cmd_args) # Absolute protoc path is needed for cmake - config_dir = '' - suffix = '' + config_dir = "" + suffix = "" if (is_windows() and not is_ninja) or (is_macOS() and args.use_xcode): - config_dir = 'Release' + config_dir = "Release" if is_windows(): - suffix = '.exe' + suffix = ".exe" - expected_protoc_path = os.path.join(protoc_build_dir, config_dir, 'protoc' + suffix) + expected_protoc_path = os.path.join(protoc_build_dir, config_dir, "protoc" + suffix) if not os.path.exists(expected_protoc_path): raise BuildError("Couldn't find {}. Host build of protoc failed.".format(expected_protoc_path)) @@ -2059,42 +2237,48 @@ def generate_documentation(source_dir, build_dir, configs, validate): if is_windows(): cwd = os.path.join(cwd, config) - contrib_op_doc_path = os.path.join(source_dir, 'docs', 'ContribOperators.md') - opkernel_doc_path = os.path.join(source_dir, 'docs', 'OperatorKernels.md') - shutil.copy(os.path.join(source_dir, 'tools', 'python', 'gen_contrib_doc.py'), cwd) - shutil.copy(os.path.join(source_dir, 'tools', 'python', 'gen_opkernel_doc.py'), cwd) + contrib_op_doc_path = os.path.join(source_dir, "docs", "ContribOperators.md") + opkernel_doc_path = os.path.join(source_dir, "docs", "OperatorKernels.md") + shutil.copy(os.path.join(source_dir, "tools", "python", "gen_contrib_doc.py"), cwd) + shutil.copy(os.path.join(source_dir, "tools", "python", "gen_opkernel_doc.py"), cwd) # limit to just com.microsoft (excludes purely internal stuff like com.microsoft.nchwc). - run_subprocess([sys.executable, 'gen_contrib_doc.py', '--output_path', contrib_op_doc_path, - '--domains', 'com.microsoft'], cwd=cwd) + run_subprocess( + [sys.executable, "gen_contrib_doc.py", "--output_path", contrib_op_doc_path, "--domains", "com.microsoft"], + cwd=cwd, + ) # we currently limit the documentation created by a build to the CPU and CUDA EPs. # Run get_opkernel_doc.py directly if you need/want documentation from other EPs that are enabled in the build. - run_subprocess([sys.executable, 'gen_opkernel_doc.py', '--output_path', opkernel_doc_path, - '--providers', 'CPU', 'CUDA'], cwd=cwd) + run_subprocess( + [sys.executable, "gen_opkernel_doc.py", "--output_path", opkernel_doc_path, "--providers", "CPU", "CUDA"], + cwd=cwd, + ) if validate: try: have_diff = False - def diff_file(path, regenerate_qualifiers=''): - diff = subprocess.check_output(['git', 'diff', path], cwd=source_dir) + def diff_file(path, regenerate_qualifiers=""): + diff = subprocess.check_output(["git", "diff", path], cwd=source_dir) if diff: nonlocal have_diff have_diff = True - log.warning('The updated document {} is different from the checked in version. ' - 'Please regenerate the file{}, or copy the updated version from the ' - 'CI build\'s published artifacts if applicable.'.format(path, regenerate_qualifiers)) - log.debug('diff:\n' + str(diff)) + log.warning( + "The updated document {} is different from the checked in version. " + "Please regenerate the file{}, or copy the updated version from the " + "CI build's published artifacts if applicable.".format(path, regenerate_qualifiers) + ) + log.debug("diff:\n" + str(diff)) - diff_file(opkernel_doc_path, ' with CPU and CUDA execution providers enabled') + diff_file(opkernel_doc_path, " with CPU and CUDA execution providers enabled") diff_file(contrib_op_doc_path) if have_diff: # Output for the CI to publish the updated md files as an artifact - print('##vso[task.setvariable variable=DocUpdateNeeded]true') - raise BuildError('Generated documents have diffs. Check build output for details.') + print("##vso[task.setvariable variable=DocUpdateNeeded]true") + raise BuildError("Generated documents have diffs. Check build output for details.") except subprocess.CalledProcessError: - raise BuildError('git diff returned non-zero error code') + raise BuildError("git diff returned non-zero error code") def main(): @@ -2111,7 +2295,7 @@ def main(): args.update = True args.build = True if cross_compiling: - args.test = args.android_abi == 'x86_64' or args.android_abi == 'arm64-v8a' + args.test = args.android_abi == "x86_64" or args.android_abi == "arm64-v8a" else: args.test = True @@ -2131,13 +2315,13 @@ def main(): args.build_shared_lib = True if args.build_nuget and cross_compiling: - raise BuildError('Currently nuget package creation is not supported while cross-compiling') + raise BuildError("Currently nuget package creation is not supported while cross-compiling") if args.enable_pybind and args.disable_rtti: raise BuildError("Python bindings use typeid so you can't disable RTTI") if args.enable_pybind and args.disable_exceptions: - raise BuildError('Python bindings require exceptions to be enabled.') + raise BuildError("Python bindings require exceptions to be enabled.") if args.nnapi_min_api: if not args.use_nnapi: @@ -2163,14 +2347,14 @@ def main(): if args.wasm_malloc is not None: # mark --wasm_malloc as deprecated log.warning( - "Flag '--wasm_malloc=' is deprecated. " - "Please use '--emscripten_settings MALLOC='.") + "Flag '--wasm_malloc=' is deprecated. " "Please use '--emscripten_settings MALLOC='." + ) if args.code_coverage and not args.android: raise BuildError("Using --code_coverage requires --android") if args.gen_api_doc and len(args.config) != 1: - raise BuildError('Using --get-api-doc requires a single build config') + raise BuildError("Using --get-api-doc requires a single build config") # Disabling unit tests for VAD-F as FPGA only supports # models with NCHW layout @@ -2191,8 +2375,7 @@ def main(): # cmake_path and ctest_path can be None. For example, if a person only wants to run the tests, he/she doesn't need # to have cmake/ctest. cmake_path = resolve_executable_path(args.cmake_path) - ctest_path = None if args.use_vstest else resolve_executable_path( - args.ctest_path) + ctest_path = None if args.use_vstest else resolve_executable_path(args.ctest_path) build_dir = args.build_dir script_dir = os.path.realpath(os.path.dirname(__file__)) source_dir = os.path.normpath(os.path.join(script_dir, "..", "..")) @@ -2227,12 +2410,14 @@ def main(): if args.update: if is_reduced_ops_build(args): from reduce_op_kernels import reduce_ops + for config in configs: reduce_ops( config_path=args.include_ops_by_config, build_dir=get_config_build_dir(build_dir, config), enable_type_reduction=args.enable_reduced_operator_type_support, - use_cuda=args.use_cuda) + use_cuda=args.use_cuda, + ) cmake_extra_args = [] path_to_protoc_exe = args.path_to_protoc_exe @@ -2241,60 +2426,55 @@ def main(): if is_windows(): cpu_arch = platform.architecture()[0] if args.build_wasm: - cmake_extra_args = ['-G', 'Ninja'] - elif args.cmake_generator == 'Ninja': - if cpu_arch == '32bit' or args.arm or args.arm64 or args.arm64ec: + cmake_extra_args = ["-G", "Ninja"] + elif args.cmake_generator == "Ninja": + if cpu_arch == "32bit" or args.arm or args.arm64 or args.arm64ec: raise BuildError( "To cross-compile with Ninja, load the toolset " "environment for the target processor (e.g. Cross " - "Tools Command Prompt for VS)") - cmake_extra_args = ['-G', args.cmake_generator] + "Tools Command Prompt for VS)" + ) + cmake_extra_args = ["-G", args.cmake_generator] elif args.arm or args.arm64 or args.arm64ec: # Cross-compiling for ARM(64) architecture # First build protoc for host to use during cross-compilation if path_to_protoc_exe is None: - path_to_protoc_exe = build_protoc_for_host( - cmake_path, source_dir, build_dir, args) + path_to_protoc_exe = build_protoc_for_host(cmake_path, source_dir, build_dir, args) if args.arm: - cmake_extra_args = ['-A', 'ARM'] + cmake_extra_args = ["-A", "ARM"] elif args.arm64: - cmake_extra_args = ['-A', 'ARM64'] + cmake_extra_args = ["-A", "ARM64"] elif args.arm64ec: - cmake_extra_args = ['-A', 'ARM64EC'] - cmake_extra_args += ['-G', args.cmake_generator] + cmake_extra_args = ["-A", "ARM64EC"] + cmake_extra_args += ["-G", args.cmake_generator] # Cannot test on host build machine for cross-compiled # builds (Override any user-defined behaviour for test if any) if args.test: log.warning( "Cannot test on host build machine for cross-compiled " - "ARM(64) builds. Will skip test running after build.") + "ARM(64) builds. Will skip test running after build." + ) args.test = False - elif cpu_arch == '32bit' or args.x86: - cmake_extra_args = [ - '-A', 'Win32', '-T', 'host=x64', '-G', args.cmake_generator - ] + elif cpu_arch == "32bit" or args.x86: + cmake_extra_args = ["-A", "Win32", "-T", "host=x64", "-G", args.cmake_generator] else: if args.msvc_toolset: - toolset = 'host=x64,version=' + args.msvc_toolset + toolset = "host=x64,version=" + args.msvc_toolset else: - toolset = 'host=x64' + toolset = "host=x64" if args.cuda_version: - toolset += ',cuda=' + args.cuda_version - cmake_extra_args = [ - '-A', 'x64', '-T', toolset, '-G', args.cmake_generator - ] + toolset += ",cuda=" + args.cuda_version + cmake_extra_args = ["-A", "x64", "-T", toolset, "-G", args.cmake_generator] if args.enable_wcos: - cmake_extra_defines.append('CMAKE_USER_MAKE_RULES_OVERRIDE=wcos_rules_override.cmake') + cmake_extra_defines.append("CMAKE_USER_MAKE_RULES_OVERRIDE=wcos_rules_override.cmake") elif args.cmake_generator is not None and not (is_macOS() and args.use_xcode): - cmake_extra_args += ['-G', args.cmake_generator] + cmake_extra_args += ["-G", args.cmake_generator] elif is_macOS(): if args.use_xcode: - cmake_extra_args += ['-G', 'Xcode'] - if not args.ios and not args.android and \ - args.osx_arch == 'arm64' and platform.machine() == 'x86_64': + cmake_extra_args += ["-G", "Xcode"] + if not args.ios and not args.android and args.osx_arch == "arm64" and platform.machine() == "x86_64": if args.test: - log.warning( - "Cannot test ARM64 build on X86_64. Will skip test running after build.") + log.warning("Cannot test ARM64 build on X86_64. Will skip test running after build.") args.test = False if args.build_wasm: @@ -2307,17 +2487,17 @@ def main(): log.info("Activating emsdk...") run_subprocess([emsdk_file, "activate", emsdk_version], cwd=emsdk_dir) - if (args.android or args.ios or args.build_wasm - or is_cross_compiling_on_apple(args)) and args.path_to_protoc_exe is None: + if ( + args.android or args.ios or args.build_wasm or is_cross_compiling_on_apple(args) + ) and args.path_to_protoc_exe is None: # Cross-compiling for Android, iOS, and WebAssembly - path_to_protoc_exe = build_protoc_for_host( - cmake_path, source_dir, build_dir, args) + path_to_protoc_exe = build_protoc_for_host(cmake_path, source_dir, build_dir, args) if is_ubuntu_1604(): - if (args.arm or args.arm64): + if args.arm or args.arm64: raise BuildError( - "Only Windows ARM(64) cross-compiled builds supported " - "currently through this script") + "Only Windows ARM(64) cross-compiled builds supported " "currently through this script" + ) if not is_docker() and not args.use_acl and not args.use_armnn: install_python_deps() @@ -2339,19 +2519,27 @@ def main(): if args.build_eager_mode: eager_root_dir = os.path.join(source_dir, "orttraining", "orttraining", "eager") if args.eager_customop_module and not args.eager_customop_header: - raise Exception('eager_customop_header must be provided when eager_customop_module is') + raise Exception("eager_customop_header must be provided when eager_customop_module is") elif args.eager_customop_header and not args.eager_customop_module: - raise Exception('eager_customop_module must be provided when eager_customop_header is') + raise Exception("eager_customop_module must be provided when eager_customop_header is") def gen_ops(gen_cpp_name: str, header_file: str, ops_module: str, custom_ops: bool): - gen_cpp_scratch_name = gen_cpp_name + '.working' - print(f'Generating ORT ATen overrides (output_file: {gen_cpp_name}, header_file: {header_file},' - f'ops_module: {ops_module}), custom_ops: {custom_ops}') - - cmd = [sys.executable, os.path.join(os.path.join(eager_root_dir, 'opgen', 'opgen.py')), - '--output_file', gen_cpp_scratch_name, - '--ops_module', ops_module, - '--header_file', header_file] + gen_cpp_scratch_name = gen_cpp_name + ".working" + print( + f"Generating ORT ATen overrides (output_file: {gen_cpp_name}, header_file: {header_file}," + f"ops_module: {ops_module}), custom_ops: {custom_ops}" + ) + + cmd = [ + sys.executable, + os.path.join(os.path.join(eager_root_dir, "opgen", "opgen.py")), + "--output_file", + gen_cpp_scratch_name, + "--ops_module", + ops_module, + "--header_file", + header_file, + ] if custom_ops: cmd += ["--custom_ops"] @@ -2359,8 +2547,10 @@ def gen_ops(gen_cpp_name: str, header_file: str, ops_module: str, custom_ops: bo subprocess.check_call(cmd) import filecmp - if (not os.path.isfile(gen_cpp_name) or - not filecmp.cmp(gen_cpp_name, gen_cpp_scratch_name, shallow=False)): + + if not os.path.isfile(gen_cpp_name) or not filecmp.cmp( + gen_cpp_name, gen_cpp_scratch_name, shallow=False + ): os.rename(gen_cpp_scratch_name, gen_cpp_name) else: os.remove(gen_cpp_scratch_name) @@ -2368,32 +2558,53 @@ def gen_ops(gen_cpp_name: str, header_file: str, ops_module: str, custom_ops: bo def gen_ort_ops(): # generate native aten ops import torch - regdecs_path = os.path.join(os.path.dirname(torch.__file__), 'include/ATen/RegistrationDeclarations.h') - ops_module = os.path.join(eager_root_dir, 'opgen/opgen/atenops.py') - gen_ops(os.path.join(eager_root_dir, 'ort_aten.g.cpp'), regdecs_path, ops_module, False) + regdecs_path = os.path.join(os.path.dirname(torch.__file__), "include/ATen/RegistrationDeclarations.h") + + ops_module = os.path.join(eager_root_dir, "opgen/opgen/atenops.py") + gen_ops(os.path.join(eager_root_dir, "ort_aten.g.cpp"), regdecs_path, ops_module, False) # generate custom ops if not args.eager_customop_header: - args.eager_customop_header = os.path.realpath(os.path.join( - eager_root_dir, - "opgen", - "CustomOpDeclarations.h")) + args.eager_customop_header = os.path.realpath( + os.path.join(eager_root_dir, "opgen", "CustomOpDeclarations.h") + ) if not args.eager_customop_module: - args.eager_customop_module = os.path.join(eager_root_dir, 'opgen/opgen/custom_ops.py') + args.eager_customop_module = os.path.join(eager_root_dir, "opgen/opgen/custom_ops.py") - gen_ops(os.path.join(eager_root_dir, 'ort_customops.g.cpp'), - args.eager_customop_header, args.eager_customop_module, True) + gen_ops( + os.path.join(eager_root_dir, "ort_customops.g.cpp"), + args.eager_customop_header, + args.eager_customop_module, + True, + ) gen_ort_ops() if args.enable_external_custom_op_schemas and not is_linux(): raise BuildError("Registering external custom op schemas is only supported on Linux.") generate_build_tree( - cmake_path, source_dir, build_dir, cuda_home, cudnn_home, rocm_home, mpi_home, nccl_home, - tensorrt_home, migraphx_home, acl_home, acl_libs, armnn_home, armnn_libs, - path_to_protoc_exe, configs, cmake_extra_defines, args, cmake_extra_args) + cmake_path, + source_dir, + build_dir, + cuda_home, + cudnn_home, + rocm_home, + mpi_home, + nccl_home, + tensorrt_home, + migraphx_home, + acl_home, + acl_libs, + armnn_home, + armnn_libs, + path_to_protoc_exe, + configs, + cmake_extra_defines, + args, + cmake_extra_args, + ) if args.clean: clean_targets(cmake_path, build_dir, configs) @@ -2427,8 +2638,8 @@ def gen_ort_ops(): # either. if args.build: if args.build_wheel: - nightly_build = bool(os.getenv('NIGHTLY_BUILD') == '1') - default_training_package_device = bool(os.getenv('DEFAULT_TRAINING_PACKAGE_DEVICE') == '1') + nightly_build = bool(os.getenv("NIGHTLY_BUILD") == "1") + default_training_package_device = bool(os.getenv("DEFAULT_TRAINING_PACKAGE_DEVICE") == "1") build_python_wheel( source_dir, build_dir, @@ -2450,8 +2661,8 @@ def gen_ort_ops(): args.enable_training, nightly_build=nightly_build, default_training_package_device=default_training_package_device, - use_ninja=(args.cmake_generator == 'Ninja'), - build_eager_mode=args.build_eager_mode + use_ninja=(args.cmake_generator == "Ninja"), + build_eager_mode=args.build_eager_mode, ) if args.build_nuget: build_nuget_package( @@ -2468,22 +2679,18 @@ def gen_ort_ops(): ) if args.test and args.build_nuget: - run_csharp_tests( - source_dir, - build_dir, - args.use_cuda, - args.use_openvino, - args.use_tensorrt, - args.use_dnnl) + run_csharp_tests(source_dir, build_dir, args.use_cuda, args.use_openvino, args.use_tensorrt, args.use_dnnl) if args.gen_doc and (args.build or args.test): - generate_documentation(source_dir, build_dir, configs, args.gen_doc == 'validate') + generate_documentation(source_dir, build_dir, configs, args.gen_doc == "validate") if args.gen_api_doc and (args.build or args.test): - print('Generating Python doc for ORTModule...') - docbuild_dir = os.path.join(source_dir, 'tools', 'doc') - run_subprocess(['bash', 'builddoc.sh', os.path.dirname(sys.executable), - source_dir, build_dir, args.config[0]], cwd=docbuild_dir) + print("Generating Python doc for ORTModule...") + docbuild_dir = os.path.join(source_dir, "tools", "doc") + run_subprocess( + ["bash", "builddoc.sh", os.path.dirname(sys.executable), source_dir, build_dir, args.config[0]], + cwd=docbuild_dir, + ) log.info("Build complete") diff --git a/tools/ci_build/clean_docker_image_cache.py b/tools/ci_build/clean_docker_image_cache.py index 59a20484b3f28..8ad3bb11c8c46 100755 --- a/tools/ci_build/clean_docker_image_cache.py +++ b/tools/ci_build/clean_docker_image_cache.py @@ -10,8 +10,8 @@ import re import sys import tempfile -from logger import get_logger +from logger import get_logger SCRIPT_DIR = os.path.dirname(os.path.realpath(__file__)) REPO_DIR = os.path.normpath(os.path.join(SCRIPT_DIR, "..", "..")) @@ -21,7 +21,6 @@ from util import run # noqa: E402 - log = get_logger("clean_docker_image_cache") @@ -35,36 +34,26 @@ def parse_args(): "retained or removed. " "For an image to be retained, it must have been accessed at least N " "times (specified by --cache-min-access-count) over the past K days " - "(specified by --cache-history-days).") + "(specified by --cache-history-days)." + ) - parser.add_argument( - "--container-registry", required=True, - help="The container registry name.") + parser.add_argument("--container-registry", required=True, help="The container registry name.") + parser.add_argument("--log-storage-account", required=True, help="The storage account name.") + parser.add_argument("--log-storage-account-container", required=True, help="The storage account container name.") parser.add_argument( - "--log-storage-account", required=True, - help="The storage account name.") - parser.add_argument( - "--log-storage-account-container", required=True, - help="The storage account container name.") - parser.add_argument( - "--log-storage-path-pattern", default="*.json", - help="The log path pattern in the storage account container.") + "--log-storage-path-pattern", default="*.json", help="The log path pattern in the storage account container." + ) - parser.add_argument( - "--cache-history-days", type=int, default=7, - help="The length of the cache history in days.") + parser.add_argument("--cache-history-days", type=int, default=7, help="The length of the cache history in days.") parser.add_argument( - "--cache-min-access-count", type=int, default=1, - help="The minimum access count over the cache history.") + "--cache-min-access-count", type=int, default=1, help="The minimum access count over the cache history." + ) - parser.add_argument( - "--dry-run", action="store_true", - help="Do a dry-run and do not remove any images.") + parser.add_argument("--dry-run", action="store_true", help="Do a dry-run and do not remove any images.") - parser.add_argument( - "--az-path", default="az", help="Path to the az client.") + parser.add_argument("--az-path", default="az", help="Path to the az client.") return parser.parse_args() @@ -78,12 +67,19 @@ def az(*args, parse_output=True, az_path): def download_logs(storage_account, container, log_path_pattern, target_dir, az_path): log_paths = az( - "storage", "blob", "download-batch", - "--destination", target_dir, - "--source", container, - "--account-name", storage_account, - "--pattern", log_path_pattern, - az_path=az_path) + "storage", + "blob", + "download-batch", + "--destination", + target_dir, + "--source", + container, + "--account-name", + storage_account, + "--pattern", + log_path_pattern, + az_path=az_path, + ) return [os.path.join(target_dir, log_path) for log_path in log_paths] @@ -95,7 +91,8 @@ def get_image_name(image_info): timestamp_pattern = re.compile( - r"^(?P\d+)-(?P\d+)-(?P\d+)T(?P\d+):(?P\d+):(?P\d+)") + r"^(?P\d+)-(?P\d+)-(?P\d+)T(?P\d+):(?P\d+):(?P\d+)" +) def parse_timestamp(timestamp_str): @@ -104,9 +101,14 @@ def parse_timestamp(timestamp_str): return None return datetime.datetime( - year=int(match['year']), month=int(match['month']), day=int(match['day']), - hour=int(match['hour']), minute=int(match['minute']), second=int(match['second']), - tzinfo=datetime.timezone.utc) + year=int(match["year"]), + month=int(match["month"]), + day=int(match["day"]), + hour=int(match["hour"]), + minute=int(match["minute"]), + second=int(match["second"]), + tzinfo=datetime.timezone.utc, + ) def parse_log_line(line, min_datetime): @@ -117,11 +119,12 @@ def check_time(value): return timestamp is not None and timestamp >= min_datetime for field_name, expected_value_or_checker in [ - ("category", "ContainerRegistryRepositoryEvents"), - ("operationName", lambda value: value in ["Pull", "Push"]), - ("resultType", "HttpStatusCode"), - ("resultDescription", lambda value: value in ["200", "201"]), - ("time", check_time)]: + ("category", "ContainerRegistryRepositoryEvents"), + ("operationName", lambda value: value in ["Pull", "Push"]), + ("resultType", "HttpStatusCode"), + ("resultDescription", lambda value: value in ["200", "201"]), + ("time", check_time), + ]: value = entry.get(field_name, "") if callable(expected_value_or_checker): if not expected_value_or_checker(value): @@ -156,29 +159,41 @@ def get_valid_images_from_logs(log_paths, min_datetime, min_access_count): def get_registry_images(container_registry, az_path): registry_images = set() # set of ImageInfo - repositories = az( - "acr", "repository", "list", "--name", container_registry, - az_path=az_path) + repositories = az("acr", "repository", "list", "--name", container_registry, az_path=az_path) for repository in repositories: digests = az( - "acr", "repository", "show-manifests", - "--repository", repository, "--name", container_registry, - "--query", "[*].digest", - az_path=az_path) - - registry_images.update( - [ImageInfo(repository, digest) for digest in digests]) + "acr", + "repository", + "show-manifests", + "--repository", + repository, + "--name", + container_registry, + "--query", + "[*].digest", + az_path=az_path, + ) + + registry_images.update([ImageInfo(repository, digest) for digest in digests]) return registry_images def clean_images(container_registry, image_names, az_path): for image_name in image_names: - az("acr", "repository", "delete", "--name", container_registry, - "--image", image_name, "--yes", - az_path=az_path, - parse_output=False) + az( + "acr", + "repository", + "delete", + "--name", + container_registry, + "--image", + image_name, + "--yes", + az_path=az_path, + parse_output=False, + ) # Note: @@ -208,31 +223,27 @@ def main(): args.log_storage_account_container, args.log_storage_path_pattern, tmp_dir, - args.az_path) + args.az_path, + ) cache_history = datetime.timedelta(days=args.cache_history_days) - min_timestamp = \ - datetime.datetime.now(tz=datetime.timezone.utc) - cache_history + min_timestamp = datetime.datetime.now(tz=datetime.timezone.utc) - cache_history - valid_images = get_valid_images_from_logs( - log_paths, min_timestamp, args.cache_min_access_count) + valid_images = get_valid_images_from_logs(log_paths, min_timestamp, args.cache_min_access_count) all_images = get_registry_images(args.container_registry, args.az_path) def sorted_image_names(image_infos): return sorted([get_image_name(image_info) for image_info in image_infos]) - log.debug("All images:\n{}".format( - "\n".join(sorted_image_names(all_images)))) - log.debug("Valid images:\n{}".format( - "\n".join(sorted_image_names(valid_images)))) + log.debug("All images:\n{}".format("\n".join(sorted_image_names(all_images)))) + log.debug("Valid images:\n{}".format("\n".join(sorted_image_names(valid_images)))) images_to_clean = all_images - valid_images image_names_to_clean = sorted_image_names(images_to_clean) - log.info("Images to clean:\n{}".format( - "\n".join(image_names_to_clean))) + log.info("Images to clean:\n{}".format("\n".join(image_names_to_clean))) if args.dry_run: log.info("Dry run, no images will be cleaned.") diff --git a/tools/ci_build/coverage.py b/tools/ci_build/coverage.py index 13aee41f9cbff..48d919aa7358b 100644 --- a/tools/ci_build/coverage.py +++ b/tools/ci_build/coverage.py @@ -8,9 +8,10 @@ # 2. The tests are run on the target emulator and *.gcda files are available on the emulator # 3. The emulator which ran tests must be running. Otherwise this script will fail +import argparse import os import sys -import argparse + from build import run_subprocess SCRIPT_DIR = os.path.dirname(os.path.realpath(__file__)) @@ -23,14 +24,14 @@ def parse_arguments(): parser = argparse.ArgumentParser() + parser.add_argument("--build_dir", required=True, help="Path to the build directory.") parser.add_argument( - "--build_dir", required=True, help="Path to the build directory.") - parser.add_argument( - "--config", default="Debug", + "--config", + default="Debug", choices=["Debug", "MinSizeRel", "Release", "RelWithDebInfo"], - help="Configuration(s) to run code coverage.") - parser.add_argument( - "--android_sdk_path", required=True, help="The Android SDK root.") + help="Configuration(s) to run code coverage.", + ) + parser.add_argument("--android_sdk_path", required=True, help="The Android SDK root.") return parser.parse_args() @@ -40,18 +41,18 @@ def main(): sdk_tool_paths = android.get_sdk_tool_paths(args.android_sdk_path) def adb_pull(src, dest, **kwargs): - return run_subprocess([sdk_tool_paths.adb, 'pull', src, dest], **kwargs) + return run_subprocess([sdk_tool_paths.adb, "pull", src, dest], **kwargs) def adb_shell(*args, **kwargs): - return run_subprocess([sdk_tool_paths.adb, 'shell', *args], **kwargs) + return run_subprocess([sdk_tool_paths.adb, "shell", *args], **kwargs) script_dir = os.path.realpath(os.path.dirname(__file__)) source_dir = os.path.normpath(os.path.join(script_dir, "..", "..")) cwd = os.path.abspath(os.path.join(args.build_dir, args.config)) - adb_shell('cd /data/local/tmp && tar -zcf gcda_files.tar.gz *.dir') - adb_pull('/data/local/tmp/gcda_files.tar.gz', cwd) + adb_shell("cd /data/local/tmp && tar -zcf gcda_files.tar.gz *.dir") + adb_pull("/data/local/tmp/gcda_files.tar.gz", cwd) os.chdir(cwd) - run_subprocess("tar -zxf gcda_files.tar.gz -C CMakeFiles".split(' ')) + run_subprocess("tar -zxf gcda_files.tar.gz -C CMakeFiles".split(" ")) cmd = ["gcovr", "-s", "-r"] cmd.append(os.path.join(source_dir, "onnxruntime")) cmd.extend([".", "-o"]) diff --git a/tools/ci_build/gen_def.py b/tools/ci_build/gen_def.py index 950b2b6b79c2d..6ef2c3d991436 100755 --- a/tools/ci_build/gen_def.py +++ b/tools/ci_build/gen_def.py @@ -16,15 +16,15 @@ def parse_arguments(): args = parse_arguments() print("Generating symbol file for %s" % str(args.config)) -with open(args.version_file, 'r') as f: +with open(args.version_file, "r") as f: VERSION_STRING = f.read().strip() print("VERSION:%s" % VERSION_STRING) symbols = set() for c in args.config: - file_name = os.path.join(args.src_root, 'core', 'providers', c, 'symbols.txt') - with open(file_name, 'r') as file: + file_name = os.path.join(args.src_root, "core", "providers", c, "symbols.txt") + with open(file_name, "r") as file: for line in file: line = line.strip() if line in symbols: @@ -34,31 +34,31 @@ def parse_arguments(): symbols = sorted(symbols) symbol_index = 1 -with open(args.output, 'w') as file: - if args.style == 'vc': - file.write('LIBRARY\n') - file.write('EXPORTS\n') - elif args.style == 'xcode': - pass # xcode compile don't has any header. +with open(args.output, "w") as file: + if args.style == "vc": + file.write("LIBRARY\n") + file.write("EXPORTS\n") + elif args.style == "xcode": + pass # xcode compile don't has any header. else: - file.write('VERS_%s {\n' % VERSION_STRING) - file.write(' global:\n') + file.write("VERS_%s {\n" % VERSION_STRING) + file.write(" global:\n") for symbol in symbols: - if args.style == 'vc': + if args.style == "vc": file.write(" %s @%d\n" % (symbol, symbol_index)) - elif args.style == 'xcode': + elif args.style == "xcode": file.write("_%s\n" % symbol) else: file.write(" %s;\n" % symbol) symbol_index += 1 - if args.style == 'gcc': + if args.style == "gcc": file.write(" local:\n") file.write(" *;\n") file.write("}; \n") -with open(args.output_source, 'w') as file: +with open(args.output_source, "w") as file: file.write("#include \n") for c in args.config: # WinML adapter should not be exported in platforms other than Windows. @@ -69,6 +69,6 @@ def parse_arguments(): file.write("void* GetFunctionEntryByName(const char* name){\n") for symbol in symbols: if symbol != "OrtGetWinMLAdapter": - file.write("if(strcmp(name,\"%s\") ==0) return (void*)&%s;\n" % (symbol, symbol)) + file.write('if(strcmp(name,"%s") ==0) return (void*)&%s;\n' % (symbol, symbol)) file.write("return NULL;\n") file.write("}\n") diff --git a/tools/ci_build/get_docker_image.py b/tools/ci_build/get_docker_image.py index c028c243167a9..e0eb06fc14fa7 100755 --- a/tools/ci_build/get_docker_image.py +++ b/tools/ci_build/get_docker_image.py @@ -8,8 +8,8 @@ import os import shlex import sys -from logger import get_logger +from logger import get_logger SCRIPT_DIR = os.path.dirname(os.path.realpath(__file__)) REPO_DIR = os.path.normpath(os.path.join(SCRIPT_DIR, "..", "..")) @@ -19,7 +19,6 @@ from util import run # noqa: E402 - log = get_logger("get_docker_image") @@ -34,36 +33,36 @@ def parse_args(): "This script checks whether an image with that tag is initially " "present in the container registry to determine whether to pull or " "build the image. " - "The user must be logged in to the container registry.") + "The user must be logged in to the container registry." + ) + parser.add_argument("--dockerfile", default="Dockerfile", help="Path to the Dockerfile.") + parser.add_argument("--context", default=".", help="Path to the build context.") parser.add_argument( - "--dockerfile", default="Dockerfile", help="Path to the Dockerfile.") - parser.add_argument( - "--context", default=".", help="Path to the build context.") - parser.add_argument( - "--docker-build-args", default="", + "--docker-build-args", + default="", help="String of Docker build args which may affect the image content. " "These will be used in differentiating images from one another. " - "For example, '--build-arg'.") + "For example, '--build-arg'.", + ) parser.add_argument( - "--docker-build-args-not-affecting-image-content", default="", - help="String of Docker build args which do not affect the image " - "content.") + "--docker-build-args-not-affecting-image-content", + default="", + help="String of Docker build args which do not affect the image " "content.", + ) parser.add_argument( "--container-registry", - help="The Azure container registry name. " - "If not provided, no container registry will be used.") - parser.add_argument( - "--repository", required=True, help="The image repository name.") + help="The Azure container registry name. " "If not provided, no container registry will be used.", + ) + parser.add_argument("--repository", required=True, help="The image repository name.") - parser.add_argument( - "--docker-path", default="docker", help="Path to docker.") + parser.add_argument("--docker-path", default="docker", help="Path to docker.") return parser.parse_args() -FileInfo = collections.namedtuple('FileInfo', ['path', 'mode']) +FileInfo = collections.namedtuple("FileInfo", ["path", "mode"]) def file_info_str(file_info: FileInfo): @@ -113,19 +112,15 @@ def update_hash_with_file(file_info: FileInfo, hash_obj): def generate_tag(dockerfile_path, context_path, docker_build_args_str): hash_obj = hashlib.sha256() hash_obj.update(docker_build_args_str.encode()) - update_hash_with_file( - make_file_info_from_path(dockerfile_path), hash_obj) - update_hash_with_directory( - make_file_info_from_path(context_path), hash_obj) + update_hash_with_file(make_file_info_from_path(dockerfile_path), hash_obj) + update_hash_with_directory(make_file_info_from_path(context_path), hash_obj) return "image_content_digest_{}".format(hash_obj.hexdigest()) def container_registry_has_image(full_image_name, docker_path): env = os.environ.copy() env["DOCKER_CLI_EXPERIMENTAL"] = "enabled" # needed for "docker manifest" - proc = run( - docker_path, "manifest", "inspect", "--insecure", full_image_name, - env=env, check=False, quiet=True) + proc = run(docker_path, "manifest", "inspect", "--insecure", full_image_name, env=env, check=False, quiet=True) image_found = proc.returncode == 0 log.debug("Image {} in registry".format("found" if image_found else "not found")) return image_found @@ -134,8 +129,11 @@ def container_registry_has_image(full_image_name, docker_path): def main(): args = parse_args() - log.debug("Dockerfile: {}, context: {}, docker build args: '{}'".format( - args.dockerfile, args.context, args.docker_build_args)) + log.debug( + "Dockerfile: {}, context: {}, docker build args: '{}'".format( + args.dockerfile, args.context, args.docker_build_args + ) + ) use_container_registry = args.container_registry is not None @@ -144,10 +142,11 @@ def main(): tag = generate_tag(args.dockerfile, args.context, args.docker_build_args) - full_image_name = \ - "{}.azurecr.io/{}:{}".format(args.container_registry, args.repository, tag) \ - if use_container_registry else \ - "{}:{}".format(args.repository, tag) + full_image_name = ( + "{}.azurecr.io/{}:{}".format(args.container_registry, args.repository, tag) + if use_container_registry + else "{}:{}".format(args.repository, tag) + ) log.info("Image: {}".format(full_image_name)) @@ -156,13 +155,18 @@ def main(): run(args.docker_path, "pull", full_image_name) else: log.info("Building image...") - run(args.docker_path, "build", + run( + args.docker_path, + "build", "--pull", *shlex.split(args.docker_build_args), *shlex.split(args.docker_build_args_not_affecting_image_content), - "--tag", full_image_name, - "--file", args.dockerfile, - args.context) + "--tag", + full_image_name, + "--file", + args.dockerfile, + args.context + ) if use_container_registry: # avoid pushing if an identically tagged image has been pushed since the last check diff --git a/tools/ci_build/github/android/build_aar_package.py b/tools/ci_build/github/android/build_aar_package.py index efaf4579a9371..c0cacb4231665 100644 --- a/tools/ci_build/github/android/build_aar_package.py +++ b/tools/ci_build/github/android/build_aar_package.py @@ -35,42 +35,43 @@ def _parse_build_settings(args): setting_file = args.build_settings_file.resolve() if not setting_file.is_file(): - raise FileNotFoundError('Build config file {} is not a file.'.format(setting_file)) + raise FileNotFoundError("Build config file {} is not a file.".format(setting_file)) with open(setting_file) as f: build_settings_data = json.load(f) build_settings = {} - if 'build_abis' in build_settings_data: - build_settings['build_abis'] = build_settings_data['build_abis'] + if "build_abis" in build_settings_data: + build_settings["build_abis"] = build_settings_data["build_abis"] else: - build_settings['build_abis'] = DEFAULT_BUILD_ABIS + build_settings["build_abis"] = DEFAULT_BUILD_ABIS build_params = [] - if 'build_params' in build_settings_data: - build_params += build_settings_data['build_params'] + if "build_params" in build_settings_data: + build_params += build_settings_data["build_params"] else: - raise ValueError('build_params is required in the build config file') + raise ValueError("build_params is required in the build config file") - if 'android_min_sdk_version' in build_settings_data: - build_settings['android_min_sdk_version'] = build_settings_data['android_min_sdk_version'] + if "android_min_sdk_version" in build_settings_data: + build_settings["android_min_sdk_version"] = build_settings_data["android_min_sdk_version"] else: - build_settings['android_min_sdk_version'] = DEFAULT_ANDROID_MIN_SDK_VER - build_params += ['--android_api=' + str(build_settings['android_min_sdk_version'])] + build_settings["android_min_sdk_version"] = DEFAULT_ANDROID_MIN_SDK_VER + build_params += ["--android_api=" + str(build_settings["android_min_sdk_version"])] - if 'android_target_sdk_version' in build_settings_data: - build_settings['android_target_sdk_version'] = build_settings_data['android_target_sdk_version'] + if "android_target_sdk_version" in build_settings_data: + build_settings["android_target_sdk_version"] = build_settings_data["android_target_sdk_version"] else: - build_settings['android_target_sdk_version'] = DEFAULT_ANDROID_TARGET_SDK_VER + build_settings["android_target_sdk_version"] = DEFAULT_ANDROID_TARGET_SDK_VER - if build_settings['android_min_sdk_version'] > build_settings['android_target_sdk_version']: + if build_settings["android_min_sdk_version"] > build_settings["android_target_sdk_version"]: raise ValueError( - 'android_min_sdk_version {} cannot be larger than android_target_sdk_version {}'.format( - build_settings['android_min_sdk_version'], build_settings['android_target_sdk_version'] - )) + "android_min_sdk_version {} cannot be larger than android_target_sdk_version {}".format( + build_settings["android_min_sdk_version"], build_settings["android_target_sdk_version"] + ) + ) - build_settings['build_params'] = build_params + build_settings["build_params"] = build_params build_settings["build_variant"] = build_settings_data.get("build_variant", DEFAULT_BUILD_VARIANT) return build_settings @@ -83,28 +84,25 @@ def _build_aar(args): # Setup temp environment for building temp_env = os.environ.copy() - temp_env['ANDROID_HOME'] = os.path.abspath(args.android_sdk_path) - temp_env['ANDROID_NDK_HOME'] = os.path.abspath(args.android_ndk_path) + temp_env["ANDROID_HOME"] = os.path.abspath(args.android_sdk_path) + temp_env["ANDROID_NDK_HOME"] = os.path.abspath(args.android_ndk_path) # Temp dirs to hold building results - intermediates_dir = os.path.join(build_dir, 'intermediates') + intermediates_dir = os.path.join(build_dir, "intermediates") build_config = args.config - aar_dir = os.path.join(intermediates_dir, 'aar', build_config) - jnilibs_dir = os.path.join(intermediates_dir, 'jnilibs', build_config) - exe_dir = os.path.join(intermediates_dir, 'executables', build_config) - base_build_command = [sys.executable, BUILD_PY] + build_settings['build_params'] + ['--config=' + build_config] - header_files_path = '' + aar_dir = os.path.join(intermediates_dir, "aar", build_config) + jnilibs_dir = os.path.join(intermediates_dir, "jnilibs", build_config) + exe_dir = os.path.join(intermediates_dir, "executables", build_config) + base_build_command = [sys.executable, BUILD_PY] + build_settings["build_params"] + ["--config=" + build_config] + header_files_path = "" # Build binary for each ABI, one by one - for abi in build_settings['build_abis']: + for abi in build_settings["build_abis"]: abi_build_dir = os.path.join(intermediates_dir, abi) - abi_build_command = base_build_command + [ - '--android_abi=' + abi, - '--build_dir=' + abi_build_dir - ] + abi_build_command = base_build_command + ["--android_abi=" + abi, "--build_dir=" + abi_build_dir] if ops_config_path is not None: - abi_build_command += ['--include_ops_by_config=' + ops_config_path] + abi_build_command += ["--include_ops_by_config=" + ops_config_path] subprocess.run(abi_build_command, env=temp_env, shell=False, check=True, cwd=REPO_DIR) @@ -112,7 +110,7 @@ def _build_aar(args): # to jnilibs/[abi] for later compiling the aar package abi_jnilibs_dir = os.path.join(jnilibs_dir, abi) os.makedirs(abi_jnilibs_dir, exist_ok=True) - for lib_name in ['libonnxruntime.so', 'libonnxruntime4j_jni.so']: + for lib_name in ["libonnxruntime.so", "libonnxruntime4j_jni.so"]: target_lib_name = os.path.join(abi_jnilibs_dir, lib_name) # If the symbolic already exists, delete it first # For some reason, os.path.exists will return false for a symbolic link in Linux, @@ -123,72 +121,85 @@ def _build_aar(args): # copy executables for each abi, in case we want to publish those as well abi_exe_dir = os.path.join(exe_dir, abi) - for exe_name in ['libonnxruntime.so', 'onnxruntime_perf_test', 'onnx_test_runner']: + for exe_name in ["libonnxruntime.so", "onnxruntime_perf_test", "onnx_test_runner"]: os.makedirs(abi_exe_dir, exist_ok=True) target_exe_name = os.path.join(abi_exe_dir, exe_name) shutil.copyfile(os.path.join(abi_build_dir, build_config, exe_name), target_exe_name) # we only need to define the header files path once if not header_files_path: - header_files_path = os.path.join(abi_build_dir, build_config, 'android', 'headers') + header_files_path = os.path.join(abi_build_dir, build_config, "android", "headers") # The directory to publish final AAR - aar_publish_dir = os.path.join(build_dir, 'aar_out', build_config) + aar_publish_dir = os.path.join(build_dir, "aar_out", build_config) os.makedirs(aar_publish_dir, exist_ok=True) # get the common gradle command args gradle_command = [ - 'gradle', - '--no-daemon', - '-b=build-android.gradle', - '-c=settings-android.gradle', - '-DjniLibsDir=' + jnilibs_dir, - '-DbuildDir=' + aar_dir, - '-DheadersDir=' + header_files_path, - '-DpublishDir=' + aar_publish_dir, - '-DminSdkVer=' + str(build_settings['android_min_sdk_version']), - '-DtargetSdkVer=' + str(build_settings['android_target_sdk_version']), - '-DbuildVariant=' + str(build_settings['build_variant']) + "gradle", + "--no-daemon", + "-b=build-android.gradle", + "-c=settings-android.gradle", + "-DjniLibsDir=" + jnilibs_dir, + "-DbuildDir=" + aar_dir, + "-DheadersDir=" + header_files_path, + "-DpublishDir=" + aar_publish_dir, + "-DminSdkVer=" + str(build_settings["android_min_sdk_version"]), + "-DtargetSdkVer=" + str(build_settings["android_target_sdk_version"]), + "-DbuildVariant=" + str(build_settings["build_variant"]), ] # If not using shell on Window, will not be able to find gradle in path use_shell = True if is_windows() else False # clean, build, and publish to a local directory - subprocess.run(gradle_command + ['clean'], env=temp_env, shell=use_shell, check=True, cwd=JAVA_ROOT) - subprocess.run(gradle_command + ['build'], env=temp_env, shell=use_shell, check=True, cwd=JAVA_ROOT) - subprocess.run(gradle_command + ['publish'], env=temp_env, shell=use_shell, check=True, cwd=JAVA_ROOT) + subprocess.run(gradle_command + ["clean"], env=temp_env, shell=use_shell, check=True, cwd=JAVA_ROOT) + subprocess.run(gradle_command + ["build"], env=temp_env, shell=use_shell, check=True, cwd=JAVA_ROOT) + subprocess.run(gradle_command + ["publish"], env=temp_env, shell=use_shell, check=True, cwd=JAVA_ROOT) def parse_args(): parser = argparse.ArgumentParser( os.path.basename(__file__), - description='''Create Android Archive (AAR) package for one or more Android ABI(s) + description="""Create Android Archive (AAR) package for one or more Android ABI(s) and building properties specified in the given build config file, see tools/ci_build/github/android/default_mobile_aar_build_settings.json for details. The output of the final AAR package can be found under [build_dir]/aar_out - ''' + """, ) - parser.add_argument("--android_sdk_path", type=str, default=os.environ.get("ANDROID_HOME", ""), - help="Path to the Android SDK") + parser.add_argument( + "--android_sdk_path", type=str, default=os.environ.get("ANDROID_HOME", ""), help="Path to the Android SDK" + ) - parser.add_argument("--android_ndk_path", type=str, default=os.environ.get("ANDROID_NDK_HOME", ""), - help="Path to the Android NDK") + parser.add_argument( + "--android_ndk_path", type=str, default=os.environ.get("ANDROID_NDK_HOME", ""), help="Path to the Android NDK" + ) - parser.add_argument('--build_dir', type=str, default=os.path.join(REPO_DIR, 'build/android_aar'), - help='Provide the root directory for build output') + parser.add_argument( + "--build_dir", + type=str, + default=os.path.join(REPO_DIR, "build/android_aar"), + help="Provide the root directory for build output", + ) parser.add_argument( - "--include_ops_by_config", type=str, - help="Include ops from config file. See /docs/Reduced_Operator_Kernel_build.md for more information.") + "--include_ops_by_config", + type=str, + help="Include ops from config file. See /docs/Reduced_Operator_Kernel_build.md for more information.", + ) - parser.add_argument("--config", type=str, default="Release", - choices=["Debug", "MinSizeRel", "Release", "RelWithDebInfo"], - help="Configuration to build.") + parser.add_argument( + "--config", + type=str, + default="Release", + choices=["Debug", "MinSizeRel", "Release", "RelWithDebInfo"], + help="Configuration to build.", + ) - parser.add_argument('build_settings_file', type=pathlib.Path, - help='Provide the file contains settings for building AAR') + parser.add_argument( + "build_settings_file", type=pathlib.Path, help="Provide the file contains settings for building AAR" + ) return parser.parse_args() @@ -198,12 +209,12 @@ def main(): # Android SDK and NDK path are required if not args.android_sdk_path: - raise ValueError('android_sdk_path is required') + raise ValueError("android_sdk_path is required") if not args.android_ndk_path: - raise ValueError('android_ndk_path is required') + raise ValueError("android_ndk_path is required") _build_aar(args) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/tools/ci_build/github/apple/build_and_assemble_ios_pods.py b/tools/ci_build/github/apple/build_and_assemble_ios_pods.py index 7a3a0ee597687..e316aa5d2c324 100755 --- a/tools/ci_build/github/apple/build_and_assemble_ios_pods.py +++ b/tools/ci_build/github/apple/build_and_assemble_ios_pods.py @@ -10,55 +10,71 @@ import sys import tempfile - from c.assemble_c_pod_package import assemble_c_pod_package from objectivec.assemble_objc_pod_package import assemble_objc_pod_package -from package_assembly_utils import get_ort_version, PackageVariant - +from package_assembly_utils import PackageVariant, get_ort_version SCRIPT_PATH = pathlib.Path(__file__).resolve() SCRIPT_DIR = SCRIPT_PATH.parent REPO_DIR = SCRIPT_PATH.parents[4] -logging.basicConfig( - format="%(asctime)s %(name)s [%(levelname)s] - %(message)s", - level=logging.DEBUG) +logging.basicConfig(format="%(asctime)s %(name)s [%(levelname)s] - %(message)s", level=logging.DEBUG) log = logging.getLogger(SCRIPT_PATH.stem) def parse_args(): parser = argparse.ArgumentParser( description="Builds an iOS framework and uses it to assemble iOS pod package files.", - formatter_class=argparse.ArgumentDefaultsHelpFormatter) - - parser.add_argument("--build-dir", type=pathlib.Path, default=REPO_DIR / "build" / "ios_framework", - help="The build directory. This will contain the iOS framework build output.") - parser.add_argument("--staging-dir", type=pathlib.Path, default=REPO_DIR / "build" / "ios_pod_staging", - help="The staging directory. This will contain the iOS pod package files. " - "The pod package files do not have dependencies on files in the build directory.") - - parser.add_argument("--pod-version", default=f"{get_ort_version()}-local", - help="The version string of the pod. The same version is used for all pods.") - - parser.add_argument("--variant", choices=PackageVariant.release_variant_names(), - default=PackageVariant.Mobile.name, - help="Pod package variant.") - - parser.add_argument("--test", action="store_true", - help="Run tests on the framework and pod package files.") + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + + parser.add_argument( + "--build-dir", + type=pathlib.Path, + default=REPO_DIR / "build" / "ios_framework", + help="The build directory. This will contain the iOS framework build output.", + ) + parser.add_argument( + "--staging-dir", + type=pathlib.Path, + default=REPO_DIR / "build" / "ios_pod_staging", + help="The staging directory. This will contain the iOS pod package files. " + "The pod package files do not have dependencies on files in the build directory.", + ) + + parser.add_argument( + "--pod-version", + default=f"{get_ort_version()}-local", + help="The version string of the pod. The same version is used for all pods.", + ) + + parser.add_argument( + "--variant", + choices=PackageVariant.release_variant_names(), + default=PackageVariant.Mobile.name, + help="Pod package variant.", + ) + + parser.add_argument("--test", action="store_true", help="Run tests on the framework and pod package files.") build_framework_group = parser.add_argument_group( title="iOS framework build arguments", - description="See the corresponding arguments in build_ios_framework.py for details.") + description="See the corresponding arguments in build_ios_framework.py for details.", + ) build_framework_group.add_argument("--include-ops-by-config") - build_framework_group.add_argument("--build-settings-file", required=True, - help="The positional argument of build_ios_framework.py.") - build_framework_group.add_argument("-b", "--build-ios-framework-arg", action="append", - dest="build_ios_framework_extra_args", default=[], - help="Pass an argument through to build_ios_framework.py. " - "This may be specified multiple times.") + build_framework_group.add_argument( + "--build-settings-file", required=True, help="The positional argument of build_ios_framework.py." + ) + build_framework_group.add_argument( + "-b", + "--build-ios-framework-arg", + action="append", + dest="build_ios_framework_extra_args", + default=[], + help="Pass an argument through to build_ios_framework.py. " "This may be specified multiple times.", + ) args = parser.parse_args() @@ -70,8 +86,11 @@ def run(arg_list, cwd=None): import shlex import subprocess - log.info("Running subprocess in '{0}'\n {1}".format( - cwd or os.getcwd(), " ".join([shlex.quote(arg) for arg in arg_list]))) + log.info( + "Running subprocess in '{0}'\n {1}".format( + cwd or os.getcwd(), " ".join([shlex.quote(arg) for arg in arg_list]) + ) + ) return subprocess.run(arg_list, check=True, cwd=cwd) @@ -88,23 +107,30 @@ def main(): log.info("Building iOS framework.") - build_ios_framework_args = \ - [sys.executable, str(SCRIPT_DIR / "build_ios_framework.py")] + args.build_ios_framework_extra_args + build_ios_framework_args = [ + sys.executable, + str(SCRIPT_DIR / "build_ios_framework.py"), + ] + args.build_ios_framework_extra_args if args.include_ops_by_config is not None: build_ios_framework_args += ["--include_ops_by_config", args.include_ops_by_config] - build_ios_framework_args += ["--build_dir", str(build_dir), - args.build_settings_file] + build_ios_framework_args += ["--build_dir", str(build_dir), args.build_settings_file] run(build_ios_framework_args) if args.test: - test_ios_packages_args = [sys.executable, str(SCRIPT_DIR / "test_ios_packages.py"), - "--fail_if_cocoapods_missing", - "--framework_info_file", str(framework_info_file), - "--c_framework_dir", str(build_dir / "framework_out"), - "--variant", package_variant.name] + test_ios_packages_args = [ + sys.executable, + str(SCRIPT_DIR / "test_ios_packages.py"), + "--fail_if_cocoapods_missing", + "--framework_info_file", + str(framework_info_file), + "--c_framework_dir", + str(build_dir / "framework_out"), + "--variant", + package_variant.name, + ] run(test_ios_packages_args) @@ -122,7 +148,8 @@ def main(): framework_info_file=framework_info_file, framework_dir=build_dir / "framework_out" / "onnxruntime.xcframework", public_headers_dir=build_dir / "framework_out" / "Headers", - package_variant=package_variant) + package_variant=package_variant, + ) if args.test: test_c_pod_args = ["pod", "lib", "lint", "--verbose"] @@ -136,7 +163,8 @@ def main(): staging_dir=objc_pod_staging_dir, pod_version=args.pod_version, framework_info_file=framework_info_file, - package_variant=package_variant) + package_variant=package_variant, + ) if args.test: test_objc_pod_args = ["pod", "lib", "lint", "--verbose", f"--include-podspecs={c_pod_podspec}"] diff --git a/tools/ci_build/github/apple/build_ios_framework.py b/tools/ci_build/github/apple/build_ios_framework.py index 4447f2f3b8a16..b5d84ff1c8a28 100644 --- a/tools/ci_build/github/apple/build_ios_framework.py +++ b/tools/ci_build/github/apple/build_ios_framework.py @@ -17,8 +17,8 @@ # We by default will build below 3 archs DEFAULT_BUILD_OSX_ARCHS = { - 'iphoneos': ['arm64'], - 'iphonesimulator': ['arm64', 'x86_64'], + "iphoneos": ["arm64"], + "iphonesimulator": ["arm64", "x86_64"], } @@ -31,30 +31,31 @@ def _parse_build_settings(args): build_settings["build_osx_archs"] = build_settings_data.get("build_osx_archs", DEFAULT_BUILD_OSX_ARCHS) build_params = [] - if 'build_params' in build_settings_data: - build_params += build_settings_data['build_params'] + if "build_params" in build_settings_data: + build_params += build_settings_data["build_params"] else: - raise ValueError('build_params is required in the build config file') + raise ValueError("build_params is required in the build config file") - build_settings['build_params'] = build_params + build_settings["build_params"] = build_params return build_settings # Build fat framework for all archs of a single sysroot # For example, arm64 and x86_64 for iphonesimulator -def _build_for_ios_sysroot(build_config, intermediates_dir, base_build_command, - sysroot, archs, build_dynamic_framework): +def _build_for_ios_sysroot( + build_config, intermediates_dir, base_build_command, sysroot, archs, build_dynamic_framework +): # paths of the onnxruntime libraries for different archs ort_libs = [] - info_plist_path = '' + info_plist_path = "" # Build binary for each arch, one by one for current_arch in archs: build_dir_current_arch = os.path.join(intermediates_dir, sysroot + "_" + current_arch) build_command = base_build_command + [ - '--ios_sysroot=' + sysroot, - '--osx_arch=' + current_arch, - '--build_dir=' + build_dir_current_arch + "--ios_sysroot=" + sysroot, + "--osx_arch=" + current_arch, + "--build_dir=" + build_dir_current_arch, ] # the actual build process for current arch @@ -62,19 +63,23 @@ def _build_for_ios_sysroot(build_config, intermediates_dir, base_build_command, # get the compiled lib path framework_dir = os.path.join( - build_dir_current_arch, build_config, build_config + "-" + sysroot, - 'onnxruntime.framework' if build_dynamic_framework - else os.path.join('static_framework', 'onnxruntime.framework')) - ort_libs.append(os.path.join(framework_dir, 'onnxruntime')) + build_dir_current_arch, + build_config, + build_config + "-" + sysroot, + "onnxruntime.framework" + if build_dynamic_framework + else os.path.join("static_framework", "onnxruntime.framework"), + ) + ort_libs.append(os.path.join(framework_dir, "onnxruntime")) # We only need to copy Info.plist, framework_info.json, and headers once since they are the same if not info_plist_path: - info_plist_path = os.path.join(build_dir_current_arch, build_config, 'Info.plist') - framework_info_path = os.path.join(build_dir_current_arch, build_config, 'framework_info.json') - headers = glob.glob(os.path.join(framework_dir, 'Headers', '*.h')) + info_plist_path = os.path.join(build_dir_current_arch, build_config, "Info.plist") + framework_info_path = os.path.join(build_dir_current_arch, build_config, "framework_info.json") + headers = glob.glob(os.path.join(framework_dir, "Headers", "*.h")) # manually create the fat framework - framework_dir = os.path.join(intermediates_dir, 'frameworks', sysroot, 'onnxruntime.framework') + framework_dir = os.path.join(intermediates_dir, "frameworks", sysroot, "onnxruntime.framework") # remove the existing framework if any if os.path.exists(framework_dir): shutil.rmtree(framework_dir) @@ -83,15 +88,15 @@ def _build_for_ios_sysroot(build_config, intermediates_dir, base_build_command, # copy the Info.plist, framework_info.json, and header files shutil.copy(info_plist_path, framework_dir) shutil.copy(framework_info_path, os.path.dirname(framework_dir)) - header_dir = os.path.join(framework_dir, 'Headers') + header_dir = os.path.join(framework_dir, "Headers") pathlib.Path(header_dir).mkdir(parents=True, exist_ok=True) for _header in headers: shutil.copy(_header, header_dir) # use lipo to create a fat ort library - lipo_command = ['lipo', '-create'] + lipo_command = ["lipo", "-create"] lipo_command += ort_libs - lipo_command += ['-output', os.path.join(framework_dir, 'onnxruntime')] + lipo_command += ["-output", os.path.join(framework_dir, "onnxruntime")] subprocess.run(lipo_command, shell=False, check=True) return framework_dir @@ -102,47 +107,51 @@ def _build_package(args): build_dir = os.path.abspath(args.build_dir) # Temp dirs to hold building results - intermediates_dir = os.path.join(build_dir, 'intermediates') + intermediates_dir = os.path.join(build_dir, "intermediates") build_config = args.config - base_build_command = [sys.executable, BUILD_PY] + build_settings['build_params'] + ['--config=' + build_config] + base_build_command = [sys.executable, BUILD_PY] + build_settings["build_params"] + ["--config=" + build_config] if args.include_ops_by_config is not None: - base_build_command += ['--include_ops_by_config=' + str(args.include_ops_by_config.resolve())] + base_build_command += ["--include_ops_by_config=" + str(args.include_ops_by_config.resolve())] if args.path_to_protoc_exe is not None: - base_build_command += ['--path_to_protoc_exe=' + str(args.path_to_protoc_exe.resolve())] + base_build_command += ["--path_to_protoc_exe=" + str(args.path_to_protoc_exe.resolve())] # build framework for individual sysroot framework_dirs = [] - framework_info_path = '' - public_headers_path = '' - for sysroot in build_settings['build_osx_archs']: + framework_info_path = "" + public_headers_path = "" + for sysroot in build_settings["build_osx_archs"]: framework_dir = _build_for_ios_sysroot( - build_config, intermediates_dir, base_build_command, sysroot, - build_settings['build_osx_archs'][sysroot], args.build_dynamic_framework) + build_config, + intermediates_dir, + base_build_command, + sysroot, + build_settings["build_osx_archs"][sysroot], + args.build_dynamic_framework, + ) framework_dirs.append(framework_dir) # podspec and headers for each sysroot are the same, pick one of them if not framework_info_path: - framework_info_path = os.path.join(os.path.dirname(framework_dir), 'framework_info.json') - public_headers_path = os.path.join(os.path.dirname(framework_dir), 'onnxruntime.framework', 'Headers') + framework_info_path = os.path.join(os.path.dirname(framework_dir), "framework_info.json") + public_headers_path = os.path.join(os.path.dirname(framework_dir), "onnxruntime.framework", "Headers") # create the folder for xcframework and copy the LICENSE and podspec file - xcframework_dir = os.path.join(build_dir, 'framework_out') + xcframework_dir = os.path.join(build_dir, "framework_out") pathlib.Path(xcframework_dir).mkdir(parents=True, exist_ok=True) - shutil.copy(os.path.join(REPO_DIR, 'LICENSE'), xcframework_dir) - shutil.copytree(public_headers_path, os.path.join(xcframework_dir, 'Headers'), dirs_exist_ok=True) + shutil.copy(os.path.join(REPO_DIR, "LICENSE"), xcframework_dir) + shutil.copytree(public_headers_path, os.path.join(xcframework_dir, "Headers"), dirs_exist_ok=True) shutil.copy(framework_info_path, build_dir) # remove existing xcframework if any - xcframework_path = os.path.join(xcframework_dir, 'onnxruntime.xcframework') + xcframework_path = os.path.join(xcframework_dir, "onnxruntime.xcframework") if os.path.exists(xcframework_path): shutil.rmtree(xcframework_path) # Assemble the final xcframework - build_xcframework_cmd = ['xcrun', 'xcodebuild', '-create-xcframework', - '-output', xcframework_path] + build_xcframework_cmd = ["xcrun", "xcodebuild", "-create-xcframework", "-output", xcframework_path] for framework_dir in framework_dirs: - build_xcframework_cmd.extend(['-framework', framework_dir]) + build_xcframework_cmd.extend(["-framework", framework_dir]) subprocess.run(build_xcframework_cmd, shell=False, check=True, cwd=REPO_DIR) @@ -150,42 +159,56 @@ def _build_package(args): def parse_args(): parser = argparse.ArgumentParser( os.path.basename(__file__), - description='''Create iOS framework and podspec for one or more osx_archs (xcframework) + description="""Create iOS framework and podspec for one or more osx_archs (xcframework) and building properties specified in the given build config file, see tools/ci_build/github/apple/default_mobile_ios_framework_build_settings.json for details. The output of the final xcframework and podspec can be found under [build_dir]/framework_out. Please note, this building script will only work on macOS. - ''' + """, ) - parser.add_argument('--build_dir', type=pathlib.Path, default=os.path.join(REPO_DIR, 'build/iOS_framework'), - help='Provide the root directory for build output') + parser.add_argument( + "--build_dir", + type=pathlib.Path, + default=os.path.join(REPO_DIR, "build/iOS_framework"), + help="Provide the root directory for build output", + ) parser.add_argument( - "--include_ops_by_config", type=pathlib.Path, - help="Include ops from config file. See /docs/Reduced_Operator_Kernel_build.md for more information.") + "--include_ops_by_config", + type=pathlib.Path, + help="Include ops from config file. See /docs/Reduced_Operator_Kernel_build.md for more information.", + ) - parser.add_argument("--config", type=str, default="Release", - choices=["Debug", "MinSizeRel", "Release", "RelWithDebInfo"], - help="Configuration to build.") + parser.add_argument( + "--config", + type=str, + default="Release", + choices=["Debug", "MinSizeRel", "Release", "RelWithDebInfo"], + help="Configuration to build.", + ) - parser.add_argument("--build_dynamic_framework", action='store_true', - help="Build Dynamic Framework (default is build static framework).") + parser.add_argument( + "--build_dynamic_framework", + action="store_true", + help="Build Dynamic Framework (default is build static framework).", + ) - parser.add_argument('build_settings_file', type=pathlib.Path, - help='Provide the file contains settings for building iOS framework') + parser.add_argument( + "build_settings_file", type=pathlib.Path, help="Provide the file contains settings for building iOS framework" + ) parser.add_argument("--path_to_protoc_exe", type=pathlib.Path, help="Path to protoc exe.") args = parser.parse_args() if not args.build_settings_file.resolve().is_file(): - raise FileNotFoundError('Build config file {} is not a file.'.format(args.build_settings_file.resolve())) + raise FileNotFoundError("Build config file {} is not a file.".format(args.build_settings_file.resolve())) if args.include_ops_by_config is not None: include_ops_by_config_file = args.include_ops_by_config.resolve() if not include_ops_by_config_file.is_file(): - raise FileNotFoundError('Include ops config file {} is not a file.'.format(include_ops_by_config_file)) + raise FileNotFoundError("Include ops config file {} is not a file.".format(include_ops_by_config_file)) return args @@ -195,5 +218,5 @@ def main(): _build_package(args) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/tools/ci_build/github/apple/c/assemble_c_pod_package.py b/tools/ci_build/github/apple/c/assemble_c_pod_package.py index b1187c139f72d..dd80123c25591 100644 --- a/tools/ci_build/github/apple/c/assemble_c_pod_package.py +++ b/tools/ci_build/github/apple/c/assemble_c_pod_package.py @@ -8,20 +8,22 @@ import shutil import sys - _script_dir = pathlib.Path(__file__).parent.resolve(strict=True) sys.path.append(str(_script_dir.parent)) from package_assembly_utils import ( # noqa: E402 - copy_repo_relative_to_dir, gen_file_from_template, load_json_config, - PackageVariant) + PackageVariant, + copy_repo_relative_to_dir, + gen_file_from_template, + load_json_config, +) def get_pod_config_file(package_variant: PackageVariant): - ''' + """ Gets the pod configuration file path for the given package variant. - ''' + """ if package_variant == PackageVariant.Full: return _script_dir / "onnxruntime-c.config.json" elif package_variant == PackageVariant.Mobile: @@ -32,11 +34,15 @@ def get_pod_config_file(package_variant: PackageVariant): raise ValueError(f"Unhandled package variant: {package_variant}") -def assemble_c_pod_package(staging_dir: pathlib.Path, pod_version: str, - framework_info_file: pathlib.Path, - public_headers_dir: pathlib.Path, framework_dir: pathlib.Path, - package_variant: PackageVariant): - ''' +def assemble_c_pod_package( + staging_dir: pathlib.Path, + pod_version: str, + framework_info_file: pathlib.Path, + public_headers_dir: pathlib.Path, + framework_dir: pathlib.Path, + package_variant: PackageVariant, +): + """ Assembles the files for the C/C++ pod package in a staging directory. :param staging_dir Path to the staging directory for the C/C++ pod files. @@ -46,7 +52,7 @@ def assemble_c_pod_package(staging_dir: pathlib.Path, pod_version: str, :param framework_dir Path to the onnxruntime framework directory to include in the pod. :param package_variant The pod package variant. :return Tuple of (package name, path to the podspec file). - ''' + """ staging_dir = staging_dir.resolve() framework_info_file = framework_info_file.resolve(strict=True) public_headers_dir = public_headers_dir.resolve(strict=True) @@ -88,25 +94,42 @@ def assemble_c_pod_package(staging_dir: pathlib.Path, pod_version: str, def parse_args(): - parser = argparse.ArgumentParser(description=""" + parser = argparse.ArgumentParser( + description=""" Assembles the files for the C/C++ pod package in a staging directory. This directory can be validated (e.g., with `pod lib lint`) and then zipped to create a package for release. - """) - - parser.add_argument("--staging-dir", type=pathlib.Path, - default=pathlib.Path("./c-staging"), - help="Path to the staging directory for the C/C++ pod files.") - parser.add_argument("--pod-version", required=True, - help="C/C++ pod version.") - parser.add_argument("--framework-info-file", type=pathlib.Path, required=True, - help="Path to the framework_info.json file containing additional values for the podspec. " - "This file should be generated by CMake in the build directory.") - parser.add_argument("--public-headers-dir", type=pathlib.Path, required=True, - help="Path to the public headers directory to include in the pod.") - parser.add_argument("--framework-dir", type=pathlib.Path, required=True, - help="Path to the onnxruntime framework directory to include in the pod.") - parser.add_argument("--variant", choices=PackageVariant.all_variant_names(), required=True, - help="Pod package variant.") + """ + ) + + parser.add_argument( + "--staging-dir", + type=pathlib.Path, + default=pathlib.Path("./c-staging"), + help="Path to the staging directory for the C/C++ pod files.", + ) + parser.add_argument("--pod-version", required=True, help="C/C++ pod version.") + parser.add_argument( + "--framework-info-file", + type=pathlib.Path, + required=True, + help="Path to the framework_info.json file containing additional values for the podspec. " + "This file should be generated by CMake in the build directory.", + ) + parser.add_argument( + "--public-headers-dir", + type=pathlib.Path, + required=True, + help="Path to the public headers directory to include in the pod.", + ) + parser.add_argument( + "--framework-dir", + type=pathlib.Path, + required=True, + help="Path to the onnxruntime framework directory to include in the pod.", + ) + parser.add_argument( + "--variant", choices=PackageVariant.all_variant_names(), required=True, help="Pod package variant." + ) return parser.parse_args() @@ -114,12 +137,14 @@ def parse_args(): def main(): args = parse_args() - assemble_c_pod_package(staging_dir=args.staging_dir, - pod_version=args.pod_version, - framework_info_file=args.framework_info_file, - public_headers_dir=args.public_headers_dir, - framework_dir=args.framework_dir, - package_variant=PackageVariant[args.variant]) + assemble_c_pod_package( + staging_dir=args.staging_dir, + pod_version=args.pod_version, + framework_info_file=args.framework_info_file, + public_headers_dir=args.public_headers_dir, + framework_dir=args.framework_dir, + package_variant=PackageVariant[args.variant], + ) return 0 diff --git a/tools/ci_build/github/apple/objectivec/assemble_objc_pod_package.py b/tools/ci_build/github/apple/objectivec/assemble_objc_pod_package.py index 5888c409250f0..2cfa44a135d3e 100755 --- a/tools/ci_build/github/apple/objectivec/assemble_objc_pod_package.py +++ b/tools/ci_build/github/apple/objectivec/assemble_objc_pod_package.py @@ -7,15 +7,17 @@ import pathlib import sys - _script_dir = pathlib.Path(__file__).parent.resolve(strict=True) sys.path.append(str(_script_dir.parent)) from c.assemble_c_pod_package import get_pod_config_file as get_c_pod_config_file # noqa: E402 from package_assembly_utils import ( # noqa: E402 - copy_repo_relative_to_dir, gen_file_from_template, load_json_config, PackageVariant) - + PackageVariant, + copy_repo_relative_to_dir, + gen_file_from_template, + load_json_config, +) # these variables contain paths or path patterns that are relative to the repo root @@ -57,9 +59,9 @@ def get_pod_config_file(package_variant: PackageVariant): - ''' + """ Gets the pod configuration file path for the given package variant. - ''' + """ if package_variant == PackageVariant.Full: return _script_dir / "onnxruntime-objc.config.json" elif package_variant == PackageVariant.Mobile: @@ -68,10 +70,10 @@ def get_pod_config_file(package_variant: PackageVariant): raise ValueError(f"Unhandled package variant: {package_variant}") -def assemble_objc_pod_package(staging_dir: pathlib.Path, pod_version: str, - framework_info_file: pathlib.Path, - package_variant: PackageVariant): - ''' +def assemble_objc_pod_package( + staging_dir: pathlib.Path, pod_version: str, framework_info_file: pathlib.Path, package_variant: PackageVariant +): + """ Assembles the files for the Objective-C pod package in a staging directory. :param staging_dir Path to the staging directory for the Objective-C pod files. @@ -79,7 +81,7 @@ def assemble_objc_pod_package(staging_dir: pathlib.Path, pod_version: str, :param framework_info_file Path to the framework_info.json file containing additional values for the podspec. :param package_variant The pod package variant. :return Tuple of (package name, path to the podspec file). - ''' + """ staging_dir = staging_dir.resolve() framework_info_file = framework_info_file.resolve(strict=True) @@ -94,9 +96,7 @@ def assemble_objc_pod_package(staging_dir: pathlib.Path, pod_version: str, print("Warning: staging directory already exists", file=sys.stderr) # copy the necessary files to the staging directory - copy_repo_relative_to_dir( - [license_file] + source_files + test_source_files + test_resource_files, - staging_dir) + copy_repo_relative_to_dir([license_file] + source_files + test_source_files + test_resource_files, staging_dir) # generate the podspec file from the template @@ -127,21 +127,30 @@ def path_patterns_as_variable_value(patterns: list[str]): def parse_args(): - parser = argparse.ArgumentParser(description=""" + parser = argparse.ArgumentParser( + description=""" Assembles the files for the Objective-C pod package in a staging directory. This directory can be validated (e.g., with `pod lib lint`) and then zipped to create a package for release. - """) - - parser.add_argument("--staging-dir", type=pathlib.Path, - default=pathlib.Path("./onnxruntime-mobile-objc-staging"), - help="Path to the staging directory for the Objective-C pod files.") - parser.add_argument("--pod-version", required=True, - help="Objective-C pod version.") - parser.add_argument("--framework-info-file", type=pathlib.Path, required=True, - help="Path to the framework_info.json file containing additional values for the podspec. " - "This file should be generated by CMake in the build directory.") - parser.add_argument("--variant", choices=PackageVariant.release_variant_names(), required=True, - help="Pod package variant.") + """ + ) + + parser.add_argument( + "--staging-dir", + type=pathlib.Path, + default=pathlib.Path("./onnxruntime-mobile-objc-staging"), + help="Path to the staging directory for the Objective-C pod files.", + ) + parser.add_argument("--pod-version", required=True, help="Objective-C pod version.") + parser.add_argument( + "--framework-info-file", + type=pathlib.Path, + required=True, + help="Path to the framework_info.json file containing additional values for the podspec. " + "This file should be generated by CMake in the build directory.", + ) + parser.add_argument( + "--variant", choices=PackageVariant.release_variant_names(), required=True, help="Pod package variant." + ) return parser.parse_args() @@ -149,10 +158,12 @@ def parse_args(): def main(): args = parse_args() - assemble_objc_pod_package(staging_dir=args.staging_dir, - pod_version=args.pod_version, - framework_info_file=args.framework_info_file, - package_variant=PackageVariant[args.variant]) + assemble_objc_pod_package( + staging_dir=args.staging_dir, + pod_version=args.pod_version, + framework_info_file=args.framework_info_file, + package_variant=PackageVariant[args.variant], + ) return 0 diff --git a/tools/ci_build/github/apple/package_assembly_utils.py b/tools/ci_build/github/apple/package_assembly_utils.py index a2a31ae112272..18a5b39ae2333 100644 --- a/tools/ci_build/github/apple/package_assembly_utils.py +++ b/tools/ci_build/github/apple/package_assembly_utils.py @@ -7,7 +7,6 @@ import pathlib import re import shutil - from typing import Dict, List _script_dir = pathlib.Path(__file__).parent.resolve(strict=True) @@ -31,10 +30,10 @@ def all_variant_names(cls): _template_variable_pattern = re.compile(r"@(\w+)@") # match "@var@" -def gen_file_from_template(template_file: pathlib.Path, output_file: pathlib.Path, - variable_substitutions: Dict[str, str], - strict: bool = True): - ''' +def gen_file_from_template( + template_file: pathlib.Path, output_file: pathlib.Path, variable_substitutions: Dict[str, str], strict: bool = True +): + """ Generates a file from a template file. The template file may contain template variables that will be substituted with the provided values in the generated output file. @@ -46,7 +45,7 @@ def gen_file_from_template(template_file: pathlib.Path, output_file: pathlib.Pat :param variable_substitutions The mapping from template variable name to value. :param strict Whether to require the set of template variable names in the file and the keys of `variable_substitutions` to be equal. - ''' + """ with open(template_file, mode="r") as template: content = template.read() @@ -61,23 +60,25 @@ def replace_template_variable(match): if strict and variables_in_file != variable_substitutions.keys(): variables_in_substitutions = set(variable_substitutions.keys()) - raise ValueError(f"Template file variables and substitution variables do not match. " - f"Only in template file: {sorted(variables_in_file - variables_in_substitutions)}. " - f"Only in substitutions: {sorted(variables_in_substitutions - variables_in_file)}.") + raise ValueError( + f"Template file variables and substitution variables do not match. " + f"Only in template file: {sorted(variables_in_file - variables_in_substitutions)}. " + f"Only in substitutions: {sorted(variables_in_substitutions - variables_in_file)}." + ) with open(output_file, mode="w") as output: output.write(content) def copy_repo_relative_to_dir(patterns: List[str], dest_dir: pathlib.Path): - ''' + """ Copies file paths relative to the repo root to a directory. The given paths or path patterns are relative to the repo root, and the repo root-relative intermediate directory structure is maintained. :param patterns The paths or path patterns relative to the repo root. :param dest_dir The destination directory. - ''' + """ paths = [path for pattern in patterns for path in repo_root.glob(pattern)] for path in paths: repo_relative_path = path.relative_to(repo_root) @@ -87,21 +88,21 @@ def copy_repo_relative_to_dir(patterns: List[str], dest_dir: pathlib.Path): def load_json_config(json_config_file: pathlib.Path): - ''' + """ Loads configuration info from a JSON file. :param json_config_file The JSON configuration file path. :return The configuration info values. - ''' + """ with open(json_config_file, mode="r") as config: return json.load(config) def get_ort_version(): - ''' + """ Gets the ONNX Runtime version string from the repo. :return The ONNX Runtime version string. - ''' + """ with open(repo_root / "VERSION_NUMBER", mode="r") as version_file: return version_file.read().strip() diff --git a/tools/ci_build/github/apple/test_ios_packages.py b/tools/ci_build/github/apple/test_ios_packages.py index 988dfce77e372..5bfeb8b42bd29 100644 --- a/tools/ci_build/github/apple/test_ios_packages.py +++ b/tools/ci_build/github/apple/test_ios_packages.py @@ -10,10 +10,8 @@ import subprocess import tempfile - from c.assemble_c_pod_package import assemble_c_pod_package -from package_assembly_utils import gen_file_from_template, get_ort_version, PackageVariant - +from package_assembly_utils import PackageVariant, gen_file_from_template, get_ort_version SCRIPT_PATH = pathlib.Path(__file__).resolve(strict=True) REPO_DIR = SCRIPT_PATH.parents[4] @@ -21,30 +19,29 @@ def _test_ios_packages(args): # check if CocoaPods is installed - if shutil.which('pod') is None: + if shutil.which("pod") is None: if args.fail_if_cocoapods_missing: - raise ValueError('CocoaPods is required for this test') + raise ValueError("CocoaPods is required for this test") else: - print('CocoaPods is not installed, ignore this test') + print("CocoaPods is not installed, ignore this test") return # Now we need to create a zip file contains the framework and the podspec file, both of these 2 files # should be under the c_framework_dir c_framework_dir = args.c_framework_dir.resolve() if not c_framework_dir.is_dir(): - raise FileNotFoundError('c_framework_dir {} is not a folder.'.format(c_framework_dir)) + raise FileNotFoundError("c_framework_dir {} is not a folder.".format(c_framework_dir)) - has_framework = (c_framework_dir / 'onnxruntime.framework').exists() - has_xcframework = (c_framework_dir / 'onnxruntime.xcframework').exists() + has_framework = (c_framework_dir / "onnxruntime.framework").exists() + has_xcframework = (c_framework_dir / "onnxruntime.xcframework").exists() if not has_framework and not has_xcframework: - raise FileNotFoundError('{} does not have onnxruntime.framework/xcframework'.format(c_framework_dir)) + raise FileNotFoundError("{} does not have onnxruntime.framework/xcframework".format(c_framework_dir)) if has_framework and has_xcframework: - raise ValueError('Cannot proceed when both onnxruntime.framework ' - 'and onnxruntime.xcframework exist') + raise ValueError("Cannot proceed when both onnxruntime.framework " "and onnxruntime.xcframework exist") - framework_name = 'onnxruntime.framework' if has_framework else 'onnxruntime.xcframework' + framework_name = "onnxruntime.framework" if has_framework else "onnxruntime.xcframework" # create a temp folder @@ -59,14 +56,14 @@ def _test_ios_packages(args): os.makedirs(stage_dir) # assemble the test project here - target_proj_path = stage_dir / 'ios_package_test' + target_proj_path = stage_dir / "ios_package_test" # copy the test project source files to target_proj_path - test_proj_path = pathlib.Path(REPO_DIR, 'onnxruntime/test/platform/ios/ios_package_test') + test_proj_path = pathlib.Path(REPO_DIR, "onnxruntime/test/platform/ios/ios_package_test") shutil.copytree(test_proj_path, target_proj_path) # assemble local pod files here - local_pods_dir = stage_dir / 'local_pods' + local_pods_dir = stage_dir / "local_pods" # We will only publish xcframework, however, assembly of the xcframework is a post process # and it cannot be done by CMake for now. See, https://gitlab.kitware.com/cmake/cmake/-/issues/21752 @@ -76,75 +73,107 @@ def _test_ios_packages(args): framework_dir = args.c_framework_dir / framework_name public_headers_dir = framework_dir / "Headers" if has_framework else args.c_framework_dir / "Headers" - pod_name, podspec = assemble_c_pod_package(staging_dir=local_pods_dir, - pod_version=get_ort_version(), - framework_info_file=args.framework_info_file, - public_headers_dir=public_headers_dir, - framework_dir=framework_dir, - package_variant=PackageVariant[args.variant]) + pod_name, podspec = assemble_c_pod_package( + staging_dir=local_pods_dir, + pod_version=get_ort_version(), + framework_info_file=args.framework_info_file, + public_headers_dir=public_headers_dir, + framework_dir=framework_dir, + package_variant=PackageVariant[args.variant], + ) # move podspec out to target_proj_path first podspec = shutil.move(podspec, target_proj_path / podspec.name) # create a zip file contains the framework - zip_file_path = local_pods_dir / f'{pod_name}.zip' + zip_file_path = local_pods_dir / f"{pod_name}.zip" # shutil.make_archive require target file as full path without extension - shutil.make_archive(zip_file_path.with_suffix(''), 'zip', root_dir=local_pods_dir) + shutil.make_archive(zip_file_path.with_suffix(""), "zip", root_dir=local_pods_dir) # update the podspec to point to the local framework zip file - with open(podspec, 'r') as file: + with open(podspec, "r") as file: file_data = file.read() - file_data = file_data.replace('file:///http_source_placeholder', f'file:///{zip_file_path}') + file_data = file_data.replace("file:///http_source_placeholder", f"file:///{zip_file_path}") - with open(podspec, 'w') as file: + with open(podspec, "w") as file: file.write(file_data) # generate Podfile to point to pod - gen_file_from_template(target_proj_path / "Podfile.template", target_proj_path / "Podfile", - {"C_POD_NAME": pod_name, - "C_POD_PODSPEC": f"./{podspec.name}"}) + gen_file_from_template( + target_proj_path / "Podfile.template", + target_proj_path / "Podfile", + {"C_POD_NAME": pod_name, "C_POD_PODSPEC": f"./{podspec.name}"}, + ) # clean the Cocoapods cache first, in case the same pod was cached in previous runs - subprocess.run(['pod', 'cache', 'clean', '--all'], shell=False, check=True, cwd=target_proj_path) + subprocess.run(["pod", "cache", "clean", "--all"], shell=False, check=True, cwd=target_proj_path) # install pods - subprocess.run(['pod', 'install'], shell=False, check=True, cwd=target_proj_path) + subprocess.run(["pod", "install"], shell=False, check=True, cwd=target_proj_path) # run the tests if not args.prepare_test_project_only: - subprocess.run(['xcrun', 'xcodebuild', 'test', - '-workspace', './ios_package_test.xcworkspace', - '-scheme', 'ios_package_test', - '-destination', 'platform=iOS Simulator,OS=latest,name=iPhone SE (2nd generation)'], - shell=False, check=True, cwd=target_proj_path) + subprocess.run( + [ + "xcrun", + "xcodebuild", + "test", + "-workspace", + "./ios_package_test.xcworkspace", + "-scheme", + "ios_package_test", + "-destination", + "platform=iOS Simulator,OS=latest,name=iPhone SE (2nd generation)", + ], + shell=False, + check=True, + cwd=target_proj_path, + ) def parse_args(): parser = argparse.ArgumentParser( - os.path.basename(__file__), - description='Test iOS framework using CocoaPods package.' + os.path.basename(__file__), description="Test iOS framework using CocoaPods package." ) - parser.add_argument('--fail_if_cocoapods_missing', action='store_true', - help='This script will fail if CocoaPods is not installed, ' - 'will not throw error unless fail_if_cocoapod_missing is set.') + parser.add_argument( + "--fail_if_cocoapods_missing", + action="store_true", + help="This script will fail if CocoaPods is not installed, " + "will not throw error unless fail_if_cocoapod_missing is set.", + ) - parser.add_argument("--framework_info_file", type=pathlib.Path, required=True, - help="Path to the framework_info.json file containing additional values for the podspec. " - "This file should be generated by CMake in the build directory.") + parser.add_argument( + "--framework_info_file", + type=pathlib.Path, + required=True, + help="Path to the framework_info.json file containing additional values for the podspec. " + "This file should be generated by CMake in the build directory.", + ) - parser.add_argument('--c_framework_dir', type=pathlib.Path, required=True, - help='Provide the parent directory for C/C++ framework') + parser.add_argument( + "--c_framework_dir", type=pathlib.Path, required=True, help="Provide the parent directory for C/C++ framework" + ) - parser.add_argument("--variant", choices=PackageVariant.all_variant_names(), default=PackageVariant.Test.name, - help="Pod package variant.") + parser.add_argument( + "--variant", + choices=PackageVariant.all_variant_names(), + default=PackageVariant.Test.name, + help="Pod package variant.", + ) - parser.add_argument('--test_project_stage_dir', type=pathlib.Path, - help='The stage dir for the test project, if not specified, will use a temporary path') + parser.add_argument( + "--test_project_stage_dir", + type=pathlib.Path, + help="The stage dir for the test project, if not specified, will use a temporary path", + ) - parser.add_argument('--prepare_test_project_only', action='store_true', - help='Prepare the test project only, without running the tests') + parser.add_argument( + "--prepare_test_project_only", + action="store_true", + help="Prepare the test project only, without running the tests", + ) return parser.parse_args() @@ -154,5 +183,5 @@ def main(): _test_ios_packages(args) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/tools/ci_build/github/linux/docker/build_scripts/python-tag-abi-tag.py b/tools/ci_build/github/linux/docker/build_scripts/python-tag-abi-tag.py index 942394bbb0e84..f405f033441de 100644 --- a/tools/ci_build/github/linux/docker/build_scripts/python-tag-abi-tag.py +++ b/tools/ci_build/github/linux/docker/build_scripts/python-tag-abi-tag.py @@ -4,7 +4,6 @@ from wheel.vendored.packaging.tags import sys_tags - # first tag is always the more specific tag tag = next(sys_tags()) print("{0}-{1}".format(tag.interpreter, tag.abi)) diff --git a/tools/ci_build/github/linux/ort_minimal/check_build_binary_size.py b/tools/ci_build/github/linux/ort_minimal/check_build_binary_size.py index e430eb51c898c..6d7c70ae953fb 100644 --- a/tools/ci_build/github/linux/ort_minimal/check_build_binary_size.py +++ b/tools/ci_build/github/linux/ort_minimal/check_build_binary_size.py @@ -12,54 +12,61 @@ def _check_binary_size(path, readelf, threshold, os_str, arch, build_config): - print('Checking binary size of {} using {}'.format(path, readelf)) + print("Checking binary size of {} using {}".format(path, readelf)) ondisk_size = os.path.getsize(path) - print('Section:size in bytes') + print("Section:size in bytes") # call get_section_sizes to dump the section info to stdout sections = readelf_utils.get_section_sizes(path, readelf, sys.stdout) sections_total = sum(sections.values()) - print('Sections total={} bytes'.format(sections_total)) - print('File size={} bytes'.format(ondisk_size)) + print("Sections total={} bytes".format(sections_total)) + print("File size={} bytes".format(ondisk_size)) # Write the binary size to a file for uploading later # On-disk binary size jumps in 4KB increments so we use the total of the sections as it has finer granularity. # Note that the sum of the section is slightly larger than the on-disk size # due to packing and/or alignment adjustments. - with open(os.path.join(os.path.dirname(path), 'binary_size_data.txt'), 'w') as file: - file.writelines([ - 'os,arch,build_config,size\n', - '{},{},{},{}\n'.format(os_str, arch, build_config, sections_total) - ]) + with open(os.path.join(os.path.dirname(path), "binary_size_data.txt"), "w") as file: + file.writelines( + ["os,arch,build_config,size\n", "{},{},{},{}\n".format(os_str, arch, build_config, sections_total)] + ) if threshold is not None and sections_total > threshold: - raise RuntimeError('Sections total size for {} of {} exceeds threshold of {} by {}. On-disk size={}' - .format(path, sections_total, threshold, sections_total - threshold, ondisk_size)) + raise RuntimeError( + "Sections total size for {} of {} exceeds threshold of {} by {}. On-disk size={}".format( + path, sections_total, threshold, sections_total - threshold, ondisk_size + ) + ) def main(): - argparser = argparse.ArgumentParser(description='Check the binary size for provided path and ' - 'create a text file for upload to the performance dashboard.') + argparser = argparse.ArgumentParser( + description="Check the binary size for provided path and " + "create a text file for upload to the performance dashboard." + ) # optional - argparser.add_argument('-t', '--threshold', type=int, - help='Return error if binary size exceeds this threshold.') - argparser.add_argument('-r', '--readelf_path', type=str, default='readelf', help='Path to readelf executable.') - argparser.add_argument('--os', type=str, default='android', - help='OS value to include in binary_size_data.txt') - argparser.add_argument('--arch', type=str, default='arm64-v8a', - help='Arch value to include in binary_size_data.txt') - argparser.add_argument('--build_config', type=str, default='minimal-baseline', - help='Build_config value to include in binary_size_data.txt') + argparser.add_argument("-t", "--threshold", type=int, help="Return error if binary size exceeds this threshold.") + argparser.add_argument("-r", "--readelf_path", type=str, default="readelf", help="Path to readelf executable.") + argparser.add_argument("--os", type=str, default="android", help="OS value to include in binary_size_data.txt") + argparser.add_argument( + "--arch", type=str, default="arm64-v8a", help="Arch value to include in binary_size_data.txt" + ) + argparser.add_argument( + "--build_config", + type=str, + default="minimal-baseline", + help="Build_config value to include in binary_size_data.txt", + ) # file to analyze - argparser.add_argument('path', type=os.path.realpath, help='Path to binary to check.') + argparser.add_argument("path", type=os.path.realpath, help="Path to binary to check.") args = argparser.parse_args() _check_binary_size(args.path, args.readelf_path, args.threshold, args.os, args.arch, args.build_config) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/tools/ci_build/github/linux/ort_minimal/readelf_utils.py b/tools/ci_build/github/linux/ort_minimal/readelf_utils.py index 50de3871d26a3..4c0d77805afe7 100644 --- a/tools/ci_build/github/linux/ort_minimal/readelf_utils.py +++ b/tools/ci_build/github/linux/ort_minimal/readelf_utils.py @@ -1,37 +1,37 @@ #!/usr/bin/env python3 # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. -''' +""" Utilities to help analyze the sections in a binary using readelf. -''' +""" import argparse import collections import os import re -import sys import subprocess +import sys def get_section_sizes(binary_path, readelf_path, dump_to_file=None): - ''' + """ Get the size of each section using readelf. :param binary_path: Path to binary to analyze. :param readelf_path: Path to readelf binary. Default is 'readelf'. :param dump_to_file: File object to write section sizes and diagnostic info to. Defaults to None. :return: - ''' + """ - cmd = [readelf_path, '--sections', '--wide', binary_path] + cmd = [readelf_path, "--sections", "--wide", binary_path] result = subprocess.run(cmd, stdout=subprocess.PIPE) result.check_returncode() - output = result.stdout.decode('utf-8') + output = result.stdout.decode("utf-8") section_sizes = {} # Parse output in this format: # [Nr] Name Type Address Off Size ES Flg Lk Inf Al - for match in re.finditer(r'\[[\s\d]+\] (\..*)$', output, re.MULTILINE): + for match in re.finditer(r"\[[\s\d]+\] (\..*)$", output, re.MULTILINE): items = match.group(1).split() name = items[0] # convert size from hex to int @@ -39,20 +39,20 @@ def get_section_sizes(binary_path, readelf_path, dump_to_file=None): section_sizes[name] = size if dump_to_file: - print('{}:{}'.format(name, size), file=dump_to_file) + print("{}:{}".format(name, size), file=dump_to_file) return section_sizes -def diff_sections_total_size(base_binary_path, binary_path, readelf_path='readelf'): - ''' +def diff_sections_total_size(base_binary_path, binary_path, readelf_path="readelf"): + """ Diff the sections entries for two binaries. :param base_binary_path: Path to base binary for diff. :param binary_path: Path to binary to diff using. :param readelf_path: Path to 'readelf' binary. Defaults to 'readelf' :return: Ordered dictionary containing size of diff for all sections with a diff, the diff for the sum of the sections in the 'Sections total' entry, and the diff for the on-disk file size in the 'File size' entry - ''' + """ filesize = os.path.getsize(binary_path) base_filesize = os.path.getsize(base_binary_path) @@ -75,45 +75,50 @@ def diff_sections_total_size(base_binary_path, binary_path, readelf_path='readel if size != base_size: results[section] = size - base_size - results['Sections total'] = total - base_total - results['File size'] = filesize - base_filesize + results["Sections total"] = total - base_total + results["File size"] = filesize - base_filesize return results def main(): argparser = argparse.ArgumentParser( - description='Analyze sections in a binary using readelf. ' - 'Perform a diff between two binaries if --base_binary_path is specified.') - - argparser.add_argument('-r', '--readelf_path', type=str, - help='Path to readelf executable.') - argparser.add_argument('-b', '--base_binary_path', type=os.path.realpath, - default=None, help='Path to base binary if performing a diff between two binaries.') - argparser.add_argument('-w', '--write_to', type=str, default=None, - help='Path to write output to. Writes to stdout if not provided.') - argparser.add_argument('binary_path', type=os.path.realpath, - help='Shared library to analyze.') + description="Analyze sections in a binary using readelf. " + "Perform a diff between two binaries if --base_binary_path is specified." + ) + + argparser.add_argument("-r", "--readelf_path", type=str, help="Path to readelf executable.") + argparser.add_argument( + "-b", + "--base_binary_path", + type=os.path.realpath, + default=None, + help="Path to base binary if performing a diff between two binaries.", + ) + argparser.add_argument( + "-w", "--write_to", type=str, default=None, help="Path to write output to. Writes to stdout if not provided." + ) + argparser.add_argument("binary_path", type=os.path.realpath, help="Shared library to analyze.") args = argparser.parse_args() out_file = sys.stdout if args.write_to: - out_file = open(args.write_to, 'w') + out_file = open(args.write_to, "w") if args.base_binary_path: diffs = diff_sections_total_size(args.base_binary_path, args.binary_path, args.readelf_path) for key, value in diffs.items(): - print('{}:{}'.format(key, value), file=out_file) + print("{}:{}".format(key, value), file=out_file) else: section_sizes = get_section_sizes(args.binary_path, args.readelf_path, out_file) filesize = os.path.getsize(args.binary_path) - print('Sections total:{}'.format(sum(section_sizes.values())), file=out_file) - print('File size:{}'.format(filesize), file=out_file) + print("Sections total:{}".format(sum(section_sizes.values())), file=out_file) + print("File size:{}".format(filesize), file=out_file) if args.write_to: out_file.close() -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/tools/ci_build/github/windows/post_binary_sizes_to_dashboard.py b/tools/ci_build/github/windows/post_binary_sizes_to_dashboard.py index 6468366d68c6d..6d003858e39cf 100644 --- a/tools/ci_build/github/windows/post_binary_sizes_to_dashboard.py +++ b/tools/ci_build/github/windows/post_binary_sizes_to_dashboard.py @@ -4,34 +4,34 @@ import argparse -import sys -import os import datetime +import os +import sys + # ingest from dataframe import pandas -from azure.kusto.data import ( - DataFormat, - KustoConnectionStringBuilder, -) -from azure.kusto.ingest import ( - IngestionProperties, - ReportLevel, - QueuedIngestClient, -) +from azure.kusto.data import DataFormat, KustoConnectionStringBuilder +from azure.kusto.ingest import IngestionProperties, QueuedIngestClient, ReportLevel def parse_arguments(): parser = argparse.ArgumentParser(description="ONNXRuntime binary size uploader for dashboard") parser.add_argument("--commit_hash", help="Full Git commit hash") - parser.add_argument("--build_project", default='Lotus', choices=['Lotus', 'onnxruntime'], - help="Lotus or onnxruntime build project, to construct the build URL") + parser.add_argument( + "--build_project", + default="Lotus", + choices=["Lotus", "onnxruntime"], + help="Lotus or onnxruntime build project, to construct the build URL", + ) parser.add_argument("--build_id", help="Build Id") parser.add_argument("--size_data_file", help="Path to file that contains the binary size data") - parser.add_argument("--ignore_db_error", action='store_true', - help="Ignore database errors while executing this script") + parser.add_argument( + "--ignore_db_error", action="store_true", help="Ignore database errors while executing this script" + ) return parser.parse_args() + # Assumes size_data_file is a csv file with a header line, containing binary sizes and other attributes # CSV fields are: # os,arch,build_config,size @@ -40,17 +40,17 @@ def parse_arguments(): def get_binary_sizes(size_data_file): binary_size = [] - with open(size_data_file, 'r') as f: + with open(size_data_file, "r") as f: line = f.readline() - headers = line.strip().split(',') + headers = line.strip().split(",") while line: line = f.readline() if not line: break - linedata = line.strip().split(',') + linedata = line.strip().split(",") tablerow = {} for i in range(0, len(headers)): - if headers[i] == 'size': + if headers[i] == "size": tablerow[headers[i]] = int(linedata[i]) else: tablerow[headers[i]] = linedata[i] @@ -66,23 +66,27 @@ def write_to_db(binary_size_data, args): client = QueuedIngestClient(kcsb) fields = ["build_time", "build_id", "build_project", "commit_id", "os", "arch", "build_config", "size", "Branch"] now_str = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") - branch_name = os.environ.get('BUILD_SOURCEBRANCHNAME', 'master') + branch_name = os.environ.get("BUILD_SOURCEBRANCHNAME", "master") rows = [] for row in binary_size_data: - rows.append([now_str, - args.build_id, - args.build_project, - args.commit_hash, - row['os'], - row['arch'], - row['build_config'], - row['size'], - branch_name.lower()]) + rows.append( + [ + now_str, + args.build_id, + args.build_project, + args.commit_hash, + row["os"], + row["arch"], + row["build_config"], + row["size"], + branch_name.lower(), + ] + ) ingestion_props = IngestionProperties( - database="powerbi", - table="binary_size", - data_format=DataFormat.CSV, - report_level=ReportLevel.FailuresAndSuccesses + database="powerbi", + table="binary_size", + data_format=DataFormat.CSV, + report_level=ReportLevel.FailuresAndSuccesses, ) df = pandas.DataFrame(data=rows, columns=fields) client.ingest_from_dataframe(df, ingestion_properties=ingestion_props) diff --git a/tools/ci_build/github/windows/post_code_coverage_to_dashboard.py b/tools/ci_build/github/windows/post_code_coverage_to_dashboard.py index 182c2160336ae..424910e18dc73 100755 --- a/tools/ci_build/github/windows/post_code_coverage_to_dashboard.py +++ b/tools/ci_build/github/windows/post_code_coverage_to_dashboard.py @@ -9,28 +9,20 @@ # --commit_hash= import argparse +import datetime import json import sys -import datetime + # ingest from dataframe import pandas -from azure.kusto.data import ( - DataFormat, - KustoConnectionStringBuilder, -) -from azure.kusto.ingest import ( - IngestionProperties, - ReportLevel, - QueuedIngestClient, -) +from azure.kusto.data import DataFormat, KustoConnectionStringBuilder +from azure.kusto.ingest import IngestionProperties, QueuedIngestClient, ReportLevel def parse_arguments(): - parser = argparse.ArgumentParser( - description="ONNXRuntime test coverage report uploader for dashboard") + parser = argparse.ArgumentParser(description="ONNXRuntime test coverage report uploader for dashboard") parser.add_argument("--report_url", type=str, help="URL to the LLVM json report") - parser.add_argument( - "--report_file", type=str, help="Path to the local JSON/TXT report", required=True) + parser.add_argument("--report_file", type=str, help="Path to the local JSON/TXT report", required=True) parser.add_argument("--commit_hash", type=str, help="Full Git commit hash", required=True) parser.add_argument("--branch", type=str, help="Source code branch") parser.add_argument("--os", type=str, help="Build configuration:os") @@ -41,13 +33,13 @@ def parse_arguments(): def parse_txt_report(report_file): data = {} - with open(report_file, 'r') as report: + with open(report_file, "r") as report: for line in reversed(report.readlines()): - if 'TOTAL' in line: + if "TOTAL" in line: fields = line.strip().split() - data['lines_valid'] = int(fields[1]) - data['lines_covered'] = int(fields[2]) - data['coverage'] = float(fields[3].strip('%'))/100 + data["lines_valid"] = int(fields[1]) + data["lines_covered"] = int(fields[2]) + data["coverage"] = float(fields[3].strip("%")) / 100 break return data @@ -57,10 +49,10 @@ def parse_json_report(report_file): with open(report_file) as json_file: data = json.load(json_file) - linestat = data['data'][0]['totals']['lines'] - result['coverage'] = float(linestat['percent']/100.0) - result['lines_covered'] = int(linestat['covered']) - result['lines_valid'] = int(linestat['count']) + linestat = data["data"][0]["totals"]["lines"] + result["coverage"] = float(linestat["percent"] / 100.0) + result["lines_covered"] = int(linestat["covered"]) + result["lines_valid"] = int(linestat["count"]) return result @@ -70,21 +62,38 @@ def write_to_db(coverage_data, args): kcsb = KustoConnectionStringBuilder.with_az_cli_authentication(cluster) # The authentication method will be taken from the chosen KustoConnectionStringBuilder. client = QueuedIngestClient(kcsb) - fields = ["UploadTime", "CommitId", "Coverage", "LinesCovered", "TotalLines", "OS", "Arch", "BuildConfig", - "ReportURL", "Branch"] - now_str = datetime.datetime.now() .strftime("%Y-%m-%d %H:%M:%S") - rows = [[now_str, args.commit_hash, coverage_data['coverage'], - coverage_data['lines_covered'], - coverage_data['lines_valid'], args.os.lower(), - args.arch.lower(), - args.build_config.lower(), - args.report_url.lower(), - args.branch.lower()]] + fields = [ + "UploadTime", + "CommitId", + "Coverage", + "LinesCovered", + "TotalLines", + "OS", + "Arch", + "BuildConfig", + "ReportURL", + "Branch", + ] + now_str = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") + rows = [ + [ + now_str, + args.commit_hash, + coverage_data["coverage"], + coverage_data["lines_covered"], + coverage_data["lines_valid"], + args.os.lower(), + args.arch.lower(), + args.build_config.lower(), + args.report_url.lower(), + args.branch.lower(), + ] + ] ingestion_props = IngestionProperties( - database="powerbi", - table="test_coverage", - data_format=DataFormat.CSV, - report_level=ReportLevel.FailuresAndSuccesses + database="powerbi", + table="test_coverage", + data_format=DataFormat.CSV, + report_level=ReportLevel.FailuresAndSuccesses, ) df = pandas.DataFrame(data=rows, columns=fields) client.ingest_from_dataframe(df, ingestion_properties=ingestion_props) diff --git a/tools/ci_build/logger.py b/tools/ci_build/logger.py index c15fad76e329e..9deb4475721ee 100644 --- a/tools/ci_build/logger.py +++ b/tools/ci_build/logger.py @@ -5,8 +5,6 @@ def get_logger(name): - logging.basicConfig( - format="%(asctime)s %(name)s [%(levelname)s] - %(message)s", - level=logging.DEBUG) + logging.basicConfig(format="%(asctime)s %(name)s [%(levelname)s] - %(message)s", level=logging.DEBUG) return logging.getLogger(name) diff --git a/tools/ci_build/op_registration_utils.py b/tools/ci_build/op_registration_utils.py index 97d551f163411..a0c27a39594e1 100644 --- a/tools/ci_build/op_registration_utils.py +++ b/tools/ci_build/op_registration_utils.py @@ -1,9 +1,9 @@ # !/usr/bin/env python3 # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. -''' +""" Utilities to help process files containing kernel registrations. -''' +""" import os import sys @@ -15,52 +15,56 @@ def map_ort_constant_to_domain(ort_constant_name: str): - ''' + """ Map the name of the internal ONNX Runtime constant used in operator kernel registrations to the domain name used in ONNX models and configuration files. :param ort_constant_name: ONNX Runtime constant name for the domain from a kernel registration entry. :return: String with public domain name. - ''' + """ # constants are defined in /include/onnxruntime/core/graph/constants.h - constant_to_domain_map = {'kOnnxDomain': 'ai.onnx', - 'kMLDomain': 'ai.onnx.ml', - 'kMSDomain': 'com.microsoft', - 'kMSExperimentalDomain': 'com.microsoft.experimental', - 'kMSNchwcDomain': 'com.microsoft.nchwc', - 'kMSDmlDomain': 'com.microsoft.dml', - 'kNGraphDomain': 'com.intel.ai', - 'kVitisAIDomain': 'com.xilinx'} + constant_to_domain_map = { + "kOnnxDomain": "ai.onnx", + "kMLDomain": "ai.onnx.ml", + "kMSDomain": "com.microsoft", + "kMSExperimentalDomain": "com.microsoft.experimental", + "kMSNchwcDomain": "com.microsoft.nchwc", + "kMSDmlDomain": "com.microsoft.dml", + "kNGraphDomain": "com.intel.ai", + "kVitisAIDomain": "com.xilinx", + } if ort_constant_name in constant_to_domain_map: return constant_to_domain_map[ort_constant_name] else: - log.warning('Unknown domain for ONNX Runtime constant of {}.'.format(ort_constant_name)) + log.warning("Unknown domain for ONNX Runtime constant of {}.".format(ort_constant_name)) return None def get_kernel_registration_files(ort_root=None, include_cuda=False): - ''' + """ Return paths to files containing kernel registrations for CPU and CUDA providers. :param ort_root: ORT repository root directory. Inferred from the location of this script if not provided. :param include_cuda: Include the CUDA registrations in the list of files. :return: list[str] containing the kernel registration filenames. - ''' + """ if not ort_root: - ort_root = os.path.dirname(os.path.abspath(__file__)) + '/../..' + ort_root = os.path.dirname(os.path.abspath(__file__)) + "/../.." - provider_path = ort_root + '/onnxruntime/core/providers/{ep}/{ep}_execution_provider.cc' - contrib_provider_path = ort_root + '/onnxruntime/contrib_ops/{ep}/{ep}_contrib_kernels.cc' - training_provider_path = ort_root + '/orttraining/orttraining/training_ops/{ep}/{ep}_training_kernels.cc' - provider_paths = [provider_path.format(ep='cpu'), - contrib_provider_path.format(ep='cpu'), - training_provider_path.format(ep='cpu')] + provider_path = ort_root + "/onnxruntime/core/providers/{ep}/{ep}_execution_provider.cc" + contrib_provider_path = ort_root + "/onnxruntime/contrib_ops/{ep}/{ep}_contrib_kernels.cc" + training_provider_path = ort_root + "/orttraining/orttraining/training_ops/{ep}/{ep}_training_kernels.cc" + provider_paths = [ + provider_path.format(ep="cpu"), + contrib_provider_path.format(ep="cpu"), + training_provider_path.format(ep="cpu"), + ] if include_cuda: - provider_paths.append(provider_path.format(ep='cuda')) - provider_paths.append(contrib_provider_path.format(ep='cuda')) - provider_paths.append(training_provider_path.format(ep='cuda')) + provider_paths.append(provider_path.format(ep="cuda")) + provider_paths.append(contrib_provider_path.format(ep="cuda")) + provider_paths.append(training_provider_path.format(ep="cuda")) provider_paths = [os.path.abspath(p) for p in provider_paths] @@ -68,16 +72,22 @@ def get_kernel_registration_files(ort_root=None, include_cuda=False): class RegistrationProcessor: - ''' + """ Class to process lines that are extracted from a kernel registration file. For each kernel registration, process_registration is called. For all other lines, process_other_line is called. - ''' - - def process_registration(self, lines: typing.List[str], domain: str, operator: str, - start_version: int, end_version: typing.Optional[int] = None, - type: typing.Optional[str] = None): - ''' + """ + + def process_registration( + self, + lines: typing.List[str], + domain: str, + operator: str, + start_version: int, + end_version: typing.Optional[int] = None, + type: typing.Optional[str] = None, + ): + """ Process lines that contain a kernel registration. :param lines: Array containing the original lines containing the kernel registration. :param domain: Domain for the operator @@ -85,43 +95,43 @@ def process_registration(self, lines: typing.List[str], domain: str, operator: s :param start_version: Start version :param end_version: End version or None if unversioned registration :param type: Type used in registration, if this is a typed registration - ''' + """ pass def process_other_line(self, line): - ''' + """ Process a line that does not contain a kernel registration :param line: Original line - ''' + """ pass def ok(self): - ''' + """ Get overall status for processing :return: True if successful. False if not. Error will be logged as the registrations are processed. - ''' + """ return False # return False as the derived class must override to report the real status def _process_lines(lines: typing.List[str], offset: int, registration_processor: RegistrationProcessor): - ''' + """ Process one or more lines that contain a kernel registration. Merge lines if split over multiple, and call registration_processor.process_registration with the original lines and the registration information. :return: Offset for first line that was not consumed. - ''' + """ - onnx_op = 'ONNX_OPERATOR_KERNEL_CLASS_NAME' + onnx_op = "ONNX_OPERATOR_KERNEL_CLASS_NAME" onnx_op_len = len(onnx_op) - onnx_typed_op = 'ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME' + onnx_typed_op = "ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME" onnx_typed_op_len = len(onnx_typed_op) - onnx_versioned_op = 'ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME' + onnx_versioned_op = "ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME" onnx_versioned_op_len = len(onnx_versioned_op) - onnx_versioned_typed_op = 'ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME' + onnx_versioned_typed_op = "ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME" onnx_versioned_typed_op_len = len(onnx_versioned_typed_op) - end_marks = tuple([');', ')>', ')>,', ')>,};', ')>};']) + end_marks = tuple([");", ")>", ")>,", ")>,};", ")>};"]) - end_mark = '' + end_mark = "" lines_to_process = [] # merge line if split over multiple. @@ -142,47 +152,49 @@ def _process_lines(lines: typing.List[str], offset: int, registration_processor: offset += 1 if offset > len(lines): - log.error('Past end of input lines looking for line terminator.') + log.error("Past end of input lines looking for line terminator.") sys.exit(-1) - code_line = ''.join([line.strip() for line in lines_to_process]) + code_line = "".join([line.strip() for line in lines_to_process]) if onnx_op in code_line: # e.g. BuildKernelCreateInfo, trim_at = code_line.index(onnx_op) + onnx_op_len + 1 - *_, domain, start_version, op_type = \ - [arg.strip() for arg in code_line[trim_at: -len(end_mark)].split(',')] + *_, domain, start_version, op_type = [arg.strip() for arg in code_line[trim_at : -len(end_mark)].split(",")] - registration_processor.process_registration(lines_to_process, domain, op_type, - int(start_version), None, None) + registration_processor.process_registration(lines_to_process, domain, op_type, int(start_version), None, None) elif onnx_typed_op in code_line: # e.g. BuildKernelCreateInfo, trim_at = code_line.index(onnx_typed_op) + onnx_typed_op_len + 1 - *_, domain, start_version, type, op_type = \ - [arg.strip() for arg in code_line[trim_at: -len(end_mark)].split(',')] - registration_processor.process_registration(lines_to_process, domain, op_type, - int(start_version), None, type) + *_, domain, start_version, type, op_type = [ + arg.strip() for arg in code_line[trim_at : -len(end_mark)].split(",") + ] + registration_processor.process_registration(lines_to_process, domain, op_type, int(start_version), None, type) elif onnx_versioned_op in code_line: # e.g. BuildKernelCreateInfo, trim_at = code_line.index(onnx_versioned_op) + onnx_versioned_op_len + 1 - *_, domain, start_version, end_version, op_type = \ - [arg.strip() for arg in code_line[trim_at: -len(end_mark)].split(',')] - registration_processor.process_registration(lines_to_process, domain, op_type, - int(start_version), int(end_version), None) + *_, domain, start_version, end_version, op_type = [ + arg.strip() for arg in code_line[trim_at : -len(end_mark)].split(",") + ] + registration_processor.process_registration( + lines_to_process, domain, op_type, int(start_version), int(end_version), None + ) elif onnx_versioned_typed_op in code_line: # e.g. BuildKernelCreateInfo, trim_at = code_line.index(onnx_versioned_typed_op) + onnx_versioned_typed_op_len + 1 - *_, domain, start_version, end_version, type, op_type = \ - [arg.strip() for arg in code_line[trim_at: -len(end_mark)].split(',')] - registration_processor.process_registration(lines_to_process, domain, op_type, - int(start_version), int(end_version), type) + *_, domain, start_version, end_version, type, op_type = [ + arg.strip() for arg in code_line[trim_at : -len(end_mark)].split(",") + ] + registration_processor.process_registration( + lines_to_process, domain, op_type, int(start_version), int(end_version), type + ) else: log.warning("Ignoring unhandled kernel registration variant: {}".format(code_line)) @@ -193,19 +205,19 @@ def _process_lines(lines: typing.List[str], offset: int, registration_processor: def process_kernel_registration_file(filename: str, registration_processor: RegistrationProcessor): - ''' + """ Process a kernel registration file using registration_processor. :param filename: Path to file containing kernel registrations. :param registration_processor: Processor to be used. :return True if processing was successful. - ''' + """ if not os.path.isfile(filename): - log.error('File not found: {}'.format(filename)) + log.error("File not found: {}".format(filename)) return False lines = [] - with open(filename, 'r') as file_to_read: + with open(filename, "r") as file_to_read: lines = file_to_read.readlines() offset = 0 @@ -214,7 +226,7 @@ def process_kernel_registration_file(filename: str, registration_processor: Regi line = lines[offset] stripped = line.strip() - if stripped.startswith('BuildKernelCreateInfo typing.Tuple[bool, str]: - '''See if an op is required.''' + def _is_op_required( + self, domain: str, operator: str, start_version: int, end_version: typing.Optional[int] + ) -> typing.Tuple[bool, str]: + """See if an op is required.""" if self._required_ops is None: return True @@ -50,11 +55,18 @@ def _is_op_required(self, domain: str, operator: str, return False - def process_registration(self, lines: typing.List[str], constant_for_domain: str, operator: str, - start_version: int, end_version: typing.Optional[int] = None, - type: typing.Optional[str] = None): - registration_identifier = '{}:{}({}){}'.format(constant_for_domain, operator, start_version, - '<{}>'.format(type) if type else '') + def process_registration( + self, + lines: typing.List[str], + constant_for_domain: str, + operator: str, + start_version: int, + end_version: typing.Optional[int] = None, + type: typing.Optional[str] = None, + ): + registration_identifier = "{}:{}({}){}".format( + constant_for_domain, operator, start_version, "<{}>".format(type) if type else "" + ) # convert from the ORT constant name to the domain string used in the config domain = op_registration_utils.map_ort_constant_to_domain(constant_for_domain) @@ -72,17 +84,18 @@ def process_registration(self, lines: typing.List[str], constant_for_domain: str exclude = True reason = "Specific typed registration is not required." else: - log.warning('Keeping {} registration from unknown domain: {}' - .format(registration_identifier, constant_for_domain)) + log.warning( + "Keeping {} registration from unknown domain: {}".format(registration_identifier, constant_for_domain) + ) if exclude: - log.info('Disabling {} registration: {}'.format(registration_identifier, reason)) + log.info("Disabling {} registration: {}".format(registration_identifier, reason)) for line in lines: - self._output_file.write('// ' + line) + self._output_file.write("// " + line) # edge case of last entry in table where we still need the terminating }; to not be commented out - if lines[-1].rstrip().endswith('};'): - self._output_file.write('};\n') + if lines[-1].rstrip().endswith("};"): + self._output_file.write("};\n") else: for line in lines: self._output_file.write(line) @@ -95,27 +108,30 @@ def ok(self): def _get_op_reduction_file_path(ort_root: Path, build_dir: Path, original_path: typing.Optional[Path] = None): - ''' + """ Return the op reduction file path corresponding to `original_path` or the op reduction file root if unspecified. Op reduction files are in a subdirectory of `build_dir` but otherwise share the same components of `original_path` relative to `ort_root`. - ''' + """ op_reduction_root = Path(build_dir, OP_REDUCTION_DIR) - return (op_reduction_root / original_path.relative_to(ort_root)) if original_path is not None \ - else op_reduction_root + return (op_reduction_root / original_path.relative_to(ort_root)) if original_path is not None else op_reduction_root def _generate_provider_registrations( - ort_root: Path, build_dir: Path, use_cuda: bool, - required_ops: typing.Optional[dict], - op_type_impl_filter: typing.Optional[OpTypeImplFilterInterface]): - '''Generate provider registration files.''' - kernel_registration_files = [Path(f) for f in - op_registration_utils.get_kernel_registration_files(str(ort_root), use_cuda)] + ort_root: Path, + build_dir: Path, + use_cuda: bool, + required_ops: typing.Optional[dict], + op_type_impl_filter: typing.Optional[OpTypeImplFilterInterface], +): + """Generate provider registration files.""" + kernel_registration_files = [ + Path(f) for f in op_registration_utils.get_kernel_registration_files(str(ort_root), use_cuda) + ] for kernel_registration_file in kernel_registration_files: if not kernel_registration_file.is_file(): - raise ValueError(f'Kernel registration file does not exist: {kernel_registration_file}') + raise ValueError(f"Kernel registration file does not exist: {kernel_registration_file}") log.info("Processing {}".format(kernel_registration_file)) @@ -125,7 +141,7 @@ def _generate_provider_registrations( # read from original and create the reduced kernel def file with commented out lines for any kernels that are # not required - with open(reduced_path, 'w') as file_to_write: + with open(reduced_path, "w") as file_to_write: processor = _ExcludingRegistrationProcessor(required_ops, op_type_impl_filter, file_to_write) op_registration_utils.process_kernel_registration_file(kernel_registration_file, processor) @@ -136,13 +152,13 @@ def _generate_provider_registrations( def _generate_type_control_overrides(ort_root: Path, build_dir: Path, cpp_lines: typing.Sequence[str]): - ''' + """ Generate type control overrides. Insert applicable C++ code to specify operator type requirements. :param ort_root: Root of the ONNX Runtime repository :param build_dir: Path to the build directory :param cpp_lines: The C++ code to insert - ''' - src = Path(ort_root, 'onnxruntime', 'core', 'providers', 'op_kernel_type_control_overrides.inc') + """ + src = Path(ort_root, "onnxruntime", "core", "providers", "op_kernel_type_control_overrides.inc") if not src.is_file(): raise ValueError(f"Op kernel type control overrides file does not exist: {src}") @@ -157,17 +173,17 @@ def _generate_type_control_overrides(ort_root: Path, build_dir: Path, cpp_lines: if cpp_lines: # find the insertion block and replace any existing content in it inserted = False - with open(src, 'r') as input, open(target, 'w') as output: + with open(src, "r") as input, open(target, "w") as output: inside_insertion_block = False for line in input.readlines(): - if '@@insertion_point_begin(allowed_types)@@' in line: + if "@@insertion_point_begin(allowed_types)@@" in line: inside_insertion_block = True output.write(line) - [output.write('{}\n'.format(code_line)) for code_line in cpp_lines] + [output.write("{}\n".format(code_line)) for code_line in cpp_lines] inserted = True continue elif inside_insertion_block: - if '@@insertion_point_end(allowed_types)@@' in line: + if "@@insertion_point_end(allowed_types)@@" in line: inside_insertion_block = False else: # we ignore any old lines within the insertion block @@ -176,17 +192,17 @@ def _generate_type_control_overrides(ort_root: Path, build_dir: Path, cpp_lines: output.write(line) if not inserted: - raise RuntimeError('Insertion point was not found in {}'.format(target)) + raise RuntimeError("Insertion point was not found in {}".format(target)) def reduce_ops(config_path: str, build_dir: str, enable_type_reduction: bool = False, use_cuda: bool = True): - ''' + """ Reduce op kernel implementations. :param config_path: Path to configuration file that specifies the ops to include :param build_dir: Path to the build directory. The op reduction files will be generated under the build directory. :param enable_type_reduction: Whether per operator type reduction is enabled :param use_cuda: Whether to reduce op kernels for the CUDA provider - ''' + """ build_dir = Path(build_dir).resolve() build_dir.mkdir(parents=True, exist_ok=True) @@ -207,26 +223,35 @@ def reduce_ops(config_path: str, build_dir: str, enable_type_reduction: bool = F if __name__ == "__main__": parser = argparse.ArgumentParser( description="Reduces operator kernel implementations in ONNX Runtime. " - "Entire op implementations or op implementations for specific types may be pruned.") - - parser.add_argument("config_path", type=str, - help="Path to configuration file. " - "Create with /tools/python/create_reduced_build_config.py and edit if needed. " - "See /docs/ONNX_Runtime_Format_Model_Usage.md for more information.") - - parser.add_argument("--cmake_build_dir", type=str, required=True, - help="Path to the build directory. " - "The op reduction files will be generated under the build directory.") - - parser.add_argument("--enable_type_reduction", action="store_true", - help="Whether per operator type reduction is enabled.") - - parser.add_argument("--use_cuda", action="store_true", - help="Whether to reduce op kernels for the CUDA provider.") + "Entire op implementations or op implementations for specific types may be pruned." + ) + + parser.add_argument( + "config_path", + type=str, + help="Path to configuration file. " + "Create with /tools/python/create_reduced_build_config.py and edit if needed. " + "See /docs/ONNX_Runtime_Format_Model_Usage.md for more information.", + ) + + parser.add_argument( + "--cmake_build_dir", + type=str, + required=True, + help="Path to the build directory. " "The op reduction files will be generated under the build directory.", + ) + + parser.add_argument( + "--enable_type_reduction", action="store_true", help="Whether per operator type reduction is enabled." + ) + + parser.add_argument("--use_cuda", action="store_true", help="Whether to reduce op kernels for the CUDA provider.") args = parser.parse_args() - reduce_ops(config_path=args.config_path, - build_dir=args.cmake_build_dir, - enable_type_reduction=args.enable_type_reduction, - use_cuda=args.use_cuda) + reduce_ops( + config_path=args.config_path, + build_dir=args.cmake_build_dir, + enable_type_reduction=args.enable_type_reduction, + use_cuda=args.use_cuda, + ) diff --git a/tools/ci_build/upload_python_package_to_azure_storage.py b/tools/ci_build/upload_python_package_to_azure_storage.py index c30a1d8330346..05ec9df8dfa9b 100755 --- a/tools/ci_build/upload_python_package_to_azure_storage.py +++ b/tools/ci_build/upload_python_package_to_azure_storage.py @@ -2,17 +2,17 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. -import os import argparse -import warnings -import subprocess import logging +import os +import subprocess +import warnings log = logging.getLogger("Build") def parse_nightly_and_local_version_from_whl_name(blob_name): - night_build = 'nightly' if blob_name.find(".dev") > 0 else 'stable' + night_build = "nightly" if blob_name.find(".dev") > 0 else "stable" start = blob_name.find("+") if start == -1: @@ -32,20 +32,26 @@ def run_subprocess(args, cwd=None): def upload_whl(python_wheel_path, final_storage=False): storage_account_name = "onnxruntimepackages" if final_storage else "onnxruntimepackagesint" blob_name = os.path.basename(python_wheel_path) - run_subprocess(['azcopy', 'cp', python_wheel_path, f'https://{storage_account_name}.blob.core.windows.net/$web/']) + run_subprocess(["azcopy", "cp", python_wheel_path, f"https://{storage_account_name}.blob.core.windows.net/$web/"]) nightly_build, local_version = parse_nightly_and_local_version_from_whl_name(blob_name) if local_version: - html_blob_name = 'onnxruntime_{}_{}.html'.format(nightly_build, local_version) + html_blob_name = "onnxruntime_{}_{}.html".format(nightly_build, local_version) else: - html_blob_name = 'onnxruntime_{}.html'.format(nightly_build) + html_blob_name = "onnxruntime_{}.html".format(nightly_build) download_path_to_html = "./onnxruntime_{}.html".format(nightly_build) - run_subprocess(['azcopy', 'cp', f'https://{storage_account_name}.blob.core.windows.net/$web/'+html_blob_name, - download_path_to_html]) + run_subprocess( + [ + "azcopy", + "cp", + f"https://{storage_account_name}.blob.core.windows.net/$web/" + html_blob_name, + download_path_to_html, + ] + ) - blob_name_plus_replaced = blob_name.replace('+', '%2B') + blob_name_plus_replaced = blob_name.replace("+", "%2B") with open(download_path_to_html) as f: lines = f.read().splitlines() @@ -54,21 +60,30 @@ def upload_whl(python_wheel_path, final_storage=False): lines.append(new_line) lines.sort() - with open(download_path_to_html, 'w') as f: + with open(download_path_to_html, "w") as f: for item in lines: f.write("%s\n" % item) else: warnings.warn("'{}' exists in {}. The html file is not updated.".format(new_line, download_path_to_html)) - run_subprocess(['azcopy', 'cp', download_path_to_html, - f'https://{storage_account_name}.blob.core.windows.net/$web/'+html_blob_name, - '--content-type', 'text/html', '--overwrite', 'true']) + run_subprocess( + [ + "azcopy", + "cp", + download_path_to_html, + f"https://{storage_account_name}.blob.core.windows.net/$web/" + html_blob_name, + "--content-type", + "text/html", + "--overwrite", + "true", + ] + ) if __name__ == "__main__": parser = argparse.ArgumentParser(description="Upload python whl to azure storage.") parser.add_argument("--python_wheel_path", type=str, help="path to python wheel") - parser.add_argument("--final_storage", action='store_true', help="upload to final storage") + parser.add_argument("--final_storage", action="store_true", help="upload to final storage") args = parser.parse_args() diff --git a/tools/ci_build/upload_python_package_to_azure_storage_with_python.py b/tools/ci_build/upload_python_package_to_azure_storage_with_python.py index da7e4a8ea5e48..4486dac1baf53 100644 --- a/tools/ci_build/upload_python_package_to_azure_storage_with_python.py +++ b/tools/ci_build/upload_python_package_to_azure_storage_with_python.py @@ -2,14 +2,15 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. -import os import argparse +import os import warnings + from azure.storage.blob import BlockBlobService, ContentSettings def parse_nightly_and_local_version_from_whl_name(blob_name): - night_build = 'nightly' if blob_name.find(".dev") > 0 else 'stable' + night_build = "nightly" if blob_name.find(".dev") > 0 else "stable" start = blob_name.find("+") if start == -1: @@ -22,25 +23,22 @@ def parse_nightly_and_local_version_from_whl_name(blob_name): def upload_whl(python_wheel_path, account_name, account_key, container_name): - block_blob_service = BlockBlobService( - account_name=account_name, - account_key=account_key - ) + block_blob_service = BlockBlobService(account_name=account_name, account_key=account_key) blob_name = os.path.basename(python_wheel_path) block_blob_service.create_blob_from_path(container_name, blob_name, python_wheel_path) nightly_build, local_version = parse_nightly_and_local_version_from_whl_name(blob_name) if local_version: - html_blob_name = 'onnxruntime_{}_{}.html'.format(nightly_build, local_version) + html_blob_name = "onnxruntime_{}_{}.html".format(nightly_build, local_version) else: - html_blob_name = 'onnxruntime_{}.html'.format(nightly_build) + html_blob_name = "onnxruntime_{}.html".format(nightly_build) download_path_to_html = "./onnxruntime_{}.html".format(nightly_build) block_blob_service.get_blob_to_path(container_name, html_blob_name, download_path_to_html) - blob_name_plus_replaced = blob_name.replace('+', '%2B') + blob_name_plus_replaced = blob_name.replace("+", "%2B") with open(download_path_to_html) as f: lines = f.read().splitlines() @@ -49,18 +47,16 @@ def upload_whl(python_wheel_path, account_name, account_key, container_name): lines.append(new_line) lines.sort() - with open(download_path_to_html, 'w') as f: + with open(download_path_to_html, "w") as f: for item in lines: f.write("%s\n" % item) else: warnings.warn("'{}' exists in {}. The html file is not updated.".format(new_line, download_path_to_html)) - content_settings = ContentSettings(content_type='text/html') + content_settings = ContentSettings(content_type="text/html") block_blob_service.create_blob_from_path( - container_name, - html_blob_name, - download_path_to_html, - content_settings=content_settings) + container_name, html_blob_name, download_path_to_html, content_settings=content_settings + ) if __name__ == "__main__": diff --git a/tools/doc/rename_folders.py b/tools/doc/rename_folders.py index d19fb86482132..d65b8a350eed1 100644 --- a/tools/doc/rename_folders.py +++ b/tools/doc/rename_folders.py @@ -15,11 +15,11 @@ def rename_folder(root): found = [] for r, dirs, files in os.walk(root): for name in dirs: - if name.startswith('_'): + if name.startswith("_"): found.append((r, name)) renamed = [] for r, name in found: - into = name.lstrip('_') + into = name.lstrip("_") renamed.append((r, name, into)) full_src = os.path.join(r, name) full_into = os.path.join(r, into) @@ -33,13 +33,11 @@ def rename_folder(root): def replace_files(root, renamed): subs = {r[1]: r[2] for r in renamed} - reg = re.compile( - "(\\\"[a-zA-Z0-9\\.\\/\\?\\:@\\-_=#]+\\.([a-zA-Z]){2,6}" - "([a-zA-Z0-9\\.\\&\\/\\?\\:@\\-_=#])*\\\")") + reg = re.compile('(\\"[a-zA-Z0-9\\.\\/\\?\\:@\\-_=#]+\\.([a-zA-Z]){2,6}' '([a-zA-Z0-9\\.\\&\\/\\?\\:@\\-_=#])*\\")') for r, dirs, files in os.walk(root): for name in files: - if os.path.splitext(name)[-1] != '.html': + if os.path.splitext(name)[-1] != ".html": continue full = os.path.join(r, name) with open(full, "r", encoding="utf-8") as f: @@ -54,8 +52,8 @@ def replace_files(root, renamed): raise ValueError("%r == %r" % (k, v)) if ('"%s' % k) in f[0]: repl.append((f[0], f[0].replace('"%s' % k, '"%s' % v))) - if ('/%s' % k) in f[0]: - repl.append((f[0], f[0].replace('/%s' % k, '/%s' % v))) + if ("/%s" % k) in f[0]: + repl.append((f[0], f[0].replace("/%s" % k, "/%s" % v))) if len(repl) == 0: continue print("update %r" % full) @@ -67,17 +65,20 @@ def replace_files(root, renamed): if __name__ == "__main__": import sys + if len(sys.argv) > 1: root = sys.argv[-1] else: root = "../../build/docs/html" - print('look into %r' % root) + print("look into %r" % root) ren = rename_folder(root) if len(ren) == 0: - ren = [('', '_static', 'static'), - ('', '_images', 'images'), - ('', '_downloads', 'downloads'), - ('', '_sources', 'sources'), - ('', '_modules', 'modules')] + ren = [ + ("", "_static", "static"), + ("", "_images", "images"), + ("", "_downloads", "downloads"), + ("", "_sources", "sources"), + ("", "_modules", "modules"), + ] replace_files(root, ren) - print('done.') + print("done.") diff --git a/tools/nuget/generate_nuspec_for_native_nuget.py b/tools/nuget/generate_nuspec_for_native_nuget.py index a0f3bbf042ed3..08f691496bf6e 100644 --- a/tools/nuget/generate_nuspec_for_native_nuget.py +++ b/tools/nuget/generate_nuspec_for_native_nuget.py @@ -2,9 +2,9 @@ # Licensed under the MIT License. import argparse -import sys import os import re +import sys from pathlib import Path @@ -13,21 +13,21 @@ # ep: cuda, tensorrt, None def get_package_name(os, cpu_arch, ep): pkg_name = None - if os == 'win': + if os == "win": pkg_name = "onnxruntime-win-" pkg_name += cpu_arch - if ep == 'cuda': + if ep == "cuda": pkg_name += "-cuda" - elif ep == 'tensorrt': + elif ep == "tensorrt": pkg_name += "-tensorrt" - elif os == 'linux': + elif os == "linux": pkg_name = "onnxruntime-linux-" pkg_name += cpu_arch - if ep == 'cuda': + if ep == "cuda": pkg_name += "-cuda" - elif ep == 'tensorrt': + elif ep == "tensorrt": pkg_name += "-tensorrt" - elif os == 'osx': + elif os == "osx": pkg_name = "onnxruntime-osx-" + cpu_arch return pkg_name @@ -36,7 +36,7 @@ def get_package_name(os, cpu_arch, ep): # And onnxruntime, onnxruntime_providers_shared and # onnxruntime_providers_tensorrt from tensorrt build def is_this_file_needed(ep, filename): - return (ep != 'cuda' or 'cuda' in filename) and (ep != 'tensorrt' or 'cuda' not in filename) + return (ep != "cuda" or "cuda" in filename) and (ep != "tensorrt" or "cuda" not in filename) # nuget_artifacts_dir: the directory with uncompressed C API tarball/zip files @@ -48,54 +48,55 @@ def generate_file_list_for_ep(nuget_artifacts_dir, ep, files_list, include_pdbs) if not child.is_dir(): continue - for cpu_arch in ['x86', 'x64', 'arm', 'arm64']: - if child.name == get_package_name('win', cpu_arch, ep): - child = child / 'lib' + for cpu_arch in ["x86", "x64", "arm", "arm64"]: + if child.name == get_package_name("win", cpu_arch, ep): + child = child / "lib" for child_file in child.iterdir(): - suffixes = ['.dll', '.lib', '.pdb'] if include_pdbs else ['.dll', '.lib'] + suffixes = [".dll", ".lib", ".pdb"] if include_pdbs else [".dll", ".lib"] if child_file.suffix in suffixes and is_this_file_needed(ep, child_file.name): - files_list.append('' % cpu_arch) - for cpu_arch in ['x86_64', 'arm64']: - if child.name == get_package_name('osx', cpu_arch, ep): - child = child / 'lib' - if cpu_arch == 'x86_64': - cpu_arch = 'x64' + files_list.append( + '' % cpu_arch + ) + for cpu_arch in ["x86_64", "arm64"]: + if child.name == get_package_name("osx", cpu_arch, ep): + child = child / "lib" + if cpu_arch == "x86_64": + cpu_arch = "x64" for child_file in child.iterdir(): # Check if the file has digits like onnxruntime.1.8.0.dylib. We can skip such things - is_versioned_dylib = re.match(r'.*[\.\d+]+\.dylib$', child_file.name) - if child_file.is_file() and child_file.suffix == '.dylib' and not is_versioned_dylib: - files_list.append('' % cpu_arch) - for cpu_arch in ['x64', 'aarch64']: - if child.name == get_package_name('linux', cpu_arch, ep): - child = child / 'lib' - if cpu_arch == 'x86_64': - cpu_arch = 'x64' - elif cpu_arch == 'aarch64': - cpu_arch = 'arm64' + is_versioned_dylib = re.match(r".*[\.\d+]+\.dylib$", child_file.name) + if child_file.is_file() and child_file.suffix == ".dylib" and not is_versioned_dylib: + files_list.append( + '' % cpu_arch + ) + for cpu_arch in ["x64", "aarch64"]: + if child.name == get_package_name("linux", cpu_arch, ep): + child = child / "lib" + if cpu_arch == "x86_64": + cpu_arch = "x64" + elif cpu_arch == "aarch64": + cpu_arch = "arm64" for child_file in child.iterdir(): if not child_file.is_file(): continue - if child_file.suffix == '.so' and is_this_file_needed(ep, child_file.name): - files_list.append('' % cpu_arch) + if child_file.suffix == ".so" and is_this_file_needed(ep, child_file.name): + files_list.append( + '' % cpu_arch + ) - if child.name == 'onnxruntime-android': + if child.name == "onnxruntime-android": for child_file in child.iterdir(): - if child_file.suffix in ['.aar']: - files_list.append('') + if child_file.suffix in [".aar"]: + files_list.append('') - if child.name == 'onnxruntime-ios-xcframework': - files_list.append('') + if child.name == "onnxruntime-ios-xcframework": + files_list.append('') def parse_arguments(): - parser = argparse.ArgumentParser(description="ONNX Runtime create nuget spec script " - "(for hosting native shared library artifacts)", - usage='') + parser = argparse.ArgumentParser( + description="ONNX Runtime create nuget spec script " "(for hosting native shared library artifacts)", usage="" + ) # Main arguments parser.add_argument("--package_name", required=True, help="ORT package name. Eg: Microsoft.ML.OnnxRuntime.Gpu") parser.add_argument("--package_version", required=True, help="ORT package version. Eg: 1.0.0") @@ -106,53 +107,64 @@ def parse_arguments(): parser.add_argument("--packages_path", required=True, help="Nuget packages output directory.") parser.add_argument("--sources_path", required=True, help="OnnxRuntime source code root.") parser.add_argument("--commit_id", required=True, help="The last commit id included in this package.") - parser.add_argument("--is_release_build", required=False, default=None, type=str, - help="Flag indicating if the build is a release build. Accepted values: true/false.") - parser.add_argument("--execution_provider", required=False, default='None', type=str, - choices=['cuda', 'dnnl', 'openvino', 'tensorrt', 'None'], - help="The selected execution provider for this build.") + parser.add_argument( + "--is_release_build", + required=False, + default=None, + type=str, + help="Flag indicating if the build is a release build. Accepted values: true/false.", + ) + parser.add_argument( + "--execution_provider", + required=False, + default="None", + type=str, + choices=["cuda", "dnnl", "openvino", "tensorrt", "None"], + help="The selected execution provider for this build.", + ) return parser.parse_args() def generate_id(list, package_name): - list.append('' + package_name + '') + list.append("" + package_name + "") def generate_version(list, package_version): - list.append('' + package_version + '') + list.append("" + package_version + "") def generate_authors(list, authors): - list.append('' + authors + '') + list.append("" + authors + "") def generate_owners(list, owners): - list.append('' + owners + '') + list.append("" + owners + "") def generate_description(list, package_name): - description = '' + description = "" - if package_name == 'Microsoft.AI.MachineLearning': - description = 'This package contains Windows ML binaries.' - elif 'Microsoft.ML.OnnxRuntime' in package_name: # This is a Microsoft.ML.OnnxRuntime.* package - description = 'This package contains native shared library artifacts ' \ - 'for all supported platforms of ONNX Runtime.' + if package_name == "Microsoft.AI.MachineLearning": + description = "This package contains Windows ML binaries." + elif "Microsoft.ML.OnnxRuntime" in package_name: # This is a Microsoft.ML.OnnxRuntime.* package + description = ( + "This package contains native shared library artifacts " "for all supported platforms of ONNX Runtime." + ) - list.append('' + description + '') + list.append("" + description + "") def generate_copyright(list, copyright): - list.append('' + copyright + '') + list.append("" + copyright + "") def generate_tags(list, tags): - list.append('' + tags + '') + list.append("" + tags + "") def generate_icon(list, icon_file): - list.append('' + icon_file + '') + list.append("" + icon_file + "") def generate_license(list): @@ -160,7 +172,7 @@ def generate_license(list): def generate_project_url(list, project_url): - list.append('' + project_url + '') + list.append("" + project_url + "") def generate_repo_url(list, repo_url, commit_id): @@ -170,61 +182,61 @@ def generate_repo_url(list, repo_url, commit_id): def generate_dependencies(list, package_name, version): dml_dependency = '' - if (package_name == 'Microsoft.AI.MachineLearning'): - list.append('') + if package_name == "Microsoft.AI.MachineLearning": + list.append("") # Support .Net Core list.append('') list.append(dml_dependency) - list.append('') + list.append("") # UAP10.0.16299, This is the earliest release of the OS that supports .NET Standard apps list.append('') list.append(dml_dependency) - list.append('') + list.append("") # Support Native C++ list.append('') list.append(dml_dependency) - list.append('') + list.append("") - list.append('') + list.append("") else: - include_dml = package_name == 'Microsoft.ML.OnnxRuntime.DirectML' + include_dml = package_name == "Microsoft.ML.OnnxRuntime.DirectML" - list.append('') + list.append("") # Support .Net Core list.append('') list.append('') if include_dml: list.append(dml_dependency) - list.append('') + list.append("") # Support .Net Standard list.append('') list.append('') if include_dml: list.append(dml_dependency) - list.append('') + list.append("") # Support .Net Framework list.append('') list.append('') if include_dml: list.append(dml_dependency) - list.append('') - if package_name == 'Microsoft.ML.OnnxRuntime': + list.append("") + if package_name == "Microsoft.ML.OnnxRuntime": # Support monoandroid11.0 list.append('') list.append('') - list.append('') + list.append("") # Support xamarinios10 list.append('') list.append('') - list.append('') + list.append("") # Support Native C++ if include_dml: list.append('') list.append(dml_dependency) - list.append('') + list.append("") - list.append('') + list.append("") def get_env_var(key): @@ -232,179 +244,231 @@ def get_env_var(key): def generate_release_notes(list): - list.append('') - list.append('Release Def:') + list.append("") + list.append("Release Def:") - branch = get_env_var('BUILD_SOURCEBRANCH') - list.append('\t' + 'Branch: ' + (branch if branch is not None else '')) + branch = get_env_var("BUILD_SOURCEBRANCH") + list.append("\t" + "Branch: " + (branch if branch is not None else "")) - version = get_env_var('BUILD_SOURCEVERSION') - list.append('\t' + 'Commit: ' + (version if version is not None else '')) + version = get_env_var("BUILD_SOURCEVERSION") + list.append("\t" + "Commit: " + (version if version is not None else "")) - build_id = get_env_var('BUILD_BUILDID') - list.append('\t' + 'Build: https://aiinfra.visualstudio.com/Lotus/_build/results?buildId=' + - (build_id if build_id is not None else '')) + build_id = get_env_var("BUILD_BUILDID") + list.append( + "\t" + + "Build: https://aiinfra.visualstudio.com/Lotus/_build/results?buildId=" + + (build_id if build_id is not None else "") + ) - list.append('') + list.append("") def generate_metadata(list, args): - metadata_list = [''] + metadata_list = [""] generate_id(metadata_list, args.package_name) generate_version(metadata_list, args.package_version) - generate_authors(metadata_list, 'Microsoft') - generate_owners(metadata_list, 'Microsoft') + generate_authors(metadata_list, "Microsoft") + generate_owners(metadata_list, "Microsoft") generate_description(metadata_list, args.package_name) - generate_copyright(metadata_list, '\xc2\xa9 ' + 'Microsoft Corporation. All rights reserved.') - generate_tags(metadata_list, 'ONNX ONNX Runtime Machine Learning') - generate_icon(metadata_list, 'ORT_icon_for_light_bg.png') + generate_copyright(metadata_list, "\xc2\xa9 " + "Microsoft Corporation. All rights reserved.") + generate_tags(metadata_list, "ONNX ONNX Runtime Machine Learning") + generate_icon(metadata_list, "ORT_icon_for_light_bg.png") generate_license(metadata_list) - generate_project_url(metadata_list, 'https://github.com/Microsoft/onnxruntime') - generate_repo_url(metadata_list, 'https://github.com/Microsoft/onnxruntime.git', args.commit_id) + generate_project_url(metadata_list, "https://github.com/Microsoft/onnxruntime") + generate_repo_url(metadata_list, "https://github.com/Microsoft/onnxruntime.git", args.commit_id) generate_dependencies(metadata_list, args.package_name, args.package_version) generate_release_notes(metadata_list) - metadata_list.append('') + metadata_list.append("") list += metadata_list def generate_files(list, args): - files_list = [''] + files_list = [""] - is_cpu_package = args.package_name in ['Microsoft.ML.OnnxRuntime', 'Microsoft.ML.OnnxRuntime.OpenMP'] - is_mklml_package = args.package_name == 'Microsoft.ML.OnnxRuntime.MKLML' - is_cuda_gpu_package = args.package_name == 'Microsoft.ML.OnnxRuntime.Gpu' - is_dml_package = args.package_name == 'Microsoft.ML.OnnxRuntime.DirectML' - is_windowsai_package = args.package_name == 'Microsoft.AI.MachineLearning' + is_cpu_package = args.package_name in ["Microsoft.ML.OnnxRuntime", "Microsoft.ML.OnnxRuntime.OpenMP"] + is_mklml_package = args.package_name == "Microsoft.ML.OnnxRuntime.MKLML" + is_cuda_gpu_package = args.package_name == "Microsoft.ML.OnnxRuntime.Gpu" + is_dml_package = args.package_name == "Microsoft.ML.OnnxRuntime.DirectML" + is_windowsai_package = args.package_name == "Microsoft.AI.MachineLearning" includes_winml = is_windowsai_package includes_directml = (is_dml_package or is_windowsai_package) and ( - args.target_architecture == 'x64' or args.target_architecture == 'x86') + args.target_architecture == "x64" or args.target_architecture == "x86" + ) is_windows_build = is_windows() nuget_dependencies = {} if is_windows_build: - nuget_dependencies = {'mklml': 'mklml.dll', - 'openmp': 'libiomp5md.dll', - 'dnnl': 'dnnl.dll', - 'tvm': 'tvm.dll', - 'providers_shared_lib': 'onnxruntime_providers_shared.dll', - 'dnnl_ep_shared_lib': 'onnxruntime_providers_dnnl.dll', - 'tensorrt_ep_shared_lib': 'onnxruntime_providers_tensorrt.dll', - 'openvino_ep_shared_lib': 'onnxruntime_providers_openvino.dll', - 'cuda_ep_shared_lib': 'onnxruntime_providers_cuda.dll', - 'onnxruntime_perf_test': 'onnxruntime_perf_test.exe', - 'onnx_test_runner': 'onnx_test_runner.exe'} + nuget_dependencies = { + "mklml": "mklml.dll", + "openmp": "libiomp5md.dll", + "dnnl": "dnnl.dll", + "tvm": "tvm.dll", + "providers_shared_lib": "onnxruntime_providers_shared.dll", + "dnnl_ep_shared_lib": "onnxruntime_providers_dnnl.dll", + "tensorrt_ep_shared_lib": "onnxruntime_providers_tensorrt.dll", + "openvino_ep_shared_lib": "onnxruntime_providers_openvino.dll", + "cuda_ep_shared_lib": "onnxruntime_providers_cuda.dll", + "onnxruntime_perf_test": "onnxruntime_perf_test.exe", + "onnx_test_runner": "onnx_test_runner.exe", + } copy_command = "copy" runtimes_target = '" target="runtimes\\win-' else: - nuget_dependencies = {'mklml': 'libmklml_intel.so', - 'mklml_1': 'libmklml_gnu.so', - 'openmp': 'libiomp5.so', - 'dnnl': 'libdnnl.so.1', - 'tvm': 'libtvm.so.0.5.1', - 'providers_shared_lib': 'libonnxruntime_providers_shared.so', - 'dnnl_ep_shared_lib': 'libonnxruntime_providers_dnnl.so', - 'tensorrt_ep_shared_lib': 'libonnxruntime_providers_tensorrt.so', - 'openvino_ep_shared_lib': 'libonnxruntime_providers_openvino.so', - 'cuda_ep_shared_lib': 'libonnxruntime_providers_cuda.so', - 'onnxruntime_perf_test': 'onnxruntime_perf_test', - 'onnx_test_runner': 'onnx_test_runner'} + nuget_dependencies = { + "mklml": "libmklml_intel.so", + "mklml_1": "libmklml_gnu.so", + "openmp": "libiomp5.so", + "dnnl": "libdnnl.so.1", + "tvm": "libtvm.so.0.5.1", + "providers_shared_lib": "libonnxruntime_providers_shared.so", + "dnnl_ep_shared_lib": "libonnxruntime_providers_dnnl.so", + "tensorrt_ep_shared_lib": "libonnxruntime_providers_tensorrt.so", + "openvino_ep_shared_lib": "libonnxruntime_providers_openvino.so", + "cuda_ep_shared_lib": "libonnxruntime_providers_cuda.so", + "onnxruntime_perf_test": "onnxruntime_perf_test", + "onnx_test_runner": "onnx_test_runner", + } copy_command = "cp" runtimes_target = '" target="runtimes\\linux-' if is_windowsai_package: - runtimes_native_folder = '_native' + runtimes_native_folder = "_native" else: - runtimes_native_folder = 'native' + runtimes_native_folder = "native" runtimes = '{}{}\\{}"'.format(runtimes_target, args.target_architecture, runtimes_native_folder) # Process headers - files_list.append('') - files_list.append('') - files_list.append('') - - if args.execution_provider == 'openvino': - files_list.append('') - - if args.execution_provider == 'tensorrt': - files_list.append('') - - if args.execution_provider == 'dnnl': - files_list.append('') + files_list.append( + "' + ) + files_list.append( + "' + ) + files_list.append( + "' + ) + + if args.execution_provider == "openvino": + files_list.append( + "' + ) + + if args.execution_provider == "tensorrt": + files_list.append( + "' + ) + + if args.execution_provider == "dnnl": + files_list.append( + "' + ) if includes_directml: - files_list.append('') + files_list.append( + "' + ) if includes_winml: # Add microsoft.ai.machinelearning headers - files_list.append('') - files_list.append('') - files_list.append('') + files_list.append( + "' + ) + files_list.append( + "' + ) + files_list.append( + "' + ) # Add custom operator headers - mlop_path = 'onnxruntime\\core\\providers\\dml\\dmlexecutionprovider\\inc\\mloperatorauthor.h' - files_list.append('') + mlop_path = "onnxruntime\\core\\providers\\dml\\dmlexecutionprovider\\inc\\mloperatorauthor.h" + files_list.append( + "' + ) # Process microsoft.ai.machinelearning.winmd - files_list.append('') + files_list.append( + "' + ) # Process microsoft.ai.machinelearning.experimental.winmd - files_list.append('') - if args.target_architecture == 'x64': - interop_dll_path = 'Microsoft.AI.MachineLearning.Interop\\net5.0-windows10.0.17763.0' - interop_dll = interop_dll_path + '\\Microsoft.AI.MachineLearning.Interop.dll' - files_list.append('') - interop_pdb_path = 'Microsoft.AI.MachineLearning.Interop\\net5.0-windows10.0.17763.0' - interop_pdb = interop_pdb_path + '\\Microsoft.AI.MachineLearning.Interop.pdb' - files_list.append('') + files_list.append( + "' + ) + if args.target_architecture == "x64": + interop_dll_path = "Microsoft.AI.MachineLearning.Interop\\net5.0-windows10.0.17763.0" + interop_dll = interop_dll_path + "\\Microsoft.AI.MachineLearning.Interop.dll" + files_list.append( + "' + ) + interop_pdb_path = "Microsoft.AI.MachineLearning.Interop\\net5.0-windows10.0.17763.0" + interop_pdb = interop_pdb_path + "\\Microsoft.AI.MachineLearning.Interop.pdb" + files_list.append( + "' + ) is_ado_packaging_build = False # Process runtimes # Process onnxruntime import lib, dll, and pdb if is_windows_build: - nuget_artifacts_dir = Path(args.native_build_path) / 'nuget-artifacts' + nuget_artifacts_dir = Path(args.native_build_path) / "nuget-artifacts" # the winml package includes pdbs. for other packages exclude them. include_pdbs = includes_winml if nuget_artifacts_dir.exists(): # Code path for ADO build pipeline, the files under 'nuget-artifacts' are # downloaded from other build jobs if is_cuda_gpu_package: - ep_list = ['tensorrt', 'cuda', None] + ep_list = ["tensorrt", "cuda", None] else: ep_list = [None] for ep in ep_list: @@ -412,225 +476,402 @@ def generate_files(list, args): is_ado_packaging_build = True else: # Code path for local dev build - files_list.append('') - files_list.append('') - if include_pdbs and os.path.exists(os.path.join(args.native_build_path, 'onnxruntime.pdb')): - files_list.append('') + files_list.append( + "" + ) + files_list.append( + "" + ) + if include_pdbs and os.path.exists(os.path.join(args.native_build_path, "onnxruntime.pdb")): + files_list.append( + "" + ) else: - files_list.append('') + files_list.append( + "' + ) if includes_winml: # Process microsoft.ai.machinelearning import lib, dll, and pdb - files_list.append('') - files_list.append('') - files_list.append('') + files_list.append( + "' + ) + files_list.append( + "' + ) + files_list.append( + "' + ) # Process execution providers which are built as shared libs if args.execution_provider == "tensorrt" and not is_ado_packaging_build: - files_list.append('') - files_list.append('') - files_list.append('') + files_list.append( + "' + ) + files_list.append( + "' + ) + files_list.append( + "' + ) if args.execution_provider == "dnnl": - files_list.append('') - files_list.append('') + files_list.append( + "' + ) + files_list.append( + "' + ) if args.execution_provider == "openvino": - openvino_path = get_env_var('INTEL_OPENVINO_DIR') - files_list.append('') - files_list.append('') + openvino_path = get_env_var("INTEL_OPENVINO_DIR") + files_list.append( + "' + ) + files_list.append( + "' + ) if is_windows(): if "2022" in openvino_path: - dll_list_path = os.path.join(openvino_path, 'runtime\\bin\\intel64\\Release\\') - tbb_list_path = os.path.join(openvino_path, 'runtime\\3rdparty\\tbb\\bin\\') + dll_list_path = os.path.join(openvino_path, "runtime\\bin\\intel64\\Release\\") + tbb_list_path = os.path.join(openvino_path, "runtime\\3rdparty\\tbb\\bin\\") else: dll_list_path = os.path.join( - openvino_path, 'deployment_tools\\inference_engine\\bin\\intel64\\Release\\') - tbb_list_path = os.path.join(openvino_path, 'deployment_tools\\inference_engine\\external\\tbb\\bin\\') - ngraph_list_path = os.path.join(openvino_path, 'deployment_tools\\ngraph\\lib\\') + openvino_path, "deployment_tools\\inference_engine\\bin\\intel64\\Release\\" + ) + tbb_list_path = os.path.join(openvino_path, "deployment_tools\\inference_engine\\external\\tbb\\bin\\") + ngraph_list_path = os.path.join(openvino_path, "deployment_tools\\ngraph\\lib\\") for ngraph_element in os.listdir(ngraph_list_path): - if ngraph_element.endswith('dll'): - files_list.append('') + if ngraph_element.endswith("dll"): + files_list.append( + "' + ) for dll_element in os.listdir(dll_list_path): - if dll_element.endswith('dll'): - files_list.append('') + if dll_element.endswith("dll"): + files_list.append( + "' + ) # plugins.xml - files_list.append('') + files_list.append( + "' + ) # usb-ma2x8x.mvcmd - files_list.append('') + files_list.append( + "' + ) for tbb_element in os.listdir(tbb_list_path): - if tbb_element.endswith('dll'): - files_list.append('') + if tbb_element.endswith("dll"): + files_list.append( + "' + ) if args.execution_provider == "cuda" or is_cuda_gpu_package and not is_ado_packaging_build: - files_list.append('') - files_list.append('') + files_list.append( + "' + ) + files_list.append( + "' + ) # process all other library dependencies if is_cpu_package or is_cuda_gpu_package or is_dml_package or is_mklml_package: # Process dnnl dependency - if os.path.exists(os.path.join(args.native_build_path, nuget_dependencies['dnnl'])): - files_list.append('') + if os.path.exists(os.path.join(args.native_build_path, nuget_dependencies["dnnl"])): + files_list.append( + "" + ) # Process mklml dependency - if os.path.exists(os.path.join(args.native_build_path, nuget_dependencies['mklml'])): - files_list.append('') - - if is_linux() and os.path.exists(os.path.join(args.native_build_path, nuget_dependencies['mklml_1'])): - files_list.append('') + if os.path.exists(os.path.join(args.native_build_path, nuget_dependencies["mklml"])): + files_list.append( + "" + ) + + if is_linux() and os.path.exists(os.path.join(args.native_build_path, nuget_dependencies["mklml_1"])): + files_list.append( + "" + ) # Process libiomp5md dependency - if os.path.exists(os.path.join(args.native_build_path, nuget_dependencies['openmp'])): - files_list.append('') + if os.path.exists(os.path.join(args.native_build_path, nuget_dependencies["openmp"])): + files_list.append( + "" + ) # Process tvm dependency - if os.path.exists(os.path.join(args.native_build_path, nuget_dependencies['tvm'])): - files_list.append('') + if os.path.exists(os.path.join(args.native_build_path, nuget_dependencies["tvm"])): + files_list.append( + "" + ) # Some tools to be packaged in nightly build only, should not be released # These are copied to the runtimes folder for convenience of loading with the dlls - if args.is_release_build.lower() != 'true' and args.target_architecture == 'x64' and \ - os.path.exists(os.path.join(args.native_build_path, nuget_dependencies['onnxruntime_perf_test'])): - files_list.append('') - - if args.is_release_build.lower() != 'true' and args.target_architecture == 'x64' and \ - os.path.exists(os.path.join(args.native_build_path, nuget_dependencies['onnx_test_runner'])): - files_list.append('') + if ( + args.is_release_build.lower() != "true" + and args.target_architecture == "x64" + and os.path.exists(os.path.join(args.native_build_path, nuget_dependencies["onnxruntime_perf_test"])) + ): + files_list.append( + "" + ) + + if ( + args.is_release_build.lower() != "true" + and args.target_architecture == "x64" + and os.path.exists(os.path.join(args.native_build_path, nuget_dependencies["onnx_test_runner"])) + ): + files_list.append( + "" + ) # Process props and targets files if is_windowsai_package: - windowsai_src = 'Microsoft.AI.MachineLearning' - windowsai_props = 'Microsoft.AI.MachineLearning.props' - windowsai_targets = 'Microsoft.AI.MachineLearning.targets' - windowsai_native_props = os.path.join(args.sources_path, 'csharp', 'src', windowsai_src, windowsai_props) - windowsai_rules = 'Microsoft.AI.MachineLearning.Rules.Project.xml' - windowsai_native_rules = os.path.join(args.sources_path, 'csharp', 'src', windowsai_src, windowsai_rules) - windowsai_native_targets = os.path.join(args.sources_path, 'csharp', 'src', windowsai_src, windowsai_targets) - build = 'build\\native' - files_list.append('') + windowsai_src = "Microsoft.AI.MachineLearning" + windowsai_props = "Microsoft.AI.MachineLearning.props" + windowsai_targets = "Microsoft.AI.MachineLearning.targets" + windowsai_native_props = os.path.join(args.sources_path, "csharp", "src", windowsai_src, windowsai_props) + windowsai_rules = "Microsoft.AI.MachineLearning.Rules.Project.xml" + windowsai_native_rules = os.path.join(args.sources_path, "csharp", "src", windowsai_src, windowsai_rules) + windowsai_native_targets = os.path.join(args.sources_path, "csharp", "src", windowsai_src, windowsai_targets) + build = "build\\native" + files_list.append("') # Process native targets - files_list.append('') + files_list.append("') # Process rules - files_list.append('') + files_list.append("') # Process .net5.0 targets - if args.target_architecture == 'x64': - interop_src = 'Microsoft.AI.MachineLearning.Interop' - interop_props = 'Microsoft.AI.MachineLearning.props' - interop_targets = 'Microsoft.AI.MachineLearning.targets' - windowsai_net50_props = os.path.join(args.sources_path, 'csharp', 'src', interop_src, interop_props) - windowsai_net50_targets = os.path.join(args.sources_path, 'csharp', 'src', interop_src, interop_targets) - files_list.append('') - files_list.append('') + if args.target_architecture == "x64": + interop_src = "Microsoft.AI.MachineLearning.Interop" + interop_props = "Microsoft.AI.MachineLearning.props" + interop_targets = "Microsoft.AI.MachineLearning.targets" + windowsai_net50_props = os.path.join(args.sources_path, "csharp", "src", interop_src, interop_props) + windowsai_net50_targets = os.path.join(args.sources_path, "csharp", "src", interop_src, interop_targets) + files_list.append("') + files_list.append("') if is_cpu_package or is_cuda_gpu_package or is_dml_package or is_mklml_package: # Process props file - source_props = os.path.join(args.sources_path, 'csharp', 'src', 'Microsoft.ML.OnnxRuntime', 'targets', - 'netstandard', 'props.xml') - target_props = os.path.join(args.sources_path, 'csharp', 'src', 'Microsoft.ML.OnnxRuntime', 'targets', - 'netstandard', args.package_name + '.props') - os.system(copy_command + ' ' + source_props + ' ' + target_props) - files_list.append('') - files_list.append('') - files_list.append('') + source_props = os.path.join( + args.sources_path, "csharp", "src", "Microsoft.ML.OnnxRuntime", "targets", "netstandard", "props.xml" + ) + target_props = os.path.join( + args.sources_path, + "csharp", + "src", + "Microsoft.ML.OnnxRuntime", + "targets", + "netstandard", + args.package_name + ".props", + ) + os.system(copy_command + " " + source_props + " " + target_props) + files_list.append("') + files_list.append("') + files_list.append("') # Process targets file - source_targets = os.path.join(args.sources_path, 'csharp', 'src', 'Microsoft.ML.OnnxRuntime', 'targets', - 'netstandard', 'targets.xml') - target_targets = os.path.join(args.sources_path, 'csharp', 'src', 'Microsoft.ML.OnnxRuntime', 'targets', - 'netstandard', args.package_name + '.targets') - os.system(copy_command + ' ' + source_targets + ' ' + target_targets) - files_list.append('') - files_list.append('') - files_list.append('') + source_targets = os.path.join( + args.sources_path, "csharp", "src", "Microsoft.ML.OnnxRuntime", "targets", "netstandard", "targets.xml" + ) + target_targets = os.path.join( + args.sources_path, + "csharp", + "src", + "Microsoft.ML.OnnxRuntime", + "targets", + "netstandard", + args.package_name + ".targets", + ) + os.system(copy_command + " " + source_targets + " " + target_targets) + files_list.append("') + files_list.append("') + files_list.append("') # Process xamarin targets files - if args.package_name == 'Microsoft.ML.OnnxRuntime': - monoandroid_source_targets = os.path.join(args.sources_path, 'csharp', 'src', 'Microsoft.ML.OnnxRuntime', - 'targets', 'monoandroid11.0', 'targets.xml') - monoandroid_target_targets = os.path.join(args.sources_path, 'csharp', 'src', 'Microsoft.ML.OnnxRuntime', - 'targets', 'monoandroid11.0', args.package_name + '.targets') - os.system(copy_command + ' ' + monoandroid_source_targets + ' ' + monoandroid_target_targets) - - xamarinios_source_targets = os.path.join(args.sources_path, 'csharp', 'src', 'Microsoft.ML.OnnxRuntime', - 'targets', 'xamarinios10', 'targets.xml') - xamarinios_target_targets = os.path.join(args.sources_path, 'csharp', 'src', 'Microsoft.ML.OnnxRuntime', - 'targets', 'xamarinios10', args.package_name + '.targets') - os.system(copy_command + ' ' + xamarinios_source_targets + ' ' + xamarinios_target_targets) - - files_list.append('') - files_list.append('') - files_list.append('') - files_list.append('') + if args.package_name == "Microsoft.ML.OnnxRuntime": + monoandroid_source_targets = os.path.join( + args.sources_path, + "csharp", + "src", + "Microsoft.ML.OnnxRuntime", + "targets", + "monoandroid11.0", + "targets.xml", + ) + monoandroid_target_targets = os.path.join( + args.sources_path, + "csharp", + "src", + "Microsoft.ML.OnnxRuntime", + "targets", + "monoandroid11.0", + args.package_name + ".targets", + ) + os.system(copy_command + " " + monoandroid_source_targets + " " + monoandroid_target_targets) + + xamarinios_source_targets = os.path.join( + args.sources_path, "csharp", "src", "Microsoft.ML.OnnxRuntime", "targets", "xamarinios10", "targets.xml" + ) + xamarinios_target_targets = os.path.join( + args.sources_path, + "csharp", + "src", + "Microsoft.ML.OnnxRuntime", + "targets", + "xamarinios10", + args.package_name + ".targets", + ) + os.system(copy_command + " " + xamarinios_source_targets + " " + xamarinios_target_targets) + + files_list.append("') + files_list.append( + "' + ) + files_list.append("') + files_list.append( + "' + ) # Process License, ThirdPartyNotices, Privacy - files_list.append('') - files_list.append('') - files_list.append('') - files_list.append('') - files_list.append('') + files_list.append("') + files_list.append( + "' + ) + files_list.append( + "' + ) + files_list.append( + "' + ) + files_list.append("") list += files_list def generate_nuspec(args): lines = [''] - lines.append('') + lines.append("") generate_metadata(lines, args) generate_files(lines, args) - lines.append('') + lines.append("") return lines @@ -647,16 +888,23 @@ def is_macos(): def validate_platform(): - if not(is_windows() or is_linux() or is_macos()): - raise Exception('Native Nuget generation is currently supported only on Windows, Linux, and MacOS') + if not (is_windows() or is_linux() or is_macos()): + raise Exception("Native Nuget generation is currently supported only on Windows, Linux, and MacOS") def validate_execution_provider(execution_provider): if is_linux(): - if not (execution_provider == 'None' or execution_provider == 'dnnl' or execution_provider == 'cuda' - or execution_provider == 'tensorrt' or execution_provider == 'openvino'): - raise Exception('On Linux platform nuget generation is supported only ' - 'for cpu|cuda|dnnl|tensorrt|openvino execution providers.') + if not ( + execution_provider == "None" + or execution_provider == "dnnl" + or execution_provider == "cuda" + or execution_provider == "tensorrt" + or execution_provider == "openvino" + ): + raise Exception( + "On Linux platform nuget generation is supported only " + "for cpu|cuda|dnnl|tensorrt|openvino execution providers." + ) def main(): @@ -667,19 +915,19 @@ def main(): validate_execution_provider(args.execution_provider) - if (args.is_release_build.lower() != 'true' and args.is_release_build.lower() != 'false'): - raise Exception('Only valid options for IsReleaseBuild are: true and false') + if args.is_release_build.lower() != "true" and args.is_release_build.lower() != "false": + raise Exception("Only valid options for IsReleaseBuild are: true and false") # Generate nuspec lines = generate_nuspec(args) # Create the nuspec needed to generate the Nuget - with open(os.path.join(args.native_build_path, 'NativeNuget.nuspec'), 'w') as f: + with open(os.path.join(args.native_build_path, "NativeNuget.nuspec"), "w") as f: for line in lines: # Uncomment the printing of the line if you need to debug what's produced on a CI machine # print(line) f.write(line) - f.write('\n') + f.write("\n") if __name__ == "__main__": diff --git a/tools/nuget/validate_package.py b/tools/nuget/validate_package.py index 85bf04fc7c569..36b8da005c072 100644 --- a/tools/nuget/validate_package.py +++ b/tools/nuget/validate_package.py @@ -2,35 +2,54 @@ # Licensed under the MIT License. import argparse -import sys -import os -import zipfile # Available Python 3.2 or higher import glob +import os import re +import sys +import zipfile # Available Python 3.2 or higher -linux_gpu_package_libraries = ["libonnxruntime_providers_shared.so", "libonnxruntime_providers_cuda.so", - "libonnxruntime_providers_tensorrt.so"] -win_gpu_package_libraries = ["onnxruntime_providers_shared.lib", "onnxruntime_providers_shared.dll", - "onnxruntime_providers_cuda.lib", "onnxruntime_providers_cuda.dll", - "onnxruntime_providers_tensorrt.lib", "onnxruntime_providers_tensorrt.dll"] -gpu_related_header_files = ["cpu_provider_factory.h", "tensorrt_provider_factory.h", "onnxruntime_c_api.h", - "onnxruntime_cxx_api.h", "onnxruntime_cxx_inline.h"] +linux_gpu_package_libraries = [ + "libonnxruntime_providers_shared.so", + "libonnxruntime_providers_cuda.so", + "libonnxruntime_providers_tensorrt.so", +] +win_gpu_package_libraries = [ + "onnxruntime_providers_shared.lib", + "onnxruntime_providers_shared.dll", + "onnxruntime_providers_cuda.lib", + "onnxruntime_providers_cuda.dll", + "onnxruntime_providers_tensorrt.lib", + "onnxruntime_providers_tensorrt.dll", +] +gpu_related_header_files = [ + "cpu_provider_factory.h", + "tensorrt_provider_factory.h", + "onnxruntime_c_api.h", + "onnxruntime_cxx_api.h", + "onnxruntime_cxx_inline.h", +] def parse_arguments(): parser = argparse.ArgumentParser( description="Validate ONNX Runtime native nuget containing native shared library artifacts spec script", - usage='') + usage="", + ) # Main arguments parser.add_argument("--package_type", required=True, help="Specify nuget, tarball or zip.") parser.add_argument("--package_name", required=True, help="Package name to be validated.") - parser.add_argument("--package_path", required=True, help="Path containing the package to be validated." + - "Must only contain only one package within this.") - parser.add_argument("--platforms_supported", required=True, - help="Comma separated list (no space). Ex: linux-x64,win-x86,osx-x64") - parser.add_argument("--verify_nuget_signing", - help="Flag indicating if Nuget package signing is to be verified. " - "Only accepts 'true' or 'false'") + parser.add_argument( + "--package_path", + required=True, + help="Path containing the package to be validated." + "Must only contain only one package within this.", + ) + parser.add_argument( + "--platforms_supported", required=True, help="Comma separated list (no space). Ex: linux-x64,win-x86,osx-x64" + ) + parser.add_argument( + "--verify_nuget_signing", + help="Flag indicating if Nuget package signing is to be verified. " "Only accepts 'true' or 'false'", + ) return parser.parse_args() @@ -43,8 +62,9 @@ def is_windows(): return sys.platform.startswith("win") -def check_if_dlls_are_present(package_type, is_windows_ai_package, is_gpu_package, platforms_supported, - zip_file, package_path): +def check_if_dlls_are_present( + package_type, is_windows_ai_package, is_gpu_package, platforms_supported, zip_file, package_path +): platforms = platforms_supported.strip().split(",") if package_type == "tarball": file_list_in_package = list() @@ -55,7 +75,7 @@ def check_if_dlls_are_present(package_type, is_windows_ai_package, is_gpu_packag for platform in platforms: if platform.startswith("win"): - native_folder = '_native' if is_windows_ai_package else 'native' + native_folder = "_native" if is_windows_ai_package else "native" if package_type == "nuget": folder = "runtimes/" + platform + "/" + native_folder @@ -65,22 +85,22 @@ def check_if_dlls_are_present(package_type, is_windows_ai_package, is_gpu_packag header_folder = package_path + "/include" path = folder + "/" + "onnxruntime.dll" - print('Checking path: ' + path) - if (path not in file_list_in_package): + print("Checking path: " + path) + if path not in file_list_in_package: print("onnxruntime.dll not found for " + platform) raise Exception("onnxruntime.dll not found for " + platform) if is_gpu_package: for dll in win_gpu_package_libraries: path = folder + "/" + dll - print('Checking path: ' + path) - if (path not in file_list_in_package): + print("Checking path: " + path) + if path not in file_list_in_package: print(dll + " not found for " + platform) raise Exception(dll + " not found for " + platform) for header in gpu_related_header_files: path = header_folder + "/" + header - print('Checking path: ' + path) - if (path not in file_list_in_package): + print("Checking path: " + path) + if path not in file_list_in_package: print(header + " not found for " + platform) raise Exception(header + " not found for " + platform) @@ -93,29 +113,29 @@ def check_if_dlls_are_present(package_type, is_windows_ai_package, is_gpu_packag header_folder = package_path + "/include" path = folder + "/" + "libonnxruntime.so" - print('Checking path: ' + path) - if (path not in file_list_in_package): + print("Checking path: " + path) + if path not in file_list_in_package: print("libonnxruntime.so not found for " + platform) raise Exception("libonnxruntime.so not found for " + platform) if is_gpu_package: for so in linux_gpu_package_libraries: path = folder + "/" + so - print('Checking path: ' + path) - if (path not in file_list_in_package): + print("Checking path: " + path) + if path not in file_list_in_package: print(so + " not found for " + platform) raise Exception(so + " not found for " + platform) for header in gpu_related_header_files: path = header_folder + "/" + header - print('Checking path: ' + path) - if (path not in file_list_in_package): + print("Checking path: " + path) + if path not in file_list_in_package: print(header + " not found for " + platform) raise Exception(header + " not found for " + platform) elif platform.startswith("osx"): path = "runtimes/" + platform + "/native/libonnxruntime.dylib" - print('Checking path: ' + path) - if (path not in file_list_in_package): + print("Checking path: " + path) + if path not in file_list_in_package: print("libonnxruntime.dylib not found for " + platform) raise Exception("libonnxruntime.dylib not found for " + platform) @@ -124,13 +144,13 @@ def check_if_dlls_are_present(package_type, is_windows_ai_package, is_gpu_packag def check_if_nuget_is_signed(nuget_path): - code_sign_summary_file = glob.glob(os.path.join(nuget_path, '*.md')) - if (len(code_sign_summary_file) != 1): - print('CodeSignSummary files found in path: ') + code_sign_summary_file = glob.glob(os.path.join(nuget_path, "*.md")) + if len(code_sign_summary_file) != 1: + print("CodeSignSummary files found in path: ") print(code_sign_summary_file) - raise Exception('No CodeSignSummary files / more than one CodeSignSummary files found in the given path.') + raise Exception("No CodeSignSummary files / more than one CodeSignSummary files found in the given path.") - print('CodeSignSummary file: ' + code_sign_summary_file[0]) + print("CodeSignSummary file: " + code_sign_summary_file[0]) with open(code_sign_summary_file[0]) as f: contents = f.read() @@ -141,10 +161,10 @@ def check_if_nuget_is_signed(nuget_path): def validate_tarball(args): files = glob.glob(os.path.join(args.package_path, args.package_name)) - if (len(files) != 1): - print('packages found in path: ') + if len(files) != 1: + print("packages found in path: ") print(files) - raise Exception('No packages / more than one packages found in the given path.') + raise Exception("No packages / more than one packages found in the given path.") package_name = args.package_name if "-gpu-" in package_name.lower(): @@ -152,23 +172,24 @@ def validate_tarball(args): else: is_gpu_package = False - package_folder = re.search('(.*)[.].*', package_name).group(1) + package_folder = re.search("(.*)[.].*", package_name).group(1) - print('tar zxvf ' + package_name) + print("tar zxvf " + package_name) os.system("tar zxvf " + package_name) is_windows_ai_package = False zip_file = None - check_if_dlls_are_present(args.package_type, is_windows_ai_package, is_gpu_package, - args.platforms_supported, zip_file, package_folder) + check_if_dlls_are_present( + args.package_type, is_windows_ai_package, is_gpu_package, args.platforms_supported, zip_file, package_folder + ) def validate_zip(args): files = glob.glob(os.path.join(args.package_path, args.package_name)) - if (len(files) != 1): - print('packages found in path: ') + if len(files) != 1: + print("packages found in path: ") print(files) - raise Exception('No packages / more than one packages found in the given path.') + raise Exception("No packages / more than one packages found in the given path.") package_name = args.package_name if "-gpu-" in package_name.lower(): @@ -176,21 +197,22 @@ def validate_zip(args): else: is_gpu_package = False - package_folder = re.search('(.*)[.].*', package_name).group(1) + package_folder = re.search("(.*)[.].*", package_name).group(1) is_windows_ai_package = False zip_file = zipfile.ZipFile(package_name) - check_if_dlls_are_present(args.package_type, is_windows_ai_package, is_gpu_package, - args.platforms_supported, zip_file, package_folder) + check_if_dlls_are_present( + args.package_type, is_windows_ai_package, is_gpu_package, args.platforms_supported, zip_file, package_folder + ) def validate_nuget(args): files = glob.glob(os.path.join(args.package_path, args.package_name)) - nuget_packages_found_in_path = [i for i in files if i.endswith('.nupkg') and "Managed" not in i] - if (len(nuget_packages_found_in_path) != 1): - print('Nuget packages found in path: ') + nuget_packages_found_in_path = [i for i in files if i.endswith(".nupkg") and "Managed" not in i] + if len(nuget_packages_found_in_path) != 1: + print("Nuget packages found in path: ") print(nuget_packages_found_in_path) - raise Exception('No Nuget packages / more than one Nuget packages found in the given path.') + raise Exception("No Nuget packages / more than one Nuget packages found in the given path.") nuget_file_name = nuget_packages_found_in_path[0] full_nuget_path = os.path.join(args.package_path, nuget_file_name) @@ -215,10 +237,10 @@ def validate_nuget(args): # Do all validations here try: if not is_windows(): - raise Exception('Nuget validation is currently supported only on Windows') + raise Exception("Nuget validation is currently supported only on Windows") # Make a copy of the Nuget package - print('Copying [' + full_nuget_path + '] -> [' + nupkg_copy_name + '], and extracting its contents') + print("Copying [" + full_nuget_path + "] -> [" + nupkg_copy_name + "], and extracting its contents") os.system("copy " + full_nuget_path + " " + nupkg_copy_name) # Convert nupkg to zip @@ -226,27 +248,28 @@ def validate_nuget(args): zip_file = zipfile.ZipFile(zip_copy_name) # Check if the relevant dlls are present in the Nuget/Zip - print('Checking if the Nuget contains relevant dlls') - is_windows_ai_package = os.path.basename(full_nuget_path).startswith('Microsoft.AI.MachineLearning') - check_if_dlls_are_present(args.package_type, is_windows_ai_package, is_gpu_package, - args.platforms_supported, zip_file, None) + print("Checking if the Nuget contains relevant dlls") + is_windows_ai_package = os.path.basename(full_nuget_path).startswith("Microsoft.AI.MachineLearning") + check_if_dlls_are_present( + args.package_type, is_windows_ai_package, is_gpu_package, args.platforms_supported, zip_file, None + ) # Check if the Nuget has been signed - if (args.verify_nuget_signing != 'true' and args.verify_nuget_signing != 'false'): - raise Exception('Parameter verify_nuget_signing accepts only true or false as an argument') + if args.verify_nuget_signing != "true" and args.verify_nuget_signing != "false": + raise Exception("Parameter verify_nuget_signing accepts only true or false as an argument") - if (args.verify_nuget_signing == 'true'): - print('Verifying if Nuget has been signed') - if(not check_if_nuget_is_signed(args.package_path)): - print('Nuget signing verification failed') - raise Exception('Nuget signing verification failed') + if args.verify_nuget_signing == "true": + print("Verifying if Nuget has been signed") + if not check_if_nuget_is_signed(args.package_path): + print("Nuget signing verification failed") + raise Exception("Nuget signing verification failed") except Exception as e: print(e) exit_code = 1 finally: - print('Cleaning up after Nuget validation') + print("Cleaning up after Nuget validation") if zip_file is not None: zip_file.close() @@ -255,9 +278,9 @@ def validate_nuget(args): os.remove(zip_copy_name) if exit_code == 0: - print('Nuget validation was successful') + print("Nuget validation was successful") else: - raise Exception('Nuget validation was unsuccessful') + raise Exception("Nuget validation was unsuccessful") def main(): @@ -270,7 +293,7 @@ def main(): elif args.package_type == "zip": validate_zip(args) else: - print('Package type {} is not supported'.format(args.package_type)) + print("Package type {} is not supported".format(args.package_type)) if __name__ == "__main__": diff --git a/tools/python/FindOptimizerOpsetVersionUpdatesRequired.py b/tools/python/FindOptimizerOpsetVersionUpdatesRequired.py index 13f167166237b..d0c940ae9911c 100644 --- a/tools/python/FindOptimizerOpsetVersionUpdatesRequired.py +++ b/tools/python/FindOptimizerOpsetVersionUpdatesRequired.py @@ -14,10 +14,12 @@ def parse_args(): parser = argparse.ArgumentParser( - description='Find optimizers that involve operators which may need an update to the supported opset versions.') + description="Find optimizers that involve operators which may need an update to the supported opset versions." + ) - root_arg = parser.add_argument('--ort-root', '-o', required=True, type=str, - help='The root directory of the ONNX Runtime repository to search.') + root_arg = parser.add_argument( + "--ort-root", "-o", required=True, type=str, help="The root directory of the ONNX Runtime repository to search." + ) args = parser.parse_args() @@ -38,16 +40,18 @@ def get_call_args_from_file(filename, function_or_declaration): for match in re.finditer(function_or_declaration, line): # check we have both the opening and closing brackets for the function call/declaration. # if we do we have all the arguments - start = line.find('(', match.end()) - end = line.find(')', match.end()) + start = line.find("(", match.end()) + end = line.find(")", match.end()) have_all_args = start != -1 and end != -1 if have_all_args: - results.append(line[start + 1: end]) + results.append(line[start + 1 : end]) else: # TODO: handle automatically by merging lines - log.error("Call/Declaration is split over multiple lines. Please check manually." - "File:{} Line:{}".format(filename, line_num)) + log.error( + "Call/Declaration is split over multiple lines. Please check manually." + "File:{} Line:{}".format(filename, line_num) + ) continue line_num += 1 @@ -59,27 +63,29 @@ def get_latest_op_versions(root_dir): """Find the entries for the latest opset for each operator.""" op_to_opset = {} - files = [os.path.join(root_dir, "onnxruntime/core/providers/cpu/cpu_execution_provider.cc"), - os.path.join(root_dir, "onnxruntime/contrib_ops/cpu_contrib_kernels.cc")] + files = [ + os.path.join(root_dir, "onnxruntime/core/providers/cpu/cpu_execution_provider.cc"), + os.path.join(root_dir, "onnxruntime/contrib_ops/cpu_contrib_kernels.cc"), + ] for file in files: # e.g. class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, Clip); - calls = get_call_args_from_file(file, 'ONNX_OPERATOR_KERNEL_CLASS_NAME') + calls = get_call_args_from_file(file, "ONNX_OPERATOR_KERNEL_CLASS_NAME") for call in calls: - args = call.split(',') + args = call.split(",") domain = args[1].strip() opset = args[2].strip() op = args[3].strip() - op_to_opset[domain + '.' + op] = opset + op_to_opset[domain + "." + op] = opset # e.g. class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 11, float, ArgMax); - calls = get_call_args_from_file(file, 'ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME') + calls = get_call_args_from_file(file, "ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME") for call in calls: - args = call.split(',') + args = call.split(",") domain = args[1].strip() opset = args[2].strip() op = args[4].strip() - op_to_opset[domain + '.' + op] = opset + op_to_opset[domain + "." + op] = opset return op_to_opset @@ -88,31 +94,31 @@ def find_potential_issues(root_dir, op_to_opset): optimizer_dir = os.path.join(root_dir, "onnxruntime/core/optimizer") - files = glob.glob(optimizer_dir + '/**/*.cc', recursive=True) - files += glob.glob(optimizer_dir + '/**/*.h', recursive=True) + files = glob.glob(optimizer_dir + "/**/*.cc", recursive=True) + files += glob.glob(optimizer_dir + "/**/*.h", recursive=True) for file in files: - calls = get_call_args_from_file(file, 'graph_utils::IsSupportedOptypeVersionAndDomain') + calls = get_call_args_from_file(file, "graph_utils::IsSupportedOptypeVersionAndDomain") for call in calls: # Need to handle multiple comma separated version numbers, and the optional domain argument. # e.g. IsSupportedOptypeVersionAndDomain(node, "MaxPool", {1, 8, 10}) # IsSupportedOptypeVersionAndDomain(node, "FusedConv", {1}, kMSDomain) - args = call.split(',', 2) # first 2 args are simple, remainder need custom processing + args = call.split(",", 2) # first 2 args are simple, remainder need custom processing op = args[1].strip() versions_and_domain_arg = args[2] - v1 = versions_and_domain_arg.find('{') - v2 = versions_and_domain_arg.find('}') - versions = versions_and_domain_arg[v1 + 1: v2].split(',') + v1 = versions_and_domain_arg.find("{") + v2 = versions_and_domain_arg.find("}") + versions = versions_and_domain_arg[v1 + 1 : v2].split(",") last_version = versions[-1].strip() - domain_arg_start = versions_and_domain_arg.find(',', v2) + domain_arg_start = versions_and_domain_arg.find(",", v2) if domain_arg_start != -1: - domain = versions_and_domain_arg[domain_arg_start + 1:].strip() + domain = versions_and_domain_arg[domain_arg_start + 1 :].strip() else: domain = "kOnnxDomain" if op.startswith('"') and op.endswith('"'): - op = domain + '.' + op[1:-1] + op = domain + "." + op[1:-1] else: log.error("Symbolic name of '{}' found for op. Please check manually. File:{}".format(op, file)) continue @@ -120,13 +126,16 @@ def find_potential_issues(root_dir, op_to_opset): if op in op_to_opset: latest = op_to_opset[op] if int(latest) != int(last_version): - log.warning("Newer opset found for {}. Latest:{} Optimizer support ends at {}. File:{}" - .format(op, latest, last_version, file)) + log.warning( + "Newer opset found for {}. Latest:{} Optimizer support ends at {}. File:{}".format( + op, latest, last_version, file + ) + ) else: log.error("Failed to find version information for {}. File:{}".format(op, file)) -if __name__ == '__main__': +if __name__ == "__main__": arguments = parse_args() op_to_opset_map = get_latest_op_versions(arguments.ort_root) find_potential_issues(arguments.ort_root, op_to_opset_map) diff --git a/tools/python/check_onnx_model_mobile_usability.py b/tools/python/check_onnx_model_mobile_usability.py index a7ba2ab4cdaa8..7b8a09632398e 100644 --- a/tools/python/check_onnx_model_mobile_usability.py +++ b/tools/python/check_onnx_model_mobile_usability.py @@ -7,5 +7,5 @@ # in the ORT python package (where it must use relative imports) from util.check_onnx_model_mobile_usability import check_usability -if __name__ == '__main__': +if __name__ == "__main__": check_usability() diff --git a/tools/python/convert_onnx_models_to_ort.py b/tools/python/convert_onnx_models_to_ort.py index 7e22813957da0..131cbcbe70be5 100644 --- a/tools/python/convert_onnx_models_to_ort.py +++ b/tools/python/convert_onnx_models_to_ort.py @@ -7,6 +7,5 @@ # in the ORT python package (where it must use relative imports) from util.convert_onnx_models_to_ort import convert_onnx_models_to_ort - -if __name__ == '__main__': +if __name__ == "__main__": convert_onnx_models_to_ort() diff --git a/tools/python/create_reduced_build_config.py b/tools/python/create_reduced_build_config.py index 9062c977658f9..f7bbe5001c685 100644 --- a/tools/python/create_reduced_build_config.py +++ b/tools/python/create_reduced_build_config.py @@ -3,26 +3,27 @@ # Licensed under the MIT License. import argparse -import onnx import pathlib import sys import typing +import onnx from util.file_utils import files_from_file_or_dir, path_match_suffix_ignore_case def _get_suffix_match_predicate(suffix: str): def predicate(file_path: pathlib.Path): return path_match_suffix_ignore_case(file_path, suffix) + return predicate def _extract_ops_from_onnx_graph(graph, operators, domain_opset_map): - '''Extract ops from an ONNX graph and all subgraphs''' + """Extract ops from an ONNX graph and all subgraphs""" for operator in graph.node: # empty domain is used as an alias for 'ai.onnx' - domain = operator.domain if operator.domain else 'ai.onnx' + domain = operator.domain if operator.domain else "ai.onnx" if domain not in operators or domain not in domain_opset_map: continue @@ -35,7 +36,7 @@ def _extract_ops_from_onnx_graph(graph, operators, domain_opset_map): elif attr.type == onnx.AttributeProto.GRAPHS: # Currently no ONNX operators use GRAPHS. # Fail noisily if we encounter this so we can implement support - raise RuntimeError('Unexpected attribute proto of GRAPHS') + raise RuntimeError("Unexpected attribute proto of GRAPHS") def _process_onnx_model(model_path, required_ops): @@ -45,7 +46,7 @@ def _process_onnx_model(model_path, required_ops): domain_opset_map = {} for opset in model.opset_import: # empty domain == ai.onnx - domain = opset.domain if opset.domain else 'ai.onnx' + domain = opset.domain if opset.domain else "ai.onnx" domain_opset_map[domain] = opset.version if domain not in required_ops: @@ -60,7 +61,7 @@ def _process_onnx_model(model_path, required_ops): def _extract_ops_from_onnx_model(model_files: typing.Iterable[pathlib.Path]): - '''Extract ops from ONNX models''' + """Extract ops from ONNX models""" required_ops = {} @@ -78,7 +79,7 @@ def create_config_from_onnx_models(model_files: typing.Iterable[pathlib.Path], o output_file.parent.mkdir(parents=True, exist_ok=True) - with open(output_file, 'w') as out: + with open(output_file, "w") as out: out.write("# Generated from ONNX model/s:\n") for model_file in sorted(model_files): out.write(f"# - {model_file}\n") @@ -87,33 +88,46 @@ def create_config_from_onnx_models(model_files: typing.Iterable[pathlib.Path], o for opset in sorted(required_ops[domain].keys()): ops = required_ops[domain][opset] if ops: - out.write("{};{};{}\n".format(domain, opset, ','.join(sorted(ops)))) + out.write("{};{};{}\n".format(domain, opset, ",".join(sorted(ops)))) def main(): argparser = argparse.ArgumentParser( - 'Script to create a reduced build config file from either ONNX or ORT format model/s. ' - 'See /docs/Reduced_Operator_Kernel_build.md for more information on the configuration file format.', - formatter_class=argparse.ArgumentDefaultsHelpFormatter) - - argparser.add_argument('-f', '--format', choices=['ONNX', 'ORT'], default='ONNX', - help='Format of model/s to process.') - argparser.add_argument('-t', '--enable_type_reduction', action='store_true', - help='Enable tracking of the specific types that individual operators require. ' - 'Operator implementations MAY support limiting the type support included in the build ' - 'to these types. Only possible with ORT format models.') - argparser.add_argument('model_path_or_dir', type=pathlib.Path, - help='Path to a single model, or a directory that will be recursively searched ' - 'for models to process.') - - argparser.add_argument('config_path', nargs='?', type=pathlib.Path, default=None, - help='Path to write configuration file to. Default is to write to required_operators.config ' - 'or required_operators_and_types.config in the same directory as the models.') + "Script to create a reduced build config file from either ONNX or ORT format model/s. " + "See /docs/Reduced_Operator_Kernel_build.md for more information on the configuration file format.", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + + argparser.add_argument( + "-f", "--format", choices=["ONNX", "ORT"], default="ONNX", help="Format of model/s to process." + ) + argparser.add_argument( + "-t", + "--enable_type_reduction", + action="store_true", + help="Enable tracking of the specific types that individual operators require. " + "Operator implementations MAY support limiting the type support included in the build " + "to these types. Only possible with ORT format models.", + ) + argparser.add_argument( + "model_path_or_dir", + type=pathlib.Path, + help="Path to a single model, or a directory that will be recursively searched " "for models to process.", + ) + + argparser.add_argument( + "config_path", + nargs="?", + type=pathlib.Path, + default=None, + help="Path to write configuration file to. Default is to write to required_operators.config " + "or required_operators_and_types.config in the same directory as the models.", + ) args = argparser.parse_args() - if args.enable_type_reduction and args.format == 'ONNX': - print('Type reduction requires model format to be ORT.', file=sys.stderr) + if args.enable_type_reduction and args.format == "ONNX": + print("Type reduction requires model format to be ORT.", file=sys.stderr) sys.exit(-1) model_path_or_dir = args.model_path_or_dir.resolve() @@ -123,10 +137,10 @@ def main(): config_path = model_path_or_dir if model_path_or_dir.is_dir() else model_path_or_dir.parent if config_path.is_dir(): - filename = 'required_operators_and_types.config' if args.enable_type_reduction else 'required_operators.config' + filename = "required_operators_and_types.config" if args.enable_type_reduction else "required_operators.config" config_path = config_path.joinpath(filename) - if args.format == 'ONNX': + if args.format == "ONNX": model_files = files_from_file_or_dir(model_path_or_dir, _get_suffix_match_predicate(".onnx")) create_config_from_onnx_models(model_files, config_path) else: diff --git a/tools/python/dump_ort_model.py b/tools/python/dump_ort_model.py index b975172e3269e..ced54ae73adb1 100644 --- a/tools/python/dump_ort_model.py +++ b/tools/python/dump_ort_model.py @@ -6,38 +6,38 @@ import sys import typing -from util.ort_format_model.types import FbsTypeInfo # the import of FbsTypeInfo sets up the path so we can import ort_flatbuffers_py import ort_flatbuffers_py.fbs as fbs +from util.ort_format_model.types import FbsTypeInfo class OrtFormatModelDumper: - 'Class to dump an ORT format model.' + "Class to dump an ORT format model." def __init__(self, model_path: str): - ''' + """ Initialize ORT format model dumper :param model_path: Path to model - ''' - self._file = open(model_path, 'rb').read() + """ + self._file = open(model_path, "rb").read() self._buffer = bytearray(self._file) if not fbs.InferenceSession.InferenceSession.InferenceSessionBufferHasIdentifier(self._buffer, 0): raise RuntimeError("File does not appear to be a valid ORT format model: '{}'".format(model_path)) self._model = fbs.InferenceSession.InferenceSession.GetRootAsInferenceSession(self._buffer, 0).Model() def _dump_initializers(self, graph: fbs.Graph): - print('Initializers:') + print("Initializers:") for idx in range(0, graph.InitializersLength()): tensor = graph.Initializers(idx) dims = [] for dim in range(0, tensor.DimsLength()): dims.append(tensor.Dims(dim)) - print(f'{tensor.Name().decode()} data_type={tensor.DataType()} dims={dims}') - print('--------') + print(f"{tensor.Name().decode()} data_type={tensor.DataType()} dims={dims}") + print("--------") def _dump_nodeargs(self, graph: fbs.Graph): - print('NodeArgs:') + print("NodeArgs:") for idx in range(0, graph.NodeArgsLength()): node_arg = graph.NodeArgs(idx) type = node_arg.Type() @@ -62,30 +62,32 @@ def _dump_nodeargs(self, graph: fbs.Graph): elif d.DimType() == fbs.DimensionValueType.DimensionValueType.PARAM: dims.append(d.DimParam().decode()) else: - dims.append('?') + dims.append("?") else: dims = None - print(f'{node_arg.Name().decode()} type={type_str} dims={dims}') - print('--------') + print(f"{node_arg.Name().decode()} type={type_str} dims={dims}") + print("--------") def _dump_node(self, node: fbs.Node): optype = node.OpType().decode() - domain = node.Domain().decode() or 'ai.onnx' # empty domain defaults to ai.onnx + domain = node.Domain().decode() or "ai.onnx" # empty domain defaults to ai.onnx inputs = [node.Inputs(i).decode() for i in range(0, node.InputsLength())] outputs = [node.Outputs(i).decode() for i in range(0, node.OutputsLength())] - print(f'{node.Index()}:{node.Name().decode()}({domain}:{optype}) ' - f'inputs=[{",".join(inputs)} outputs=[{",".join(outputs)}]') + print( + f"{node.Index()}:{node.Name().decode()}({domain}:{optype}) " + f'inputs=[{",".join(inputs)} outputs=[{",".join(outputs)}]' + ) def _dump_graph(self, graph: fbs.Graph): - ''' + """ Process one level of the Graph, descending into any subgraphs when they are found - ''' + """ self._dump_initializers(graph) self._dump_nodeargs(graph) - print('Nodes:') + print("Nodes:") for i in range(0, graph.NodesLength()): node = graph.Nodes(i) self._dump_node(node) @@ -95,17 +97,17 @@ def _dump_graph(self, graph: fbs.Graph): attr = node.Attributes(j) attr_type = attr.Type() if attr_type == fbs.AttributeType.AttributeType.GRAPH: - print(f'## Subgraph for {node.OpType().decode()}.{attr.Name().decode()} ##') + print(f"## Subgraph for {node.OpType().decode()}.{attr.Name().decode()} ##") self._dump_graph(attr.G()) - print(f'## End {node.OpType().decode()}.{attr.Name().decode()} Subgraph ##') + print(f"## End {node.OpType().decode()}.{attr.Name().decode()} Subgraph ##") elif attr_type == fbs.AttributeType.AttributeType.GRAPHS: # the ONNX spec doesn't currently define any operators that have multiple graphs in an attribute # so entering this 'elif' isn't currently possible - print(f'## Subgraphs for {node.OpType().decode()}.{attr.Name().decode()} ##') + print(f"## Subgraphs for {node.OpType().decode()}.{attr.Name().decode()} ##") for k in range(0, attr.GraphsLength()): - print(f'## Subgraph {k} ##') + print(f"## Subgraph {k} ##") self._dump_graph(attr.Graphs(k)) - print(f'## End Subgraph {k} ##') + print(f"## End Subgraph {k} ##") def dump(self, output: typing.IO): graph = self._model.Graph() @@ -117,14 +119,15 @@ def dump(self, output: typing.IO): def parse_args(): - parser = argparse.ArgumentParser(os.path.basename(__file__), - description='Dump an ORT format model. Output is to .txt') - parser.add_argument('--stdout', action='store_true', help='Dump to stdout instead of writing to file.') - parser.add_argument('model_path', help='Path to ORT format model') + parser = argparse.ArgumentParser( + os.path.basename(__file__), description="Dump an ORT format model. Output is to .txt" + ) + parser.add_argument("--stdout", action="store_true", help="Dump to stdout instead of writing to file.") + parser.add_argument("model_path", help="Path to ORT format model") args = parser.parse_args() if not os.path.isfile(args.model_path): - parser.error(f'{args.model_path} is not a file.') + parser.error(f"{args.model_path} is not a file.") return args @@ -141,5 +144,5 @@ def main(): d.dump(ofile) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/tools/python/dump_subgraphs.py b/tools/python/dump_subgraphs.py index 52036eda45833..a1b9782374ca7 100644 --- a/tools/python/dump_subgraphs.py +++ b/tools/python/dump_subgraphs.py @@ -1,14 +1,15 @@ -import onnx import argparse import os +import onnx + def export_and_recurse(node, attribute, output_dir, level): name = node.name - name = name.replace('/', '_') + name = name.replace("/", "_") sub_model = onnx.ModelProto() sub_model.graph.MergeFrom(attribute.g) - filename = 'L' + str(level) + '_' + node.op_type + '_' + attribute.name + '_' + name + '.onnx' + filename = "L" + str(level) + "_" + node.op_type + "_" + attribute.name + "_" + name + ".onnx" onnx.save_model(sub_model, os.path.join(output_dir, filename)) dump_subgraph(sub_model, output_dir, level + 1) @@ -18,20 +19,21 @@ def dump_subgraph(model, output_dir, level=0): for node in graph.node: if node.op_type == "Scan" or node.op_type == "Loop": - body_attribute = list(filter(lambda attr: attr.name == 'body', node.attribute))[0] + body_attribute = list(filter(lambda attr: attr.name == "body", node.attribute))[0] export_and_recurse(node, body_attribute, output_dir, level) if node.op_type == "If": - then_attribute = list(filter(lambda attr: attr.name == 'then_branch', node.attribute))[0] - else_attribute = list(filter(lambda attr: attr.name == 'else_branch', node.attribute))[0] + then_attribute = list(filter(lambda attr: attr.name == "then_branch", node.attribute))[0] + else_attribute = list(filter(lambda attr: attr.name == "else_branch", node.attribute))[0] export_and_recurse(node, then_attribute, output_dir, level) export_and_recurse(node, else_attribute, output_dir, level) def parse_args(): - parser = argparse.ArgumentParser(os.path.basename(__file__), - description='Dump all subgraphs from an ONNX model into separate onnx files.') - parser.add_argument('-m', '--model', required=True, help='model file') - parser.add_argument('-o', '--out', required=True, help='output directory') + parser = argparse.ArgumentParser( + os.path.basename(__file__), description="Dump all subgraphs from an ONNX model into separate onnx files." + ) + parser.add_argument("-m", "--model", required=True, help="model file") + parser.add_argument("-o", "--out", required=True, help="output directory") return parser.parse_args() @@ -48,5 +50,5 @@ def main(): dump_subgraph(model, out) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/tools/python/example_operator_perf_test.py b/tools/python/example_operator_perf_test.py index 6bcca0f8210a7..50a3edd5c9b27 100644 --- a/tools/python/example_operator_perf_test.py +++ b/tools/python/example_operator_perf_test.py @@ -3,18 +3,19 @@ input combinations. """ -import onnx -from onnx import helper -from onnx import TensorProto -import numpy as np import time import timeit -import onnxruntime as rt + +import numpy as np +import onnx # if you copy this script elsewhere you may need to add the tools\python dir to the sys.path for this # import to work. # e.g. sys.path.append(r'\tools\python') import ort_test_dir_utils +from onnx import TensorProto, helper + +import onnxruntime as rt # make input deterministic np.random.seed(123) @@ -26,22 +27,26 @@ def create_model(model_name): graph_def = helper.make_graph( nodes=[ - helper.make_node(op_type="TopK", inputs=['X', 'K'], outputs=['Values', 'Indices'], name='topk', - # attributes are also key-value pairs using the attribute name and appropriate type - largest=1), + helper.make_node( + op_type="TopK", + inputs=["X", "K"], + outputs=["Values", "Indices"], + name="topk", + # attributes are also key-value pairs using the attribute name and appropriate type + largest=1, + ), ], - name='test-model', + name="test-model", inputs=[ # create inputs with symbolic dims so we can use any input sizes - helper.make_tensor_value_info("X", TensorProto.FLOAT, ['batch', 'items']), + helper.make_tensor_value_info("X", TensorProto.FLOAT, ["batch", "items"]), helper.make_tensor_value_info("K", TensorProto.INT64, [1]), ], outputs=[ - helper.make_tensor_value_info("Values", TensorProto.FLOAT, ['batch', 'k']), - helper.make_tensor_value_info("Indices", TensorProto.INT64, ['batch', 'k']), + helper.make_tensor_value_info("Values", TensorProto.FLOAT, ["batch", "k"]), + helper.make_tensor_value_info("Indices", TensorProto.INT64, ["batch", "k"]), ], - initializer=[ - ] + initializer=[], ) model = helper.make_model(graph_def, opset_imports=[helper.make_operatorsetid("", 11)]) @@ -56,7 +61,7 @@ def create_model(model_name): def create_test_input(n, num_items, k): x = np.random.randn(n, num_items).astype(np.float32) k_in = np.asarray([k]).astype(np.int64) - inputs = {'X': x, 'K': k_in} + inputs = {"X": x, "K": k_in} return inputs @@ -98,12 +103,12 @@ def run_test(): # ignore the outputs as we're not validating them in a performance test sess.run(None, inputs) end = time.time_ns() - assert (end - start > 0) + assert end - start > 0 total += end - start total_iters += iters # Adjust the output you want as needed - print('n={},items={},k={},avg:{:.4f}'.format(n, num_items, k, total / total_iters)) + print("n={},items={},k={},avg:{:.4f}".format(n, num_items, k, total / total_iters)) # combine the various input parameters and create input for each test for n in batches: @@ -126,21 +131,21 @@ def create_example_test_directory(): # fill in the inputs that we want to use specific values for input_data = {} - input_data['K'] = np.asarray([64]).astype(np.int64) + input_data["K"] = np.asarray([64]).astype(np.int64) # provide symbolic dim values as needed - symbolic_dim_values = {'batch': 25, 'items': 256} + symbolic_dim_values = {"batch": 25, "items": 256} # create the directory. random input will be created for any missing inputs. # the model will be run and the output will be saved as expected output for future runs - ort_test_dir_utils.create_test_dir('topk.onnx', 'PerfTests', 'test1', input_data, symbolic_dim_values) + ort_test_dir_utils.create_test_dir("topk.onnx", "PerfTests", "test1", input_data, symbolic_dim_values) # this will create the model file in the current directory -create_model('topk.onnx') +create_model("topk.onnx") # this will create a test directory that can be used with onnx_test_runner or onnxruntime_perf_test create_example_test_directory() # this can loop over various combinations of input, using the specified number of threads -run_perf_tests('topk.onnx', 1) +run_perf_tests("topk.onnx", 1) diff --git a/tools/python/gen_contrib_doc.py b/tools/python/gen_contrib_doc.py index ec59f1e21673d..bad97451cf6c0 100644 --- a/tools/python/gen_contrib_doc.py +++ b/tools/python/gen_contrib_doc.py @@ -2,44 +2,41 @@ # This file is copied and adapted from https://github.com/onnx/onnx repository. # There was no copyright statement on the file at the time of copying. -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function -from __future__ import unicode_literals +from __future__ import absolute_import, division, print_function, unicode_literals -from collections import defaultdict +import argparse import io import os import pathlib import sys -import argparse +from collections import defaultdict +from typing import Any, Dict, List, Sequence, Set, Text, Tuple import numpy as np # type: ignore +from onnx import AttributeProto, FunctionProto import onnxruntime.capi.onnxruntime_pybind11_state as rtpy from onnxruntime.capi.onnxruntime_pybind11_state import schemadef # noqa: F401 from onnxruntime.capi.onnxruntime_pybind11_state.schemadef import OpSchema # noqa: F401 -from typing import Any, Text, Sequence, Dict, List, Set, Tuple -from onnx import AttributeProto, FunctionProto -ONNX_ML = not bool(os.getenv('ONNX_ML') == '0') +ONNX_ML = not bool(os.getenv("ONNX_ML") == "0") ONNX_DOMAIN = "onnx" ONNX_ML_DOMAIN = "onnx-ml" if ONNX_ML: - ext = '-ml.md' + ext = "-ml.md" else: - ext = '.md' + ext = ".md" def display_number(v): # type: (int) -> Text if OpSchema.is_infinite(v): - return '∞' + return "∞" return Text(v) def should_render_domain(domain, domain_filter): # type: (Text) -> bool - if domain == ONNX_DOMAIN or domain == '' or domain == ONNX_ML_DOMAIN or domain == 'ai.onnx.ml': + if domain == ONNX_DOMAIN or domain == "" or domain == ONNX_ML_DOMAIN or domain == "ai.onnx.ml": return False if domain_filter and domain not in domain_filter: @@ -50,21 +47,21 @@ def should_render_domain(domain, domain_filter): # type: (Text) -> bool def format_name_with_domain(domain, schema_name): # type: (Text, Text) -> Text if domain: - return '{}.{}'.format(domain, schema_name) + return "{}.{}".format(domain, schema_name) else: return schema_name def format_name_with_version(schema_name, version): # type: (Text, Text) -> Text - return '{}-{}'.format(schema_name, version) + return "{}-{}".format(schema_name, version) def display_attr_type(v): # type: (OpSchema.AttrType) -> Text assert isinstance(v, OpSchema.AttrType) s = Text(v) - s = s[s.rfind('.') + 1:].lower() - if s[-1] == 's': - s = 'list of ' + s + s = s[s.rfind(".") + 1 :].lower() + if s[-1] == "s": + s = "list of " + s return s @@ -79,31 +76,31 @@ def display_domain_short(domain): # type: (Text) -> Text if domain: return domain else: - return 'ai.onnx (default)' + return "ai.onnx (default)" def display_version_link(name, version): # type: (Text, int) -> Text - changelog_md = 'Changelog' + ext - name_with_ver = '{}-{}'.format(name, version) + changelog_md = "Changelog" + ext + name_with_ver = "{}-{}".format(name, version) return '{}'.format(changelog_md, name_with_ver, name_with_ver) def display_function_version_link(name, version): # type: (Text, int) -> Text - changelog_md = 'FunctionsChangelog' + ext - name_with_ver = '{}-{}'.format(name, version) + changelog_md = "FunctionsChangelog" + ext + name_with_ver = "{}-{}".format(name, version) return '{}'.format(changelog_md, name_with_ver, name_with_ver) def get_attribute_value(attr): # type: (AttributeProto) -> Any - if attr.HasField('f'): + if attr.HasField("f"): return attr.f - elif attr.HasField('i'): + elif attr.HasField("i"): return attr.i - elif attr.HasField('s'): + elif attr.HasField("s"): return attr.s - elif attr.HasField('t'): + elif attr.HasField("t"): return attr.t - elif attr.HasField('g'): + elif attr.HasField("g"): return attr.g elif len(attr.floats): return list(attr.floats) @@ -120,30 +117,34 @@ def get_attribute_value(attr): # type: (AttributeProto) -> Any def display_schema(schema, versions): # type: (OpSchema, Sequence[OpSchema]) -> Text - s = '' + s = "" # doc schemadoc = schema.doc if schemadoc: - s += '\n' - s += '\n'.join(' ' + line - for line in schemadoc.lstrip().splitlines()) - s += '\n' + s += "\n" + s += "\n".join(" " + line for line in schemadoc.lstrip().splitlines()) + s += "\n" # since version - s += '\n#### Version\n' + s += "\n#### Version\n" if schema.support_level == OpSchema.SupportType.EXPERIMENTAL: - s += '\nNo versioning maintained for experimental ops.' + s += "\nNo versioning maintained for experimental ops." else: - s += '\nThis version of the operator has been ' + ('deprecated' if schema.deprecated else 'available') + \ - ' since version {}'.format(schema.since_version) - s += ' of {}.\n'.format(display_domain(schema.domain)) + s += ( + "\nThis version of the operator has been " + + ("deprecated" if schema.deprecated else "available") + + " since version {}".format(schema.since_version) + ) + s += " of {}.\n".format(display_domain(schema.domain)) if len(versions) > 1: # TODO: link to the Changelog.md - s += '\nOther versions of this operator: {}\n'.format( - ', '.join(format_name_with_version( - format_name_with_domain(v.domain, v.name), v.since_version) - for v in versions[:-1])) + s += "\nOther versions of this operator: {}\n".format( + ", ".join( + format_name_with_version(format_name_with_domain(v.domain, v.name), v.since_version) + for v in versions[:-1] + ) + ) # If this schema is deprecated, don't display any of the following sections if schema.deprecated: @@ -152,46 +153,44 @@ def display_schema(schema, versions): # type: (OpSchema, Sequence[OpSchema]) -> # attributes attribs = schema.attributes if attribs: - s += '\n#### Attributes\n\n' - s += '
\n' + s += "\n#### Attributes\n\n" + s += "
\n" for _, attr in sorted(attribs.items()): # option holds either required or default value - opt = '' + opt = "" if attr.required: - opt = 'required' - elif hasattr(attr, 'default_value') and attr.default_value.name: + opt = "required" + elif hasattr(attr, "default_value") and attr.default_value.name: default_value = get_attribute_value(attr.default_value) def format_value(value): # type: (Any) -> Text if isinstance(value, float): value = np.round(value, 5) if isinstance(value, (bytes, bytearray)) and sys.version_info[0] == 3: - value = value.decode('utf-8') + value = value.decode("utf-8") return str(value) if isinstance(default_value, list): default_value = [format_value(val) for val in default_value] else: default_value = format_value(default_value) - opt = 'default is {}'.format(default_value) + opt = "default is {}".format(default_value) - s += '
{} : {}{}
\n'.format( - attr.name, - display_attr_type(attr.type), - ' ({})'.format(opt) if opt else '') - s += '
{}
\n'.format(attr.description) - s += '
\n' + s += "
{} : {}{}
\n".format( + attr.name, display_attr_type(attr.type), " ({})".format(opt) if opt else "" + ) + s += "
{}
\n".format(attr.description) + s += "
\n" # inputs - s += '\n#### Inputs' + s += "\n#### Inputs" if schema.min_input != schema.max_input: - s += ' ({} - {})'.format(display_number(schema.min_input), - display_number(schema.max_input)) - s += '\n\n' + s += " ({} - {})".format(display_number(schema.min_input), display_number(schema.max_input)) + s += "\n\n" inputs = schema.inputs if inputs: - s += '
\n' + s += "
\n" for inp in inputs: option_str = "" if OpSchema.FormalParameterOption.Optional == inp.option: @@ -201,20 +200,19 @@ def format_value(value): # type: (Any) -> Text option_str = " (variadic)" else: option_str = " (variadic, heterogeneous)" - s += '
{}{} : {}
\n'.format(inp.name, option_str, inp.typeStr) - s += '
{}
\n'.format(inp.description) + s += "
{}{} : {}
\n".format(inp.name, option_str, inp.typeStr) + s += "
{}
\n".format(inp.description) - s += '
\n' + s += "
\n" # outputs - s += '\n#### Outputs' + s += "\n#### Outputs" if schema.min_output != schema.max_output: - s += ' ({} - {})'.format(display_number(schema.min_output), - display_number(schema.max_output)) - s += '\n\n' + s += " ({} - {})".format(display_number(schema.min_output), display_number(schema.max_output)) + s += "\n\n" outputs = schema.outputs if outputs: - s += '
\n' + s += "
\n" for output in outputs: option_str = "" if OpSchema.FormalParameterOption.Optional == output.option: @@ -224,88 +222,89 @@ def format_value(value): # type: (Any) -> Text option_str = " (variadic)" else: option_str = " (variadic, heterogeneous)" - s += '
{}{} : {}
\n'.format(output.name, option_str, output.typeStr) - s += '
{}
\n'.format(output.description) + s += "
{}{} : {}
\n".format(output.name, option_str, output.typeStr) + s += "
{}
\n".format(output.description) - s += '
\n' + s += "
\n" # type constraints - s += '\n#### Type Constraints' - s += '\n\n' + s += "\n#### Type Constraints" + s += "\n\n" typecons = schema.type_constraints if typecons: - s += '
\n' + s += "
\n" for type_constraint in typecons: allowed_types = type_constraint.allowed_type_strs - allowed_type_str = '' - if (len(allowed_types) > 0): + allowed_type_str = "" + if len(allowed_types) > 0: allowed_type_str = allowed_types[0] for allowedType in allowed_types[1:]: - allowed_type_str += ', ' + allowedType - s += '
{} : {}
\n'.format( - type_constraint.type_param_str, allowed_type_str) - s += '
{}
\n'.format(type_constraint.description) - s += '
\n' + allowed_type_str += ", " + allowedType + s += "
{} : {}
\n".format(type_constraint.type_param_str, allowed_type_str) + s += "
{}
\n".format(type_constraint.description) + s += "
\n" return s def display_function(function, versions, domain=ONNX_DOMAIN): # type: (FunctionProto, List[int], Text) -> Text - s = '' + s = "" if domain: - domain_prefix = '{}.'.format(ONNX_ML_DOMAIN) + domain_prefix = "{}.".format(ONNX_ML_DOMAIN) else: - domain_prefix = '' + domain_prefix = "" # doc if function.doc_string: - s += '\n' - s += '\n'.join(' ' + line - for line in function.doc_string.lstrip().splitlines()) - s += '\n' + s += "\n" + s += "\n".join(" " + line for line in function.doc_string.lstrip().splitlines()) + s += "\n" # since version - s += '\n#### Version\n' - s += '\nThis version of the function has been available since version {}'.format(function.since_version) - s += ' of {}.\n'.format(display_domain(domain_prefix)) + s += "\n#### Version\n" + s += "\nThis version of the function has been available since version {}".format(function.since_version) + s += " of {}.\n".format(display_domain(domain_prefix)) if len(versions) > 1: - s += '\nOther versions of this function: {}\n'.format( - ', '.join(display_function_version_link(domain_prefix + function.name, v) - for v in versions if v != function.since_version)) + s += "\nOther versions of this function: {}\n".format( + ", ".join( + display_function_version_link(domain_prefix + function.name, v) + for v in versions + if v != function.since_version + ) + ) # inputs - s += '\n#### Inputs' - s += '\n\n' + s += "\n#### Inputs" + s += "\n\n" if function.input: - s += '
\n' + s += "
\n" for input in function.input: - s += '
{};
\n'.format(input) - s += '
\n' + s += "
{};
\n".format(input) + s += "
\n" # outputs - s += '\n#### Outputs' - s += '\n\n' + s += "\n#### Outputs" + s += "\n\n" if function.output: - s += '
\n' + s += "
\n" for output in function.output: - s += '
{};
\n'.format(output) - s += '
\n' + s += "
{};
\n".format(output) + s += "
\n" # attributes if function.attribute: - s += '\n#### Attributes\n\n' - s += '
\n' + s += "\n#### Attributes\n\n" + s += "
\n" for attr in function.attribute: - s += '
{};
\n'.format(attr) - s += '
\n' + s += "
{};
\n".format(attr) + s += "
\n" return s def support_level_str(level): # type: (OpSchema.SupportType) -> Text - return \ - "experimental " if level == OpSchema.SupportType.EXPERIMENTAL else "" + return "experimental " if level == OpSchema.SupportType.EXPERIMENTAL else "" # def function_status_str(status=OperatorStatus.Value("EXPERIMENTAL")): # type: ignore @@ -315,24 +314,29 @@ def support_level_str(level): # type: (OpSchema.SupportType) -> Text def main(output_path: str, domain_filter: [str]): - with io.open(output_path, 'w', newline='', encoding="utf-8") as fout: - fout.write('## Contrib Operator Schemas\n') + with io.open(output_path, "w", newline="", encoding="utf-8") as fout: + fout.write("## Contrib Operator Schemas\n") fout.write( "*This file is automatically generated from the registered contrib operator schemas by " "[this script](https://github.com/microsoft/onnxruntime/blob/master/tools/python/gen_contrib_doc.py).\n" - "Do not modify directly.*\n") + "Do not modify directly.*\n" + ) # domain -> support level -> name -> [schema] - index = defaultdict(lambda: defaultdict(lambda: defaultdict(list))) # type: Dict[Text, Dict[int, Dict[Text, List[OpSchema]]]] # noqa: E501 + index = defaultdict( + lambda: defaultdict(lambda: defaultdict(list)) + ) # type: Dict[Text, Dict[int, Dict[Text, List[OpSchema]]]] # noqa: E501 for schema in rtpy.get_all_operator_schema(): index[schema.domain][int(schema.support_level)][schema.name].append(schema) - fout.write('\n') + fout.write("\n") # Preprocess the Operator Schemas # [(domain, [(support_level, [(schema name, current schema, all versions schemas)])])] - operator_schemas = list() # type: List[Tuple[Text, List[Tuple[int, List[Tuple[Text, OpSchema, List[OpSchema]]]]]]] # noqa: E501 + operator_schemas = ( + list() + ) # type: List[Tuple[Text, List[Tuple[int, List[Tuple[Text, OpSchema, List[OpSchema]]]]]]] # noqa: E501 exsting_ops = set() # type: Set[Text] for domain, _supportmap in sorted(index.items()): if not should_render_domain(domain, domain_filter): @@ -353,7 +357,7 @@ def main(output_path: str, domain_filter: [str]): # Table of contents for domain, supportmap in operator_schemas: - s = '* {}\n'.format(display_domain_short(domain)) + s = "* {}\n".format(display_domain_short(domain)) fout.write(s) for _, namemap in supportmap: @@ -361,39 +365,51 @@ def main(output_path: str, domain_filter: [str]): s = ' * {}{}\n'.format( support_level_str(schema.support_level), format_name_with_domain(domain, n), - format_name_with_domain(domain, n)) + format_name_with_domain(domain, n), + ) fout.write(s) - fout.write('\n') + fout.write("\n") for domain, supportmap in operator_schemas: - s = '## {}\n'.format(display_domain_short(domain)) + s = "## {}\n".format(display_domain_short(domain)) fout.write(s) for _, namemap in supportmap: for op_type, schema, versions in namemap: # op_type - s = ('### {}**{}**' + (' (deprecated)' if schema.deprecated else '') - + '\n').format( + s = ( + '### {}**{}**' + + (" (deprecated)" if schema.deprecated else "") + + "\n" + ).format( support_level_str(schema.support_level), format_name_with_domain(domain, op_type), format_name_with_domain(domain, op_type.lower()), - format_name_with_domain(domain, op_type)) + format_name_with_domain(domain, op_type), + ) s += display_schema(schema, versions) - s += '\n\n' + s += "\n\n" fout.write(s) -if __name__ == '__main__': - parser = argparse.ArgumentParser(description='ONNX Runtime Contrib Operator Documentation Generator') - parser.add_argument('--domains', nargs='+', - help="Filter to specified domains. " - "e.g. `--domains com.microsoft com.microsoft.nchwc`") - parser.add_argument('--output_path', help='output markdown file path', type=pathlib.Path, required=True, - default=os.path.join(os.path.dirname(os.path.realpath(__file__)), 'ContribOperators.md')) +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="ONNX Runtime Contrib Operator Documentation Generator") + parser.add_argument( + "--domains", + nargs="+", + help="Filter to specified domains. " "e.g. `--domains com.microsoft com.microsoft.nchwc`", + ) + parser.add_argument( + "--output_path", + help="output markdown file path", + type=pathlib.Path, + required=True, + default=os.path.join(os.path.dirname(os.path.realpath(__file__)), "ContribOperators.md"), + ) args = parser.parse_args() output_path = args.output_path.resolve() diff --git a/tools/python/gen_opkernel_doc.py b/tools/python/gen_opkernel_doc.py index 992115eebd2c2..640bc22ba7255 100644 --- a/tools/python/gen_opkernel_doc.py +++ b/tools/python/gen_opkernel_doc.py @@ -12,38 +12,38 @@ def format_version_range(v): - if (v[1] >= 2147483647): - return str(v[0])+'+' + if v[1] >= 2147483647: + return str(v[0]) + "+" else: - if (v[0] == v[1]): + if v[0] == v[1]: return str(v[0]) else: - return '['+str(v[0])+', '+str(v[1])+']' + return "[" + str(v[0]) + ", " + str(v[1]) + "]" def format_type_constraints(tc): counter = 0 - tcstr = '' + tcstr = "" firsttcitem = True for tcitem in tc: counter += 1 if firsttcitem: firsttcitem = False else: - tcstr += ', ' + tcstr += ", " tcstr += tcitem return tcstr def format_param_strings(params): firstparam = True - s = '' + s = "" if params: for param in sorted(params): if firstparam: firstparam = False else: - s += '

or

' + s += "

or

" s += param return s @@ -53,8 +53,8 @@ def expand_providers(provider_filter: [str]): if provider_filter: for provider in provider_filter: p = provider.lower() - if not p.endswith('executionprovider'): - p += 'executionprovider' + if not p.endswith("executionprovider"): + p += "executionprovider" providers.add(p) return providers @@ -64,29 +64,30 @@ def main(output_path: pathlib.Path, provider_filter: [str]): providers = expand_providers(provider_filter) - with io.open(output_path, 'w', newline='', encoding="utf-8") as fout: - fout.write('## Supported Operators and Data Types\n') + with io.open(output_path, "w", newline="", encoding="utf-8") as fout: + fout.write("## Supported Operators and Data Types\n") fout.write( "*This file is automatically generated from the registered kernels by " "[this script](https://github.com/microsoft/onnxruntime/blob/master/tools/python/gen_opkernel_doc.py).\n" - "Do not modify directly.*\n\n") + "Do not modify directly.*\n\n" + ) opdef = rtpy.get_all_operator_schema() paramdict = {} for schema in opdef: inputs = schema.inputs domain = schema.domain - if (domain == ''): - domain = 'ai.onnx' - fullname = domain+'.'+schema.name - paramstr = '' + if domain == "": + domain = "ai.onnx" + fullname = domain + "." + schema.name + paramstr = "" firstinput = True if inputs: for inp in inputs: if firstinput: firstinput = False else: - paramstr += '
' - paramstr += '*in* {}:**{}**'.format(inp.name, inp.typeStr) + paramstr += "
" + paramstr += "*in* {}:**{}**".format(inp.name, inp.typeStr) outputs = schema.outputs if outputs: @@ -94,10 +95,10 @@ def main(output_path: pathlib.Path, provider_filter: [str]): if firstinput: firstinput = False else: - paramstr += '
' - paramstr += '*out* {}:**{}**'.format(outp.name, outp.typeStr) + paramstr += "
" + paramstr += "*out* {}:**{}**".format(outp.name, outp.typeStr) - paramstr += '' + paramstr += "" paramset = paramdict.get(fullname, None) if paramset is None: paramdict[fullname] = set() @@ -107,28 +108,28 @@ def main(output_path: pathlib.Path, provider_filter: [str]): index = defaultdict(lambda: defaultdict(lambda: defaultdict(list))) for op in rtpy.get_all_opkernel_def(): domain = op.domain - if (domain == ''): - domain = 'ai.onnx' + if domain == "": + domain = "ai.onnx" index[op.provider][domain][op.op_name].append(op) # TOC - fout.write('## Execution Providers\n\n') + fout.write("## Execution Providers\n\n") for provider in sorted(index.keys()): if providers and provider.lower() not in providers: continue - fout.write('- [{}](#{})\n'.format(provider, provider.lower())) - fout.write('\n---------------') + fout.write("- [{}](#{})\n".format(provider, provider.lower())) + fout.write("\n---------------") for provider, domainmap in sorted(index.items()): if providers and provider.lower() not in providers: continue fout.write('\n\n\n\n'.format(provider.lower())) - fout.write('## Operators implemented by {}\n\n'.format(provider)) - fout.write('| Op Name | Parameters | OpSet Version | Types Supported |\n') - fout.write('|---------|------------|---------------|-----------------|\n') + fout.write("## Operators implemented by {}\n\n".format(provider)) + fout.write("| Op Name | Parameters | OpSet Version | Types Supported |\n") + fout.write("|---------|------------|---------------|-----------------|\n") for domain, namemap in sorted(domainmap.items()): - fout.write('|**Operator Domain:** *'+domain+'*||||\n') + fout.write("|**Operator Domain:** *" + domain + "*||||\n") for name, ops in sorted(namemap.items()): version_type_index = defaultdict(lambda: defaultdict(set)) for op in ops: @@ -138,36 +139,44 @@ def main(output_path: pathlib.Path, provider_filter: [str]): namefirsttime = True for version_range, typemap in sorted(version_type_index.items(), key=lambda x: x[0], reverse=True): - if (namefirsttime): - params = paramdict.get(domain+'.'+name, None) - fout.write('|' + name + '|' + format_param_strings(params) + '|') + if namefirsttime: + params = paramdict.get(domain + "." + name, None) + fout.write("|" + name + "|" + format_param_strings(params) + "|") namefirsttime = False else: - fout.write('|||') - fout.write(format_version_range(version_range) + '|') + fout.write("|||") + fout.write(format_version_range(version_range) + "|") tnameindex = 0 for tname, tcset in sorted(typemap.items()): tnameindex += 1 tclist = [] for tc in sorted(tcset): tclist.append(tc) - fout.write('**'+tname+'** = '+format_type_constraints(tclist)) - if (tnameindex < len(typemap)): - fout.write('
') - fout.write('|\n') - - fout.write('| |\n| |\n') - - -if __name__ == '__main__': - parser = argparse.ArgumentParser(description='ONNX Runtime Operator Kernel Documentation Generator') - parser.add_argument('--providers', nargs='+', - help="Filter to specified execution providers. Case-insensitive. " - "Matches provider names from /include/onnxruntime/core/graph/constants.h'. " - "'ExecutionProvider' is automatically appended as needed. " - "e.g. `--providers cpu cuda` will match CPUExecutionProvider and CUDAExecutionProvider.") - parser.add_argument('--output_path', help='output markdown file path', type=pathlib.Path, required=True, - default=os.path.join(os.path.dirname(os.path.realpath(__file__)), 'OperatorKernels.md')) + fout.write("**" + tname + "** = " + format_type_constraints(tclist)) + if tnameindex < len(typemap): + fout.write("
") + fout.write("|\n") + + fout.write("| |\n| |\n") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="ONNX Runtime Operator Kernel Documentation Generator") + parser.add_argument( + "--providers", + nargs="+", + help="Filter to specified execution providers. Case-insensitive. " + "Matches provider names from /include/onnxruntime/core/graph/constants.h'. " + "'ExecutionProvider' is automatically appended as needed. " + "e.g. `--providers cpu cuda` will match CPUExecutionProvider and CUDAExecutionProvider.", + ) + parser.add_argument( + "--output_path", + help="output markdown file path", + type=pathlib.Path, + required=True, + default=os.path.join(os.path.dirname(os.path.realpath(__file__)), "OperatorKernels.md"), + ) args = parser.parse_args() main(args.output_path, args.providers) diff --git a/tools/python/gen_ort_mobile_pkg_doc.py b/tools/python/gen_ort_mobile_pkg_doc.py index b8de140556402..5818c362d23b2 100644 --- a/tools/python/gen_ort_mobile_pkg_doc.py +++ b/tools/python/gen_ort_mobile_pkg_doc.py @@ -1,29 +1,33 @@ import argparse import os import pathlib + from util import reduced_build_config_parser from util.ort_format_model.operator_type_usage_processors import GloballyAllowedTypesOpTypeImplFilter def generate_docs(output_file, required_ops, op_type_impl_filter): - with open(output_file, 'w') as out: - out.write('# ONNX Runtime Mobile Pre-Built Package Operator and Type Support\n\n') + with open(output_file, "w") as out: + out.write("# ONNX Runtime Mobile Pre-Built Package Operator and Type Support\n\n") # Description - out.write('## Supported operators and types\n\n') - out.write('The supported operators and types are based on what is required to support float32 and quantized ' - 'versions of popular models. The full list of input models used to determine this list is available ' - '[here](https://github.com/microsoft/onnxruntime/blob/master/tools/ci_build/github/android/mobile_package.required_operators.readme.txt)') # noqa - out.write('\n\n') + out.write("## Supported operators and types\n\n") + out.write( + "The supported operators and types are based on what is required to support float32 and quantized " + "versions of popular models. The full list of input models used to determine this list is available " + "[here](https://github.com/microsoft/onnxruntime/blob/master/tools/ci_build/github/android/mobile_package" + ".required_operators.readme.txt)" + ) + out.write("\n\n") # Globally supported types - out.write('## Supported data input types\n\n') - assert(op_type_impl_filter.__class__ is GloballyAllowedTypesOpTypeImplFilter) + out.write("## Supported data input types\n\n") + assert op_type_impl_filter.__class__ is GloballyAllowedTypesOpTypeImplFilter global_types = op_type_impl_filter.global_type_list() for type in sorted(global_types): - out.write(' - {}\n'.format(type)) - out.write('\n') - out.write('NOTE: Operators used to manipulate dimensions and indices will support int32 and int64.\n\n') + out.write(" - {}\n".format(type)) + out.write("\n") + out.write("NOTE: Operators used to manipulate dimensions and indices will support int32 and int64.\n\n") domain_op_opsets = [] for domain in sorted(required_ops.keys()): @@ -32,41 +36,53 @@ def generate_docs(output_file, required_ops, op_type_impl_filter): for opset in sorted(required_ops[domain].keys()): str_opset = str(opset) for op in required_ops[domain][opset]: - op_with_domain = '{}:{}'.format(domain, op) + op_with_domain = "{}:{}".format(domain, op) if op_with_domain not in op_opsets: op_opsets[op_with_domain] = [] op_opsets[op_with_domain].append(str_opset) - out.write('## Supported Operators\n\n') - out.write('|Operator|Opsets|\n') - out.write('|--------|------|\n') + out.write("## Supported Operators\n\n") + out.write("|Operator|Opsets|\n") + out.write("|--------|------|\n") for domain, op_opsets in domain_op_opsets: - out.write('|**{}**||\n'.format(domain)) + out.write("|**{}**||\n".format(domain)) for op in sorted(op_opsets.keys()): - out.write('|{}|{}|\n'.format(op, ', '.join(op_opsets[op]))) - out.write('|||\n') + out.write("|{}|{}|\n".format(op, ", ".join(op_opsets[op]))) + out.write("|||\n") def main(): script_dir = os.path.dirname(os.path.realpath(__file__)) parser = argparse.ArgumentParser( - description='ONNX Runtime Mobile Pre-Built Package Operator and Type Support Documentation Generator', - formatter_class=argparse.ArgumentDefaultsHelpFormatter) - - default_config_path = \ - pathlib.Path(os.path.join(script_dir, '../ci_build/github/android/mobile_package.required_operators.config') - ).resolve() - - default_output_path = \ - pathlib.Path(os.path.join(script_dir, '../../docs/ORTMobilePackageOperatorTypeSupport.md')).resolve() - - parser.add_argument('--config_path', help='Path to build configuration used to generate package.', required=False, - type=pathlib.Path, default=default_config_path) - - parser.add_argument('--output_path', help='output markdown file path', required=False, - type=pathlib.Path, default=default_output_path) + description="ONNX Runtime Mobile Pre-Built Package Operator and Type Support Documentation Generator", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + + default_config_path = pathlib.Path( + os.path.join(script_dir, "../ci_build/github/android/mobile_package.required_operators.config") + ).resolve() + + default_output_path = pathlib.Path( + os.path.join(script_dir, "../../docs/ORTMobilePackageOperatorTypeSupport.md") + ).resolve() + + parser.add_argument( + "--config_path", + help="Path to build configuration used to generate package.", + required=False, + type=pathlib.Path, + default=default_config_path, + ) + + parser.add_argument( + "--output_path", + help="output markdown file path", + required=False, + type=pathlib.Path, + default=default_output_path, + ) args = parser.parse_args() config_file = args.config_path.resolve(strict=True) # must exist so strict=True @@ -77,5 +93,5 @@ def main(): generate_docs(output_path, required_ops, op_type_impl_filter) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/tools/python/get_submodules.py b/tools/python/get_submodules.py index c2992c9aabb5f..88ac53a2a86c8 100644 --- a/tools/python/get_submodules.py +++ b/tools/python/get_submodules.py @@ -1,8 +1,8 @@ -from pathlib import Path import argparse import configparser import json import re +from pathlib import Path import pygit2 @@ -29,9 +29,9 @@ def lookup_submodule(repo, submodule_path): pass config = configparser.ConfigParser() - config.read(Path(repo.workdir, '.gitmodules')) + config.read(Path(repo.workdir, ".gitmodules")) for section in config.sections(): - if config[section]['path'] == submodule_path: + if config[section]["path"] == submodule_path: name = re.fullmatch('submodule "(.*)"', section).group(1) submodule = repo.lookup_submodule(name) return submodule @@ -56,7 +56,7 @@ def recursive_process(base_repo): def main(repo_path, output_file): repo = pygit2.Repository(repo_path) registrations = recursive_process(repo) - with open(output_file, 'w') as f: + with open(output_file, "w") as f: json.dump(registrations, f, indent=4, sort_keys=True) diff --git a/tools/python/onnx_test_data_utils.py b/tools/python/onnx_test_data_utils.py index e4e4c54aa12d6..c7670e4a84258 100644 --- a/tools/python/onnx_test_data_utils.py +++ b/tools/python/onnx_test_data_utils.py @@ -4,7 +4,6 @@ import sys import numpy as np - import onnx from onnx import numpy_helper @@ -31,7 +30,7 @@ def dump_pb(dir_or_filename): All files must contain a serialized TensorProto.""" if os.path.isdir(dir_or_filename): - for f in glob.glob(os.path.join(dir_or_filename, '*.pb')): + for f in glob.glob(os.path.join(dir_or_filename, "*.pb")): print(f) dump_tensorproto_pb_file(f) else: @@ -115,77 +114,93 @@ def get_arg_parser(): update_name_in_pb: Update the TensorProto.name value in a pb file. Updates the input file unless --output is specified. """, - formatter_class=argparse.RawDescriptionHelpFormatter + formatter_class=argparse.RawDescriptionHelpFormatter, + ) + + parser.add_argument( + "--action", + help="Action to perform", + choices=["dump_pb", "numpy_to_pb", "image_to_pb", "random_to_pb", "update_name_in_pb"], + required=True, + ) + + parser.add_argument("--input", help="The input filename or directory name") + parser.add_argument("--name", help="The value to set TensorProto.name to if creating/updating one.") + parser.add_argument("--output", help="Filename to serialize the TensorProto to.") + + image_to_pb_group = parser.add_argument_group("image_to_pb", "image_to_pb specific options") + image_to_pb_group.add_argument( + "--resize", + default=None, + type=lambda s: [int(item) for item in s.split(",")], + help="Provide the height and width to resize to as comma separated values." + " e.g. --shape 200,300 will resize to height 200 and width 300.", + ) + image_to_pb_group.add_argument( + "--channels_last", action="store_true", help="Transpose image from channels first to channels last." + ) + image_to_pb_group.add_argument( + "--add_batch_dim", + action="store_true", + help="Prepend a batch dimension with value of 1 to the shape. " "i.e. convert from CHW to NCHW", ) - parser.add_argument('--action', help='Action to perform', - choices=['dump_pb', 'numpy_to_pb', 'image_to_pb', 'random_to_pb', 'update_name_in_pb'], - required=True) - - parser.add_argument('--input', help='The input filename or directory name') - parser.add_argument('--name', help='The value to set TensorProto.name to if creating/updating one.') - parser.add_argument('--output', help='Filename to serialize the TensorProto to.') - - image_to_pb_group = parser.add_argument_group('image_to_pb', - 'image_to_pb specific options') - image_to_pb_group.add_argument('--resize', default=None, type=lambda s: [int(item) for item in s.split(',')], - help='Provide the height and width to resize to as comma separated values.' - ' e.g. --shape 200,300 will resize to height 200 and width 300.') - image_to_pb_group.add_argument('--channels_last', action='store_true', - help='Transpose image from channels first to channels last.') - image_to_pb_group.add_argument('--add_batch_dim', action='store_true', - help='Prepend a batch dimension with value of 1 to the shape. ' - 'i.e. convert from CHW to NCHW') - - random_to_pb_group = parser.add_argument_group('random_to_pb', - 'random_to_pb specific options') - random_to_pb_group.add_argument('--shape', type=lambda s: [int(item) for item in s.split(',')], - help='Provide the shape as comma separated values e.g. --shape 200,200') - random_to_pb_group.add_argument('--datatype', - help="numpy dtype value for the data type. e.g. f4=float32, i8=int64. " - "See: https://docs.scipy.org/doc/numpy/reference/arrays.dtypes.html") - random_to_pb_group.add_argument('--min_value', default=0, type=int, - help="Limit the generated values to this minimum.") - random_to_pb_group.add_argument('--max_value', default=1, type=int, - help="Limit the generated values to this maximum.") - random_to_pb_group.add_argument('--seed', default=None, type=int, - help="seed to use for the random values so they're deterministic.") + random_to_pb_group = parser.add_argument_group("random_to_pb", "random_to_pb specific options") + random_to_pb_group.add_argument( + "--shape", + type=lambda s: [int(item) for item in s.split(",")], + help="Provide the shape as comma separated values e.g. --shape 200,200", + ) + random_to_pb_group.add_argument( + "--datatype", + help="numpy dtype value for the data type. e.g. f4=float32, i8=int64. " + "See: https://docs.scipy.org/doc/numpy/reference/arrays.dtypes.html", + ) + random_to_pb_group.add_argument( + "--min_value", default=0, type=int, help="Limit the generated values to this minimum." + ) + random_to_pb_group.add_argument( + "--max_value", default=1, type=int, help="Limit the generated values to this maximum." + ) + random_to_pb_group.add_argument( + "--seed", default=None, type=int, help="seed to use for the random values so they're deterministic." + ) return parser -if __name__ == '__main__': +if __name__ == "__main__": arg_parser = get_arg_parser() args = arg_parser.parse_args() - if args.action == 'dump_pb': + if args.action == "dump_pb": if not args.input: print("Missing argument. Need input to be specified.", file=sys.stderr) sys.exit(-1) np.set_printoptions(precision=10) dump_pb(args.input) - elif args.action == 'numpy_to_pb': + elif args.action == "numpy_to_pb": if not args.input or not args.output or not args.name: print("Missing argument. Need input, output and name to be specified.", file=sys.stderr) sys.exit(-1) # read data saved with numpy data = np.load(args.input) numpy_to_pb(args.name, data, args.output) - elif args.action == 'image_to_pb': + elif args.action == "image_to_pb": if not args.input or not args.output or not args.name: print("Missing argument. Need input, output, name to be specified.", file=sys.stderr) sys.exit(-1) img_np = image_to_numpy(args.input, args.resize, args.channels_last, args.add_batch_dim) numpy_to_pb(args.name, img_np, args.output) - elif args.action == 'random_to_pb': + elif args.action == "random_to_pb": if not args.output or not args.shape or not args.datatype or not args.name: print("Missing argument. Need output, shape, datatype and name to be specified.", file=sys.stderr) sys.exit(-1) data = create_random_data(args.shape, args.datatype, args.min_value, args.max_value, args.seed) numpy_to_pb(args.name, data, args.output) - elif args.action == 'update_name_in_pb': + elif args.action == "update_name_in_pb": if not args.input or not args.name: print("Missing argument. Need input and name to be specified.", file=sys.stderr) sys.exit(-1) diff --git a/tools/python/ort_test_dir_utils.py b/tools/python/ort_test_dir_utils.py index 22037673f0274..1b5525b9169d2 100644 --- a/tools/python/ort_test_dir_utils.py +++ b/tools/python/ort_test_dir_utils.py @@ -1,19 +1,20 @@ import glob -import numpy as np -import onnx -import onnx_test_data_utils -import onnxruntime as ort import os import shutil +import numpy as np +import onnx +import onnx_test_data_utils from onnx import numpy_helper +import onnxruntime as ort + def _get_numpy_type(model_info, name): for i in model_info: if i.name == name: - type_name = i.type.WhichOneof('value') - if type_name == 'tensor_type': + type_name = i.type.WhichOneof("value") + if type_name == "tensor_type": return onnx.mapping.TENSOR_TYPE_TO_NP_TYPE[i.type.tensor_type.elem_type] else: raise ValueError("Type is not handled: {}".format(type_name)) @@ -36,17 +37,17 @@ def _create_missing_input_data(model_inputs, name_input_map, symbolic_dim_values # models whose ir_version < 4 can have input same as initializer; no need to create input data if input.name in initializer_set: continue - input_type = input.type.WhichOneof('value') - if input_type != 'tensor_type': - raise ValueError('Unsupported model. Need to handle input type of {}'.format(input_type)) + input_type = input.type.WhichOneof("value") + if input_type != "tensor_type": + raise ValueError("Unsupported model. Need to handle input type of {}".format(input_type)) shape = input.type.tensor_type.shape dims = [] for dim in shape.dim: - dim_type = dim.WhichOneof('value') - if dim_type == 'dim_value': + dim_type = dim.WhichOneof("value") + if dim_type == "dim_value": dims.append(dim.dim_value) - elif dim_type == 'dim_param': + elif dim_type == "dim_param": if dim.dim_param not in symbolic_dim_values_map: raise ValueError("Value for symbolic dim {} was not provided.".format(dim.dim_param)) @@ -63,9 +64,9 @@ def _create_missing_input_data(model_inputs, name_input_map, symbolic_dim_values name_input_map[input.name] = data -def create_test_dir(model_path, root_path, test_name, - name_input_map=None, symbolic_dim_values_map=None, - name_output_map=None): +def create_test_dir( + model_path, root_path, test_name, name_input_map=None, symbolic_dim_values_map=None, name_output_map=None +): """ Create a test directory that can be used with onnx_test_runner or onnxruntime_perf_test. Generates random input data for any missing inputs. @@ -119,7 +120,7 @@ def save_data(prefix, name_data_map, model_info): np_type = _get_numpy_type(model_info, name) tensor = numpy_helper.from_array(data.astype(np_type), name) filename = os.path.join(test_data_dir, "{}_{}.pb".format(prefix, idx)) - with open(filename, 'wb') as f: + with open(filename, "wb") as f: f.write(tensor.SerializeToString()) idx += 1 @@ -159,8 +160,8 @@ def read_test_dir(dir_name): inputs = {} outputs = {} - input_files = glob.glob(os.path.join(dir_name, 'input_*.pb')) - output_files = glob.glob(os.path.join(dir_name, 'output_*.pb')) + input_files = glob.glob(os.path.join(dir_name, "input_*.pb")) + output_files = glob.glob(os.path.join(dir_name, "output_*.pb")) for i in input_files: name, data = onnx_test_data_utils.read_tensorproto_pb_file(i) @@ -186,12 +187,14 @@ def run_test_dir(model_or_dir): if os.path.isdir(model_or_dir): model_dir = os.path.abspath(model_or_dir) # check there's only one onnx file - onnx_models = glob.glob(os.path.join(model_dir, '*.onnx')) - ort_models = glob.glob(os.path.join(model_dir, '*.ort')) + onnx_models = glob.glob(os.path.join(model_dir, "*.onnx")) + ort_models = glob.glob(os.path.join(model_dir, "*.ort")) models = onnx_models + ort_models if len(models) > 1: - raise ValueError("'Multiple .onnx and/or .ort files found in {}. '" - "'Please provide specific .onnx or .ort file as input.".format(model_dir)) + raise ValueError( + "'Multiple .onnx and/or .ort files found in {}. '" + "'Please provide specific .onnx or .ort file as input.".format(model_dir) + ) elif len(models) == 0: raise ValueError("'No .onnx or .ort files found in {}.".format(model_dir)) @@ -200,9 +203,9 @@ def run_test_dir(model_or_dir): model_path = os.path.abspath(model_or_dir) model_dir = os.path.dirname(model_path) - print('Running tests in {} for {}'.format(model_dir, model_path)) + print("Running tests in {} for {}".format(model_dir, model_path)) - test_dirs = [d for d in glob.glob(os.path.join(model_dir, 'test*')) if os.path.isdir(d)] + test_dirs = [d for d in glob.glob(os.path.join(model_dir, "test*")) if os.path.isdir(d)] if not test_dirs: raise ValueError("No directories with name starting with 'test' were found in {}.".format(model_dir)) @@ -216,11 +219,11 @@ def run_test_dir(model_or_dir): output_names = list(expected_outputs.keys()) # handle case where there's a single expected output file but no name in it (empty string for name) # e.g. ONNX test models 20190729\opset8\tf_mobilenet_v2_1.4_224 - if len(output_names) == 1 and output_names[0] == '': + if len(output_names) == 1 and output_names[0] == "": output_names = [o.name for o in sess.get_outputs()] - assert len(output_names) == 1, 'There should be single output_name.' - expected_outputs[output_names[0]] = expected_outputs[''] - expected_outputs.pop('') + assert len(output_names) == 1, "There should be single output_name." + expected_outputs[output_names[0]] = expected_outputs[""] + expected_outputs.pop("") else: output_names = [o.name for o in sess.get_outputs()] @@ -232,15 +235,15 @@ def run_test_dir(model_or_dir): expected = expected_outputs[output_names[idx]] actual = run_outputs[idx] - if expected.dtype.char in np.typecodes['AllFloat']: - if not np.isclose(expected, actual, rtol=1.e-3, atol=1.e-3).all(): - print('Mismatch for {}:\nExpected:{}\nGot:{}'.format(output_names[idx], expected, actual)) + if expected.dtype.char in np.typecodes["AllFloat"]: + if not np.isclose(expected, actual, rtol=1.0e-3, atol=1.0e-3).all(): + print("Mismatch for {}:\nExpected:{}\nGot:{}".format(output_names[idx], expected, actual)) failed = True else: if not np.equal(expected, actual).all(): - print('Mismatch for {}:\nExpected:{}\nGot:{}'.format(output_names[idx], expected, actual)) + print("Mismatch for {}:\nExpected:{}\nGot:{}".format(output_names[idx], expected, actual)) failed = True if failed: - raise ValueError('FAILED due to output mismatch.') + raise ValueError("FAILED due to output mismatch.") else: - print('PASS') + print("PASS") diff --git a/tools/python/remove_initializer_from_input.py b/tools/python/remove_initializer_from_input.py index 9111dd3cd5e40..935d5a44c75fe 100644 --- a/tools/python/remove_initializer_from_input.py +++ b/tools/python/remove_initializer_from_input.py @@ -1,6 +1,7 @@ -import onnx import argparse +import onnx + def get_args(): parser = argparse.ArgumentParser() @@ -15,9 +16,7 @@ def remove_initializer_from_input(): model = onnx.load(args.input) if model.ir_version < 4: - print( - 'Model with ir_version below 4 requires to include initilizer in graph input' - ) + print("Model with ir_version below 4 requires to include initilizer in graph input") return inputs = model.graph.input @@ -32,5 +31,5 @@ def remove_initializer_from_input(): onnx.save(model, args.output) -if __name__ == '__main__': +if __name__ == "__main__": remove_initializer_from_input() diff --git a/tools/python/run_android_emulator.py b/tools/python/run_android_emulator.py index f818fc5f4cc76..31793fbbc1340 100755 --- a/tools/python/run_android_emulator.py +++ b/tools/python/run_android_emulator.py @@ -7,9 +7,8 @@ import shlex import sys -from util import get_logger import util.android as android - +from util import get_logger log = get_logger("run_android_emulator") @@ -18,32 +17,29 @@ def parse_args(): parser = argparse.ArgumentParser( description="Manages the running of an Android emulator. " "Supported modes are to start and stop (default), only start, or only " - "stop the emulator.") + "stop the emulator." + ) - parser.add_argument( - "--create-avd", action="store_true", - help="Whether to create the Android virtual device.") + parser.add_argument("--create-avd", action="store_true", help="Whether to create the Android virtual device.") - parser.add_argument( - "--start", action="store_true", help="Start the emulator.") - parser.add_argument( - "--stop", action="store_true", help="Stop the emulator.") + parser.add_argument("--start", action="store_true", help="Start the emulator.") + parser.add_argument("--stop", action="store_true", help="Stop the emulator.") + parser.add_argument("--android-sdk-root", required=True, help="Path to the Android SDK root.") parser.add_argument( - "--android-sdk-root", required=True, help="Path to the Android SDK root.") - parser.add_argument( - "--system-image", default="system-images;android-29;google_apis;x86_64", - help="The Android system image package name.") - parser.add_argument( - "--avd-name", default="ort_android", - help="The Android virtual device name.") + "--system-image", + default="system-images;android-29;google_apis;x86_64", + help="The Android system image package name.", + ) + parser.add_argument("--avd-name", default="ort_android", help="The Android virtual device name.") parser.add_argument( - "--emulator-extra-args", default="", - help="A string of extra arguments to pass to the Android emulator.") + "--emulator-extra-args", default="", help="A string of extra arguments to pass to the Android emulator." + ) parser.add_argument( "--emulator-pid-file", help="Output/input file containing the PID of the emulator process. " - "This is only required if exactly one of --start or --stop is given.") + "This is only required if exactly one of --start or --stop is given.", + ) args = parser.parse_args() diff --git a/tools/python/sparsify_initializers.py b/tools/python/sparsify_initializers.py index df461f1a9219c..17bddae6bbe40 100644 --- a/tools/python/sparsify_initializers.py +++ b/tools/python/sparsify_initializers.py @@ -8,9 +8,10 @@ import argparse import logging -import numpy as np import sys -from typing import Tuple, List +from typing import List, Tuple + +import numpy as np import onnx from onnx import ModelProto, SparseTensorProto, TensorProto, numpy_helper @@ -21,16 +22,20 @@ def parse_arguments(): parser = argparse.ArgumentParser() - parser.add_argument('--input', required=True, type=str, help='input model path') - parser.add_argument('--output', required=True, type=str, help='output model path') - parser.add_argument('--exclude', required=False, type=str, - help='semicolon separated list of initializer names to exclude') - parser.add_argument('--tolerance', required=False, type=float, default=1e-6, - help='FP absolute tolerance.') - parser.add_argument('--sparsity_threshold', required=False, - type=float, default=0.5, - help='convert to sparse initializers if sparsity is at least this much') - parser.add_argument('--verbose', required=False, action='store_true') + parser.add_argument("--input", required=True, type=str, help="input model path") + parser.add_argument("--output", required=True, type=str, help="output model path") + parser.add_argument( + "--exclude", required=False, type=str, help="semicolon separated list of initializer names to exclude" + ) + parser.add_argument("--tolerance", required=False, type=float, default=1e-6, help="FP absolute tolerance.") + parser.add_argument( + "--sparsity_threshold", + required=False, + type=float, + default=0.5, + help="convert to sparse initializers if sparsity is at least this much", + ) + parser.add_argument("--verbose", required=False, action="store_true") parser.set_defaults(verbose=False) args = parser.parse_args() return args @@ -39,21 +44,20 @@ def parse_arguments(): def setup_logging(verbose): # type: (bool) -> None log_handler = logging.StreamHandler(sys.stdout) if verbose: - log_handler.setFormatter(logging.Formatter('[%(filename)s:%(lineno)s - %(funcName)20s()] %(message)s')) + log_handler.setFormatter(logging.Formatter("[%(filename)s:%(lineno)s - %(funcName)20s()] %(message)s")) logging_level = logging.DEBUG else: - log_handler.setFormatter(logging.Formatter('%(filename)20s: %(message)s')) + log_handler.setFormatter(logging.Formatter("%(filename)20s: %(message)s")) logging_level = logging.INFO log_handler.setLevel(logging_level) logger.addHandler(log_handler) logger.setLevel(logging_level) -def convert_tensor_to_sparse(tensor, - sparsity_threshold, - tolerance): # type: (TensorProto, float, float) -> Tuple[SparseTensorProto, float] - """ returns a tuple of sparse_tensor and sparsity level - """ +def convert_tensor_to_sparse( + tensor, sparsity_threshold, tolerance +): # type: (TensorProto, float, float) -> Tuple[SparseTensorProto, float] + """returns a tuple of sparse_tensor and sparsity level""" values = [] indices = [] nnz_count = 0 @@ -74,7 +78,7 @@ def convert_tensor_to_sparse(tensor, indices.append(index) nnz_count += 1 - sparsity = float(1.) - float(nnz_count)/data_len + sparsity = float(1.0) - float(nnz_count) / data_len ind_data_type = TensorProto.INT8 ind_dtype = np.int8 @@ -95,9 +99,11 @@ def convert_tensor_to_sparse(tensor, ind_data_type = TensorProto.INT64 ind_dtype = np.int64 - logger.debug(f"initializer={tensor.name}, dtype={tensor_data.dtype}, \ + logger.debug( + f"initializer={tensor.name}, dtype={tensor_data.dtype}, \ data_len={data_len}, nnz={nnz_count}, sparsity={sparsity}, \ - max_indices_value={max_indices_value}, sparse_indices_type={ind_dtype}") + max_indices_value={max_indices_value}, sparse_indices_type={ind_dtype}" + ) if sparsity < sparsity_threshold: return (object(), sparsity) @@ -109,8 +115,10 @@ def convert_tensor_to_sparse(tensor, np_indices = np.array(indices).astype(ind_dtype) total_sparse_bytes = np_values.nbytes + np_indices.nbytes - logger.debug(f"initializer={tensor.name}, initializer_bytes={tensor_data_bytes}, \ - sparse_initializer_bytes={total_sparse_bytes}") + logger.debug( + f"initializer={tensor.name}, initializer_bytes={tensor_data_bytes}, \ + sparse_initializer_bytes={total_sparse_bytes}" + ) # This check is usually useful for sparsity_threshold=0.5 where much # depends on the size of the indices entries and the size of the original tensor. @@ -118,30 +126,23 @@ def convert_tensor_to_sparse(tensor, # int32 indices are often selected, thus we really want to guard against loosing # rather than winning. if tensor_data_bytes <= total_sparse_bytes: - sparsity = float(1.) - float(tensor_data_bytes)/total_sparse_bytes + sparsity = float(1.0) - float(tensor_data_bytes) / total_sparse_bytes logger.debug(f"initializer={tensor.name}, adjusted_sparsity={sparsity}") return (object(), sparsity) - values_tensor = onnx.helper.make_tensor(tensor.name, - tensor.data_type, - [len(values)], - np_values.tobytes(), - raw=True) + values_tensor = onnx.helper.make_tensor(tensor.name, tensor.data_type, [len(values)], np_values.tobytes(), raw=True) - indicies_tensor = onnx.helper.make_tensor(tensor.name + '_indicies', - ind_data_type, - [ind_len], - np_indices.tobytes(), - raw=True) + indicies_tensor = onnx.helper.make_tensor( + tensor.name + "_indicies", ind_data_type, [ind_len], np_indices.tobytes(), raw=True + ) sparse_tensor = onnx.helper.make_sparse_tensor(values_tensor, indicies_tensor, tensor.dims) return (sparse_tensor, sparsity) -def convert_initializers(model, - exclude_names, - sparsity_threshold, - tolerance): # type: (ModelProto, List[str], float, float) -> None +def convert_initializers( + model, exclude_names, sparsity_threshold, tolerance +): # type: (ModelProto, List[str], float, float) -> None graph = model.graph converted_sparse = [] remaining_initializers = [] @@ -170,7 +171,7 @@ def main(): args = parse_arguments() setup_logging(args.verbose) - exclude_names = set() if args.exclude is None else set(args.exclude.split(';')) + exclude_names = set() if args.exclude is None else set(args.exclude.split(";")) model = ModelProto() with open(args.input, "rb") as input_file: diff --git a/tools/python/update_version.py b/tools/python/update_version.py index 419f9e9013fb6..ca94bacd03b9a 100755 --- a/tools/python/update_version.py +++ b/tools/python/update_version.py @@ -2,121 +2,130 @@ def update_version(): - version = '' + version = "" cwd = os.path.dirname(os.path.realpath(__file__)) - with open(os.path.join(cwd, '..', '..', 'VERSION_NUMBER')) as f: + with open(os.path.join(cwd, "..", "..", "VERSION_NUMBER")) as f: version = f.readline().strip() lines = [] - current_version = '' - file_path = os.path.join(cwd, '..', '..', 'docs', 'Versioning.md') + current_version = "" + file_path = os.path.join(cwd, "..", "..", "docs", "Versioning.md") with open(file_path) as f: lines = f.readlines() for line in lines: - if line.startswith('|'): - sections = line.split('|') + if line.startswith("|"): + sections = line.split("|") if len(sections) == 8 and sections[1].strip()[0].isdigit(): current_version = sections[1].strip() break - print('Current version of ORT seems to be: ' + current_version) + print("Current version of ORT seems to be: " + current_version) if version != current_version: - with open(file_path, 'w') as f: + with open(file_path, "w") as f: for i, line in enumerate(lines): f.write(line) - if line.startswith('|--'): - sections = lines[i+1].split('|') + if line.startswith("|--"): + sections = lines[i + 1].split("|") # Make sure there are no 'False Positive' version additions # by making sure the line we are building a new line from # contains the current_version if len(sections) > 1 and sections[1].strip() == current_version: - sections[1] = ' ' + version + ' ' - new_line = '|'.join(sections) + sections[1] = " " + version + " " + new_line = "|".join(sections) f.write(new_line) lines = [] - current_version = '' - file_path = os.path.join(cwd, '..', '..', 'docs', 'python', 'README.rst') + current_version = "" + file_path = os.path.join(cwd, "..", "..", "docs", "python", "README.rst") with open(file_path) as f: lines = f.readlines() for line in lines: - sections = line.strip().split('.') + sections = line.strip().split(".") if len(sections) == 3 and sections[0].isdigit() and sections[1].isdigit() and sections[2].isdigit(): current_version = line.strip() break if version != current_version: inserted = False - with open(file_path, 'w') as f: + with open(file_path, "w") as f: for line in lines: - sections = line.strip().split('.') - if inserted is False and len(sections) == 3 and \ - sections[0].isdigit() and sections[1].isdigit() and sections[2].isdigit(): - f.write(version + '\n') - f.write('^'*len(version) + '\n\n') - f.write('Release Notes : https://github.com/Microsoft/onnxruntime/releases/tag/v' - + version.strip() + '\n\n') + sections = line.strip().split(".") + if ( + inserted is False + and len(sections) == 3 + and sections[0].isdigit() + and sections[1].isdigit() + and sections[2].isdigit() + ): + f.write(version + "\n") + f.write("^" * len(version) + "\n\n") + f.write( + "Release Notes : https://github.com/Microsoft/onnxruntime/releases/tag/v" + + version.strip() + + "\n\n" + ) inserted = True f.write(line) lines = [] - current_version = '' - file_path = os.path.join(cwd, '..', '..', 'package', 'rpm', 'onnxruntime.spec') + current_version = "" + file_path = os.path.join(cwd, "..", "..", "package", "rpm", "onnxruntime.spec") with open(file_path) as f: lines = f.readlines() for line in lines: - if line.startswith('Version:'): - current_version = line.split(':')[1].strip() + if line.startswith("Version:"): + current_version = line.split(":")[1].strip() break if version != current_version: - with open(file_path, 'w') as f: + with open(file_path, "w") as f: for line in lines: - if line.startswith('Version:'): - f.write('Version: ' + version + '\n') + if line.startswith("Version:"): + f.write("Version: " + version + "\n") continue f.write(line) lines = [] - current_version = '' - file_path = os.path.join(cwd, '..', '..', 'onnxruntime', '__init__.py') + current_version = "" + file_path = os.path.join(cwd, "..", "..", "onnxruntime", "__init__.py") with open(file_path) as f: lines = f.readlines() for line in lines: - if line.startswith('__version__'): - current_version = line.split('=')[1].strip()[1:-1] + if line.startswith("__version__"): + current_version = line.split("=")[1].strip()[1:-1] break if version != current_version: - with open(file_path, 'w') as f: + with open(file_path, "w") as f: for line in lines: - if line.startswith('__version__'): + if line.startswith("__version__"): f.write('__version__ = "' + version + '"\n') continue f.write(line) # update version for NPM packages - current_version = '' - js_root = os.path.join(cwd, '..', '..', 'js') + current_version = "" + js_root = os.path.join(cwd, "..", "..", "js") def run(args, cwd): - from util import run, is_windows + from util import is_windows, run + if is_windows(): - args = ['cmd', '/c'] + args + args = ["cmd", "/c"] + args run(*args, cwd=cwd) # check if node, npm and yarn are installed - run(['node', '--version'], cwd=js_root) - run(['npm', '--version'], cwd=js_root) - run(['yarn', '--version'], cwd=js_root) + run(["node", "--version"], cwd=js_root) + run(["npm", "--version"], cwd=js_root) + run(["yarn", "--version"], cwd=js_root) # upgrade version for onnxruntime-common - run(['npm', 'version', version], cwd=os.path.join(js_root, 'common')) - run(['npm', 'install', '--package-lock-only', '--ignore-scripts'], cwd=os.path.join(js_root, 'common')) + run(["npm", "version", version], cwd=os.path.join(js_root, "common")) + run(["npm", "install", "--package-lock-only", "--ignore-scripts"], cwd=os.path.join(js_root, "common")) # upgrade version for onnxruntime-node - run(['npm', 'version', version], cwd=os.path.join(js_root, 'node')) - run(['npm', 'install', '--package-lock-only', '--ignore-scripts'], cwd=os.path.join(js_root, 'node')) + run(["npm", "version", version], cwd=os.path.join(js_root, "node")) + run(["npm", "install", "--package-lock-only", "--ignore-scripts"], cwd=os.path.join(js_root, "node")) # upgrade version for onnxruntime-web - run(['npm', 'version', version], cwd=os.path.join(js_root, 'web')) - run(['npm', 'install', '--package-lock-only', '--ignore-scripts'], cwd=os.path.join(js_root, 'web')) + run(["npm", "version", version], cwd=os.path.join(js_root, "web")) + run(["npm", "install", "--package-lock-only", "--ignore-scripts"], cwd=os.path.join(js_root, "web")) # upgrade version for onnxruntime-react-native - run(['npm', 'version', version], cwd=os.path.join(js_root, 'react_native')) - run(['yarn', 'upgrade', 'onnxruntime-common'], cwd=os.path.join(js_root, 'react_native')) + run(["npm", "version", version], cwd=os.path.join(js_root, "react_native")) + run(["yarn", "upgrade", "onnxruntime-common"], cwd=os.path.join(js_root, "react_native")) if __name__ == "__main__": diff --git a/tools/python/util/__init__.py b/tools/python/util/__init__.py index 5dcc1bdd9fac4..92209c863f147 100644 --- a/tools/python/util/__init__.py +++ b/tools/python/util/__init__.py @@ -3,17 +3,19 @@ from .get_azcopy import get_azcopy from .logger import get_logger -from .platform_helpers import (is_windows, is_macOS, is_linux) +from .platform_helpers import is_linux, is_macOS, is_windows from .run import run try: import flatbuffers # noqa + from .reduced_build_config_parser import parse_config except ImportError: - get_logger('tools_python_utils').info('flatbuffers module is not installed. parse_config will not be available') + get_logger("tools_python_utils").info("flatbuffers module is not installed. parse_config will not be available") # see if we can make the pytorch helpers available. import importlib.util # noqa + have_torch = importlib.util.find_spec("torch") if have_torch: from .pytorch_export_helpers import infer_input_info diff --git a/tools/python/util/__init__append.py b/tools/python/util/__init__append.py index f1d318335ac5d..266fde0a75ae9 100644 --- a/tools/python/util/__init__append.py +++ b/tools/python/util/__init__append.py @@ -1,5 +1,6 @@ # appended to the __init__.py in the onnxruntime module's 'tools' folder from /tools/python/util/__init__append.py import importlib.util + have_torch = importlib.util.find_spec("torch") if have_torch: from .pytorch_export_helpers import infer_input_info # noqa diff --git a/tools/python/util/android/__init__.py b/tools/python/util/android/__init__.py index c4174952fbbcc..915cda58f5321 100644 --- a/tools/python/util/android/__init__.py +++ b/tools/python/util/android/__init__.py @@ -1,7 +1,4 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. -from .android import ( - SdkToolPaths, get_sdk_tool_paths, - create_virtual_device, - start_emulator, stop_emulator) +from .android import SdkToolPaths, create_virtual_device, get_sdk_tool_paths, start_emulator, stop_emulator diff --git a/tools/python/util/android/android.py b/tools/python/util/android/android.py index ec7f8d99f4125..75d0f9cc5e2e2 100644 --- a/tools/python/util/android/android.py +++ b/tools/python/util/android/android.py @@ -11,15 +11,13 @@ import time import typing -from ..run import run from ..platform_helpers import is_windows - +from ..run import run _log = logging.getLogger("util.android") -SdkToolPaths = collections.namedtuple( - "SdkToolPaths", ["emulator", "adb", "sdkmanager", "avdmanager"]) +SdkToolPaths = collections.namedtuple("SdkToolPaths", ["emulator", "adb", "sdkmanager", "avdmanager"]) def get_sdk_tool_paths(sdk_root: str): @@ -41,34 +39,33 @@ def resolve_path(dirnames, basename): return None return SdkToolPaths( - emulator=resolve_path( - [os.path.join(sdk_root, "emulator")], - filename("emulator", "exe")), - adb=resolve_path( - [os.path.join(sdk_root, "platform-tools")], - filename("adb", "exe")), + emulator=resolve_path([os.path.join(sdk_root, "emulator")], filename("emulator", "exe")), + adb=resolve_path([os.path.join(sdk_root, "platform-tools")], filename("adb", "exe")), sdkmanager=resolve_path( - [os.path.join(sdk_root, "tools", "bin"), - os.path.join(sdk_root, "cmdline-tools", "tools", "bin")], - filename("sdkmanager", "bat")), + [os.path.join(sdk_root, "tools", "bin"), os.path.join(sdk_root, "cmdline-tools", "tools", "bin")], + filename("sdkmanager", "bat"), + ), avdmanager=resolve_path( - [os.path.join(sdk_root, "tools", "bin"), - os.path.join(sdk_root, "cmdline-tools", "tools", "bin")], - filename("avdmanager", "bat"))) - - -def create_virtual_device( - sdk_tool_paths: SdkToolPaths, - system_image_package_name: str, - avd_name: str): - run(sdk_tool_paths.sdkmanager, "--install", system_image_package_name, - input=b"y") - - run(sdk_tool_paths.avdmanager, "create", "avd", - "--name", avd_name, - "--package", system_image_package_name, + [os.path.join(sdk_root, "tools", "bin"), os.path.join(sdk_root, "cmdline-tools", "tools", "bin")], + filename("avdmanager", "bat"), + ), + ) + + +def create_virtual_device(sdk_tool_paths: SdkToolPaths, system_image_package_name: str, avd_name: str): + run(sdk_tool_paths.sdkmanager, "--install", system_image_package_name, input=b"y") + + run( + sdk_tool_paths.avdmanager, + "create", + "avd", + "--name", + avd_name, + "--package", + system_image_package_name, "--force", - input=b"no") + input=b"no", + ) _process_creationflags = subprocess.CREATE_NEW_PROCESS_GROUP if is_windows() else 0 @@ -100,30 +97,36 @@ def _stop_process_with_pid(pid: int): def start_emulator( - sdk_tool_paths: SdkToolPaths, - avd_name: str, - extra_args: typing.Optional[typing.Sequence[str]] = None) -> subprocess.Popen: - with contextlib.ExitStack() as emulator_stack, \ - contextlib.ExitStack() as waiter_stack: + sdk_tool_paths: SdkToolPaths, avd_name: str, extra_args: typing.Optional[typing.Sequence[str]] = None +) -> subprocess.Popen: + with contextlib.ExitStack() as emulator_stack, contextlib.ExitStack() as waiter_stack: emulator_args = [ - sdk_tool_paths.emulator, "-avd", avd_name, - "-memory", "4096", - "-timezone", "America/Los_Angeles", + sdk_tool_paths.emulator, + "-avd", + avd_name, + "-memory", + "4096", + "-timezone", + "America/Los_Angeles", "-no-snapshot", "-no-audio", "-no-boot-anim", - "-no-window"] + "-no-window", + ] if extra_args is not None: emulator_args += extra_args - emulator_process = emulator_stack.enter_context( - _start_process(*emulator_args)) + emulator_process = emulator_stack.enter_context(_start_process(*emulator_args)) emulator_stack.callback(_stop_process, emulator_process) waiter_process = waiter_stack.enter_context( _start_process( - sdk_tool_paths.adb, "wait-for-device", "shell", - "while [[ -z $(getprop sys.boot_completed) ]]; do sleep 1; done; input keyevent 82")) + sdk_tool_paths.adb, + "wait-for-device", + "shell", + "while [[ -z $(getprop sys.boot_completed) ]]; do sleep 1; done; input keyevent 82", + ) + ) waiter_stack.callback(_stop_process, waiter_process) # poll subprocesses diff --git a/tools/python/util/check_onnx_model_mobile_usability.py b/tools/python/util/check_onnx_model_mobile_usability.py index 7042d4cd2d018..f8095dd4d1bd4 100644 --- a/tools/python/util/check_onnx_model_mobile_usability.py +++ b/tools/python/util/check_onnx_model_mobile_usability.py @@ -6,35 +6,38 @@ import pathlib # need this before the mobile helper imports for some reason -logging.basicConfig(format='%(levelname)s: %(message)s') +logging.basicConfig(format="%(levelname)s: %(message)s") from .mobile_helpers import check_model_can_use_ort_mobile_pkg, usability_checker # noqa def check_usability(): parser = argparse.ArgumentParser( - description='''Analyze an ONNX model to determine how well it will work in mobile scenarios, and whether - it is likely to be able to use the pre-built ONNX Runtime Mobile Android or iOS package.''', - formatter_class=argparse.ArgumentDefaultsHelpFormatter) - - parser.add_argument('--config_path', - help='Path to required operators and types configuration used to build ' - 'the pre-built ORT mobile package.', - required=False, - type=pathlib.Path, - default=check_model_can_use_ort_mobile_pkg.get_default_config_path()) - parser.add_argument('--log_level', choices=['debug', 'info', 'warning', 'error'], - default='info', help='Logging level') - parser.add_argument('model_path', help='Path to ONNX model to check', type=pathlib.Path) + description="""Analyze an ONNX model to determine how well it will work in mobile scenarios, and whether + it is likely to be able to use the pre-built ONNX Runtime Mobile Android or iOS package.""", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + + parser.add_argument( + "--config_path", + help="Path to required operators and types configuration used to build " "the pre-built ORT mobile package.", + required=False, + type=pathlib.Path, + default=check_model_can_use_ort_mobile_pkg.get_default_config_path(), + ) + parser.add_argument( + "--log_level", choices=["debug", "info", "warning", "error"], default="info", help="Logging level" + ) + parser.add_argument("model_path", help="Path to ONNX model to check", type=pathlib.Path) args = parser.parse_args() - logger = logging.getLogger('check_usability') + logger = logging.getLogger("check_usability") - if args.log_level == 'debug': + if args.log_level == "debug": logger.setLevel(logging.DEBUG) - elif args.log_level == 'info': + elif args.log_level == "info": logger.setLevel(logging.INFO) - elif args.log_level == 'warning': + elif args.log_level == "warning": logger.setLevel(logging.WARNING) else: logger.setLevel(logging.ERROR) @@ -42,19 +45,23 @@ def check_usability(): try_eps = usability_checker.analyze_model(args.model_path, skip_optimize=False, logger=logger) check_model_can_use_ort_mobile_pkg.run_check(args.model_path, args.config_path, logger) - logger.info("Run `python -m onnxruntime.tools.convert_onnx_models_to_ort ...` to convert the ONNX model to ORT " - "format. " - "By default, the conversion tool will create an ORT format model with saved optimizations which can " - "potentially be applied at runtime (with a .with_runtime_opt.ort file extension) for use with NNAPI " - "or CoreML, and a fully optimized ORT format model (with a .ort file extension) for use with the CPU " - "EP.") + logger.info( + "Run `python -m onnxruntime.tools.convert_onnx_models_to_ort ...` to convert the ONNX model to ORT " + "format. " + "By default, the conversion tool will create an ORT format model with saved optimizations which can " + "potentially be applied at runtime (with a .with_runtime_opt.ort file extension) for use with NNAPI " + "or CoreML, and a fully optimized ORT format model (with a .ort file extension) for use with the CPU " + "EP." + ) if try_eps: - logger.info("As NNAPI or CoreML may provide benefits with this model it is recommended to compare the " - "performance of the .with_runtime_opt.ort model using the NNAPI EP on Android, and the " - "CoreML EP on iOS, against the performance of the .ort model using the CPU EP.") + logger.info( + "As NNAPI or CoreML may provide benefits with this model it is recommended to compare the " + "performance of the .with_runtime_opt.ort model using the NNAPI EP on Android, and the " + "CoreML EP on iOS, against the performance of the .ort model using the CPU EP." + ) else: logger.info("For optimal performance the .ort model should be used with the CPU EP. ") -if __name__ == '__main__': +if __name__ == "__main__": check_usability() diff --git a/tools/python/util/convert_onnx_models_to_ort.py b/tools/python/util/convert_onnx_models_to_ort.py index 8e1776ebcadaa..073d3d02ae425 100644 --- a/tools/python/util/convert_onnx_models_to_ort.py +++ b/tools/python/util/convert_onnx_models_to_ort.py @@ -11,6 +11,7 @@ import typing import onnxruntime as ort + from .file_utils import files_from_file_or_dir, path_match_suffix_ignore_case from .onnx_model_utils import get_optimization_level from .ort_format_model import create_config_from_models @@ -22,26 +23,34 @@ class OptimizationStyle(enum.Enum): def _optimization_suffix(optimization_level_str: str, optimization_style: OptimizationStyle, suffix: str): - return "{}{}{}".format(f".{optimization_level_str}" if optimization_level_str != "all" else "", - ".with_runtime_opt" if optimization_style == OptimizationStyle.Runtime else "", - suffix) + return "{}{}{}".format( + f".{optimization_level_str}" if optimization_level_str != "all" else "", + ".with_runtime_opt" if optimization_style == OptimizationStyle.Runtime else "", + suffix, + ) -def _create_config_file_path(model_path_or_dir: pathlib.Path, - optimization_level_str: str, - optimization_style: OptimizationStyle, - enable_type_reduction: bool): - config_name = "{}{}".format('required_operators_and_types' if enable_type_reduction else 'required_operators', - _optimization_suffix(optimization_level_str, optimization_style, ".config")) +def _create_config_file_path( + model_path_or_dir: pathlib.Path, + optimization_level_str: str, + optimization_style: OptimizationStyle, + enable_type_reduction: bool, +): + config_name = "{}{}".format( + "required_operators_and_types" if enable_type_reduction else "required_operators", + _optimization_suffix(optimization_level_str, optimization_style, ".config"), + ) if model_path_or_dir.is_dir(): return model_path_or_dir / config_name return model_path_or_dir.with_suffix(f".{config_name}") -def _create_session_options(optimization_level: ort.GraphOptimizationLevel, - output_model_path: pathlib.Path, - custom_op_library: pathlib.Path, - session_options_config_entries: typing.Dict[str, str]): +def _create_session_options( + optimization_level: ort.GraphOptimizationLevel, + output_model_path: pathlib.Path, + custom_op_library: pathlib.Path, + session_options_config_entries: typing.Dict[str, str], +): so = ort.SessionOptions() so.optimized_model_filepath = str(output_model_path) so.graph_optimization_level = optimization_level @@ -55,11 +64,17 @@ def _create_session_options(optimization_level: ort.GraphOptimizationLevel, return so -def _convert(model_path_or_dir: pathlib.Path, output_dir: typing.Optional[pathlib.Path], - optimization_level_str: str, optimization_style: OptimizationStyle, - custom_op_library: pathlib.Path, create_optimized_onnx_model: bool, allow_conversion_failures: bool, - target_platform: str, session_options_config_entries: typing.Dict[str, str]) \ - -> typing.List[pathlib.Path]: +def _convert( + model_path_or_dir: pathlib.Path, + output_dir: typing.Optional[pathlib.Path], + optimization_level_str: str, + optimization_style: OptimizationStyle, + custom_op_library: pathlib.Path, + create_optimized_onnx_model: bool, + allow_conversion_failures: bool, + target_platform: str, + session_options_config_entries: typing.Dict[str, str], +) -> typing.List[pathlib.Path]: model_dir = model_path_or_dir if model_path_or_dir.is_dir() else model_path_or_dir.parent output_dir = output_dir or model_dir @@ -81,15 +96,15 @@ def is_model_file_to_convert(file_path: pathlib.Path): if len(models) == 0: raise ValueError("No model files were found in '{}'".format(model_path_or_dir)) - providers = ['CPUExecutionProvider'] + providers = ["CPUExecutionProvider"] # if the optimization level is 'all' we manually exclude the NCHWc transformer. It's not applicable to ARM # devices, and creates a device specific model which won't run on all hardware. # If someone really really really wants to run it they could manually create an optimized onnx model first, # or they could comment out this code. optimizer_filter = None - if optimization_level == ort.GraphOptimizationLevel.ORT_ENABLE_ALL and target_platform != 'amd64': - optimizer_filter = ['NchwcTransformer'] + if optimization_level == ort.GraphOptimizationLevel.ORT_ENABLE_ALL and target_platform != "amd64": + optimizer_filter = ["NchwcTransformer"] converted_models = [] @@ -101,7 +116,8 @@ def is_model_file_to_convert(file_path: pathlib.Path): (output_dir / relative_model_path).parent.mkdir(parents=True, exist_ok=True) ort_target_path = (output_dir / relative_model_path).with_suffix( - _optimization_suffix(optimization_level_str, optimization_style, ".ort")) + _optimization_suffix(optimization_level_str, optimization_style, ".ort") + ) if create_optimized_onnx_model: # Create an ONNX file with the same optimization level that will be used for the ORT format file. @@ -109,27 +125,32 @@ def is_model_file_to_convert(file_path: pathlib.Path): # If runtime optimizations are saved in the ORT format model, there may be some difference in the # graphs at runtime between the ORT format model and this saved ONNX model. optimized_target_path = (output_dir / relative_model_path).with_suffix( - _optimization_suffix(optimization_level_str, optimization_style, ".optimized.onnx")) - so = _create_session_options(optimization_level, optimized_target_path, custom_op_library, - session_options_config_entries) + _optimization_suffix(optimization_level_str, optimization_style, ".optimized.onnx") + ) + so = _create_session_options( + optimization_level, optimized_target_path, custom_op_library, session_options_config_entries + ) if optimization_style == OptimizationStyle.Runtime: # Limit the optimizations to those that can run in a model with runtime optimizations. - so.add_session_config_entry('optimization.minimal_build_optimizations', 'apply') + so.add_session_config_entry("optimization.minimal_build_optimizations", "apply") print("Saving optimized ONNX model {} to {}".format(model, optimized_target_path)) - _ = ort.InferenceSession(str(model), sess_options=so, providers=providers, - disabled_optimizers=optimizer_filter) + _ = ort.InferenceSession( + str(model), sess_options=so, providers=providers, disabled_optimizers=optimizer_filter + ) # Load ONNX model, optimize, and save to ORT format - so = _create_session_options(optimization_level, ort_target_path, custom_op_library, - session_options_config_entries) - so.add_session_config_entry('session.save_model_format', 'ORT') + so = _create_session_options( + optimization_level, ort_target_path, custom_op_library, session_options_config_entries + ) + so.add_session_config_entry("session.save_model_format", "ORT") if optimization_style == OptimizationStyle.Runtime: - so.add_session_config_entry('optimization.minimal_build_optimizations', 'save') + so.add_session_config_entry("optimization.minimal_build_optimizations", "save") print("Converting optimized ONNX model {} to ORT format model {}".format(model, ort_target_path)) - _ = ort.InferenceSession(str(model), sess_options=so, providers=providers, - disabled_optimizers=optimizer_filter) + _ = ort.InferenceSession( + str(model), sess_options=so, providers=providers, disabled_optimizers=optimizer_filter + ) converted_models.append(ort_target_path) @@ -150,55 +171,78 @@ def is_model_file_to_convert(file_path: pathlib.Path): def parse_args(): parser = argparse.ArgumentParser( os.path.basename(__file__), - description='''Convert the ONNX format model/s in the provided directory to ORT format models. + description="""Convert the ONNX format model/s in the provided directory to ORT format models. All files with a `.onnx` extension will be processed. For each one, an ORT format model will be created in the same directory. A configuration file will also be created containing the list of required operators for all converted models. This configuration file should be used as input to the minimal build via the `--include_ops_by_config` parameter. - ''' + """, + ) + + parser.add_argument( + "--optimization_style", + nargs="+", + default=[OptimizationStyle.Fixed.name, OptimizationStyle.Runtime.name], + choices=[e.name for e in OptimizationStyle], + help="Style of optimization to perform on the ORT format model. " + "Multiple values may be provided. The conversion will run once for each value. " + "The general guidance is to use models optimized with " + f"'{OptimizationStyle.Runtime.name}' style when using NNAPI or CoreML and " + f"'{OptimizationStyle.Fixed.name}' style otherwise. " + f"'{OptimizationStyle.Fixed.name}': Run optimizations directly before saving the ORT " + "format model. This bakes in any platform-specific optimizations. " + f"'{OptimizationStyle.Runtime.name}': Run basic optimizations directly and save certain " + "other optimizations to be applied at runtime if possible. This is useful when using a " + "compiling EP like NNAPI or CoreML that may run an unknown (at model conversion time) " + "number of nodes. The saved optimizations can further optimize nodes not assigned to the " + "compiling EP at runtime.", + ) + + parser.add_argument( + "--enable_type_reduction", + action="store_true", + help="Add operator specific type information to the configuration file to potentially reduce " + "the types supported by individual operator implementations.", + ) + + parser.add_argument( + "--custom_op_library", + type=pathlib.Path, + default=None, + help="Provide path to shared library containing custom operator kernels to register.", + ) + + parser.add_argument( + "--save_optimized_onnx_model", + action="store_true", + help="Save the optimized version of each ONNX model. " + "This will have the same level of optimizations applied as the ORT format model.", + ) + + parser.add_argument( + "--allow_conversion_failures", + action="store_true", + help="Whether to proceed after encountering model conversion failures.", + ) + + parser.add_argument( + "--target_platform", + type=str, + default=None, + choices=["arm", "amd64"], + help="Specify the target platform where the exported model will be used. " + "This parameter can be used to choose between platform-specific options, " + "such as QDQIsInt8Allowed(arm), NCHWc (amd64) and NHWC (arm/amd64) format, different " + "optimizer level options, etc.", ) - parser.add_argument('--optimization_style', - nargs='+', - default=[OptimizationStyle.Fixed.name, OptimizationStyle.Runtime.name], - choices=[e.name for e in OptimizationStyle], - help="Style of optimization to perform on the ORT format model. " - "Multiple values may be provided. The conversion will run once for each value. " - "The general guidance is to use models optimized with " - f"'{OptimizationStyle.Runtime.name}' style when using NNAPI or CoreML and " - f"'{OptimizationStyle.Fixed.name}' style otherwise. " - f"'{OptimizationStyle.Fixed.name}': Run optimizations directly before saving the ORT " - "format model. This bakes in any platform-specific optimizations. " - f"'{OptimizationStyle.Runtime.name}': Run basic optimizations directly and save certain " - "other optimizations to be applied at runtime if possible. This is useful when using a " - "compiling EP like NNAPI or CoreML that may run an unknown (at model conversion time) " - "number of nodes. The saved optimizations can further optimize nodes not assigned to the " - "compiling EP at runtime.") - - parser.add_argument('--enable_type_reduction', action='store_true', - help='Add operator specific type information to the configuration file to potentially reduce ' - 'the types supported by individual operator implementations.') - - parser.add_argument('--custom_op_library', type=pathlib.Path, default=None, - help='Provide path to shared library containing custom operator kernels to register.') - - parser.add_argument('--save_optimized_onnx_model', action='store_true', - help='Save the optimized version of each ONNX model. ' - 'This will have the same level of optimizations applied as the ORT format model.') - - parser.add_argument('--allow_conversion_failures', action='store_true', - help='Whether to proceed after encountering model conversion failures.') - - parser.add_argument('--target_platform', type=str, default=None, choices=['arm', 'amd64'], - help='Specify the target platform where the exported model will be used. ' - 'This parameter can be used to choose between platform-specific options, ' - 'such as QDQIsInt8Allowed(arm), NCHWc (amd64) and NHWC (arm/amd64) format, different ' - 'optimizer level options, etc.') - - parser.add_argument('model_path_or_dir', type=pathlib.Path, - help='Provide path to ONNX model or directory containing ONNX model/s to convert. ' - 'All files with a .onnx extension, including those in subdirectories, will be ' - 'processed.') + parser.add_argument( + "model_path_or_dir", + type=pathlib.Path, + help="Provide path to ONNX model or directory containing ONNX model/s to convert. " + "All files with a .onnx extension, including those in subdirectories, will be " + "processed.", + ) return parser.parse_args() @@ -221,23 +265,29 @@ def convert_onnx_models_to_ort(): session_options_config_entries = {} - if args.target_platform == 'arm': + if args.target_platform == "arm": session_options_config_entries["session.qdqisint8allowed"] = "1" else: session_options_config_entries["session.qdqisint8allowed"] = "0" for optimization_style in optimization_styles: - print("Converting models with optimization style '{}' and level '{}'".format( - optimization_style.name, optimization_level_str)) + print( + "Converting models with optimization style '{}' and level '{}'".format( + optimization_style.name, optimization_level_str + ) + ) converted_models = _convert( - model_path_or_dir=model_path_or_dir, output_dir=None, - optimization_level_str=optimization_level_str, optimization_style=optimization_style, + model_path_or_dir=model_path_or_dir, + output_dir=None, + optimization_level_str=optimization_level_str, + optimization_style=optimization_style, custom_op_library=custom_op_library, create_optimized_onnx_model=args.save_optimized_onnx_model, allow_conversion_failures=args.allow_conversion_failures, target_platform=args.target_platform, - session_options_config_entries=session_options_config_entries) + session_options_config_entries=session_options_config_entries, + ) with contextlib.ExitStack() as context_stack: if optimization_style == OptimizationStyle.Runtime: @@ -246,31 +296,42 @@ def convert_onnx_models_to_ort(): # without runtime optimizations to get a complete set of ops that may be needed for the config file. model_dir = model_path_or_dir if model_path_or_dir.is_dir() else model_path_or_dir.parent temp_output_dir = context_stack.enter_context( - tempfile.TemporaryDirectory(dir=model_dir, suffix=".without_runtime_opt")) + tempfile.TemporaryDirectory(dir=model_dir, suffix=".without_runtime_opt") + ) session_options_config_entries_for_second_conversion = session_options_config_entries.copy() # Limit the optimizations to those that can run in a model with runtime optimizations. session_options_config_entries_for_second_conversion[ - "optimization.minimal_build_optimizations"] = "apply" + "optimization.minimal_build_optimizations" + ] = "apply" - print("Converting models again without runtime optimizations to generate a complete config file. " - "These converted models are temporary and will be deleted.") + print( + "Converting models again without runtime optimizations to generate a complete config file. " + "These converted models are temporary and will be deleted." + ) converted_models += _convert( - model_path_or_dir=model_path_or_dir, output_dir=temp_output_dir, - optimization_level_str=optimization_level_str, optimization_style=OptimizationStyle.Fixed, + model_path_or_dir=model_path_or_dir, + output_dir=temp_output_dir, + optimization_level_str=optimization_level_str, + optimization_style=OptimizationStyle.Fixed, custom_op_library=custom_op_library, create_optimized_onnx_model=False, # not useful as they would be created in a temp directory allow_conversion_failures=args.allow_conversion_failures, target_platform=args.target_platform, - session_options_config_entries=session_options_config_entries_for_second_conversion) + session_options_config_entries=session_options_config_entries_for_second_conversion, + ) - print("Generating config file from ORT format models with optimization style '{}' and level '{}'".format( - optimization_style.name, optimization_level_str)) + print( + "Generating config file from ORT format models with optimization style '{}' and level '{}'".format( + optimization_style.name, optimization_level_str + ) + ) - config_file = _create_config_file_path(model_path_or_dir, optimization_level_str, optimization_style, - args.enable_type_reduction) + config_file = _create_config_file_path( + model_path_or_dir, optimization_level_str, optimization_style, args.enable_type_reduction + ) create_config_from_models(converted_models, config_file, args.enable_type_reduction) -if __name__ == '__main__': +if __name__ == "__main__": convert_onnx_models_to_ort() diff --git a/tools/python/util/file_utils.py b/tools/python/util/file_utils.py index 73505b73369bb..0373ac171144f 100644 --- a/tools/python/util/file_utils.py +++ b/tools/python/util/file_utils.py @@ -1,31 +1,31 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. +import os import pathlib import typing -import os def path_match_suffix_ignore_case(path: typing.Union[pathlib.Path, str], suffix: str) -> bool: - ''' + """ Returns whether `path` ends in `suffix`, ignoring case. - ''' + """ if not isinstance(path, str): path = str(path) return path.casefold().endswith(suffix.casefold()) -def files_from_file_or_dir(file_or_dir_path: typing.Union[pathlib.Path, str], - predicate: typing.Callable[[pathlib.Path], bool] = lambda _: True) \ - -> typing.List[pathlib.Path]: - ''' +def files_from_file_or_dir( + file_or_dir_path: typing.Union[pathlib.Path, str], predicate: typing.Callable[[pathlib.Path], bool] = lambda _: True +) -> typing.List[pathlib.Path]: + """ Gets the files in `file_or_dir_path` satisfying `predicate`. If `file_or_dir_path` is a file, the single file is considered. Otherwise, all files in the directory are considered. :param file_or_dir_path: Path to a file or directory. :param predicate: Predicate to determine if a file is included. :return: A list of files. - ''' + """ if not isinstance(file_or_dir_path, pathlib.Path): file_or_dir_path = pathlib.Path(file_or_dir_path) diff --git a/tools/python/util/get_azcopy.py b/tools/python/util/get_azcopy.py index d3cb71431e61c..15aa9ca67e2de 100644 --- a/tools/python/util/get_azcopy.py +++ b/tools/python/util/get_azcopy.py @@ -27,9 +27,7 @@ def _check_version(azcopy_path): - proc = subprocess.run( - [azcopy_path, "--version"], - stdout=subprocess.PIPE, universal_newlines=True) + proc = subprocess.run([azcopy_path, "--version"], stdout=subprocess.PIPE, universal_newlines=True) match = re.search(r"\d+(?:\.\d+)+", proc.stdout) if not match: @@ -62,12 +60,10 @@ def get_azcopy(local_azcopy_path="azcopy"): azcopy_path = shutil.which(local_azcopy_path) if azcopy_path is None or not _check_version(azcopy_path): - temp_dir = context_stack.enter_context( - tempfile.TemporaryDirectory()) + temp_dir = context_stack.enter_context(tempfile.TemporaryDirectory()) download_url = _AZCOPY_DOWNLOAD_URLS[platform.system()] - download_basename = urllib.parse.urlsplit( - download_url).path.rsplit("/", 1)[-1] + download_basename = urllib.parse.urlsplit(download_url).path.rsplit("/", 1)[-1] assert len(download_basename) > 0 downloaded_path = os.path.join(temp_dir, download_basename) diff --git a/tools/python/util/logger.py b/tools/python/util/logger.py index c15fad76e329e..9deb4475721ee 100644 --- a/tools/python/util/logger.py +++ b/tools/python/util/logger.py @@ -5,8 +5,6 @@ def get_logger(name): - logging.basicConfig( - format="%(asctime)s %(name)s [%(levelname)s] - %(message)s", - level=logging.DEBUG) + logging.basicConfig(format="%(asctime)s %(name)s [%(levelname)s] - %(message)s", level=logging.DEBUG) return logging.getLogger(name) diff --git a/tools/python/util/make_dynamic_shape_fixed.py b/tools/python/util/make_dynamic_shape_fixed.py index bc304bee400da..f4e09a8cc04a3 100644 --- a/tools/python/util/make_dynamic_shape_fixed.py +++ b/tools/python/util/make_dynamic_shape_fixed.py @@ -3,40 +3,55 @@ # Licensed under the MIT License. import argparse -import onnx import os import pathlib import sys -from .onnx_model_utils import make_dim_param_fixed, make_input_shape_fixed, fix_output_shapes +import onnx + +from .onnx_model_utils import fix_output_shapes, make_dim_param_fixed, make_input_shape_fixed def make_dynamic_shape_fixed_helper(): - parser = argparse.ArgumentParser(f'{os.path.basename(__file__)}:{make_dynamic_shape_fixed_helper.__name__}', - description=''' + parser = argparse.ArgumentParser( + f"{os.path.basename(__file__)}:{make_dynamic_shape_fixed_helper.__name__}", + description=""" Assign a fixed value to a dim_param or input shape - Provide either dim_param and dim_value or input_name and input_shape.''') + Provide either dim_param and dim_value or input_name and input_shape.""", + ) - parser.add_argument('--dim_param', type=str, required=False, - help="Symbolic parameter name. Provide dim_value if specified.") - parser.add_argument('--dim_value', type=int, required=False, - help="Value to replace dim_param with in the model. Must be > 0.") - parser.add_argument('--input_name', type=str, required=False, - help="Model input name to replace shape of. Provide input_shape if specified.") - parser.add_argument('--input_shape', type=lambda x: [int(i) for i in x.split(',')], required=False, - help="Shape to use for input_shape. Provide comma separated list for the shape. " - "All values must be > 0. e.g. --input_shape 1,3,256,256") + parser.add_argument( + "--dim_param", type=str, required=False, help="Symbolic parameter name. Provide dim_value if specified." + ) + parser.add_argument( + "--dim_value", type=int, required=False, help="Value to replace dim_param with in the model. Must be > 0." + ) + parser.add_argument( + "--input_name", + type=str, + required=False, + help="Model input name to replace shape of. Provide input_shape if specified.", + ) + parser.add_argument( + "--input_shape", + type=lambda x: [int(i) for i in x.split(",")], + required=False, + help="Shape to use for input_shape. Provide comma separated list for the shape. " + "All values must be > 0. e.g. --input_shape 1,3,256,256", + ) - parser.add_argument('input_model', type=pathlib.Path, help='Provide path to ONNX model to update.') - parser.add_argument('output_model', type=pathlib.Path, help='Provide path to write updated ONNX model to.') + parser.add_argument("input_model", type=pathlib.Path, help="Provide path to ONNX model to update.") + parser.add_argument("output_model", type=pathlib.Path, help="Provide path to write updated ONNX model to.") args = parser.parse_args() - if (args.dim_param and args.input_name) or \ - (not args.dim_param and not args.input_name) or \ - (args.dim_param and (not args.dim_value or args.dim_value < 1)) or \ - (args.input_name and (not args.input_shape or any([value < 1 for value in args.input_shape]))): - print('Invalid usage.') + if ( + (args.dim_param and args.input_name) + or (not args.dim_param and not args.input_name) + or (args.dim_param and (not args.dim_value or args.dim_value < 1)) + or (args.input_name and (not args.input_shape or any([value < 1 for value in args.input_shape]))) + ): + print("Invalid usage.") parser.print_help() sys.exit(-1) @@ -53,5 +68,5 @@ def make_dynamic_shape_fixed_helper(): onnx.save(model, str(args.output_model.resolve())) -if __name__ == '__main__': +if __name__ == "__main__": make_dynamic_shape_fixed_helper() diff --git a/tools/python/util/mobile_helpers/check_model_can_use_ort_mobile_pkg.py b/tools/python/util/mobile_helpers/check_model_can_use_ort_mobile_pkg.py index c72653ebc30cf..f58a25b64a505 100644 --- a/tools/python/util/mobile_helpers/check_model_can_use_ort_mobile_pkg.py +++ b/tools/python/util/mobile_helpers/check_model_can_use_ort_mobile_pkg.py @@ -6,38 +6,39 @@ import argparse import logging -import onnx import pathlib import sys + +import onnx from onnx import shape_inference + from ..onnx_model_utils import get_opsets_imported from ..reduced_build_config_parser import parse_config - cpp_to_tensorproto_type = { - 'float': 1, - 'uint8_t': 2, - 'int8_t': 3, - 'uint16_t': 4, - 'int16_t': 5, - 'int32_t': 6, - 'int64_t': 7, - 'std::string': 8, - 'bool': 9, - 'MLFloat16': 10, - 'double': 11, - 'uint32_t': 12, - 'uint64_t': 13, - 'Complex64': 14, # not supported by ORT - 'Complex128': 15, # not supported by ORT - 'BFloat16': 16 + "float": 1, + "uint8_t": 2, + "int8_t": 3, + "uint16_t": 4, + "int16_t": 5, + "int32_t": 6, + "int64_t": 7, + "std::string": 8, + "bool": 9, + "MLFloat16": 10, + "double": 11, + "uint32_t": 12, + "uint64_t": 13, + "Complex64": 14, # not supported by ORT + "Complex128": 15, # not supported by ORT + "BFloat16": 16, } tensorproto_type_to_cpp = {v: k for k, v in cpp_to_tensorproto_type.items()} def check_graph(graph, opsets, required_ops, global_types, special_types, unsupported_ops, logger): - ''' + """ Check the graph and any subgraphs for usage of types or operators which we know are not supported. :param graph: Graph to process. :param opsets: Map of domain to opset version that the model imports. @@ -49,28 +50,28 @@ def check_graph(graph, opsets, required_ops, global_types, special_types, unsupp :param unsupported_ops: Set of unsupported operators that were found. :param logger: Logger for diagnostic output. :return: Returns whether the graph uses unsupported operators or types. - ''' + """ has_unsupported_types = False value_info_map = {vi.name: vi for vi in graph.value_info} def _is_type_supported(value_info, description): is_supported = True - type_name = value_info.type.WhichOneof('value') - if type_name == 'tensor_type': + type_name = value_info.type.WhichOneof("value") + if type_name == "tensor_type": t = value_info.type.tensor_type.elem_type if t not in global_types and t not in special_types: cpp_type = tensorproto_type_to_cpp[t] - logger.debug(f'Element type {cpp_type} of {description} is not supported.') + logger.debug(f"Element type {cpp_type} of {description} is not supported.") is_supported = False else: # we don't support sequences, map, sparse tensors, or optional types in the pre-built package - logger.debug(f'Data type {type_name} of {description} is not supported.') + logger.debug(f"Data type {type_name} of {description} is not supported.") is_supported = False return is_supported def _input_output_is_supported(value_info, input_output): - return _is_type_supported(value_info, f'graph {input_output} {value_info.name}') + return _is_type_supported(value_info, f"graph {input_output} {value_info.name}") # node outputs are simpler to check. # node inputs have a much wider mix of types, some of which come from initializers and most likely are always @@ -80,7 +81,7 @@ def _node_output_is_supported(name): is_supported = True if name in value_info_map: vi = value_info_map[name] - is_supported = _is_type_supported(vi, f'node output {name}') + is_supported = _is_type_supported(vi, f"node output {name}") else: # we don't have type info so ignore pass @@ -88,28 +89,30 @@ def _node_output_is_supported(name): return is_supported for i in graph.input: - if not _input_output_is_supported(i, 'input'): + if not _input_output_is_supported(i, "input"): has_unsupported_types = True for o in graph.output: - if not _input_output_is_supported(o, 'output'): + if not _input_output_is_supported(o, "output"): has_unsupported_types = True for node in graph.node: # required_ops are map of [domain][opset] to set of op_type names. '' == ai.onnx - domain = node.domain or 'ai.onnx' + domain = node.domain or "ai.onnx" # special case Constant as we will convert to an initializer during model load - if domain == 'ai.onnx' and node.op_type == 'Constant': + if domain == "ai.onnx" and node.op_type == "Constant": continue # some models don't have complete imports. use 1 as a default as that's valid for custom domains and should # result in an error for any others. not sure why ONNX or ORT validation allows this though. opset = opsets[domain] if domain in opsets else 1 - if domain not in required_ops or \ - opset not in required_ops[domain] or \ - node.op_type not in required_ops[domain][opset]: - unsupported_ops.add(f'{domain}:{opset}:{node.op_type}') + if ( + domain not in required_ops + or opset not in required_ops[domain] + or node.op_type not in required_ops[domain][opset] + ): + unsupported_ops.add(f"{domain}:{opset}:{node.op_type}") for output_name in node.output: if not _node_output_is_supported(output_name): @@ -117,14 +120,14 @@ def _node_output_is_supported(name): # recurse into subgraph for control flow nodes (Scan/Loop/If) for attr in node.attribute: - if attr.HasField('g'): + if attr.HasField("g"): check_graph(attr.g, opsets, required_ops, global_types, special_types, unsupported_ops, logger) return has_unsupported_types or unsupported_ops def _get_global_tensorproto_types(op_type_impl_filter, logger: logging.Logger): - ''' + """ Map the globally supported types (C++) to onnx.TensorProto.DataType values used in the model See https://github.com/onnx/onnx/blob/1faae95520649c93ae8d0b403816938a190f4fa7/onnx/onnx.proto#L485 @@ -132,7 +135,7 @@ def _get_global_tensorproto_types(op_type_impl_filter, logger: logging.Logger): :param op_type_impl_filter: type filter from reduced build configuration parser :param logger: Logger :return: tuple of globally enabled types and special cased types - ''' + """ global_cpp_types = op_type_impl_filter.global_type_list() global_onnx_tensorproto_types = set() @@ -148,9 +151,11 @@ def _get_global_tensorproto_types(op_type_impl_filter, logger: logging.Logger): # additionally we have a number of operators (e.g. Not, Where) that always require the use of bool. # this _may_ mean values involving these types can be processed, but without adding a lot more code we don't know # for sure. - special_types = [cpp_to_tensorproto_type['int32_t'], - cpp_to_tensorproto_type['int64_t'], - cpp_to_tensorproto_type['bool']] + special_types = [ + cpp_to_tensorproto_type["int32_t"], + cpp_to_tensorproto_type["int64_t"], + cpp_to_tensorproto_type["bool"], + ] return global_onnx_tensorproto_types, special_types @@ -158,7 +163,7 @@ def _get_global_tensorproto_types(op_type_impl_filter, logger: logging.Logger): def get_default_config_path(): # get default path to config that was used to create the pre-built package. script_dir = pathlib.Path(__file__).parent - local_config = script_dir / 'mobile_package.required_operators.config' + local_config = script_dir / "mobile_package.required_operators.config" # if we're running in the ORT python package the file should be local. otherwise assume we're running from the # ORT repo @@ -166,22 +171,23 @@ def get_default_config_path(): default_config_path = local_config else: ort_root = script_dir.parents[3] - default_config_path = \ - ort_root / 'tools' / 'ci_build' / 'github' / 'android' / 'mobile_package.required_operators.config' + default_config_path = ( + ort_root / "tools" / "ci_build" / "github" / "android" / "mobile_package.required_operators.config" + ) return default_config_path -def run_check_with_model(model_with_type_info: onnx.ModelProto, - mobile_pkg_build_config: pathlib.Path, - logger: logging.Logger): - ''' +def run_check_with_model( + model_with_type_info: onnx.ModelProto, mobile_pkg_build_config: pathlib.Path, logger: logging.Logger +): + """ Check if an ONNX model can be used with the ORT Mobile pre-built package. :param model_with_type_info: ONNX model that has had ONNX shape inferencing run on to add type/shape information. :param mobile_pkg_build_config: Configuration file used to build the ORT Mobile package. :param logger: Logger for output :return: True if supported - ''' + """ if not mobile_pkg_build_config: mobile_pkg_build_config = get_default_config_path() @@ -194,56 +200,69 @@ def run_check_with_model(model_with_type_info: onnx.ModelProto, opsets = get_opsets_imported(model_with_type_info) # If the ONNX opset of the model is not supported we can recommend using our tools to update that first. - supported_onnx_opsets = set(required_ops['ai.onnx'].keys()) + supported_onnx_opsets = set(required_ops["ai.onnx"].keys()) # we have a contrib op that is erroneously in the ai.onnx domain with opset 1. manually remove that incorrect value supported_onnx_opsets.remove(1) - onnx_opset_model_uses = opsets['ai.onnx'] + onnx_opset_model_uses = opsets["ai.onnx"] if onnx_opset_model_uses not in supported_onnx_opsets: - logger.info(f'Model uses ONNX opset {onnx_opset_model_uses}.') - logger.info(f'The pre-built package only supports ONNX opsets {sorted(supported_onnx_opsets)}.') - logger.info('Please try updating the ONNX model opset to a supported version using ' - 'python -m onnxruntime.tools.onnx_model_utils.update_onnx_opset ...') + logger.info(f"Model uses ONNX opset {onnx_opset_model_uses}.") + logger.info(f"The pre-built package only supports ONNX opsets {sorted(supported_onnx_opsets)}.") + logger.info( + "Please try updating the ONNX model opset to a supported version using " + "python -m onnxruntime.tools.onnx_model_utils.update_onnx_opset ..." + ) return False unsupported_ops = set() - logger.debug('Checking if the data types and operators used in the model are supported ' - 'in the pre-built ORT package...') - unsupported = check_graph(model_with_type_info.graph, opsets, required_ops, - global_onnx_tensorproto_types, special_types, - unsupported_ops, logger) + logger.debug( + "Checking if the data types and operators used in the model are supported " "in the pre-built ORT package..." + ) + unsupported = check_graph( + model_with_type_info.graph, + opsets, + required_ops, + global_onnx_tensorproto_types, + special_types, + unsupported_ops, + logger, + ) if unsupported_ops: - logger.info('Unsupported operators:') + logger.info("Unsupported operators:") for entry in sorted(unsupported_ops): - logger.info(' ' + entry) + logger.info(" " + entry) if unsupported: - logger.info('\nModel is not supported by the pre-built package due to unsupported types and/or operators.') - logger.info('Please see https://onnxruntime.ai/docs/reference/mobile/prebuilt-package/ for information ' - 'on what is supported in the pre-built package.') - logger.info('A custom build of ONNX Runtime will be required to run the model. Please see ' - 'https://onnxruntime.ai/docs/build/custom.html for details on performing that.') + logger.info("\nModel is not supported by the pre-built package due to unsupported types and/or operators.") + logger.info( + "Please see https://onnxruntime.ai/docs/reference/mobile/prebuilt-package/ for information " + "on what is supported in the pre-built package." + ) + logger.info( + "A custom build of ONNX Runtime will be required to run the model. Please see " + "https://onnxruntime.ai/docs/build/custom.html for details on performing that." + ) else: - logger.info('Model should work with the pre-built package.') + logger.info("Model should work with the pre-built package.") - logger.info('---------------\n') + logger.info("---------------\n") return not unsupported -def run_check(model_path: pathlib.Path, - mobile_pkg_build_config: pathlib.Path, - logger: logging.Logger): - ''' +def run_check(model_path: pathlib.Path, mobile_pkg_build_config: pathlib.Path, logger: logging.Logger): + """ Check if an ONNX model will be able to be used with the ORT Mobile pre-built package. :param model_path: Path to ONNX model. :param mobile_pkg_build_config: Configuration file used to build the ORT Mobile package. :param logger: Logger for output :return: True if supported - ''' - logger.info(f'Checking if pre-built ORT Mobile package can be used with {model_path} once model is ' - 'converted from ONNX to ORT format using onnxruntime.tools.convert_onnx_models_to_ort...') + """ + logger.info( + f"Checking if pre-built ORT Mobile package can be used with {model_path} once model is " + "converted from ONNX to ORT format using onnxruntime.tools.convert_onnx_models_to_ort..." + ) model_file = model_path.resolve(strict=True) model = onnx.load(str(model_file)) @@ -259,23 +278,26 @@ def run_check(model_path: pathlib.Path, def main(): parser = argparse.ArgumentParser( - description='Check if model can be run using the ONNX Runtime Mobile Pre-Built Package', - formatter_class=argparse.ArgumentDefaultsHelpFormatter) + description="Check if model can be run using the ONNX Runtime Mobile Pre-Built Package", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) - parser.add_argument('--config_path', - help='Path to required operators and types configuration used to build ' - 'the pre-built ORT mobile package.', - required=False, - type=pathlib.Path, default=get_default_config_path()) + parser.add_argument( + "--config_path", + help="Path to required operators and types configuration used to build " "the pre-built ORT mobile package.", + required=False, + type=pathlib.Path, + default=get_default_config_path(), + ) - parser.add_argument('model_path', help='Path to ONNX model to check', type=pathlib.Path) + parser.add_argument("model_path", help="Path to ONNX model to check", type=pathlib.Path) args = parser.parse_args() - logger = logging.getLogger('default') + logger = logging.getLogger("default") logger.setLevel(logging.INFO) run_check(args.model_path, args.config_path, logger) -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/tools/python/util/mobile_helpers/test/test_check_model_can_use_ort_mobile_pkg.py b/tools/python/util/mobile_helpers/test/test_check_model_can_use_ort_mobile_pkg.py index 6c6d96d3022d2..98a8cd5b9d6c9 100644 --- a/tools/python/util/mobile_helpers/test/test_check_model_can_use_ort_mobile_pkg.py +++ b/tools/python/util/mobile_helpers/test/test_check_model_can_use_ort_mobile_pkg.py @@ -2,11 +2,12 @@ # Licensed under the MIT License. import logging -import onnx import pathlib import unittest +import onnx from testfixtures import LogCapture + from ..check_model_can_use_ort_mobile_pkg import run_check, run_check_with_model # example usage from /tools/python @@ -16,12 +17,13 @@ script_dir = pathlib.Path(__file__).parent ort_root = script_dir.parents[4] -ort_package_build_config_filename = \ - ort_root / 'tools' / 'ci_build' / 'github' / 'android' / 'mobile_package.required_operators.config' +ort_package_build_config_filename = ( + ort_root / "tools" / "ci_build" / "github" / "android" / "mobile_package.required_operators.config" +) def _create_logger(): - logger = logging.getLogger('default') + logger = logging.getLogger("default") logger.setLevel(logging.DEBUG) return logger @@ -30,32 +32,32 @@ class TestMobilePackageModelChecker(unittest.TestCase): def test_supported_model(self): with LogCapture() as log_capture: logger = _create_logger() - model_path = ort_root / 'onnxruntime' / 'test' / 'testdata' / 'ort_github_issue_4031.onnx' + model_path = ort_root / "onnxruntime" / "test" / "testdata" / "ort_github_issue_4031.onnx" supported = run_check(model_path, ort_package_build_config_filename, logger) self.assertTrue(supported) # print(log_capture) log_capture.check_present( - ('default', 'INFO', 'Model should work with the pre-built package.'), + ("default", "INFO", "Model should work with the pre-built package."), ) def test_model_invalid_opset(self): with LogCapture() as log_capture: logger = _create_logger() - model_path = ort_root / 'onnxruntime' / 'test' / 'testdata' / 'mnist.onnx' + model_path = ort_root / "onnxruntime" / "test" / "testdata" / "mnist.onnx" supported = run_check(model_path, ort_package_build_config_filename, logger) self.assertFalse(supported) # print(log_capture) log_capture.check_present( - ('default', 'INFO', 'Model uses ONNX opset 8.'), - ('default', 'INFO', 'The pre-built package only supports ONNX opsets [12, 13, 14, 15].') + ("default", "INFO", "Model uses ONNX opset 8."), + ("default", "INFO", "The pre-built package only supports ONNX opsets [12, 13, 14, 15]."), ) def test_model_unsupported_op_and_types(self): with LogCapture() as log_capture: logger = _create_logger() - model_path = ort_root / 'onnxruntime' / 'test' / 'testdata' / 'sequence_insert.onnx' + model_path = ort_root / "onnxruntime" / "test" / "testdata" / "sequence_insert.onnx" # Model uses opset 11 which is not supported in the mobile package. Update to supported opset first # Note: Ideally this would use update_onnx_opset however the ONNX opset update tools isn't working with @@ -73,7 +75,7 @@ def test_model_unsupported_op_and_types(self): # print(log_capture) log_capture.check_present( - ('default', 'DEBUG', 'Data type sequence_type of graph input input_seq is not supported.'), - ('default', 'INFO', 'Unsupported operators:'), - ('default', 'INFO', ' ai.onnx:13:SequenceInsert'), + ("default", "DEBUG", "Data type sequence_type of graph input input_seq is not supported."), + ("default", "INFO", "Unsupported operators:"), + ("default", "INFO", " ai.onnx:13:SequenceInsert"), ) diff --git a/tools/python/util/mobile_helpers/test/test_usability_checker.py b/tools/python/util/mobile_helpers/test/test_usability_checker.py index dafeab389fa94..230f326207c2f 100644 --- a/tools/python/util/mobile_helpers/test/test_usability_checker.py +++ b/tools/python/util/mobile_helpers/test/test_usability_checker.py @@ -3,8 +3,10 @@ import logging import pathlib -from testfixtures import LogCapture import unittest + +from testfixtures import LogCapture + from ..usability_checker import analyze_model # example usage from /tools/python @@ -18,95 +20,103 @@ def _create_logger(): - logger = logging.getLogger('default') + logger = logging.getLogger("default") logger.setLevel(logging.DEBUG) return logger class TestAnalyzer(unittest.TestCase): def test_mnist(self): - ''' + """ Test MNIST which should be fully covered by both NNAPI and CoreML as is. :return: - ''' + """ with LogCapture() as log_capture: logger = _create_logger() - model_path = ort_root / 'onnxruntime' / 'test' / 'testdata' / 'mnist.onnx' + model_path = ort_root / "onnxruntime" / "test" / "testdata" / "mnist.onnx" analyze_model(model_path, skip_optimize, logger) # print(log_capture) log_capture.check_present( - ('default', 'INFO', '1 partitions with a total of 8/8 nodes can be handled by the NNAPI EP.'), - ('default', 'INFO', 'Model should perform well with NNAPI as is: YES'), - ('default', 'INFO', '1 partitions with a total of 8/8 nodes can be handled by the CoreML EP.'), - ('default', 'INFO', 'Model should perform well with CoreML as is: YES'), + ("default", "INFO", "1 partitions with a total of 8/8 nodes can be handled by the NNAPI EP."), + ("default", "INFO", "Model should perform well with NNAPI as is: YES"), + ("default", "INFO", "1 partitions with a total of 8/8 nodes can be handled by the CoreML EP."), + ("default", "INFO", "Model should perform well with CoreML as is: YES"), ) def test_scan_model(self): - ''' + """ Test a Speech model where all the top level nodes are Scan. All the real operators are in subgraphs, so we don't use NNAPI/CoreML currently. We want to make sure nodes in subgraphs are counted. - ''' + """ with LogCapture() as log_capture: logger = _create_logger() # mnist - should have perfect coverage - model_path = ort_root / 'onnxruntime' / 'test' / 'testdata' / 'scan_1.onnx' + model_path = ort_root / "onnxruntime" / "test" / "testdata" / "scan_1.onnx" analyze_model(model_path, skip_optimize, logger) # print(log_capture) log_capture.check_present( - ('default', 'INFO', '0 partitions with a total of 0/76 nodes can be handled by the NNAPI EP.'), - ('default', 'INFO', '72 nodes are in subgraphs, which are currently not handled.'), - ('default', 'INFO', 'Unsupported ops: ai.onnx:Scan'), - ('default', 'INFO', 'Model should perform well with NNAPI as is: NO'), - ('default', 'INFO', '0 partitions with a total of 0/76 nodes can be handled by the CoreML EP.'), - ('default', 'INFO', 'Model should perform well with CoreML as is: NO') + ("default", "INFO", "0 partitions with a total of 0/76 nodes can be handled by the NNAPI EP."), + ("default", "INFO", "72 nodes are in subgraphs, which are currently not handled."), + ("default", "INFO", "Unsupported ops: ai.onnx:Scan"), + ("default", "INFO", "Model should perform well with NNAPI as is: NO"), + ("default", "INFO", "0 partitions with a total of 0/76 nodes can be handled by the CoreML EP."), + ("default", "INFO", "Model should perform well with CoreML as is: NO"), ) def test_dynamic_shape(self): - ''' + """ Test a model with dynamic input shape and supported op. If we make the shape fixed it should report it will run well with NNAPI/CoreML. - ''' + """ with LogCapture() as log_capture: logger = _create_logger() - model_path = ort_root / 'onnxruntime' / 'test' / 'testdata' / 'abs_free_dimensions.onnx' + model_path = ort_root / "onnxruntime" / "test" / "testdata" / "abs_free_dimensions.onnx" analyze_model(model_path, skip_optimize, logger) # print(log_capture) log_capture.check_present( - ('default', 'INFO', '0 partitions with a total of 0/1 nodes can be handled by the NNAPI EP.'), - ('default', 'INFO', 'Model should perform well with NNAPI as is: NO'), - ('default', 'INFO', 'Model should perform well with NNAPI if modified to have fixed input shapes: YES'), - ('default', 'INFO', '0 partitions with a total of 0/1 nodes can be handled by the CoreML EP.'), - ('default', 'INFO', 'CoreML cannot run any nodes in this model.'), - ('default', 'INFO', 'Model should perform well with CoreML as is: NO'), - ('default', 'INFO', 'Model should perform well with CoreML if modified to have fixed input shapes: NO') + ("default", "INFO", "0 partitions with a total of 0/1 nodes can be handled by the NNAPI EP."), + ("default", "INFO", "Model should perform well with NNAPI as is: NO"), + ("default", "INFO", "Model should perform well with NNAPI if modified to have fixed input shapes: YES"), + ("default", "INFO", "0 partitions with a total of 0/1 nodes can be handled by the CoreML EP."), + ("default", "INFO", "CoreML cannot run any nodes in this model."), + ("default", "INFO", "Model should perform well with CoreML as is: NO"), + ("default", "INFO", "Model should perform well with CoreML if modified to have fixed input shapes: NO"), ) def test_multi_partitions(self): - ''' + """ Test a model that breaks into too many partitions to be recommended for use with NNAPI/CoreML - ''' + """ with LogCapture() as log_capture: logger = _create_logger() - model_path = ort_root / 'onnxruntime' / 'test' / 'testdata' / 'gh_issue_9671.onnx' + model_path = ort_root / "onnxruntime" / "test" / "testdata" / "gh_issue_9671.onnx" analyze_model(model_path, skip_optimize, logger) # print(log_capture) log_capture.check_present( - ('default', 'INFO', '3 partitions with a total of 17/46 nodes can be handled by the NNAPI EP.'), - ('default', 'INFO', 'Partition sizes: [5, 4, 8]'), - ('default', 'INFO', 'Unsupported ops: ai.onnx:Gather,ai.onnx:ReduceProd,ai.onnx:ReduceSum,' - 'ai.onnx:Shape,ai.onnx:Unsqueeze'), - ('default', 'INFO', 'NNAPI is not recommended with this model as there are 3 partitions ' - 'covering 37.0% of the nodes in the model. ' - 'This will most likely result in worse performance than just using the CPU EP.'), - ('default', 'INFO', 'Model should perform well with NNAPI as is: NO'), - ('default', 'INFO', 'Partition information if the model was updated to make the shapes fixed:'), - ('default', 'INFO', '3 partitions with a total of 23/46 nodes can be handled by the NNAPI EP.'), - ('default', 'INFO', 'Partition sizes: [3, 12, 8]'), - ('default', 'INFO', '3 partitions with a total of 15/46 nodes can be handled by the CoreML EP.'), - ('default', 'INFO', 'Partition sizes: [4, 4, 7]'), - ('default', 'INFO', 'Model should perform well with CoreML as is: NO') + ("default", "INFO", "3 partitions with a total of 17/46 nodes can be handled by the NNAPI EP."), + ("default", "INFO", "Partition sizes: [5, 4, 8]"), + ( + "default", + "INFO", + "Unsupported ops: ai.onnx:Gather,ai.onnx:ReduceProd,ai.onnx:ReduceSum," + "ai.onnx:Shape,ai.onnx:Unsqueeze", + ), + ( + "default", + "INFO", + "NNAPI is not recommended with this model as there are 3 partitions " + "covering 37.0% of the nodes in the model. " + "This will most likely result in worse performance than just using the CPU EP.", + ), + ("default", "INFO", "Model should perform well with NNAPI as is: NO"), + ("default", "INFO", "Partition information if the model was updated to make the shapes fixed:"), + ("default", "INFO", "3 partitions with a total of 23/46 nodes can be handled by the NNAPI EP."), + ("default", "INFO", "Partition sizes: [3, 12, 8]"), + ("default", "INFO", "3 partitions with a total of 15/46 nodes can be handled by the CoreML EP."), + ("default", "INFO", "Partition sizes: [4, 4, 7]"), + ("default", "INFO", "Model should perform well with CoreML as is: NO"), ) diff --git a/tools/python/util/mobile_helpers/usability_checker.py b/tools/python/util/mobile_helpers/usability_checker.py index 56e02901d4b86..f286544fa510e 100644 --- a/tools/python/util/mobile_helpers/usability_checker.py +++ b/tools/python/util/mobile_helpers/usability_checker.py @@ -3,47 +3,53 @@ import argparse import logging -import onnx import os import pathlib import tempfile - from collections import deque from enum import IntEnum -from ..onnx_model_utils import get_producer_consumer_maps, optimize_model, \ - iterate_graph_per_graph_func, iterate_graph_per_node_func, is_fixed_size_tensor + +import onnx + +from ..onnx_model_utils import ( + get_producer_consumer_maps, + is_fixed_size_tensor, + iterate_graph_per_graph_func, + iterate_graph_per_node_func, + optimize_model, +) class _SupportedOpsChecker: - ''' + """ Class to process the md file with list of supported ops and caveats for an execution provider. e.g. /tools/ci_build/github/android/nnapi_supported_ops.md /tools/ci_build/github/apple/coreml_supported_ops.md - ''' + """ def __init__(self, filename): self._filename = filename self._ops = {} # op to caveats self._ops_seen = set() - with open(filename, 'r') as f: + with open(filename, "r") as f: for line in f.readlines(): # we're looking for a markdown table with 2 columns. first is op name. second is caveats # op name is domain:op - if line.startswith('|'): - pieces = line.strip().split('|') + if line.startswith("|"): + pieces = line.strip().split("|") if len(pieces) == 4: # pre-first '|'. op, caveat, post-last '|' domain_op = pieces[1] caveat = pieces[2] - caveat = caveat.replace('
', ' ') # remove some HTML tags + caveat = caveat.replace("
", " ") # remove some HTML tags # skip lines that don't have the ':' which separates the domain and op # e.g. the table header will fail this check - if ':' in domain_op: + if ":" in domain_op: self._ops[domain_op] = caveat def is_op_supported(self, node): - domain = node.domain if node.domain else 'ai.onnx' - domain_op = domain + ':' + node.op_type + domain = node.domain if node.domain else "ai.onnx" + domain_op = domain + ":" + node.op_type is_supported = domain_op in self._ops if is_supported: @@ -56,15 +62,15 @@ def get_caveats(self): for op in sorted(self._ops_seen): caveat = self._ops[op] if caveat: - caveats.append(f'{op}:{caveat}') + caveats.append(f"{op}:{caveat}") return caveats class PartitioningInfo: class TryWithEP(IntEnum): - NO = 0, - MAYBE = 1, + NO = (0,) + MAYBE = (1,) YES = 2 def __init__(self): @@ -107,24 +113,28 @@ def suitability(self): return PartitioningInfo.TryWithEP.NO def dump_analysis(self, logger: logging.Logger, ep_name: str): - ''' + """ Analyze the partitioning information and log the analysis :param logger: Logger to use :param ep_name: Execution provider name to use in the log messages - ''' + """ num_nodes = self.num_nodes + self.num_nodes_in_subgraphs - logger.info(f'{self.num_partitions} partitions with a total of {self.num_supported_nodes}/{num_nodes} ' - f'nodes can be handled by the {ep_name} EP.') + logger.info( + f"{self.num_partitions} partitions with a total of {self.num_supported_nodes}/{num_nodes} " + f"nodes can be handled by the {ep_name} EP." + ) if self.num_nodes_in_subgraphs: - logger.info(f'{self.num_nodes_in_subgraphs} nodes are in subgraphs, which are currently not handled.') + logger.info(f"{self.num_nodes_in_subgraphs} nodes are in subgraphs, which are currently not handled.") if self.supported_groups: logger.info(f'Partition sizes: [{", ".join([str(len(partition)) for partition in self.supported_groups])}]') - logger.info(f'Unsupported nodes due to operator={self.nodes_unsupported_due_to_op}') + logger.info(f"Unsupported nodes due to operator={self.nodes_unsupported_due_to_op}") if self.nodes_unsupported_due_to_dynamic_input: - logger.info('Unsupported nodes due to input having a dynamic shape=%d', - self.nodes_unsupported_due_to_dynamic_input) + logger.info( + "Unsupported nodes due to input having a dynamic shape=%d", + self.nodes_unsupported_due_to_dynamic_input, + ) if logger.getEffectiveLevel() <= logging.DEBUG: # Enable this manually if you need to look at specific partitions. @@ -135,39 +145,53 @@ def dump_analysis(self, logger: logging.Logger, ep_name: str): caveats = self.supported_ops_checker.get_caveats() if caveats: - indent = ' ' * 5 - logger.debug('Caveats that have not been checked and may result in a node not being supported: ' - f'{"".join([os.linesep + indent + caveat for caveat in caveats])}') + indent = " " * 5 + logger.debug( + "Caveats that have not been checked and may result in a node not being supported: " + f'{"".join([os.linesep + indent + caveat for caveat in caveats])}' + ) pct_nodes_using_ep = self.num_supported_nodes / num_nodes * 100 if self.num_partitions == 0: logger.info(f"{ep_name} cannot run any nodes in this model.") elif self.num_partitions == 1: if pct_nodes_using_ep > 75: - logger.info(f"{ep_name} should work well for this model as there is one partition " - f"covering {pct_nodes_using_ep:.1f}% of the nodes in the model.") + logger.info( + f"{ep_name} should work well for this model as there is one partition " + f"covering {pct_nodes_using_ep:.1f}% of the nodes in the model." + ) elif pct_nodes_using_ep > 50: logger.info( f"{ep_name} may work well for this model, however only {pct_nodes_using_ep:.1f}% of nodes " - "will use it. Performance testing is required to validate.") + "will use it. Performance testing is required to validate." + ) else: logger.info( f"{ep_name} will probably not work will for this model as only {pct_nodes_using_ep:.2f}% " - "of nodes will use it.") + "of nodes will use it." + ) elif self.num_partitions == 2 and pct_nodes_using_ep > 75: - logger.info(f"{ep_name} can be considered for this model as there are two partitions " - f"covering {pct_nodes_using_ep:.1f}% of the nodes. " - "Performance testing is required to validate.") + logger.info( + f"{ep_name} can be considered for this model as there are two partitions " + f"covering {pct_nodes_using_ep:.1f}% of the nodes. " + "Performance testing is required to validate." + ) else: - logger.info(f"{ep_name} is not recommended with this model as there are {self.num_partitions} partitions " - f"covering {pct_nodes_using_ep:.1f}% of the nodes in the model. " - "This will most likely result in worse performance than just using the CPU EP.") - - -def check_partitioning(graph: onnx.GraphProto, supported_ops_checker: _SupportedOpsChecker, - require_fixed_input_sizes: bool = False, value_info: dict = None): - ''' + logger.info( + f"{ep_name} is not recommended with this model as there are {self.num_partitions} partitions " + f"covering {pct_nodes_using_ep:.1f}% of the nodes in the model. " + "This will most likely result in worse performance than just using the CPU EP." + ) + + +def check_partitioning( + graph: onnx.GraphProto, + supported_ops_checker: _SupportedOpsChecker, + require_fixed_input_sizes: bool = False, + value_info: dict = None, +): + """ Estimate the partitions the graph will be split into for nodes that is_node_supported_fn returns true for. The check on whether a node is supported is purely based on the operator type. Additional limitations @@ -182,7 +206,7 @@ def check_partitioning(graph: onnx.GraphProto, supported_ops_checker: _Supported :param value_info: Map of value name to ValueInfoProto. Required if require_fixed_input_sizes is True to lookup the shape of a value. :return PartitioningInfo instance with details - ''' + """ if require_fixed_input_sizes and not value_info: raise ValueError("value_info must be provided if require_fixed_input_sizes is True.") @@ -269,8 +293,7 @@ def close_group(): node = nodes_to_process.popleft() is_op_supported = supported_ops_checker.is_op_supported(node) - is_input_shape_supported = \ - not require_fixed_input_sizes or all(_is_fixed_shape_value(i) for i in node.input) + is_input_shape_supported = not require_fixed_input_sizes or all(_is_fixed_shape_value(i) for i in node.input) is_node_supported = is_op_supported and is_input_shape_supported if not is_node_supported: @@ -344,13 +367,12 @@ def check_nnapi_partitions(model, value_info: dict = None): # if we're running in the ORT python package the file should be local. otherwise assume we're running from the # ORT repo script_dir = pathlib.Path(__file__).parent - local_config = script_dir / 'nnapi_supported_ops.md' + local_config = script_dir / "nnapi_supported_ops.md" if local_config.exists(): config_path = local_config else: ort_root = script_dir.parents[3] - config_path = \ - ort_root / 'tools' / 'ci_build' / 'github' / 'android' / 'nnapi_supported_ops.md' + config_path = ort_root / "tools" / "ci_build" / "github" / "android" / "nnapi_supported_ops.md" return _check_ep_partitioning(model, config_path, value_info) @@ -359,25 +381,24 @@ def check_coreml_partitions(model, value_info: dict = None): # if we're running in the ORT python package the file should be local. otherwise assume we're running from the # ORT repo script_dir = pathlib.Path(__file__).parent - local_config = script_dir / 'coreml_supported_ops.md' + local_config = script_dir / "coreml_supported_ops.md" if local_config.exists(): config_path = local_config else: ort_root = script_dir.parents[3] - config_path = \ - ort_root / 'tools' / 'ci_build' / 'github' / 'apple' / 'coreml_supported_ops.md' + config_path = ort_root / "tools" / "ci_build" / "github" / "apple" / "coreml_supported_ops.md" return _check_ep_partitioning(model, config_path, value_info) def check_shapes(graph: onnx.GraphProto, logger: logging.Logger = None): - ''' + """ Check the shapes of graph inputs, values and graph outputs to determine if they have static or dynamic sizes. NNAPI and CoreML do not support dynamically sized values. :param graph: Graph to check. If shape inferencing has been run the checks on values will be meaningful. :param logger: Optional logger for diagnostic information. :return: Tuple of List of inputs with dynamic shapes, Number of dynamic values found - ''' + """ # it's OK if the input is dynamically sized and we do a Resize early to a fixed size. # it's not good if lots of ops have dynamic inputs @@ -410,8 +431,10 @@ def check_shapes(graph: onnx.GraphProto, logger: logging.Logger = None): # special case some test graphs with a single node which only have graph input and output values, and # a model where all inputs are dynamic (results in no value_info) if not graph.value_info and not (len(graph.node) == 1 or len(dynamic_inputs) == len(graph.input)): - logger.warning("Unable to check shapes within model. " - "ONNX shape inferencing should be run on the model prior to checking.") + logger.warning( + "Unable to check shapes within model. " + "ONNX shape inferencing should be run on the model prior to checking." + ) for vi in graph.value_info: if is_fixed_size_tensor(vi): @@ -420,19 +443,23 @@ def check_shapes(graph: onnx.GraphProto, logger: logging.Logger = None): num_dynamic_values += 1 if logger: - logger.info(f"Num values with fixed shape={num_fixed_values}. " - f"Num values with dynamic shape={num_dynamic_values}") + logger.info( + f"Num values with fixed shape={num_fixed_values}. " f"Num values with dynamic shape={num_dynamic_values}" + ) if dynamic_inputs: if dynamic_outputs: - logger.info("Model has dynamic inputs and outputs. Consider re-exporting model with fixed sizes " - "if NNAPI or CoreML can be used with this model.") + logger.info( + "Model has dynamic inputs and outputs. Consider re-exporting model with fixed sizes " + "if NNAPI or CoreML can be used with this model." + ) else: logger.info( - '''Model has dynamically sized inputs but fixed sized outputs. + """Model has dynamically sized inputs but fixed sized outputs. If the sizes become fixed early in the model (e.g. pre-processing of a dynamic input size results in a fixed input size for the majority of the model) performance with NNAPI and CoreML, - if applicable, should not be significantly impacted.''') + if applicable, should not be significantly impacted.""" + ) return dynamic_inputs, num_dynamic_values @@ -469,14 +496,16 @@ def check_ep(ep_name, checker_func): partition_info_with_fixed_shapes = checker_func(model_with_shape_info) if logger.getEffectiveLevel() <= logging.DEBUG: # analyze and log detailed info - logger.info('Partition information if the model was updated to make the shapes fixed:') + logger.info("Partition information if the model was updated to make the shapes fixed:") partition_info_with_fixed_shapes.dump_analysis(logger, ep_name) fixed_shape_suitability = partition_info_with_fixed_shapes.suitability() - logger.info(f"Model should perform well with {ep_name} if modified to have fixed input shapes: " - f"{fixed_shape_suitability.name}") + logger.info( + f"Model should perform well with {ep_name} if modified to have fixed input shapes: " + f"{fixed_shape_suitability.name}" + ) if fixed_shape_suitability != PartitioningInfo.TryWithEP.NO: - logger.info('Shapes can be altered using python -m onnxruntime.tools.make_dynamic_shape_fixed') + logger.info("Shapes can be altered using python -m onnxruntime.tools.make_dynamic_shape_fixed") if fixed_shape_suitability.value > suitability.value: suitability = fixed_shape_suitability @@ -486,28 +515,29 @@ def check_ep(ep_name, checker_func): nnapi_suitability = check_ep("NNAPI", check_nnapi_partitions) coreml_suitability = check_ep("CoreML", check_coreml_partitions) - if (nnapi_suitability != PartitioningInfo.TryWithEP.YES or coreml_suitability != PartitioningInfo.TryWithEP.YES) \ - and logger.getEffectiveLevel() > logging.DEBUG: - logger.info('Re-run with log level of DEBUG for more details on the NNAPI/CoreML issues.') + if ( + nnapi_suitability != PartitioningInfo.TryWithEP.YES or coreml_suitability != PartitioningInfo.TryWithEP.YES + ) and logger.getEffectiveLevel() > logging.DEBUG: + logger.info("Re-run with log level of DEBUG for more details on the NNAPI/CoreML issues.") - logger.info('---------------') + logger.info("---------------") return nnapi_suitability != PartitioningInfo.TryWithEP.NO or coreml_suitability != PartitioningInfo.TryWithEP.NO def analyze_model(model_path: pathlib.Path, skip_optimize: bool = False, logger: logging.Logger = None): - ''' + """ Analyze the provided model to determine if it's likely to work well with the NNAPI or CoreML Execution Providers :param model_path: Model to analyze. :param skip_optimize: Skip optimizing to BASIC level before checking. When exporting to ORT format we will do this optimization.. :param logger: Logger for output :return: True if either the NNAPI or CoreML Execution Providers may work well with this model. - ''' + """ if not logger: - logger = logging.getLogger('usability_checker') + logger = logging.getLogger("usability_checker") logger.setLevel(logging.INFO) - logger.info(f'Checking {model_path} for usability with ORT Mobile.') + logger.info(f"Checking {model_path} for usability with ORT Mobile.") with tempfile.TemporaryDirectory() as tmp: if not skip_optimize: @@ -522,30 +552,33 @@ def analyze_model(model_path: pathlib.Path, skip_optimize: bool = False, logger: def parse_args(): parser = argparse.ArgumentParser( - os.path.basename(__file__), - description='''Analyze an ONNX model for usage with the ORT mobile''' + os.path.basename(__file__), description="""Analyze an ONNX model for usage with the ORT mobile""" ) - parser.add_argument('--log_level', choices=['debug', 'info', 'warning', 'error'], - default='info', help='Logging level') - parser.add_argument('--skip_optimize', action='store_true', - help="Don't optimize the model to BASIC level prior to analyzing. " - "Optimization will occur when exporting the model to ORT format, so in general " - "should not be skipped unless you have a specific reason to do so.") - parser.add_argument('model_path', type=pathlib.Path, help='Provide path to ONNX model') + parser.add_argument( + "--log_level", choices=["debug", "info", "warning", "error"], default="info", help="Logging level" + ) + parser.add_argument( + "--skip_optimize", + action="store_true", + help="Don't optimize the model to BASIC level prior to analyzing. " + "Optimization will occur when exporting the model to ORT format, so in general " + "should not be skipped unless you have a specific reason to do so.", + ) + parser.add_argument("model_path", type=pathlib.Path, help="Provide path to ONNX model") return parser.parse_args() def run_analyze_model(): args = parse_args() - logger = logging.getLogger('default') + logger = logging.getLogger("default") - if args.log_level == 'debug': + if args.log_level == "debug": logger.setLevel(logging.DEBUG) - elif args.log_level == 'info': + elif args.log_level == "info": logger.setLevel(logging.INFO) - elif args.log_level == 'warning': + elif args.log_level == "warning": logger.setLevel(logging.WARNING) else: logger.setLevel(logging.ERROR) @@ -554,5 +587,5 @@ def run_analyze_model(): analyze_model(model_path, args.skip_optimize, logger) -if __name__ == '__main__': +if __name__ == "__main__": run_analyze_model() diff --git a/tools/python/util/onnx_model_utils.py b/tools/python/util/onnx_model_utils.py index 6d51d0b5da011..2d022eaf0ec36 100644 --- a/tools/python/util/onnx_model_utils.py +++ b/tools/python/util/onnx_model_utils.py @@ -2,63 +2,65 @@ # Licensed under the MIT License. import logging -import onnx -import onnxruntime as ort import pathlib +import onnx from onnx import version_converter +import onnxruntime as ort + def iterate_graph_per_node_func(graph, per_node_func, **func_args): - ''' + """ Iterate the graph including subgraphs calling the per_node_func for each node. :param graph: Graph to iterate :param per_node_func: Function to call for each node. Signature is fn(node: onnx:NodeProto, **kwargs) :param func_args: The keyword args to pass through. - ''' + """ for node in graph.node: per_node_func(node, **func_args) # recurse into subgraph for control flow nodes (Scan/Loop/If) for attr in node.attribute: - if attr.HasField('g'): + if attr.HasField("g"): iterate_graph_per_node_func(attr.g, per_node_func, **func_args) def iterate_graph_per_graph_func(graph, per_graph_func, **func_args): - ''' + """ Iterate the graph including subgraphs calling the per_graph_func for each Graph. :param graph: Graph to iterate :param per_graph_func: Function to call for each graph. Signature is fn(graph: onnx:GraphProto, **kwargs) :param func_args: The keyword args to pass through. - ''' + """ per_graph_func(graph, **func_args) for node in graph.node: # recurse into subgraph for control flow nodes (Scan/Loop/If) for attr in node.attribute: - if attr.HasField('g'): + if attr.HasField("g"): iterate_graph_per_graph_func(attr.g, per_graph_func, **func_args) def get_opsets_imported(model: onnx.ModelProto): - ''' + """ Get the opsets imported by the model :param model: Model to check. :return: Map of domain to opset. - ''' + """ opsets = {} for entry in model.opset_import: # if empty it's ai.onnx - domain = entry.domain or 'ai.onnx' + domain = entry.domain or "ai.onnx" opsets[domain] = entry.version return opsets -def update_onnx_opset(model_path: pathlib.Path, opset: int, out_path: pathlib.Path = None, - logger: logging.Logger = None): +def update_onnx_opset( + model_path: pathlib.Path, opset: int, out_path: pathlib.Path = None, logger: logging.Logger = None +): """ Helper to update the opset of a model using onnx version_converter. Target opset must be greater than current opset. :param model_path: Path to model to update @@ -84,30 +86,32 @@ def update_onnx_opset(model_path: pathlib.Path, opset: int, out_path: pathlib.Pa return new_model -def optimize_model(model_path: pathlib.Path, - output_path: pathlib.Path, - level: ort.GraphOptimizationLevel = ort.GraphOptimizationLevel.ORT_ENABLE_BASIC, - log_level: int = 3): - ''' +def optimize_model( + model_path: pathlib.Path, + output_path: pathlib.Path, + level: ort.GraphOptimizationLevel = ort.GraphOptimizationLevel.ORT_ENABLE_BASIC, + log_level: int = 3, +): + """ Optimize an ONNX model using ONNX Runtime to the specified level :param model_path: Path to ONNX model :param output_path: Path to save optimized model to. :param level: onnxruntime.GraphOptimizationLevel to use. Default is ORT_ENABLE_BASIC. :param log_level: Log level. Defaults to Error (3) so we don't get output about unused initializers being removed. Warning (2) or Info (1) may be desirable in some scenarios. - ''' + """ so = ort.SessionOptions() so.optimized_model_filepath = str(output_path.resolve()) so.graph_optimization_level = level so.log_severity_level = log_level # create session to optimize. this will write the updated model to output_path - _ = ort.InferenceSession(str(model_path.resolve(strict=True)), so, providers=['CPUExecutionProvider']) + _ = ort.InferenceSession(str(model_path.resolve(strict=True)), so, providers=["CPUExecutionProvider"]) def _replace_symbolic_dim_value(graph: onnx.GraphProto, **kwargs): - param_to_replace = kwargs['dim_param'] - value = kwargs['value'] + param_to_replace = kwargs["dim_param"] + value = kwargs["value"] def update_dim_values(value_infos): for vi in value_infos: @@ -115,7 +119,7 @@ def update_dim_values(value_infos): shape = vi.type.tensor_type.shape if shape: for dim in shape.dim: - if dim.HasField('dim_param') and dim.dim_param == param_to_replace: + if dim.HasField("dim_param") and dim.dim_param == param_to_replace: dim.Clear() dim.dim_value = value @@ -130,7 +134,7 @@ def clear_invalid_values(value): shape = value.type.tensor_type.shape if shape: for dim in shape.dim: - if dim.HasField('dim_value') and dim.dim_value < 1: + if dim.HasField("dim_value") and dim.dim_value < 1: dim.Clear() for i in graph.input: @@ -144,33 +148,33 @@ def clear_invalid_values(value): def remove_invalid_dim_values(graph: onnx.GraphProto): - ''' + """ Iterate the graph and subgraphs, unsetting any dim_value entries that have a value of less than 1. These are typically erroneously inserted by a converter to represent a dynamic dimension. :param graph: GraphProto to update - ''' + """ iterate_graph_per_graph_func(graph, _remove_invalid_dim_values_impl) def make_dim_param_fixed(graph: onnx.GraphProto, param_name: str, value: int): - ''' + """ Iterate all values in the graph, replacing dim_param in a tensor shape with the provided value. :param graph: GraphProto to update :param param_name: dim_param to set :param value: value to use - ''' + """ iterate_graph_per_graph_func(graph, _replace_symbolic_dim_value, dim_param=param_name, value=value) def make_input_shape_fixed(graph: onnx.GraphProto, input_name: str, fixed_shape: [int]): - ''' + """ Update the named graph input to set shape to the provided value. This can be used to set unknown dims as well as to replace dim values. If setting the input shape replaces a dim_param, update any other values in the graph that use the dim_param. :param graph: Graph to update :param input_name: Name of graph input to update. :param fixed_shape: Shape to use. - ''' + """ # remove any invalid dim values first. typically this is a dim_value of -1. remove_invalid_dim_values(graph) @@ -178,22 +182,22 @@ def make_input_shape_fixed(graph: onnx.GraphProto, input_name: str, fixed_shape: for i in graph.input: if i.name == input_name: if not i.type.HasField("tensor_type"): - raise ValueError(f'Input {input_name} is not a tensor') + raise ValueError(f"Input {input_name} is not a tensor") # graph inputs are required to have a shape to provide the rank shape = i.type.tensor_type.shape if len(shape.dim) != len(fixed_shape): - raise ValueError( - f'Rank mismatch. Existing:{len(shape.dim)} Replacement:{len(fixed_shape)}') + raise ValueError(f"Rank mismatch. Existing:{len(shape.dim)} Replacement:{len(fixed_shape)}") for idx, dim in enumerate(shape.dim): # check any existing fixed dims match - if dim.HasField('dim_value'): + if dim.HasField("dim_value"): if dim.dim_value != fixed_shape[idx]: raise ValueError( f"Can't replace existing fixed size of {dim.dim_value} with {fixed_shape[idx]} " - f"for dimension {idx + 1}") - elif dim.HasField('dim_param'): + f"for dimension {idx + 1}" + ) + elif dim.HasField("dim_param"): # replacing a dim_param so have to do that through the entire graph make_dim_param_fixed(graph, dim.dim_param, fixed_shape[idx]) else: @@ -203,16 +207,18 @@ def make_input_shape_fixed(graph: onnx.GraphProto, input_name: str, fixed_shape: return - raise ValueError(f'Input {input_name} was not found in graph inputs. ' - f'Valid input names are: {",".join([i.name for i in graph.input])}') + raise ValueError( + f"Input {input_name} was not found in graph inputs. " + f'Valid input names are: {",".join([i.name for i in graph.input])}' + ) def fix_output_shapes(model: onnx.ModelProto): - ''' + """ Update the output shapesof a model where the input shape/s were made fixed, if possible. This is mainly to make the model usage clearer if the output shapes can be inferred from the new input shapes. :param model: Model that had input shapes fixed. - ''' + """ # get a version of the model with shape inferencing info in it. this will provide fixed output shapes if possible. m2 = onnx.shape_inference.infer_shapes(model) @@ -225,15 +231,16 @@ def fix_output_shapes(model: onnx.ModelProto): o.type.tensor_type.shape.CopyFrom(new_o.type.tensor_type.shape) -def _create_producer_consumer_link(node_to_producers: dict, node_to_consumers: dict, - producer: onnx.NodeProto, consumer: onnx.NodeProto): - ''' +def _create_producer_consumer_link( + node_to_producers: dict, node_to_consumers: dict, producer: onnx.NodeProto, consumer: onnx.NodeProto +): + """ Create links between two nodes for a value produced by one and consumed by the other. :param node_to_producers: Map of NodeProto to set of nodes that produce values the node consumes as inputs. :param node_to_consumers: Map of NodeProto to set of nodes that consume values the node produces as outputs. :param producer: Producer node :param consumer: Consumer node - ''' + """ if consumer not in node_to_producers: node_to_producers[consumer] = set() @@ -262,7 +269,7 @@ def is_local_value(value): inputs = [i for i in node.input] for attr in node.attribute: - if attr.HasField('g'): + if attr.HasField("g"): subgraph_implicit_inputs = _map_node_dependencies(attr.g, node_to_producers, node_to_consumers) inputs += subgraph_implicit_inputs @@ -285,7 +292,7 @@ def is_local_value(value): def get_producer_consumer_maps(graph: onnx.GraphProto): - ''' + """ Get maps for connections between the node that produces each value and the nodes that consume the value. Processing includes subgraphs. As the map key is a Node instance from the Graph there should be no ambiguity. :param graph: Graph to process. @@ -298,7 +305,7 @@ def get_producer_consumer_maps(graph: onnx.GraphProto): node_to_producers[NodeC] = set([NodeA, NodeB]) node_to_consumers[NodeC] = set([NodeD]) node_to_producers[NodeD] = set([NodeC]) - ''' + """ # use a hash of the object id for NodeProto. # we need this for the partitioning checker where we keep maps with nodes as the key. @@ -311,18 +318,19 @@ def get_producer_consumer_maps(graph: onnx.GraphProto): # top level graph should have no implicit inputs if implicit_inputs: - raise ValueError('This appears to be an invalid model with missing inputs of ' - f'{",".join(sorted(implicit_inputs))}') + raise ValueError( + "This appears to be an invalid model with missing inputs of " f'{",".join(sorted(implicit_inputs))}' + ) return node_to_producers, node_to_consumers def is_fixed_size_tensor(value: onnx.ValueInfoProto): - ''' + """ Check if value is a tensor with a fixed shape. :param value: onnx.ValueInfoProto to check :return: True if value is a tensor, with a shape, where all dimensions have fixed values. - ''' + """ is_fixed = False if value.type.HasField("tensor_type"): @@ -330,7 +338,7 @@ def is_fixed_size_tensor(value: onnx.ValueInfoProto): if shape: is_fixed = True # scalar has no dims so set to True and unset if we hit a dim without a valid value for dim in shape.dim: - if dim.HasField('dim_value') and dim.dim_value > 0: + if dim.HasField("dim_value") and dim.dim_value > 0: continue # anything else means it's a dynamic value @@ -341,16 +349,16 @@ def is_fixed_size_tensor(value: onnx.ValueInfoProto): def get_optimization_level(level): - '''Convert string to GraphOptimizationLevel.''' - if level == 'disable': + """Convert string to GraphOptimizationLevel.""" + if level == "disable": return ort.GraphOptimizationLevel.ORT_DISABLE_ALL - if level == 'basic': + if level == "basic": # Constant folding and other optimizations that only use ONNX operators return ort.GraphOptimizationLevel.ORT_ENABLE_BASIC - if level == 'extended': + if level == "extended": # Optimizations using custom operators, excluding NCHWc and NHWC layout optimizers return ort.GraphOptimizationLevel.ORT_ENABLE_EXTENDED - if level == 'all': + if level == "all": return ort.GraphOptimizationLevel.ORT_ENABLE_ALL - raise ValueError('Invalid optimization level of ' + level) + raise ValueError("Invalid optimization level of " + level) diff --git a/tools/python/util/optimize_onnx_model.py b/tools/python/util/optimize_onnx_model.py index 53c92577e3d89..4cb9b862b37cb 100644 --- a/tools/python/util/optimize_onnx_model.py +++ b/tools/python/util/optimize_onnx_model.py @@ -10,37 +10,46 @@ def optimize_model_helper(): - parser = argparse.ArgumentParser(f'{os.path.basename(__file__)}:{optimize_model_helper.__name__}', - description=''' + parser = argparse.ArgumentParser( + f"{os.path.basename(__file__)}:{optimize_model_helper.__name__}", + description=""" Optimize an ONNX model using ONNX Runtime to the specified level. See https://onnxruntime.ai/docs/performance/graph-optimizations.html for more - details of the optimization levels.''' - ) - - parser.add_argument('--opt_level', default='basic', - choices=['disable', 'basic', 'extended', 'all'], - help="Optimization level to use.") - parser.add_argument('--log_level', choices=['debug', 'info', 'warning', 'error'], type=str, required=False, - default='error', - help="Log level. Defaults to Error so we don't get output about unused initializers " - "being removed. Warning or Info may be desirable in some scenarios.") - - parser.add_argument('input_model', type=pathlib.Path, help='Provide path to ONNX model to update.') - parser.add_argument('output_model', type=pathlib.Path, help='Provide path to write optimized ONNX model to.') + details of the optimization levels.""", + ) + + parser.add_argument( + "--opt_level", + default="basic", + choices=["disable", "basic", "extended", "all"], + help="Optimization level to use.", + ) + parser.add_argument( + "--log_level", + choices=["debug", "info", "warning", "error"], + type=str, + required=False, + default="error", + help="Log level. Defaults to Error so we don't get output about unused initializers " + "being removed. Warning or Info may be desirable in some scenarios.", + ) + + parser.add_argument("input_model", type=pathlib.Path, help="Provide path to ONNX model to update.") + parser.add_argument("output_model", type=pathlib.Path, help="Provide path to write optimized ONNX model to.") args = parser.parse_args() - if args.log_level == 'error': + if args.log_level == "error": log_level = 3 - elif args.log_level == 'debug': + elif args.log_level == "debug": log_level = 0 # ORT verbose level - elif args.log_level == 'info': + elif args.log_level == "info": log_level = 1 - elif args.log_level == 'warning': + elif args.log_level == "warning": log_level = 2 optimize_model(args.input_model, args.output_model, get_optimization_level(args.opt_level), log_level) -if __name__ == '__main__': +if __name__ == "__main__": optimize_model_helper() diff --git a/tools/python/util/ort_format_model/__init__.py b/tools/python/util/ort_format_model/__init__.py index 6f37a8bf52866..dd1bcbb277226 100644 --- a/tools/python/util/ort_format_model/__init__.py +++ b/tools/python/util/ort_format_model/__init__.py @@ -7,20 +7,21 @@ # need to add the path to the ORT flatbuffers python module before we import anything else here. # we also auto-magically adjust to whether we're running from the ORT repo, or from within the ORT python package script_dir = os.path.dirname(os.path.realpath(__file__)) -fbs_py_schema_dirname = 'ort_flatbuffers_py' +fbs_py_schema_dirname = "ort_flatbuffers_py" if os.path.isdir(os.path.join(script_dir, fbs_py_schema_dirname)): # fbs bindings are in this directory, so we're running in the ORT python package ort_fbs_py_parent_dir = script_dir else: # running directly from ORT repo, so fbs bindings are under onnxruntime/core/flatbuffers - ort_root = os.path.abspath(os.path.join(script_dir, '..', '..', '..', '..')) - ort_fbs_py_parent_dir = os.path.join(ort_root, 'onnxruntime', 'core', 'flatbuffers') + ort_root = os.path.abspath(os.path.join(script_dir, "..", "..", "..", "..")) + ort_fbs_py_parent_dir = os.path.join(ort_root, "onnxruntime", "core", "flatbuffers") sys.path.append(ort_fbs_py_parent_dir) -from .utils import create_config_from_models # noqa -from .ort_model_processor import OrtFormatModelProcessor # noqa from .operator_type_usage_processors import ( # noqa GloballyAllowedTypesOpTypeImplFilter, + OperatorTypeUsageManager, OpTypeImplFilterInterface, - OperatorTypeUsageManager) +) +from .ort_model_processor import OrtFormatModelProcessor # noqa +from .utils import create_config_from_models # noqa diff --git a/tools/python/util/ort_format_model/operator_type_usage_processors.py b/tools/python/util/ort_format_model/operator_type_usage_processors.py index 92bde6adca291..8f21298518f87 100644 --- a/tools/python/util/ort_format_model/operator_type_usage_processors.py +++ b/tools/python/util/ort_format_model/operator_type_usage_processors.py @@ -3,31 +3,30 @@ import json import typing +from abc import ABC, abstractmethod + import ort_flatbuffers_py.fbs as fbs -from abc import ABC, abstractmethod from .types import FbsTypeInfo, value_name_to_typestr def _create_op_key(domain: str, optype: str): - return '{}:{}'.format(domain, optype) + return "{}:{}".format(domain, optype) def _ort_constant_for_domain(domain: str): - ''' + """ Map a string domain value to the internal ONNX Runtime constant for that domain. :param domain: Domain string to map. :return: Internal ONNX Runtime constant - ''' + """ # constants are defined in /include/onnxruntime/core/graph/constants.h # This list is limited to just the domains we have processors for - domain_to_constant_map = {'ai.onnx': 'kOnnxDomain', - 'ai.onnx.ml': 'kMLDomain', - 'com.microsoft': 'kMSDomain'} + domain_to_constant_map = {"ai.onnx": "kOnnxDomain", "ai.onnx.ml": "kMLDomain", "com.microsoft": "kMSDomain"} if domain not in domain_to_constant_map: - raise ValueError('Domain {} not found in map to ONNX Runtime constant. Please update map.'.format(domain)) + raise ValueError("Domain {} not found in map to ONNX Runtime constant. Please update map.".format(domain)) return domain_to_constant_map[domain] @@ -39,9 +38,9 @@ def _reg_type_to_cpp_type(reg_type: str): def _split_reg_types(reg_types_str: str): - ''' + """ Split on underscores but append "_t" to the previous element. - ''' + """ tokens = reg_types_str.split("_") reg_types = [] for token in tokens: @@ -53,9 +52,10 @@ def _split_reg_types(reg_types_str: str): class TypeUsageProcessor(ABC): - ''' + """ Abstract base class for processors which implement operator specific logic to determine the type or types required. - ''' + """ + def __init__(self, domain: str, optype: str): self.domain = domain self.optype = optype @@ -65,53 +65,60 @@ def __init__(self, domain: str, optype: str): def process_node(self, node: fbs.Node, value_name_to_typeinfo: dict): pass - def is_typed_registration_needed(self, type_in_registration: str, - globally_allowed_types: typing.Optional[typing.Set[str]]): - ''' + def is_typed_registration_needed( + self, type_in_registration: str, globally_allowed_types: typing.Optional[typing.Set[str]] + ): + """ Given the string from a kernel registration, determine if the registration is required or not. :param type_in_registration: Type string from kernel registration :param globally_allowed_types: Optional set of globally allowed types. If provided, these types take precedence in determining the required types. :return: True is required. False if not. - ''' + """ # Not all operators have typed registrations, so this is optionally implemented by derived classes - raise RuntimeError('Did not expect processor for {} to have typed registrations.'.format(self.name)) + raise RuntimeError("Did not expect processor for {} to have typed registrations.".format(self.name)) def get_cpp_entry(self): - ''' + """ Get the C++ code that specifies this operator's required types. :return: List with any applicable C++ code for this operator's required types. One line per entry. - ''' + """ # Not applicable for some ops, so return no lines by default. return [] @abstractmethod def to_config_entry(self): - ''' + """ Generate a configuration file entry in JSON format with the required types for the operator. :return: JSON string with required type information. - ''' + """ pass @abstractmethod def from_config_entry(self, entry: str): - ''' + """ Re-create the types required from a configuration file entry created with to_config_entry. NOTE: Any existing type information should be cleared prior to re-creating from a config file entry. :param entry: Configuration file entry - ''' + """ pass class DefaultTypeUsageProcessor(TypeUsageProcessor): - ''' + """ Operator processor which tracks the types used for selected input/s and/or output/s. - ''' - - def __init__(self, domain: str, optype: str, inputs: [int] = [0], outputs: [int] = [], - required_input_types: typing.Dict[int, typing.Set[str]] = {}, - required_output_types: typing.Dict[int, typing.Set[str]] = {}): - ''' + """ + + def __init__( + self, + domain: str, + optype: str, + inputs: [int] = [0], + outputs: [int] = [], + required_input_types: typing.Dict[int, typing.Set[str]] = {}, + required_output_types: typing.Dict[int, typing.Set[str]] = {}, + ): + """ Create DefaultTypeUsageProcessor. Types for one or more inputs and/or outputs can be tracked by the processor. The default is to track the types required for input 0, as this is the most common use case in ONNX. @@ -125,7 +132,7 @@ def __init__(self, domain: str, optype: str, inputs: [int] = [0], outputs: [int] :param outputs: Outputs to track. Zero based index. May be empty. :param required_input_types: Required input types. May be empty. :param required_output_types: Required output types. May be empty. - ''' + """ super().__init__(domain, optype) self._input_types = {} self._output_types = {} @@ -137,7 +144,7 @@ def __init__(self, domain: str, optype: str, inputs: [int] = [0], outputs: [int] self._output_types[o] = set() if not inputs and not outputs: - raise ValueError('At least one input or output must be tracked') + raise ValueError("At least one input or output must be tracked") self._required_input_types = required_input_types self._required_output_types = required_output_types @@ -147,13 +154,13 @@ def _is_type_enabled(self, reg_type, index, required_types, allowed_type_set): return cpp_type in required_types.get(index, set()) or cpp_type in allowed_type_set def is_input_type_enabled(self, reg_type, index, allowed_type_set=None): - '''Whether input type is enabled based on required and allowed types.''' + """Whether input type is enabled based on required and allowed types.""" if allowed_type_set is None: allowed_type_set = self._input_types[index] return self._is_type_enabled(reg_type, index, self._required_input_types, allowed_type_set) def is_output_type_enabled(self, reg_type, index, allowed_type_set=None): - '''Whether output type is enabled based on required and allowed types.''' + """Whether output type is enabled based on required and allowed types.""" if allowed_type_set is None: allowed_type_set = self._output_types[index] return self._is_type_enabled(reg_type, index, self._required_output_types, allowed_type_set) @@ -174,18 +181,22 @@ def process_node(self, node: fbs.Node, value_name_to_typeinfo: dict): for o in self._output_types.keys(): # Don't know of any ops where the number of outputs changed across versions, so require a valid length if o >= node.OutputsLength(): - raise RuntimeError('Node has {} outputs. Tracker for {} incorrectly configured as it requires {}.' - .format(node.OutputsLength(), self.name, o)) + raise RuntimeError( + "Node has {} outputs. Tracker for {} incorrectly configured as it requires {}.".format( + node.OutputsLength(), self.name, o + ) + ) type_str = value_name_to_typestr(node.Outputs(o), value_name_to_typeinfo) self._output_types[o].add(type_str) - def is_typed_registration_needed(self, type_in_registration: str, - globally_allowed_types: typing.Optional[typing.Set[str]]): + def is_typed_registration_needed( + self, type_in_registration: str, globally_allowed_types: typing.Optional[typing.Set[str]] + ): if 0 not in self._input_types.keys(): # currently all standard typed registrations are for input 0. # custom registrations can be handled by operator specific processors (e.g. OneHotProcessor below). - raise RuntimeError('Expected typed registration to use type from input 0. Node:{}'.format(self.name)) + raise RuntimeError("Expected typed registration to use type from input 0. Node:{}".format(self.name)) return self.is_input_type_enabled(type_in_registration, 0, globally_allowed_types) @@ -194,34 +205,40 @@ def get_cpp_entry(self): domain = _ort_constant_for_domain(self.domain) for i in sorted(self._input_types.keys()): if self._input_types[i]: - entries.append('ORT_SPECIFY_OP_KERNEL_ARG_ALLOWED_TYPES({}, {}, Input, {}, {});' - .format(domain, self.optype, i, ', '.join(sorted(self._input_types[i])))) + entries.append( + "ORT_SPECIFY_OP_KERNEL_ARG_ALLOWED_TYPES({}, {}, Input, {}, {});".format( + domain, self.optype, i, ", ".join(sorted(self._input_types[i])) + ) + ) for o in sorted(self._output_types.keys()): if self._output_types[o]: - entries.append('ORT_SPECIFY_OP_KERNEL_ARG_ALLOWED_TYPES({}, {}, Output, {}, {});' - .format(domain, self.optype, o, ', '.join(sorted(self._output_types[o])))) + entries.append( + "ORT_SPECIFY_OP_KERNEL_ARG_ALLOWED_TYPES({}, {}, Output, {}, {});".format( + domain, self.optype, o, ", ".join(sorted(self._output_types[o])) + ) + ) return entries def to_config_entry(self): # convert the sets of types to lists so they can easily written out using the json model - aggregate_info = {'inputs': {}, 'outputs': {}} + aggregate_info = {"inputs": {}, "outputs": {}} # filter out empty entries and sort the types for i in sorted(self._input_types.keys()): if self._input_types[i]: - aggregate_info['inputs'][i] = sorted(self._input_types[i]) + aggregate_info["inputs"][i] = sorted(self._input_types[i]) for o in sorted(self._output_types.keys()): if self._output_types[o]: - aggregate_info['outputs'][o] = sorted(self._output_types[o]) + aggregate_info["outputs"][o] = sorted(self._output_types[o]) # remove any empty keys - if not aggregate_info['inputs']: - aggregate_info.pop('inputs') - if not aggregate_info['outputs']: - aggregate_info.pop('outputs') + if not aggregate_info["inputs"]: + aggregate_info.pop("inputs") + if not aggregate_info["outputs"]: + aggregate_info.pop("outputs") entry = json.dumps(aggregate_info) if aggregate_info else None return entry @@ -231,48 +248,53 @@ def from_config_entry(self, entry: str): self._output_types.clear() aggregate_info = json.loads(entry) - if 'inputs' in aggregate_info: - for i_str, values in aggregate_info['inputs'].items(): + if "inputs" in aggregate_info: + for i_str, values in aggregate_info["inputs"].items(): self._input_types[int(i_str)] = set(values) - if 'outputs' in aggregate_info: - for o_str, values in aggregate_info['outputs'].items(): + if "outputs" in aggregate_info: + for o_str, values in aggregate_info["outputs"].items(): self._output_types[int(o_str)] = set(values) class Input1TypedRegistrationProcessor(DefaultTypeUsageProcessor): - ''' + """ Processor for operators where the second input type is used in a typed kernel registration. - ''' + """ + def __init__(self, domain: str, optype: str): # init with tracking of input 1 only. super().__init__(domain, optype, inputs=[1], outputs=[]) - def is_typed_registration_needed(self, type_in_registration: str, - globally_allowed_types: typing.Optional[typing.Set[str]]): + def is_typed_registration_needed( + self, type_in_registration: str, globally_allowed_types: typing.Optional[typing.Set[str]] + ): return self.is_input_type_enabled(type_in_registration, 1, globally_allowed_types) class Output0TypedRegistrationProcessor(DefaultTypeUsageProcessor): - ''' + """ Processor for operators where the first output type is used in a typed kernel registration. - ''' + """ + def __init__(self, domain: str, optype: str): # init with tracking of output 0 only. super().__init__(domain, optype, inputs=[], outputs=[0]) - def is_typed_registration_needed(self, type_in_registration: str, - globally_allowed_types: typing.Optional[typing.Set[str]]): + def is_typed_registration_needed( + self, type_in_registration: str, globally_allowed_types: typing.Optional[typing.Set[str]] + ): return self.is_output_type_enabled(type_in_registration, 0, globally_allowed_types) class OneHotProcessor(TypeUsageProcessor): - ''' + """ Processor for the OneHot operator, which requires custom logic as the type registration key is a concatenation of the three types involved instead of a single type name. - ''' + """ + def __init__(self): - super().__init__('ai.onnx', 'OneHot') + super().__init__("ai.onnx", "OneHot") self._triples = set() def process_node(self, node: fbs.Node, value_name_to_typeinfo: dict): @@ -283,8 +305,9 @@ def process_node(self, node: fbs.Node, value_name_to_typeinfo: dict): key = (type0, type2, type1) self._triples.add(key) - def is_typed_registration_needed(self, type_in_registration: str, - globally_allowed_types: typing.Optional[typing.Set[str]]): + def is_typed_registration_needed( + self, type_in_registration: str, globally_allowed_types: typing.Optional[typing.Set[str]] + ): # the OneHot registration involves a concatenation of the 3 types involved reg_types = tuple([_reg_type_to_cpp_type(reg_type) for reg_type in _split_reg_types(type_in_registration)]) if globally_allowed_types is not None: @@ -296,27 +319,27 @@ def to_config_entry(self): if not self._triples: return None - aggregate_info = {'custom': sorted(self._triples)} + aggregate_info = {"custom": sorted(self._triples)} entry = json.dumps(aggregate_info) return entry def from_config_entry(self, entry: str): self._triples.clear() aggregate_info = json.loads(entry) - if 'custom' in aggregate_info: - self._triples = set([tuple(triple) for triple in aggregate_info['custom']]) + if "custom" in aggregate_info: + self._triples = set([tuple(triple) for triple in aggregate_info["custom"]]) def _create_operator_type_usage_processors(): - ''' + """ Create a set of processors that determine the required types for all enabled operators. :return: Dictionary of operator key to processor. Key is 'domain:operator (e.g. ai.onnx:Cast)'. - ''' + """ operator_processors = {} def add(processor): if processor.name in operator_processors: - raise RuntimeError('Duplicate processor for ' + processor.name) + raise RuntimeError("Duplicate processor for " + processor.name) operator_processors[processor.name] = processor @@ -335,39 +358,84 @@ def add(processor): # - Implementation does not have any significant type specific code: # ai.onnx: Concat, Flatten, Not, Reshape, Shape, Squeeze, Unsqueeze # - default_processor_onnx_ops = ['Abs', 'ArgMax', 'ArgMin', 'AveragePool', - 'BatchNormalization', 'BitShift', - 'Ceil', 'Clip', 'Conv', 'CumSum', - 'Exp', 'Expand', - 'Floor', - 'Gemm', - 'IsNaN', - 'Log', 'LogSoftmax', 'LpNormalization', - 'MatMul', 'Max', 'MaxPool', 'Mean', 'Min', - 'NonZero', - 'Pad', - 'QLinearConv', 'QLinearMatMul', - 'Range', 'Reciprocal', 'ReduceL1', 'ReduceL2', 'ReduceLogSum', 'ReduceLogSumExp', - 'ReduceMax', 'ReduceMean', 'ReduceMin', 'ReduceProd', 'ReduceSum', 'ReduceSumSquare', - 'Relu', 'Resize', 'ReverseSequence', 'RoiAlign', 'Round', - 'Scatter', 'ScatterElements', 'ScatterND', 'Shrink', 'Sigmoid', 'Sign', 'Sin', - 'Softmax', 'Split', 'SplitToSequence', 'Sqrt', 'Sum', - 'Tanh', 'TopK', 'Transpose', - 'Unique'] + default_processor_onnx_ops = [ + "Abs", + "ArgMax", + "ArgMin", + "AveragePool", + "BatchNormalization", + "BitShift", + "Ceil", + "Clip", + "Conv", + "CumSum", + "Exp", + "Expand", + "Floor", + "Gemm", + "IsNaN", + "Log", + "LogSoftmax", + "LpNormalization", + "MatMul", + "Max", + "MaxPool", + "Mean", + "Min", + "NonZero", + "Pad", + "QLinearConv", + "QLinearMatMul", + "Range", + "Reciprocal", + "ReduceL1", + "ReduceL2", + "ReduceLogSum", + "ReduceLogSumExp", + "ReduceMax", + "ReduceMean", + "ReduceMin", + "ReduceProd", + "ReduceSum", + "ReduceSumSquare", + "Relu", + "Resize", + "ReverseSequence", + "RoiAlign", + "Round", + "Scatter", + "ScatterElements", + "ScatterND", + "Shrink", + "Sigmoid", + "Sign", + "Sin", + "Softmax", + "Split", + "SplitToSequence", + "Sqrt", + "Sum", + "Tanh", + "TopK", + "Transpose", + "Unique", + ] # ops that are used to manipulate shapes or indices so require int32_t and int64_t to be available - default_processor_onnx_ops_requiring_ints_for_input_0 = ['Add', - 'Concat', - 'Div', - 'Equal', - 'Greater', - 'Less', - 'Mul', - 'Neg', # used in tflite TransposeConv conversion - 'Sub'] + default_processor_onnx_ops_requiring_ints_for_input_0 = [ + "Add", + "Concat", + "Div", + "Equal", + "Greater", + "Less", + "Mul", + "Neg", # used in tflite TransposeConv conversion + "Sub", + ] # NOTE: QLinearConv has ONNX and internal implementations - internal_ops = ['QLinearAdd', 'QLinearMul', 'QLinearConv'] + internal_ops = ["QLinearAdd", "QLinearMul", "QLinearConv"] # TODO - review and add ML ops as needed # ML Op notes. @@ -380,45 +448,50 @@ def add(processor): # ZipMap: Switch on output type (derived from attributes) default_processor_onnxml_ops = [] - [add(DefaultTypeUsageProcessor('ai.onnx', op)) for op in default_processor_onnx_ops] - [add(DefaultTypeUsageProcessor('ai.onnx', op, required_input_types={0: {"int32_t", "int64_t"}})) - for op in default_processor_onnx_ops_requiring_ints_for_input_0] - [add(DefaultTypeUsageProcessor('ai.onnx.ml', op)) for op in default_processor_onnxml_ops] - [add(DefaultTypeUsageProcessor('com.microsoft', op)) for op in internal_ops] + [add(DefaultTypeUsageProcessor("ai.onnx", op)) for op in default_processor_onnx_ops] + [ + add(DefaultTypeUsageProcessor("ai.onnx", op, required_input_types={0: {"int32_t", "int64_t"}})) + for op in default_processor_onnx_ops_requiring_ints_for_input_0 + ] + [add(DefaultTypeUsageProcessor("ai.onnx.ml", op)) for op in default_processor_onnxml_ops] + [add(DefaultTypeUsageProcessor("com.microsoft", op)) for op in internal_ops] # # Operators that require custom handling # # Cast switches on types of input 0 and output 0 - add(DefaultTypeUsageProcessor('ai.onnx', 'Cast', inputs=[0], outputs=[0])) + add(DefaultTypeUsageProcessor("ai.onnx", "Cast", inputs=[0], outputs=[0])) # Operators that switch on the type of input 0 and 1 - add(DefaultTypeUsageProcessor('ai.onnx', 'Gather', inputs=[0, 1])) - add(DefaultTypeUsageProcessor('ai.onnx', 'GatherElements', inputs=[0, 1])) - add(DefaultTypeUsageProcessor('ai.onnx', 'Pow', inputs=[0, 1])) - add(DefaultTypeUsageProcessor('ai.onnx', 'Slice', inputs=[0, 1])) + add(DefaultTypeUsageProcessor("ai.onnx", "Gather", inputs=[0, 1])) + add(DefaultTypeUsageProcessor("ai.onnx", "GatherElements", inputs=[0, 1])) + add(DefaultTypeUsageProcessor("ai.onnx", "Pow", inputs=[0, 1])) + add(DefaultTypeUsageProcessor("ai.onnx", "Slice", inputs=[0, 1])) # Operators that switch on output type - add(DefaultTypeUsageProcessor('ai.onnx', 'ConstantOfShape', inputs=[], outputs=[0])) + add(DefaultTypeUsageProcessor("ai.onnx", "ConstantOfShape", inputs=[], outputs=[0])) # Random generator ops produce new data so we track the output type - onnx_random_ops = ['RandomNormal', 'RandomNormalLike', 'RandomUniform', 'RandomUniformLike', 'Multinomial'] - [add(DefaultTypeUsageProcessor('ai.onnx', op, inputs=[], outputs=[0])) for op in onnx_random_ops] + onnx_random_ops = ["RandomNormal", "RandomNormalLike", "RandomUniform", "RandomUniformLike", "Multinomial"] + [add(DefaultTypeUsageProcessor("ai.onnx", op, inputs=[], outputs=[0])) for op in onnx_random_ops] # Where always has a boolean first input so track the second input type for typed registration - add(Input1TypedRegistrationProcessor('ai.onnx', 'Where')) + add(Input1TypedRegistrationProcessor("ai.onnx", "Where")) # we only support 'float' as input for [Dynamic]QuantizeLinear so just track the output type # as that's what is used in the typed registration - add(Output0TypedRegistrationProcessor('ai.onnx', 'QuantizeLinear')) - add(Output0TypedRegistrationProcessor('ai.onnx', 'DynamicQuantizeLinear')) + add(Output0TypedRegistrationProcessor("ai.onnx", "QuantizeLinear")) + add(Output0TypedRegistrationProcessor("ai.onnx", "DynamicQuantizeLinear")) # make sure all the dequantize types are enabled. we use int32_t for parts of GEMM and Conv so just # enabling int8 and uint8 is not enough. # TODO: Only apply required types to the global type list and ignore if it's model based per-op type reduction - add(DefaultTypeUsageProcessor('ai.onnx', 'DequantizeLinear', inputs=[0], - required_input_types={0: {'int8_t', 'uint8_t', 'int32_t'}})) + add( + DefaultTypeUsageProcessor( + "ai.onnx", "DequantizeLinear", inputs=[0], required_input_types={0: {"int8_t", "uint8_t", "int32_t"}} + ) + ) # OneHot concatenates type strings into a triple in the typed registration # e.g. float_int64_t_int64_t @@ -428,42 +501,44 @@ def add(processor): class OpTypeImplFilterInterface(ABC): - ''' + """ Class that filters operator implementations based on type. - ''' + """ + @abstractmethod def is_typed_registration_needed(self, domain: str, optype: str, type_registration_str: str): - ''' + """ Given the string from a kernel registration, determine if the registration is required or not. :param domain: Operator domain. :param optype: Operator type. :param type_registration_str: Type string from kernel registration :return: True is required. False if not. - ''' + """ pass @abstractmethod def get_cpp_entries(self): - ''' + """ Get the C++ code that specifies the operator types to enable. :return: List of strings. One line of C++ code per entry. - ''' + """ pass class OperatorTypeUsageManager: - ''' + """ Class to manage the operator type usage processors. TODO: Currently the type tracking is not specific to a version of the operator. It's unclear how/where version specific logic could/should be added, and it would add significant complexity to track types on a per-version basis. Not clear there's enough benefit from doing so either. - ''' + """ + def __init__(self): self._all_operator_processors = _create_operator_type_usage_processors() # all possible processors self._operator_processors = {} # processors we have actually used so we can limit output to be meaningful def _get_op_processor(self, key): - 'Add the processor to _operator_processors as it is about to be used.' + "Add the processor to _operator_processors as it is about to be used." processor = None if key in self._all_operator_processors: if key not in self._operator_processors: @@ -474,13 +549,13 @@ def _get_op_processor(self, key): return processor def process_node(self, node: fbs.Node, value_name_to_typeinfo: dict): - ''' + """ Process a Node and record info on the types used. :param node: Node from ORT format model :param value_name_to_typeinfo: Map of value names to TypeInfo instances - ''' + """ optype = node.OpType().decode() - domain = node.Domain().decode() or 'ai.onnx' # empty domain defaults to ai.onnx + domain = node.Domain().decode() or "ai.onnx" # empty domain defaults to ai.onnx key = _create_op_key(domain, optype) op_processor = self._get_op_processor(key) @@ -488,12 +563,12 @@ def process_node(self, node: fbs.Node, value_name_to_typeinfo: dict): op_processor.process_node(node, value_name_to_typeinfo) def get_config_entry(self, domain: str, optype: str): - ''' + """ Get the config entry specifying the types for this operator. :param domain: Operator domain. :param optype: Operator type. :return: JSON string with type info if available, else None - ''' + """ key = _create_op_key(domain, optype) config_str = None if key in self._operator_processors: @@ -502,12 +577,12 @@ def get_config_entry(self, domain: str, optype: str): return config_str def restore_from_config_entry(self, domain: str, optype: str, config_entry: str): - ''' + """ Restore the per-operator type information from a configuration file entry. :param domain: Operator domain. :param optype: Operator type. :param config_entry: JSON string with type info as created by get_config_entry - ''' + """ key = _create_op_key(domain, optype) op_processor = self._get_op_processor(key) if op_processor: @@ -515,19 +590,19 @@ def restore_from_config_entry(self, domain: str, optype: str, config_entry: str) def debug_dump(self): - print('C++ code that will be emitted:') + print("C++ code that will be emitted:") [print(cpp_line) for cpp_line in self.get_cpp_entries()] - print('Config file type information that will be returned by get_config_entry:') + print("Config file type information that will be returned by get_config_entry:") for key in sorted(self._operator_processors.keys()): entry = self._operator_processors[key].to_config_entry() if entry: - print('{} -> {}'.format(key, entry)) + print("{} -> {}".format(key, entry)) # roundtrip test to validate that we can initialize the processor from the entry and get the # same values back self._operator_processors[key].from_config_entry(entry) - assert(entry == self._operator_processors[key].to_config_entry()) + assert entry == self._operator_processors[key].to_config_entry() class _OpTypeImplFilter(OpTypeImplFilterInterface): def __init__(self, manager): @@ -538,7 +613,8 @@ def is_typed_registration_needed(self, domain: str, optype: str, type_registrati key = _create_op_key(domain, optype) if key in self._manager._operator_processors: needed = self._manager._operator_processors[key].is_typed_registration_needed( - type_in_registration=type_registration_str, globally_allowed_types=None) + type_in_registration=type_registration_str, globally_allowed_types=None + ) return needed @@ -550,25 +626,29 @@ def get_cpp_entries(self): return entries def make_op_type_impl_filter(self): - ''' + """ Creates an OpTypeImplFilterInterface instance from this manager. Filtering uses the manager's operator type usage processor state. - ''' + """ return OperatorTypeUsageManager._OpTypeImplFilter(self) class GloballyAllowedTypesOpTypeImplFilter(OpTypeImplFilterInterface): - ''' + """ Operator implementation filter which uses globally allowed types. - ''' + """ + _valid_allowed_types = set(FbsTypeInfo.tensordatatype_to_string.values()) def __init__(self, globally_allowed_types: typing.Set[str]): self._operator_processors = _create_operator_type_usage_processors() if not globally_allowed_types.issubset(self._valid_allowed_types): - raise ValueError("Globally allowed types must all be valid. Invalid types: {}" - .format(sorted(globally_allowed_types - self._valid_allowed_types))) + raise ValueError( + "Globally allowed types must all be valid. Invalid types: {}".format( + sorted(globally_allowed_types - self._valid_allowed_types) + ) + ) self._globally_allowed_types = globally_allowed_types @@ -576,16 +656,17 @@ def is_typed_registration_needed(self, domain: str, optype: str, type_registrati key = _create_op_key(domain, optype) if key in self._operator_processors: needed = self._operator_processors[key].is_typed_registration_needed( - type_in_registration=type_registration_str, - globally_allowed_types=self._globally_allowed_types) + type_in_registration=type_registration_str, globally_allowed_types=self._globally_allowed_types + ) else: needed = _reg_type_to_cpp_type(type_registration_str) in self._globally_allowed_types return needed def get_cpp_entries(self): - return ["ORT_SPECIFY_OP_KERNEL_GLOBAL_ALLOWED_TYPES({});".format( - ", ".join(sorted(self._globally_allowed_types)))] + return [ + "ORT_SPECIFY_OP_KERNEL_GLOBAL_ALLOWED_TYPES({});".format(", ".join(sorted(self._globally_allowed_types))) + ] def global_type_list(self): return self._globally_allowed_types diff --git a/tools/python/util/ort_format_model/ort_model_processor.py b/tools/python/util/ort_format_model/ort_model_processor.py index 6c03f09dd8395..7c65930e4cd0e 100644 --- a/tools/python/util/ort_format_model/ort_model_processor.py +++ b/tools/python/util/ort_format_model/ort_model_processor.py @@ -2,21 +2,22 @@ # Licensed under the MIT License. import ort_flatbuffers_py.fbs as fbs + from .operator_type_usage_processors import OperatorTypeUsageManager class OrtFormatModelProcessor: - 'Class to process an ORT format model and determine required operators and types.' + "Class to process an ORT format model and determine required operators and types." def __init__(self, model_path: str, required_ops: dict, processors: OperatorTypeUsageManager): - ''' + """ Initialize ORT format model processor :param model_path: Path to model to load :param required_ops: Dictionary required operator information will be added to. :param processors: Operator type usage processors which will be called for each matching Node. - ''' + """ self._required_ops = required_ops # dictionary of {domain: {opset:[operators]}} - self._file = open(model_path, 'rb').read() + self._file = open(model_path, "rb").read() self._buffer = bytearray(self._file) if not fbs.InferenceSession.InferenceSession.InferenceSessionBufferHasIdentifier(self._buffer, 0): raise RuntimeError("File does not appear to be a valid ORT format model: '{}'".format(model_path)) @@ -25,14 +26,14 @@ def __init__(self, model_path: str, required_ops: dict, processors: OperatorType @staticmethod def _setup_type_info(graph: fbs.Graph, outer_scope_value_typeinfo={}): - ''' + """ Setup the node args for this level of Graph. We copy the current list which represents the outer scope values, and add the local node args to that to create the valid list of values for the current Graph. :param graph: Graph to create NodeArg list for :param outer_scope_value_typeinfo: TypeInfo for outer scope values. Empty for the top-level graph in a model. :return: Dictionary of NodeArg name to TypeInfo - ''' + """ value_name_to_typeinfo = outer_scope_value_typeinfo.copy() for j in range(0, graph.NodeArgsLength()): n = graph.NodeArgs(j) @@ -49,10 +50,10 @@ def _add_required_op(self, domain: str, opset: int, op_type: str): self._required_ops[domain][opset].add(op_type) def _process_graph(self, graph: fbs.Graph, outer_scope_value_typeinfo: dict): - ''' + """ Process one level of the Graph, descending into any subgraphs when they are found :param outer_scope_value_typeinfo: Outer scope NodeArg dictionary from ancestor graphs - ''' + """ # Merge the TypeInfo for all values in this level of the graph with the outer scope value TypeInfo. value_name_to_typeinfo = OrtFormatModelProcessor._setup_type_info(graph, outer_scope_value_typeinfo) @@ -60,7 +61,7 @@ def _process_graph(self, graph: fbs.Graph, outer_scope_value_typeinfo: dict): node = graph.Nodes(i) optype = node.OpType().decode() - domain = node.Domain().decode() or 'ai.onnx' # empty domain defaults to ai.onnx + domain = node.Domain().decode() or "ai.onnx" # empty domain defaults to ai.onnx self._add_required_op(domain, node.SinceVersion(), optype) diff --git a/tools/python/util/ort_format_model/types.py b/tools/python/util/ort_format_model/types.py index 79a0876d5ea0f..5cac69f4b3319 100644 --- a/tools/python/util/ort_format_model/types.py +++ b/tools/python/util/ort_format_model/types.py @@ -7,29 +7,29 @@ class FbsTypeInfo: "Class to provide conversion between ORT flatbuffers schema values and C++ types" tensordatatype_to_string = { - fbs.TensorDataType.TensorDataType.FLOAT: 'float', - fbs.TensorDataType.TensorDataType.UINT8: 'uint8_t', - fbs.TensorDataType.TensorDataType.INT8: 'int8_t', - fbs.TensorDataType.TensorDataType.UINT16: 'uint16_t', - fbs.TensorDataType.TensorDataType.INT16: 'int16_t', - fbs.TensorDataType.TensorDataType.INT32: 'int32_t', - fbs.TensorDataType.TensorDataType.INT64: 'int64_t', - fbs.TensorDataType.TensorDataType.STRING: 'std::string', - fbs.TensorDataType.TensorDataType.BOOL: 'bool', - fbs.TensorDataType.TensorDataType.FLOAT16: 'MLFloat16', - fbs.TensorDataType.TensorDataType.DOUBLE: 'double', - fbs.TensorDataType.TensorDataType.UINT32: 'uint32_t', - fbs.TensorDataType.TensorDataType.UINT64: 'uint64_t', + fbs.TensorDataType.TensorDataType.FLOAT: "float", + fbs.TensorDataType.TensorDataType.UINT8: "uint8_t", + fbs.TensorDataType.TensorDataType.INT8: "int8_t", + fbs.TensorDataType.TensorDataType.UINT16: "uint16_t", + fbs.TensorDataType.TensorDataType.INT16: "int16_t", + fbs.TensorDataType.TensorDataType.INT32: "int32_t", + fbs.TensorDataType.TensorDataType.INT64: "int64_t", + fbs.TensorDataType.TensorDataType.STRING: "std::string", + fbs.TensorDataType.TensorDataType.BOOL: "bool", + fbs.TensorDataType.TensorDataType.FLOAT16: "MLFloat16", + fbs.TensorDataType.TensorDataType.DOUBLE: "double", + fbs.TensorDataType.TensorDataType.UINT32: "uint32_t", + fbs.TensorDataType.TensorDataType.UINT64: "uint64_t", # fbs.TensorDataType.TensorDataType.COMPLEX64: 'complex64 is not supported', # fbs.TensorDataType.TensorDataType.COMPLEX128: 'complex128 is not supported', - fbs.TensorDataType.TensorDataType.BFLOAT16: 'BFloat16' + fbs.TensorDataType.TensorDataType.BFLOAT16: "BFloat16", } @staticmethod def typeinfo_to_str(type: fbs.TypeInfo): value_type = type.ValueType() value = type.Value() - type_str = 'unknown' + type_str = "unknown" if value_type == fbs.TypeInfoValue.TypeInfoValue.tensor_type: tensor_type_and_shape = fbs.TensorTypeAndShape.TensorTypeAndShape() @@ -44,7 +44,7 @@ def typeinfo_to_str(type: fbs.TypeInfo): key_type_str = FbsTypeInfo.tensordatatype_to_string[key_type] value_type = map_type.ValueType() # TypeInfo value_type_str = FbsTypeInfo.typeinfo_to_str(value_type) - type_str = 'std::map<{},{}>'.format(key_type_str, value_type_str) + type_str = "std::map<{},{}>".format(key_type_str, value_type_str) elif value_type == fbs.TypeInfoValue.TypeInfoValue.sequence_type: sequence_type = fbs.SequenceType.SequenceType() @@ -60,21 +60,21 @@ def typeinfo_to_str(type: fbs.TypeInfo): # due to this). type_str = elem_type_str else: - raise ValueError('Unknown or missing value type of {}'.format(value_type)) + raise ValueError("Unknown or missing value type of {}".format(value_type)) return type_str def get_typeinfo(name: str, value_name_to_typeinfo: dict) -> fbs.TypeInfo: - 'Lookup a name in a dictionary mapping value name to TypeInfo.' + "Lookup a name in a dictionary mapping value name to TypeInfo." if name not in value_name_to_typeinfo: - raise RuntimeError('Missing TypeInfo entry for ' + name) + raise RuntimeError("Missing TypeInfo entry for " + name) return value_name_to_typeinfo[name] # TypeInfo object def value_name_to_typestr(name: str, value_name_to_typeinfo: dict): - 'Lookup TypeInfo for value name and convert to a string representing the C++ type.' + "Lookup TypeInfo for value name and convert to a string representing the C++ type." type = get_typeinfo(name, value_name_to_typeinfo) type_str = FbsTypeInfo.typeinfo_to_str(type) return type_str diff --git a/tools/python/util/ort_format_model/utils.py b/tools/python/util/ort_format_model/utils.py index 2be004dc9cfaf..83ffaac5e2e80 100644 --- a/tools/python/util/ort_format_model/utils.py +++ b/tools/python/util/ort_format_model/utils.py @@ -4,10 +4,10 @@ import pathlib import typing +from ..logger import get_logger from .operator_type_usage_processors import OperatorTypeUsageManager from .ort_model_processor import OrtFormatModelProcessor -from ..logger import get_logger log = get_logger("ort_format_model.utils") @@ -24,20 +24,21 @@ def _extract_ops_and_types_from_ort_models(model_files: typing.Iterable[pathlib. return required_ops, op_type_usage_manager -def create_config_from_models(model_files: typing.Iterable[pathlib.Path], output_file: pathlib.Path, - enable_type_reduction: bool): - ''' +def create_config_from_models( + model_files: typing.Iterable[pathlib.Path], output_file: pathlib.Path, enable_type_reduction: bool +): + """ Create a configuration file with required operators and optionally required types. :param model_files: Model files to use to generate the configuration file. :param output_file: File to write configuration to. :param enable_type_reduction: Include required type information for individual operators in the configuration. - ''' + """ required_ops, op_type_processors = _extract_ops_and_types_from_ort_models(model_files, enable_type_reduction) output_file.parent.mkdir(parents=True, exist_ok=True) - with open(output_file, 'w') as out: + with open(output_file, "w") as out: out.write("# Generated from model/s:\n") for model_file in sorted(model_files): out.write(f"# - {model_file}\n") @@ -49,11 +50,13 @@ def create_config_from_models(model_files: typing.Iterable[pathlib.Path], output out.write("{};{};".format(domain, opset)) if enable_type_reduction: # type string is empty if op hasn't been seen - entries = ['{}{}'.format(op, op_type_processors.get_config_entry(domain, op) or '') - for op in sorted(ops)] + entries = [ + "{}{}".format(op, op_type_processors.get_config_entry(domain, op) or "") + for op in sorted(ops) + ] else: entries = sorted(ops) - out.write("{}\n".format(','.join(entries))) + out.write("{}\n".format(",".join(entries))) log.info("Created config in %s", output_file) diff --git a/tools/python/util/pytorch_export_helpers.py b/tools/python/util/pytorch_export_helpers.py index b90f2d83572b1..0ab7689f378c3 100644 --- a/tools/python/util/pytorch_export_helpers.py +++ b/tools/python/util/pytorch_export_helpers.py @@ -2,10 +2,10 @@ # Licensed under the MIT License. import inspect -import torch - from collections import abc +import torch + def _parse_inputs_for_onnx_export(all_input_parameters, inputs, kwargs): # extracted from https://github.com/microsoft/onnxruntime/blob/239c6ad3f021ff7cc2e6247eb074bd4208dc11e2/orttraining/orttraining/python/training/ortmodule/_io.py#L433 # noqa @@ -54,13 +54,15 @@ def _add_input(name, input): if input_parameter.kind == inspect.Parameter.VAR_POSITIONAL: # VAR_POSITIONAL parameter carries all *args parameters from original forward method for args_i in range(input_idx, len(inputs)): - name = f'{input_parameter.name}_{var_positional_idx}' + name = f"{input_parameter.name}_{var_positional_idx}" var_positional_idx += 1 inp = inputs[args_i] num_expanded_non_none_positional_inputs += _add_input(name, inp) - elif input_parameter.kind == inspect.Parameter.POSITIONAL_ONLY or \ - input_parameter.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD or \ - input_parameter.kind == inspect.Parameter.KEYWORD_ONLY: + elif ( + input_parameter.kind == inspect.Parameter.POSITIONAL_ONLY + or input_parameter.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD + or input_parameter.kind == inspect.Parameter.KEYWORD_ONLY + ): # All positional non-*args and non-**kwargs are processed here name = input_parameter.name inp = None @@ -84,15 +86,19 @@ def _add_input(name, input): def _flatten_module_input(names, args, kwargs): - '''Flatten args and kwargs in a single tuple of tensors.''' + """Flatten args and kwargs in a single tuple of tensors.""" # extracted from https://github.com/microsoft/onnxruntime/blob/239c6ad3f021ff7cc2e6247eb074bd4208dc11e2/orttraining/orttraining/python/training/ortmodule/_io.py#L110 # noqa - def is_primitive_type(value): return type(value) in {int, bool, float} - def to_tensor(value): return torch.tensor(value) + def is_primitive_type(value): + return type(value) in {int, bool, float} + + def to_tensor(value): + return torch.tensor(value) ret = [to_tensor(arg) if is_primitive_type(arg) else arg for arg in args] - ret += [to_tensor(kwargs[name]) if is_primitive_type(kwargs[name]) - else kwargs[name] for name in names if name in kwargs] + ret += [ + to_tensor(kwargs[name]) if is_primitive_type(kwargs[name]) else kwargs[name] for name in names if name in kwargs + ] # if kwargs is empty, append an empty dictionary at the end of the sample inputs to make exporter # happy. This is because the exporter is confused with kwargs and dictionary inputs otherwise. @@ -103,7 +109,7 @@ def to_tensor(value): return torch.tensor(value) def infer_input_info(module: torch.nn.Module, *inputs, **kwargs): - ''' + """ Infer the input names and order from the arguments used to execute a PyTorch module for usage exporting the model via torch.onnx.export. Assumes model is on CPU. Use `module.to(torch.device('cpu'))` if it isn't. @@ -117,7 +123,7 @@ def infer_input_info(module: torch.nn.Module, *inputs, **kwargs): :param kwargs: Keyword argument inputs :return: Tuple of ordered input names and input values. These can be used directly with torch.onnx.export as the `input_names` and `inputs` arguments. - ''' + """ module_parameters = inspect.signature(module.forward).parameters.values() input_names = _parse_inputs_for_onnx_export(module_parameters, inputs, kwargs) inputs_as_tuple = _flatten_module_input(input_names, inputs, kwargs) diff --git a/tools/python/util/qdq_helpers/optimize_qdq_model.py b/tools/python/util/qdq_helpers/optimize_qdq_model.py index 7fbc227b4597e..7691f407a8da4 100644 --- a/tools/python/util/qdq_helpers/optimize_qdq_model.py +++ b/tools/python/util/qdq_helpers/optimize_qdq_model.py @@ -3,22 +3,25 @@ # Licensed under the MIT License. import argparse -import onnx import os import pathlib +import onnx + from .qdq_model_utils import fix_dq_nodes_with_multiple_consumers def optimize_qdq_model(): - parser = argparse.ArgumentParser(os.path.basename(__file__), - description=''' + parser = argparse.ArgumentParser( + os.path.basename(__file__), + description=""" Update a QDQ format ONNX model to ensure optimal performance when executed using ONNX Runtime. - ''') + """, + ) - parser.add_argument('input_model', type=pathlib.Path, help='Provide path to ONNX model to update.') - parser.add_argument('output_model', type=pathlib.Path, help='Provide path to write updated ONNX model to.') + parser.add_argument("input_model", type=pathlib.Path, help="Provide path to ONNX model to update.") + parser.add_argument("output_model", type=pathlib.Path, help="Provide path to write updated ONNX model to.") args = parser.parse_args() @@ -30,5 +33,5 @@ def optimize_qdq_model(): onnx.save(model, str(args.output_model.resolve())) -if __name__ == '__main__': +if __name__ == "__main__": optimize_qdq_model() diff --git a/tools/python/util/qdq_helpers/qdq_model_utils.py b/tools/python/util/qdq_helpers/qdq_model_utils.py index 62e95f948b24e..7aac6f892880b 100644 --- a/tools/python/util/qdq_helpers/qdq_model_utils.py +++ b/tools/python/util/qdq_helpers/qdq_model_utils.py @@ -2,16 +2,17 @@ # Licensed under the MIT License. import onnx + from ..onnx_model_utils import get_producer_consumer_maps, iterate_graph_per_graph_func def _duplicate_dq_nodes_with_multiple_consumers(graph: onnx.GraphProto, **kwargs): - updated_graphs = kwargs['updated_graphs'] - node_to_consumers = kwargs['node_to_consumers'] - validate_updates = kwargs['validate_updates'] + updated_graphs = kwargs["updated_graphs"] + node_to_consumers = kwargs["node_to_consumers"] + validate_updates = kwargs["validate_updates"] nodes_to_update = [] - for node in filter(lambda node: node.op_type == 'DequantizeLinear', graph.node): + for node in filter(lambda node: node.op_type == "DequantizeLinear", graph.node): # node providing graph output won't have consumer nodes consumers = node_to_consumers[node] if node in node_to_consumers else [] if len(consumers) > 1: @@ -19,15 +20,16 @@ def _duplicate_dq_nodes_with_multiple_consumers(graph: onnx.GraphProto, **kwargs # TODO: If this does ever occur, as long as it's only consumed in one subgraph we could leave that # value as is (no need to handle recursing into the subgraph) and update the consumers in this # graph only - raise IndexError("DequantizeLinear node output is consumed by a subgraph. " - "This is not currently supported.") + raise IndexError( + "DequantizeLinear node output is consumed by a subgraph. " "This is not currently supported." + ) nodes_to_update.append(node) if validate_updates: if nodes_to_update: # internal error. we somehow missed an update in the first pass when validate_upates was false - raise ValueError('Graph still has DequantizeLinear nodes with multiple consumers.') + raise ValueError("Graph still has DequantizeLinear nodes with multiple consumers.") return @@ -51,11 +53,11 @@ def _duplicate_dq_nodes_with_multiple_consumers(graph: onnx.GraphProto, **kwargs duplicate = onnx.NodeProto() duplicate.CopyFrom(node) # update node name for debugging. use the global dup idx for node duplication - duplicate.name += f'/qdq_utils_dup_{dup_idx}' + duplicate.name += f"/qdq_utils_dup_{dup_idx}" # update output. use the local idx for value duplication orig_output = node.output[0] - new_output = f'{orig_output}/qdq_utils_dup_{idx}' + new_output = f"{orig_output}/qdq_utils_dup_{idx}" duplicate.output[0] = new_output # update input on the consumer node. @@ -73,25 +75,33 @@ def _duplicate_dq_nodes_with_multiple_consumers(graph: onnx.GraphProto, **kwargs def fix_dq_nodes_with_multiple_consumers(model): - ''' + """ Update a model if any DequantizeLinear nodes have multiple consumers. The QDQ node unit processing is overly complicated if this is the case, as the DQ node would be in multiple units, and the units may end up in different partitions at runtime. :param model: QDQ model to update - ''' + """ node_to_producers, node_to_consumers = get_producer_consumer_maps(model.graph) updated_graphs = [] # list of GraphProto instances that were updated_graphs - iterate_graph_per_graph_func(model.graph, _duplicate_dq_nodes_with_multiple_consumers, - node_to_consumers=node_to_consumers, validate_updates=False, - updated_graphs=updated_graphs) + iterate_graph_per_graph_func( + model.graph, + _duplicate_dq_nodes_with_multiple_consumers, + node_to_consumers=node_to_consumers, + validate_updates=False, + updated_graphs=updated_graphs, + ) if updated_graphs: updated_graphs = [] node_to_producers, node_to_consumers = get_producer_consumer_maps(model.graph) - iterate_graph_per_graph_func(model.graph, _duplicate_dq_nodes_with_multiple_consumers, - node_to_consumers=node_to_consumers, validate_updates=True, - updated_graphs=updated_graphs) + iterate_graph_per_graph_func( + model.graph, + _duplicate_dq_nodes_with_multiple_consumers, + node_to_consumers=node_to_consumers, + validate_updates=True, + updated_graphs=updated_graphs, + ) # validate with check and by running shape inference. onnx.checker.check_model(model) diff --git a/tools/python/util/qdq_helpers/test/test_qdq_model_utils.py b/tools/python/util/qdq_helpers/test/test_qdq_model_utils.py index 26e2c74707537..a360ffe91568a 100644 --- a/tools/python/util/qdq_helpers/test/test_qdq_model_utils.py +++ b/tools/python/util/qdq_helpers/test/test_qdq_model_utils.py @@ -1,10 +1,11 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. -import onnx import pathlib import unittest +import onnx + from ..qdq_model_utils import fix_dq_nodes_with_multiple_consumers script_dir = pathlib.Path(__file__).parent @@ -17,14 +18,13 @@ class TestQDQUtils(unittest.TestCase): def test_fix_DQ_with_multiple_consumers(self): - ''' - ''' - model_path = ort_root / 'onnxruntime' / 'test' / 'testdata' / 'qdq_with_multi_consumer_dq_nodes.onnx' + """ """ + model_path = ort_root / "onnxruntime" / "test" / "testdata" / "qdq_with_multi_consumer_dq_nodes.onnx" model = onnx.load(str(model_path)) - orig_dq_nodes = [n for n in model.graph.node if n.op_type == 'DequantizeLinear'] + orig_dq_nodes = [n for n in model.graph.node if n.op_type == "DequantizeLinear"] fix_dq_nodes_with_multiple_consumers(model) - new_dq_nodes = [n for n in model.graph.node if n.op_type == 'DequantizeLinear'] + new_dq_nodes = [n for n in model.graph.node if n.op_type == "DequantizeLinear"] # there are 3 DQ nodes with 2 consumers (an earlier Conv and later Add) # additionally the last one also provides a graph output diff --git a/tools/python/util/reduced_build_config_parser.py b/tools/python/util/reduced_build_config_parser.py index a55f81489e555..3a2210ccda6e6 100644 --- a/tools/python/util/reduced_build_config_parser.py +++ b/tools/python/util/reduced_build_config_parser.py @@ -6,6 +6,7 @@ # Check if the flatbuffers module is available. If not we cannot handle type reduction information in the config. try: import flatbuffers # noqa + have_flatbuffers = True from .ort_format_model import GloballyAllowedTypesOpTypeImplFilter, OperatorTypeUsageManager # noqa except ImportError: @@ -13,7 +14,7 @@ def parse_config(config_file: str, enable_type_reduction: bool = False): - ''' + """ Parse the configuration file and return the required operators dictionary and an OpTypeImplFilterInterface instance. @@ -78,10 +79,10 @@ def parse_config(config_file: str, enable_type_reduction: bool = False): required. op_type_impl_filter: OpTypeImplFilterInterface instance if type reduction is enabled, the flatbuffers module is available, and type reduction information is present. None otherwise. - ''' + """ if not os.path.isfile(config_file): - raise ValueError('Configuration file {} does not exist'.format(config_file)) + raise ValueError("Configuration file {} does not exist".format(config_file)) # only enable type reduction when flatbuffers is available enable_type_reduction = enable_type_reduction and have_flatbuffers @@ -101,7 +102,7 @@ def process_non_op_line(line): nonlocal globally_allowed_types if globally_allowed_types is not None: raise RuntimeError("Globally allowed types were already specified.") - globally_allowed_types = set(segment.strip() for segment in line.split(';')[1].split(',')) + globally_allowed_types = set(segment.strip() for segment in line.split(";")[1].split(",")) return True if line == "!no_ops_specified_means_all_ops_are_required": # handle all ops required line @@ -111,17 +112,17 @@ def process_non_op_line(line): return False - with open(config_file, 'r') as config: + with open(config_file, "r") as config: for line in [orig_line.strip() for orig_line in config.readlines()]: if process_non_op_line(line): continue - domain, opset_str, operators_str = [segment.strip() for segment in line.split(';')] - opsets = [int(s) for s in opset_str.split(',')] + domain, opset_str, operators_str = [segment.strip() for segment in line.split(";")] + opsets = [int(s) for s in opset_str.split(",")] # any type reduction information is serialized json that starts/ends with { and }. # type info is optional for each operator. - if '{' in operators_str: + if "{" in operators_str: has_op_type_reduction_info = True # parse the entries in the json dictionary with type info @@ -129,8 +130,8 @@ def process_non_op_line(line): cur = 0 end = len(operators_str) while cur < end: - next_comma = operators_str.find(',', cur) - next_open_brace = operators_str.find('{', cur) + next_comma = operators_str.find(",", cur) + next_open_brace = operators_str.find("{", cur) if next_comma == -1: next_comma = end @@ -150,14 +151,14 @@ def process_non_op_line(line): i = next_open_brace + 1 num_open_braces = 1 while num_open_braces > 0 and i < end: - if operators_str[i] == '{': + if operators_str[i] == "{": num_open_braces += 1 - elif operators_str[i] == '}': + elif operators_str[i] == "}": num_open_braces -= 1 i += 1 if num_open_braces != 0: - raise RuntimeError('Mismatched { and } in type string: ' + operators_str[next_open_brace:]) + raise RuntimeError("Mismatched { and } in type string: " + operators_str[next_open_brace:]) if op_type_usage_manager: type_str = operators_str[next_open_brace:i] @@ -171,7 +172,7 @@ def process_non_op_line(line): cur = end_str + 1 else: - operators = set([op.strip() for op in operators_str.split(',')]) + operators = set([op.strip() for op in operators_str.split(",")]) for opset in opsets: if domain not in required_ops: @@ -190,7 +191,8 @@ def process_non_op_line(line): op_type_usage_manager = None if globally_allowed_types is not None and op_type_usage_manager is not None: raise RuntimeError( - "Specifying globally allowed types and per-op type reduction info together is unsupported.") + "Specifying globally allowed types and per-op type reduction info together is unsupported." + ) if globally_allowed_types is not None: op_type_impl_filter = GloballyAllowedTypesOpTypeImplFilter(globally_allowed_types) diff --git a/tools/python/util/run.py b/tools/python/util/run.py index 8d6d4b44720ee..a8e45e9694359 100644 --- a/tools/python/util/run.py +++ b/tools/python/util/run.py @@ -6,13 +6,20 @@ import shlex import subprocess - _log = logging.getLogger("util.run") -def run(*args, cwd=None, - input=None, capture_stdout=False, capture_stderr=False, - shell=False, env=None, check=True, quiet=False): +def run( + *args, + cwd=None, + input=None, + capture_stdout=False, + capture_stderr=False, + shell=False, + env=None, + check=True, + quiet=False +): """Runs a subprocess. Args: @@ -32,19 +39,24 @@ def run(*args, cwd=None, """ cmd = [*args] - _log.info("Running subprocess in '{0}'\n {1}".format( - cwd or os.getcwd(), " ".join([shlex.quote(arg) for arg in cmd]))) + _log.info( + "Running subprocess in '{0}'\n {1}".format(cwd or os.getcwd(), " ".join([shlex.quote(arg) for arg in cmd])) + ) def output(is_stream_captured): - return subprocess.PIPE if is_stream_captured else \ - (subprocess.DEVNULL if quiet else None) + return subprocess.PIPE if is_stream_captured else (subprocess.DEVNULL if quiet else None) completed_process = subprocess.run( - cmd, cwd=cwd, check=check, input=input, - stdout=output(capture_stdout), stderr=output(capture_stderr), - env=env, shell=shell) - - _log.debug("Subprocess completed. Return code: {}".format( - completed_process.returncode)) + cmd, + cwd=cwd, + check=check, + input=input, + stdout=output(capture_stdout), + stderr=output(capture_stderr), + env=env, + shell=shell, + ) + + _log.debug("Subprocess completed. Return code: {}".format(completed_process.returncode)) return completed_process diff --git a/tools/python/util/test/test_onnx_model_utils.py b/tools/python/util/test/test_onnx_model_utils.py index 9fe6bfc8456f4..9c79aabaebe8a 100644 --- a/tools/python/util/test/test_onnx_model_utils.py +++ b/tools/python/util/test/test_onnx_model_utils.py @@ -1,17 +1,20 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. -import onnx import pathlib import unittest -from onnx import helper -from onnx import shape_inference -from onnx import TensorProto +import onnx +from onnx import TensorProto, helper, shape_inference from ..mobile_helpers.usability_checker import check_shapes -from ..onnx_model_utils import get_producer_consumer_maps, \ - make_dim_param_fixed, make_input_shape_fixed, fix_output_shapes, is_fixed_size_tensor +from ..onnx_model_utils import ( + fix_output_shapes, + get_producer_consumer_maps, + is_fixed_size_tensor, + make_dim_param_fixed, + make_input_shape_fixed, +) script_dir = pathlib.Path(__file__).parent ort_root = script_dir.parents[3] @@ -30,30 +33,26 @@ def _create_model(): # shadow a1 in main graph. # LoopAdd_SubgraphOutput should be linked to this and a1 should not be an implicit input helper.make_node("Add", ["loop_state_in", "iter"], ["a1"], "LoopAdd_Shadows"), - # main_graph_initializer should be handled (implicit input but no producer node) # graph input 'x' from main graph should also be handled helper.make_node("Add", ["main_graph_initializer", "x"], ["a2"], "LoopAdd_OuterScopeInitializer"), - # implicit input should be handled - 'z' can be accessed from outside scope # Add2 in main graph should be implicit input of the Loop node helper.make_node("Add", ["z", "a1"], ["a3"], "LoopAdd_ImplicitInput"), - # create subgraph output helper.make_node("Add", ["a2", "a3"], ["loop_state_out"], "LoopAdd_SubgraphOutput"), ], "Loop_body", [ - helper.make_tensor_value_info('iter', TensorProto.INT64, [1]), - helper.make_tensor_value_info('cond', TensorProto.BOOL, [1]), - helper.make_tensor_value_info('loop_state_in', TensorProto.FLOAT, [1]) + helper.make_tensor_value_info("iter", TensorProto.INT64, [1]), + helper.make_tensor_value_info("cond", TensorProto.BOOL, [1]), + helper.make_tensor_value_info("loop_state_in", TensorProto.FLOAT, [1]), ], [ - helper.make_tensor_value_info('cond', TensorProto.BOOL, [1]), - helper.make_tensor_value_info('loop_state_out', TensorProto.FLOAT, [1]), + helper.make_tensor_value_info("cond", TensorProto.BOOL, [1]), + helper.make_tensor_value_info("loop_state_out", TensorProto.FLOAT, [1]), ], - [ - ] + [], ) # Create the main graph @@ -65,31 +64,32 @@ def _create_model(): helper.make_node("Add", ["a1", "main_graph_initializer"], ["z"], "Add2"), # rename 'z' to use as explicit input to Loop helper.make_node("Identity", ["z"], ["state_var_in"], "RenameZ"), - helper.make_node("Loop", ["max_trip_count", "keep_going", "state_var_in"], ["state_var_out"], "Loop1", - body=body), - helper.make_node("Sub", ["a1", "state_var_out"], ["graph_output"], "sub_1") + helper.make_node( + "Loop", ["max_trip_count", "keep_going", "state_var_in"], ["state_var_out"], "Loop1", body=body + ), + helper.make_node("Sub", ["a1", "state_var_out"], ["graph_output"], "sub_1"), ], "Main_graph", [ - helper.make_tensor_value_info('x', TensorProto.FLOAT, [1]), - helper.make_tensor_value_info('y', TensorProto.FLOAT, [1]), + helper.make_tensor_value_info("x", TensorProto.FLOAT, [1]), + helper.make_tensor_value_info("y", TensorProto.FLOAT, [1]), ], [ - helper.make_tensor_value_info('graph_output', TensorProto.FLOAT, [1]), + helper.make_tensor_value_info("graph_output", TensorProto.FLOAT, [1]), ], [ - helper.make_tensor('max_trip_count', TensorProto.INT64, [1], [2]), - helper.make_tensor('main_graph_initializer', TensorProto.FLOAT, [1], [1.]), - helper.make_tensor('keep_going', TensorProto.BOOL, [1], [True]), - ] + helper.make_tensor("max_trip_count", TensorProto.INT64, [1], [2]), + helper.make_tensor("main_graph_initializer", TensorProto.FLOAT, [1], [1.0]), + helper.make_tensor("keep_going", TensorProto.BOOL, [1], [True]), + ], ) return helper.make_model(graph_proto) def test_model_with_subgraph(self): - ''' + """ Test a manually created model that has a subgraph and implicit inputs of all possible types. - ''' + """ model = self._create_model() node_to_producers, node_to_consumers = get_producer_consumer_maps(model.graph) @@ -107,19 +107,26 @@ def test_model_with_subgraph(self): loop_add_subgraph_output = subgraph.node[3] def node_name(node): - return f'{node.name}:{node.op_type}' + return f"{node.name}:{node.op_type}" def check_linked(producer, consumer): - self.assertTrue(producer in node_to_producers[consumer], - f'{node_name(producer)} not in producers for {node_name(consumer)}') - self.assertTrue(consumer in node_to_consumers[producer], - f'{node_name(consumer)} not in consumers for {node_name(producer)}') + self.assertTrue( + producer in node_to_producers[consumer], + f"{node_name(producer)} not in producers for {node_name(consumer)}", + ) + self.assertTrue( + consumer in node_to_consumers[producer], + f"{node_name(consumer)} not in consumers for {node_name(producer)}", + ) def check_not_linked(producer, consumer): - self.assertFalse(producer in node_to_producers[consumer], - f'{node_name(producer)} in producers for {node_name(consumer)}') - self.assertFalse(consumer in node_to_consumers[producer], - f'{node_name(consumer)} not in consumers for {node_name(producer)}') + self.assertFalse( + producer in node_to_producers[consumer], f"{node_name(producer)} in producers for {node_name(consumer)}" + ) + self.assertFalse( + consumer in node_to_consumers[producer], + f"{node_name(consumer)} not in consumers for {node_name(producer)}", + ) check_linked(main_graph_add_create_a1, main_graph_add_create_z) # a1 in main graph shouldn't be implicit input to loop as it is shadowed @@ -137,12 +144,13 @@ def check_not_linked(producer, consumer): class TestDynamicDimReplacement(unittest.TestCase): def test_replace_symbolic_dim(self): - ''' + """ Update a model with a single symbolic input dimension. After replacement run shape inferencing to verify that all shapes in the model have fixed sizes. - ''' - model_path = \ - ort_root / 'onnxruntime' / 'test' / 'testdata' / 'CNTK' / 'test_LSTM.tanh.bidirectional' / 'model.onnx' + """ + model_path = ( + ort_root / "onnxruntime" / "test" / "testdata" / "CNTK" / "test_LSTM.tanh.bidirectional" / "model.onnx" + ) model = onnx.load_model(str(model_path)) @@ -150,11 +158,11 @@ def test_replace_symbolic_dim(self): m2 = shape_inference.infer_shapes(model, True) dynamic_inputs, num_dynamic_values = check_shapes(m2.graph) self.assertEqual(len(dynamic_inputs), 1) - self.assertEqual(dynamic_inputs[0].name, 'Input3') + self.assertEqual(dynamic_inputs[0].name, "Input3") self.assertGreater(num_dynamic_values, 0) # update original model - make_dim_param_fixed(model.graph, 'None', 4) + make_dim_param_fixed(model.graph, "None", 4) # and validate the model no longer has dynamic values model = shape_inference.infer_shapes(model, True) @@ -163,11 +171,11 @@ def test_replace_symbolic_dim(self): self.assertEqual(num_dynamic_values, 0) def test_replace_input_shape(self): - ''' + """ Replace the entire shape for an input. This can be used when the model has inputs with unknown dimensions i.e. the dimension has no value and no symbolic name so it's harder to replace. - ''' - model_path = ort_root / 'onnxruntime' / 'test' / 'testdata' / 'gh_issue_9671.onnx' + """ + model_path = ort_root / "onnxruntime" / "test" / "testdata" / "gh_issue_9671.onnx" model = onnx.load_model(str(model_path)) @@ -175,15 +183,15 @@ def test_replace_input_shape(self): m2 = shape_inference.infer_shapes(model, True) dynamic_inputs, num_dynamic_values = check_shapes(m2.graph) self.assertEqual(len(dynamic_inputs), 3) - self.assertEqual(dynamic_inputs[0].name, 'X1') - self.assertEqual(dynamic_inputs[1].name, 'X2') - self.assertEqual(dynamic_inputs[2].name, 'X3') + self.assertEqual(dynamic_inputs[0].name, "X1") + self.assertEqual(dynamic_inputs[1].name, "X2") + self.assertEqual(dynamic_inputs[2].name, "X3") self.assertGreater(num_dynamic_values, 0) # update original model - make_input_shape_fixed(model.graph, 'X1', [2, 2, 4]) - make_input_shape_fixed(model.graph, 'X2', [2, 4]) - make_input_shape_fixed(model.graph, 'X3', [2, 2, 4]) + make_input_shape_fixed(model.graph, "X1", [2, 2, 4]) + make_input_shape_fixed(model.graph, "X2", [2, 4]) + make_input_shape_fixed(model.graph, "X3", [2, 2, 4]) # and validate the model no longer has dynamic values model = shape_inference.infer_shapes(model, True) @@ -194,18 +202,18 @@ def test_replace_input_shape_with_dim_params(self): # replace the input shape where the existing shape also has dim_param entries. # in this case we should also iterate the rest of the model and replace other instances # of the dim_param with the new value. - model_path = ort_root / 'onnxruntime' / 'test' / 'testdata' / 'fuse_mul_1.onnx' + model_path = ort_root / "onnxruntime" / "test" / "testdata" / "fuse_mul_1.onnx" model = onnx.load_model(str(model_path)) m2 = shape_inference.infer_shapes(model, True) dynamic_inputs, num_dynamic_values = check_shapes(m2.graph) self.assertEqual(len(dynamic_inputs), 1) - self.assertEqual(dynamic_inputs[0].name, 'X1') + self.assertEqual(dynamic_inputs[0].name, "X1") # input as well as other values in model have shape ['D'] so check > 1 self.assertGreater(num_dynamic_values, 1) # replace X1's shape of ['D'] -> [4] - make_input_shape_fixed(model.graph, 'X1', [4]) + make_input_shape_fixed(model.graph, "X1", [4]) # validate the model no longer has dynamic values # we don't run shape_inference here as 'D' is the only dimension in the whole model, and we should have @@ -216,14 +224,14 @@ def test_replace_input_shape_with_dim_params(self): self.assertEqual(num_dynamic_values, 0) def test_fix_output_shape(self): - ''' + """ Replace an input shape in a model where that won't update the output shape automatically. Manually fix the output so the usage of the model is clearer. - ''' - model_path = ort_root / 'onnxruntime' / 'test' / 'testdata' / 'transform' / 'fusion' / 'bias_gelu_fusion.onnx' + """ + model_path = ort_root / "onnxruntime" / "test" / "testdata" / "transform" / "fusion" / "bias_gelu_fusion.onnx" model = onnx.load_model(str(model_path)) - make_input_shape_fixed(model.graph, 'A', [2, 2, 3072]) + make_input_shape_fixed(model.graph, "A", [2, 2, 3072]) # symbolic dim names in graph inputs don't match graph outputs so they won't have been updated yet self.assertFalse(is_fixed_size_tensor(model.graph.output[0])) @@ -231,14 +239,27 @@ def test_fix_output_shape(self): self.assertTrue(is_fixed_size_tensor(model.graph.output[0])) def test_invalid_replace_input_shape(self): - model_path = ort_root / 'onnxruntime' / 'test' / 'testdata' / 'sklearn_bin_voting_classifier_soft.onnx' + model_path = ort_root / "onnxruntime" / "test" / "testdata" / "sklearn_bin_voting_classifier_soft.onnx" model = onnx.load_model(str(model_path)) # test some invalid usages - self.assertRaisesRegex(ValueError, "Rank mismatch. Existing:2 Replacement:3", - make_input_shape_fixed, model.graph, 'input', [1, 2, 3]) + self.assertRaisesRegex( + ValueError, + "Rank mismatch. Existing:2 Replacement:3", + make_input_shape_fixed, + model.graph, + "input", + [1, 2, 3], + ) - self.assertRaisesRegex(ValueError, "Can't replace existing fixed size of 2 with 3 for dimension 2", - make_input_shape_fixed, model.graph, 'input', [4, 3]) + self.assertRaisesRegex( + ValueError, + "Can't replace existing fixed size of 2 with 3 for dimension 2", + make_input_shape_fixed, + model.graph, + "input", + [4, 3], + ) - self.assertRaisesRegex(ValueError, "Input X1 was not found in graph inputs.", - make_input_shape_fixed, model.graph, 'X1', [2, 3]) + self.assertRaisesRegex( + ValueError, "Input X1 was not found in graph inputs.", make_input_shape_fixed, model.graph, "X1", [2, 3] + ) diff --git a/tools/python/util/test/test_pytorch_export_helpers.py b/tools/python/util/test/test_pytorch_export_helpers.py index 4e1bf3bcd9f76..e1201f9a63634 100644 --- a/tools/python/util/test/test_pytorch_export_helpers.py +++ b/tools/python/util/test/test_pytorch_export_helpers.py @@ -1,9 +1,10 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. -import torch import unittest +import torch + from ..pytorch_export_helpers import infer_input_info # example usage from /tools/python @@ -32,10 +33,10 @@ def setUpClass(cls): def test_positional(self): # test we can infer the input names from the forward method when positional args are used input_names, inputs_as_tuple = infer_input_info(self._model, self._input, 0, 1) - self.assertEqual(input_names, ['x', 'min', 'max']) + self.assertEqual(input_names, ["x", "min", "max"]) def test_keywords(self): # test that we sort keyword args and the inputs to match the module input_names, inputs_as_tuple = infer_input_info(self._model, self._input, max=1, min=0) - self.assertEqual(input_names, ['x', 'min', 'max']) + self.assertEqual(input_names, ["x", "min", "max"]) self.assertEqual(inputs_as_tuple, (self._input, 0, 1)) diff --git a/tools/python/util/update_onnx_opset.py b/tools/python/util/update_onnx_opset.py index 712f45304550b..0f51c12bb9769 100644 --- a/tools/python/util/update_onnx_opset.py +++ b/tools/python/util/update_onnx_opset.py @@ -10,20 +10,22 @@ def update_onnx_opset_helper(): - parser = argparse.ArgumentParser(f'{os.path.basename(__file__)}:{update_onnx_opset_helper.__name__}', - description=''' + parser = argparse.ArgumentParser( + f"{os.path.basename(__file__)}:{update_onnx_opset_helper.__name__}", + description=""" Update the ONNX opset of the model. New opset must be later than the existing one. If not specified will update to opset 15. - ''') + """, + ) - parser.add_argument('--opset', type=int, required=False, default=15, help="ONNX opset to update to.") - parser.add_argument('input_model', type=pathlib.Path, help='Provide path to ONNX model to update.') - parser.add_argument('output_model', type=pathlib.Path, help='Provide path to write updated ONNX model to.') + parser.add_argument("--opset", type=int, required=False, default=15, help="ONNX opset to update to.") + parser.add_argument("input_model", type=pathlib.Path, help="Provide path to ONNX model to update.") + parser.add_argument("output_model", type=pathlib.Path, help="Provide path to write updated ONNX model to.") args = parser.parse_args() update_onnx_opset(args.input_model, args.opset, args.output_model) -if __name__ == '__main__': +if __name__ == "__main__": update_onnx_opset_helper()