Skip to content

Commit

Permalink
Add more meta info for custom exporting (pytorch#4272)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#4272

During exporting:
- Add type attributes for inputs, prefixed with `i#_`
- Always dump shape info for outputs even if not needed by the importer

Differential Revision: D20313228

fbshipit-source-id: 562af82961c142587a810e594593f88e98b186d3
  • Loading branch information
jfix71 authored and facebook-github-bot committed Mar 7, 2020
1 parent 3777ddd commit 7f24426
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 16 deletions.
10 changes: 6 additions & 4 deletions include/glow/Graph/Nodes.h
Original file line number Diff line number Diff line change
Expand Up @@ -366,10 +366,12 @@ constexpr char qScaleSignifier[] = "qScale";
constexpr char qOffsetSignifier[] = "qOffset";
constexpr char shapeSignifier[] = "shape";

/// \returns the string ID for a type attribute property for a specific \p resNo
/// and \p signifier, e.g. to retrieve result number 0's shape.
inline std::string getTypeAttrID(unsigned resNo, const std::string &signifier) {
return "o" + std::to_string(resNo) + "_" + signifier;
/// \returns the string ID for a type attribute property for a specific \p ioNum
/// and \p signifier and whether \p isInput. E.g. to retrieve result number 0's
/// shape, you'd pass `(0, "shape", false)`.
inline std::string getTypeAttrID(unsigned ioNum, const std::string &signifier,
bool isInput = false) {
return (isInput ? "i" : "o") + std::to_string(ioNum) + "_" + signifier;
}

} // namespace glow
Expand Down
21 changes: 12 additions & 9 deletions lib/Exporter/ONNXModelWriter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -118,28 +118,31 @@ void addValueAttribute(ONNX_NAMESPACE::NodeProto *proto,
T>::assign(attr, container);
}

/// Add the type attributes from \p NV to \p proto. This includes the ElemKind,
/// the Shape, and scale/offset if ElemKind is quantized. Note that the result
/// name is prefixed onto the specific attribute being appended, as some ops
/// have multiple outputs and so this allows differentiating between them.
void addTypeAttributes(ONNX_NAMESPACE::NodeProto *proto, NodeValue NV) {
/// Add the type attributes from the \p ioNum number input or output (depending
/// on \p isInput) of \p N to \p proto. This includes the ElemKind, the Shape,
/// and scale/offset if ElemKind is quantized. Note that 'i' or 'o' along with
/// \p ioNum is prefixed onto the specific attribute being appended, as ops may
/// have multiple inputs/outputs.
void addTypeAttributes(ONNX_NAMESPACE::NodeProto *proto, const Node *N,
unsigned ioNum, bool isInput) {
NodeValue NV = isInput ? N->getNthInput(ioNum) : N->getNthResult(ioNum);
const TypeRef ty = NV.getType();

// Add ElemKind.
auto *elemKindAttr = proto->add_attribute();
elemKindAttr->set_name(getTypeAttrID(NV.getResNo(), elemKindSignifier));
elemKindAttr->set_name(getTypeAttrID(ioNum, elemKindSignifier, isInput));
AttributeAssigner<false, false, llvm::StringRef>::assign(
elemKindAttr, ty->getElementName());

// Add Shape.
addValueAttribute(proto, getTypeAttrID(NV.getResNo(), shapeSignifier),
addValueAttribute(proto, getTypeAttrID(ioNum, shapeSignifier, isInput),
NV.dims());

// Write out scale/offset if quantized ElemKind.
if (isQuantizedElemKind(ty->getElementType())) {
addValueAttribute(proto, getTypeAttrID(NV.getResNo(), qScaleSignifier),
addValueAttribute(proto, getTypeAttrID(ioNum, qScaleSignifier, isInput),
ty->getScale());
addValueAttribute(proto, getTypeAttrID(NV.getResNo(), qOffsetSignifier),
addValueAttribute(proto, getTypeAttrID(ioNum, qOffsetSignifier, isInput),
ty->getOffset());
}
}
Expand Down
12 changes: 9 additions & 3 deletions tools/ClassGen/NodeBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -677,15 +677,21 @@ void NodeBuilder::emitExportMethods(std::ostream &os) const {
for (const auto &op : nodeInputs_) {
os << " opProto->add_input(N__->get" << op
<< "().generateNodeOutputName(/* stripResNoFor0thInput */ true));\n";
// Note: Add each input's type attributes so that other tools have easy
// visibility into types. This info may go ignored by the importer.
os << " addTypeAttributes(opProto, N__, " << name_ << "Node::" << op
<< "Idx, /* isInput */ true);\n";
}

// Add all of the node's outputs.
for (const auto &op : nodeOutputs_) {
os << " opProto->add_output(N__->get" << op.second
<< "().generateNodeOutputName(/* stripResNoFor0thInput */ true));\n";
if (hasCtorTypeParams(op.second)) {
os << " addTypeAttributes(opProto, N__->get" << op.second << "());\n";
}
// Note: export the type attributes even if not needed by the importer, so
// that other tools have easy visibility into types. This info may go
// ignored by the importer.
os << " addTypeAttributes(opProto, N__, " << name_ << "Node::" << op.second
<< "Idx, /* isInput */ false);\n";
}

// Add any members the node has.
Expand Down

0 comments on commit 7f24426

Please sign in to comment.