Skip to content

Commit

Permalink
Fix the precision issue when slice/add are connected.
Browse files Browse the repository at this point in the history
  • Loading branch information
BlueSkyB committed Jan 2, 2025
1 parent 34b6d8a commit f2ffc61
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 9 deletions.
10 changes: 5 additions & 5 deletions ppq/api/espdl_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,11 @@ def espdl_quantize_onnx(
BaseGraph: The Quantized Graph, containing all information needed for backend execution
"""

model = onnx.load(onnx_import_file)
model_sim, check = simplify(model)
if check:
onnx.save(model_sim, onnx_import_file)

export_path = os.path.dirname(os.path.abspath(espdl_export_file))
os.makedirs(export_path, exist_ok=True)

Expand Down Expand Up @@ -357,11 +362,6 @@ def espdl_quantize_torch(
opset_version=13,
do_constant_folding=True,
)

model = onnx.load(onnx_file_path)
model_sim, check = simplify(model)
if check:
onnx.save(model_sim, onnx_file_path)

# step2: quantize onnx model and export espdl model
return espdl_quantize_onnx(
Expand Down
6 changes: 3 additions & 3 deletions ppq/executor/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -527,9 +527,9 @@ def __forward(
for input, config in zip(inputs, input_configs)])

# PATCH 20220208
for idx, var in enumerate(operation.inputs):
if var.name in output_names:
result_collector[output_names.index(var.name)] = inputs[idx]
# for idx, var in enumerate(operation.inputs):
# if var.name in output_names:
# result_collector[output_names.index(var.name)] = inputs[idx]

# invoking pre-forward hook
if operation_runtime_hook is not None:
Expand Down
2 changes: 1 addition & 1 deletion ppq/quantization/optim/refine.py
Original file line number Diff line number Diff line change
Expand Up @@ -545,7 +545,7 @@ def optimize(self, graph: BaseGraph, **kwargs) -> None:
for up_op in graph.get_upstream_operations(operation):
if not isinstance(up_op, QuantableOperation): continue

if len(graph.get_downstream_operations(up_op)) != 1 and not self.force_overlap: continue
if (up_op.type in PASSIVE_OPERATIONS or len(graph.get_downstream_operations(up_op)) != 1) and not self.force_overlap: continue
# for cfg, var in up_op.config_with_variable:
for cfg, var in zip(up_op.config.output_quantization_config, up_op.outputs):
if operation in var.dest_ops:
Expand Down

0 comments on commit f2ffc61

Please sign in to comment.