Skip to content

Commit

Permalink
Use opset 13
Browse files Browse the repository at this point in the history
  • Loading branch information
sun-xiangyu committed Oct 22, 2024
1 parent 34055b9 commit 5b1c404
Showing 1 changed file with 79 additions and 73 deletions.
152 changes: 79 additions & 73 deletions ppq/api/espdl_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def get_target_platform(target: str, num_of_bits: int = 8, float: bool = False):
platform = TargetPlatform.ESPDL_S3_INT8
else:
platform = TargetPlatform.FP32
print("Warning: Do not support num_of_bits:{num_of_bits}, will change to TargetPlatform.FP32")
logger.warning(f"Do not support num_of_bits:{num_of_bits}, will change to TargetPlatform.FP32")

return platform

Expand Down Expand Up @@ -171,88 +171,94 @@ def espdl_quantize_onnx(
collate_fn = partial(collate_fn_template, dtype=input_dtype, device=device)

ppq_graph = load_onnx_graph(onnx_import_file=onnx_import_file)
quantizer = PFL.Quantizer(platform=target_platform, graph=ppq_graph)
dispatching_table = PFL.Dispatcher(
graph=ppq_graph, method=dispatching_method
).dispatch(quantizer.quant_operation_types)

# Override dispatching result
if dispatching_override is not None:
for opname, platform in dispatching_override.items():
if opname not in ppq_graph.operations:
continue
assert isinstance(platform, int) or isinstance(platform, TargetPlatform), (
f"Your dispatching_override table contains a invalid setting of operation {opname}, "
"All platform setting given in dispatching_override table is expected given as int or TargetPlatform, "
f"however {type(platform)} was given."
)
dispatching_table[opname] = TargetPlatform(platform)

for opname, platform in dispatching_table.items():
if platform == TargetPlatform.UNSPECIFIED:
dispatching_table[opname] = target_platform

if inputs:
dummy_inputs = inputs
else:
dummy_inputs = get_random_inputs(input_shape, input_dtype, device)

if target_platform != TargetPlatform.FP32:
quantizer = PFL.Quantizer(platform=target_platform, graph=ppq_graph)
dispatching_table = PFL.Dispatcher(
graph=ppq_graph, method=dispatching_method
).dispatch(quantizer.quant_operation_types)

# initial quantizer
for op in ppq_graph.operations.values():
quantizer.quantize_operation(
op_name=op.name, platform=dispatching_table[op.name]
)
executor = TorchExecutor(graph=ppq_graph, device=device)
executor.tracing_operation_meta(inputs=dummy_inputs)
# Override dispatching result
if dispatching_override is not None:
for opname, platform in dispatching_override.items():
if opname not in ppq_graph.operations:
continue
assert isinstance(platform, int) or isinstance(platform, TargetPlatform), (
f"Your dispatching_override table contains a invalid setting of operation {opname}, "
"All platform setting given in dispatching_override table is expected given as int or TargetPlatform, "
f"however {type(platform)} was given."
)
dispatching_table[opname] = TargetPlatform(platform)

# Create the optimization pipeline,
pipeline = PFL.Pipeline(
[
QuantizeSimplifyPass(),
QuantizeFusionPass(activation_type=quantizer.activation_fusion_types),
ParameterQuantizePass(),
RuntimeCalibrationPass(method="kl"),
PassiveParameterQuantizePass(
clip_visiblity=QuantizationVisibility.EXPORT_WHEN_ACTIVE
),
QuantAlignmentPass(elementwise_alignment="Align to Output"),
# LearnedStepSizePass(steps=500, block_size=5)
]
)
logger.info(
f"Calibration dataset samples: {len(calib_dataloader.dataset)}, len(Calibrate iter): {len(calib_dataloader)}"
)
pipeline.optimize(
calib_steps=calib_steps,
collate_fn=collate_fn,
graph=ppq_graph,
dataloader=calib_dataloader,
executor=executor,
)
if verbose:
logger.info(quantizer.report())
logger.info("Network Quantization Finished.")
for opname, platform in dispatching_table.items():
if platform == TargetPlatform.UNSPECIFIED:
dispatching_table[opname] = target_platform


# ------------------------------------------------------------
#
# 2: Analyze Quantization Errors.
#
# ------------------------------------------------------------
if error_report:
graphwise_error_analyse(
graph=ppq_graph,
running_device=device,
collate_fn=collate_fn,
dataloader=calib_dataloader,
)
# initial quantizer
for op in ppq_graph.operations.values():
quantizer.quantize_operation(
op_name=op.name, platform=dispatching_table[op.name]
)
executor = TorchExecutor(graph=ppq_graph, device=device)
executor.tracing_operation_meta(inputs=dummy_inputs)

layerwise_error_analyse(
graph=ppq_graph,
running_device=device,
# Create the optimization pipeline,
pipeline = PFL.Pipeline(
[
QuantizeSimplifyPass(),
QuantizeFusionPass(activation_type=quantizer.activation_fusion_types),
ParameterQuantizePass(),
RuntimeCalibrationPass(method="kl"),
PassiveParameterQuantizePass(
clip_visiblity=QuantizationVisibility.EXPORT_WHEN_ACTIVE
),
QuantAlignmentPass(elementwise_alignment="Align to Output"),
# LearnedStepSizePass(steps=500, block_size=5)
]
)
logger.info(
f"Calibration dataset samples: {len(calib_dataloader.dataset)}, len(Calibrate iter): {len(calib_dataloader)}"
)
pipeline.optimize(
calib_steps=calib_steps,
collate_fn=collate_fn,
graph=ppq_graph,
dataloader=calib_dataloader,
executor=executor,
)
if verbose:
logger.info(quantizer.report())
logger.info("Network Quantization Finished.")


# ------------------------------------------------------------
#
# 2: Analyze Quantization Errors.
#
# ------------------------------------------------------------
if error_report:
graphwise_error_analyse(
graph=ppq_graph,
running_device=device,
collate_fn=collate_fn,
dataloader=calib_dataloader,
)

layerwise_error_analyse(
graph=ppq_graph,
running_device=device,
collate_fn=collate_fn,
dataloader=calib_dataloader,
)
else:
# support TargetPlatform.FP32
executor = TorchExecutor(graph=ppq_graph, device=device)
executor.tracing_operation_meta(inputs=dummy_inputs)
target_platform = TargetPlatform.ESPDL_INT8

# ------------------------------------------------------------
#
Expand Down Expand Up @@ -345,7 +351,7 @@ def espdl_quantize_torch(
]
),
f=onnx_file_path,
opset_version=11,
opset_version=13,
do_constant_folding=True,
)

Expand Down

0 comments on commit 5b1c404

Please sign in to comment.