Skip to content

Commit

Permalink
reformat code
Browse files Browse the repository at this point in the history
  • Loading branch information
inisis committed Jun 5, 2024
1 parent cd7f6ae commit b2f226a
Show file tree
Hide file tree
Showing 10 changed files with 223 additions and 220 deletions.
6 changes: 3 additions & 3 deletions onnxslim/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import os
import warnings

from .cli import slim
from .core.optimizer import DEFAULT_FUSION_PATTERNS
from .version import __version__
from onnxslim.cli import slim
from onnxslim.core.optimizer import DEFAULT_FUSION_PATTERNS
from onnxslim.version import __version__

if os.path.dirname(os.path.realpath(__file__)) == os.path.join(os.path.realpath(os.getcwd()), "onnxslim"):
message = (
Expand Down
2 changes: 1 addition & 1 deletion onnxslim/__main__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .cli._main import main
from onnxslim.cli._main import main

if __name__ == "__main__":
main()
2 changes: 1 addition & 1 deletion onnxslim/cli/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from ._main import main, slim
from onnxslim.cli._main import main, slim
22 changes: 11 additions & 11 deletions onnxslim/cli/_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import onnx

from onnxslim.utils.utils import logger
from onnxslim.utils import logger


def slim(
Expand Down Expand Up @@ -61,25 +61,25 @@ def slim(
from pathlib import Path

from onnxslim.core.slim import (
check_onnx,
check_point,
check_result,
convert_data_format,
freeze,
init_logging,
input_shape_modification,
model_save_as_external_data,
optimize,
output_modification,
save,
shape_infer,
summarize_model,
optimize,
freeze,
)

from ..utils.utils import (
from onnxslim.utils import (
dump_model_info_to_disk,
onnxruntime_inference,
print_model_info_as_table,
model_save_as_external_data,
summarize_model,
init_logging,
check_result,
check_onnx,
check_point,
save,
)

init_logging(verbose)
Expand Down
2 changes: 1 addition & 1 deletion onnxslim/core/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from onnxslim.onnx_graphsurgeon.exporters.onnx_exporter import dtype_to_onnx
from onnxslim.onnx_graphsurgeon.ir.graph import Graph
from onnxslim.onnx_graphsurgeon.ir.tensor import Constant, Variable
from onnxslim.utils.utils import logger
from onnxslim.utils import logger

DEFAULT_FUSION_PATTERNS = OrderedDict()

Expand Down
204 changes: 3 additions & 201 deletions onnxslim/core/slim.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,140 +10,21 @@

import onnxslim.onnx_graphsurgeon as gs
from onnxslim.onnx_graphsurgeon.ir.tensor import Constant
from onnxslim.onnx_graphsurgeon.logger.logger import G_LOGGER

from ..utils.utils import (
from onnxslim.utils import (
dump_model_info_to_disk,
gen_onnxruntime_input_data,
logger,
onnxruntime_inference,
print_model_info_as_table,
)
from .optimizer import delete_node, optimize_model
from .symbolic_shape_infer import SymbolicShapeInference
from onnxslim.core.optimizer import delete_node, optimize_model
from onnxslim.core.symbolic_shape_infer import SymbolicShapeInference

DEBUG = bool(os.getenv("ONNXSLIM_DEBUG"))
AUTO_MERGE = True if os.getenv("ONNXSLIM_AUTO_MERGE") is None else bool(int(os.getenv("ONNXSLIM_AUTO_MERGE")))


def init_logging(verbose=False):
"""Configure the logging settings for the application based on the verbosity level."""
for handler in logging.root.handlers[:]:
logging.root.removeHandler(handler)

if verbose: # DEBUG
logging.basicConfig(
level=logging.DEBUG,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
handlers=[logging.StreamHandler(sys.stderr)],
)
G_LOGGER.severity = logging.DEBUG
else: # ERROR
logging.basicConfig(
level=logging.ERROR,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
handlers=[logging.StreamHandler(sys.stderr)],
)
G_LOGGER.severity = logging.ERROR

G_LOGGER.colors = False

import onnxruntime as ort

ort.set_default_logger_severity(3)


def get_opset(model: onnx.ModelProto) -> int:
try:
for importer in model.opset_import:
if importer.domain == "" or importer.domain == "ai.onnx":
return importer.version

return None
except:
return None


def summarize_model(model: onnx.ModelProto) -> Dict:
logger.debug("Start summarizing model.")
model_info = {}

model_size = model.ByteSize()
model_info["model_size"] = model_size

op_info = {}
op_type_counts = {}

def get_tensor_dtype_shape(tensor):
"""Extract the data type and shape of an ONNX tensor."""
type_str = onnx.mapping.TENSOR_TYPE_TO_NP_TYPE.get(tensor.type.tensor_type.elem_type, "Unknown")
shape = None
if tensor.type.tensor_type.HasField("shape"):
shape = []
for dim in tensor.type.tensor_type.shape.dim:
if dim.HasField("dim_param"):
shape.append(dim.dim_param)
elif dim.HasField("dim_value"):
shape.append(dim.dim_value)
else:
shape.append(None)

return (type_str, shape)

def get_shape(inputs: onnx.ModelProto) -> Dict[str, List[int]]:
op_shape_info = {}
for input in inputs:
type_str, shape = get_tensor_dtype_shape(input)
if shape:
op_shape_info[input.name] = str(type_str) + ": " + str(tuple(shape))
else:
op_shape_info[input.name] = str(type_str) + ": None"

return op_shape_info

value_info_dict = {value_info.name: value_info for value_info in model.graph.value_info}

for node in model.graph.node:
op_type = node.op_type
if op_type in op_type_counts:
op_type_counts[op_type] += 1
else:
op_type_counts[op_type] = 1

for output in node.output:
shapes = []
if output in value_info_dict:
tensor = value_info_dict[output]
type_str, shape = get_tensor_dtype_shape(tensor)
shapes.append([type_str, shape])

op_info[node.name] = [node.op_type, shapes]

model_info["op_set"] = str(get_opset(model))
model_info["op_info"] = op_info
model_info["op_type_counts"] = op_type_counts

model_info["op_input_info"] = get_shape(model.graph.input)
model_info["op_output_info"] = get_shape(model.graph.output)

logger.debug("Finish summarizing model.")
return model_info


def model_save_as_external_data(model: onnx.ModelProto, model_path: str):
"""Save an ONNX model with tensor data as an external file."""
location = os.path.basename(model_path) + ".data"
if os.path.exists(location):
os.remove(location)
onnx.save(
model,
model_path,
save_as_external_data=True,
all_tensors_to_one_file=True,
location=location,
)


def input_shape_modification(model: onnx.ModelProto, input_shapes: str) -> onnx.ModelProto:
if not input_shapes:
return
Expand Down Expand Up @@ -205,14 +86,6 @@ def output_modification(model: onnx.ModelProto, outputs: str) -> onnx.ModelProto
return model


def check_onnx(model: onnx.ModelProto, model_check_inputs=None):
"""Validates an ONNX model by generating input data and performing inference to check outputs."""
input_data_dict = gen_onnxruntime_input_data(model, model_check_inputs)
raw_onnx_output = onnxruntime_inference(model, input_data_dict)

return input_data_dict, raw_onnx_output


def shape_infer(model: onnx.ModelProto):
"""Infer tensor shapes in an ONNX model using symbolic and static shape inference techniques."""
logger.debug("Start shape inference.")
Expand Down Expand Up @@ -253,24 +126,6 @@ def optimize(model: onnx.ModelProto, skip_fusion_patterns: str = None):
return model


def check_point(model: onnx.ModelProto):
"""Imports an ONNX model checkpoint into a Graphsurgeon graph representation."""
graph_check_point = gs.import_onnx(model)

return graph_check_point


def is_converged(model: onnx.ModelProto, graph_ckpt, iter: int) -> bool:
logger.debug(f"optimization iter: {iter}")
graph = gs.import_onnx(model)
if graph == graph_ckpt:
print(f"converged at iter: {iter}")
return None
else:
graph_ckpt = graph
return False


def convert_data_format(model: onnx.ModelProto, dtype: str) -> onnx.ModelProto:
if dtype == "fp16":
from onnxconverter_common import float16
Expand All @@ -297,59 +152,6 @@ def convert_data_format(model: onnx.ModelProto, dtype: str) -> onnx.ModelProto:
return model


def save(model: onnx.ModelProto, model_path: str, model_check: bool = False):
"""Save an ONNX model to a specified path, with optional model checking for validity."""
if model_check:
try:
checker.check_model(model)
except ValueError:
logger.warning("Model too large and cannot be checked.")

if model_path:
if (
model.ByteSize() <= checker.MAXIMUM_PROTOBUF
): # model larger than 2GB can be saved, but compiler like trtexec won't parse it
onnx.save(model, model_path)
else:
import os

location = os.path.basename(model_path) + ".data"
if os.path.exists(location):
os.remove(location)
onnx.save(
model,
model_path,
save_as_external_data=True,
all_tensors_to_one_file=True,
location=location,
)
logger.debug("Model too large and saved as external data automatically.")


def check_result(raw_onnx_output, slimmed_onnx_output):
"""Verify the consistency of outputs between the raw and slimmed ONNX models, logging warnings if discrepancies are
detected.
"""
if set(raw_onnx_output.keys()) != set(slimmed_onnx_output.keys()):
logger.warning("Model output mismatch after slimming.")
logger.warning("Raw model output keys: {}".format(raw_onnx_output.keys()))
logger.warning("Slimmed model output keys: {}".format(slimmed_onnx_output.keys()))
logger.warning("Please check the model carefully.")
return
else:
for key in raw_onnx_output.keys():
if not np.allclose(
raw_onnx_output[key],
slimmed_onnx_output[key],
rtol=1e-03,
atol=1e-04,
equal_nan=True,
):
logger.warning("Model output mismatch after slimming.")
logger.warning("Please check the model carefully.")
return


def freeze(model: onnx.ModelProto):
"""Freeze the input layers of an ONNX model by removing the initializers from the input graph."""
inputs = model.graph.input
Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
Loading

0 comments on commit b2f226a

Please sign in to comment.