Skip to content

Commit

Permalink
Add device assignment export to graph dumpers.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 173472156
  • Loading branch information
tensorflower-gardener committed Oct 25, 2017
1 parent b4e09b4 commit a80b929
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 1 deletion.
7 changes: 6 additions & 1 deletion tensorflow/compiler/xla/service/hlo_graph_dumper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -926,6 +926,9 @@ string HloDotDumper::GetInstructionNodeExtraInfo(const HloInstruction* instr) {
[](int64 stride) { return stride == 1; })
? ""
: StrCat("stride=", VectorString(instr->slice_strides()));
case HloOpcode::kSend:
case HloOpcode::kRecv:
return StrCat("channel_id=", instr->channel_id());
default:
return "";
}
Expand All @@ -935,7 +938,9 @@ string HloDotDumper::GetInstructionNodeExtraInfo(const HloInstruction* instr) {
if (!opcode_specific_info.empty()) {
lines.push_back(opcode_specific_info);
}

if (instr->device_assignment().has_device()) {
lines.push_back(StrCat("device=", instr->device_assignment().device()));
}
// Show the shape and layout of the instruction, unless it's an inlined fusion
// node -- there the shape and layout is present in the output node.
if (instr->opcode() != HloOpcode::kFusion ||
Expand Down
10 changes: 10 additions & 0 deletions tensorflow/compiler/xla/service/hlo_tfgraph_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ TensorShapeProto GetTensorShape(const HloInstruction* instruction) {
return tensor_shape;
}

string GetDeviceName(int device) { return StrCat("/device/XLA:", device); }

} // namespace

void CleanNodeName(string* name) {
Expand Down Expand Up @@ -178,6 +180,10 @@ void HloTfGraphBuilder::SetNodeAttrs(const HloInstruction* instruction,
case HloOpcode::kCustomCall:
attrs["custom_call_target"].set_s(instruction->custom_call_target());
break;
case HloOpcode::kSend:
case HloOpcode::kRecv:
attrs["channel_id"].set_i(instruction->channel_id());
break;
default:
break;
}
Expand All @@ -192,6 +198,10 @@ Status HloTfGraphBuilder::AddInstruction(const HloInstruction* instruction) {
NodeDef* node_def = graph_def_.add_node();
node_def->set_name(GetNodeNameForInstruction(instruction));
node_def->set_op(GetOpDefName(instruction));
if (instruction->device_assignment().has_device()) {
node_def->set_device(
GetDeviceName(instruction->device_assignment().device()));
}
SetNodeAttrs(instruction, node_def);
if (instruction->opcode() == HloOpcode::kFusion) {
for (auto* fused_instruction : instruction->fused_instructions()) {
Expand Down

0 comments on commit a80b929

Please sign in to comment.