Skip to content

Commit

Permalink
Add tensor attribute support.
Browse files Browse the repository at this point in the history
  • Loading branch information
BlueSkyB committed Jan 23, 2025
1 parent 9b51d64 commit d899dca
Showing 1 changed file with 22 additions and 1 deletion.
23 changes: 22 additions & 1 deletion ppq/parser/espdl/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
cast,
)

import onnx
from onnx import numpy_helper
import flatbuffers
import numpy as np
from cryptography.hazmat.backends import default_backend
Expand Down Expand Up @@ -450,6 +452,7 @@ def make_attribute(
attr_i = None
attr_f = None
attr_s = None
attr_t = None
attr_ints = None
attr_floats = None
attr_strings = None
Expand All @@ -462,6 +465,21 @@ def make_attribute(
# Encode strings into utf-8
value = _to_bytes(value)
attr_s = builder.CreateByteVector(value)
elif isinstance(value, onnx.TensorProto):
value_data = numpy_helper.to_array(value)
if (value.data_type >= TensorDataType.TensorDataType.UNDEFINED
and value.data_type <= TensorDataType.TensorDataType.UINT64):
attr_t = make_tensor(
name = value.name,
data_type = value.data_type,
dims = value.dims,
vals = value_data,
raw = True,
doc_string = value.doc_string,
)
else:
raise ValueError(f"Don't support type {onnx.TensorProto.DataType.Name(value.data_type)}")

# Iterable cases
elif isinstance(value, collections.abc.Iterable):
value = list(value)
Expand Down Expand Up @@ -515,6 +533,9 @@ def make_attribute(
if attr_s is not None:
Attribute.AddS(builder, attr_s)
Attribute.AddAttrType(builder, AttributeType.AttributeType().STRING)
if attr_t is not None:
Attribute.AddT(builder, attr_t)
Attribute.AddAttrType(builder, AttributeType.AttributeType().TENSOR)
if attr_ints is not None:
Attribute.AddInts(builder, attr_ints)
Attribute.AddAttrType(builder, AttributeType.AttributeType().INTS)
Expand Down Expand Up @@ -701,7 +722,7 @@ def str_list(str_elem: Callable[[_T], str], xs: Sequence[_T]) -> str:
# TODO: Bit nervous about Python 2 / Python 3 determinism implications
content.append(repr(_sanitize_str(attr.SAsNumpy().tobytes())))
elif attr_type == AttributeType.AttributeType.TENSOR:
if attr.t().dimsLength() > 0:
if attr.T().DimsLength() > 0:
content.append("<Tensor>")
else:
# special case to print scalars
Expand Down

0 comments on commit d899dca

Please sign in to comment.