Skip to content

Commit

Permalink
Assorted small updates to toString of Nodes/Graph (pytorch#4243)
Browse files Browse the repository at this point in the history
Summary:
To be used in future PR:

- Print out `Constants` and `Placeholders` used by a `Function`
- `skipUsers`: Do not print out number of users for `Constant` and `Placeholder`
- `skipUsersForStorage`: For Function, plumbed into `skipUsers`
Pull Request resolved: pytorch#4243

Test Plan: Updated tests.

Reviewed By: yinghai

Differential Revision: D20213867

Pulled By: jfix71

fbshipit-source-id: b8497be9394782142ade532bcfaa9c791ef79acb
  • Loading branch information
jfix71 authored and facebook-github-bot committed Mar 3, 2020
1 parent dadf6f1 commit a82b168
Show file tree
Hide file tree
Showing 5 changed files with 53 additions and 16 deletions.
8 changes: 5 additions & 3 deletions include/glow/Graph/Graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -1707,11 +1707,13 @@ class Function final : public Named {
/// Dump a textual representation of the Function into provided output stream.
void dump() const;

/// Dump a textual representation of the Function to std::string.
std::string toString() const;
/// Dump a textual representation of the Function to std::string. If
/// \p skipUsersForStorage then user counts for Storage will not be dumped.
std::string toString(bool skipUsersForStorage = false) const;

/// Dump a textual representation of the Function into default output stream.
void dump(llvm::raw_ostream &os) const;
/// If \p skipUsersForStorage then user counts for Storage will not be dumped.
void dump(llvm::raw_ostream &os, bool skipUsersForStorage = false) const;

/// Dump a dotty graph that depicts the function into a file.
/// \returns full path to the file.
Expand Down
4 changes: 2 additions & 2 deletions include/glow/Graph/Nodes.h
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ class Constant : public Storage {

bool isDataParallel() const { return false; }

std::string getDebugDesc() const;
std::string getDebugDesc(bool skipUsers = false) const;

llvm::hash_code getHash() const;

Expand Down Expand Up @@ -187,7 +187,7 @@ class Placeholder : public Storage {

bool isDataParallel() const { return false; }

std::string getDebugDesc() const;
std::string getDebugDesc(bool skipUsers = false) const;

llvm::hash_code getHash() const;
};
Expand Down
12 changes: 9 additions & 3 deletions lib/Graph/Graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4173,14 +4173,14 @@ void Function::dump() const {
}
}

std::string Function::toString() const {
std::string Function::toString(bool skipUsersForStorage) const {
std::string storage;
llvm::raw_string_ostream os(storage);
dump(os);
dump(os, skipUsersForStorage);
return os.str();
}

void Function::dump(llvm::raw_ostream &os) const {
void Function::dump(llvm::raw_ostream &os, bool skipUsersForStorage) const {
os << "Graph structure " << getName() << ":\n";
std::set<const Node *, SortNamed> sorted;
for (const Node &n : nodes_) {
Expand All @@ -4189,6 +4189,12 @@ void Function::dump(llvm::raw_ostream &os) const {
for (auto *n : sorted) {
os << n->getDebugDesc();
}
for (auto *C : getNamedSorted(findConstants())) {
os << C->getDebugDesc(skipUsersForStorage);
}
for (auto *P : getNamedSorted(findPlaceholders())) {
os << P->getDebugDesc(skipUsersForStorage);
}
}

/// We can't use NodeWalker here, because it ignores result indices, which
Expand Down
14 changes: 9 additions & 5 deletions lib/Graph/Nodes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,22 +83,26 @@ Node *Storage::clone() const { llvm_unreachable("Storage can't be cloned."); }
// Debug description methods
//===----------------------------------------------------------------------===//

std::string Constant::getDebugDesc() const {
std::string Constant::getDebugDesc(bool skipUsers) const {
DescriptionBuilder db(getKindName());
db.addParam("name", quote(getName()))
.addParam("layout", getLayout())
.addParam("output", *getType())
.addParam("users", getNumUsers());
.addParam("output", *getType());
if (!skipUsers) {
db.addParam("users", getNumUsers());
}
return db;
}

std::string Placeholder::getDebugDesc() const {
std::string Placeholder::getDebugDesc(bool skipUsers) const {
DescriptionBuilder db(getKindName());
db.addParam("name", quote(getName()))
.addParam("layout", getLayout())
.addParam("output", *getType())
.addParam("users", getNumUsers())
.addParam("trainable", isTraining());
if (!skipUsers) {
db.addParam("users", getNumUsers());
}
return db;
}

Expand Down
31 changes: 28 additions & 3 deletions tests/unittests/GraphTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1867,8 +1867,8 @@ TEST(Graph, testDumpStructure) {
name : "input"
layout : *
output : float<4 x 320 x 200 x 100 x 3>
users : 0
trainable : 1
users : 0
)";
EXPECT_EQ(mesN, expectMes);
EXPECT_EQ(mesN, osN1.str());
Expand All @@ -1893,13 +1893,38 @@ K : 3
users : 0
Values : float<10 x 3>
Indices : index64<10 x 3>
Placeholder
name : "input__1"
layout : *
output : float<10 x 10>
trainable : 1
users : 1
)";
EXPECT_EQ(mesF, expectMesF);
EXPECT_EQ(mesF, osF1.str());
std::string storageF2;
llvm::raw_string_ostream osF2(storageF2);
osF2 << F2;
EXPECT_EQ(mesF, osF2.str());
storageF1.clear();
F2->dump(osF1, /* skipUsersForStorage */ true);
mesF = F2->toString(/* skipUsersForStorage */ true);
expectMesF = R"(Graph structure F2:
TopK
name : topk
Input : float<10 x 10>
K : 3
users : 0
Values : float<10 x 3>
Indices : index64<10 x 3>
Placeholder
name : "input__1"
layout : *
output : float<10 x 10>
trainable : 1
)";
EXPECT_EQ(mesF, expectMesF);
EXPECT_EQ(mesF, osF1.str());
// Test Module
MD.createConstant(ElemKind::FloatTy, {1, 1}, "dummy");
std::string storageM1;
Expand All @@ -1917,15 +1942,15 @@ Placeholder
name : "input__1"
layout : *
output : float<10 x 10>
users : 1
trainable : 1
users : 1
Placeholder
name : "input"
layout : *
output : float<4 x 320 x 200 x 100 x 3>
users : 0
trainable : 1
users : 0
Function : F2
Function : F
Expand Down

0 comments on commit a82b168

Please sign in to comment.