Skip to content

Commit

Permalink
Merge pull request tensorflow#17899 from benoitsteiner/branch_189913309
Browse files Browse the repository at this point in the history
Branch 189913309
  • Loading branch information
sb2nov authored Mar 21, 2018
2 parents 4e108ef + a926014 commit 00c90e6
Show file tree
Hide file tree
Showing 60 changed files with 1,952 additions and 540 deletions.
22 changes: 22 additions & 0 deletions tensorflow/compiler/xla/client/executable_build_options.cc
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,28 @@ ExecutableBuildOptions::generate_hlo_graph() const {
return generate_hlo_graph_;
}

ExecutableBuildOptions& ExecutableBuildOptions::set_dump_optimized_hlo_proto_to(
tensorflow::StringPiece dirpath) {
dump_optimized_hlo_proto_to_ = dirpath.ToString();
return *this;
}

const tensorflow::gtl::optional<string>&
ExecutableBuildOptions::dump_optimized_hlo_proto_to() const {
return dump_optimized_hlo_proto_to_;
}

ExecutableBuildOptions& ExecutableBuildOptions::set_dump_per_pass_hlo_proto_to(
tensorflow::StringPiece dirpath) {
dump_per_pass_hlo_proto_to_ = dirpath.ToString();
return *this;
}

const tensorflow::gtl::optional<string>&
ExecutableBuildOptions::dump_per_pass_hlo_proto_to() const {
return dump_per_pass_hlo_proto_to_;
}

ExecutableBuildOptions& ExecutableBuildOptions::set_hlo_profile(bool enabled) {
hlo_profile_ = enabled;
return *this;
Expand Down
15 changes: 15 additions & 0 deletions tensorflow/compiler/xla/client/executable_build_options.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ limitations under the License.

#include "tensorflow/compiler/xla/service/device_memory_allocator.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/gtl/optional.h"

namespace xla {
Expand Down Expand Up @@ -57,6 +58,18 @@ class ExecutableBuildOptions {
ExecutableBuildOptions& set_generate_hlo_graph(string regex);
const tensorflow::gtl::optional<string>& generate_hlo_graph() const;

// If set, specifies a dirpath to dump the end-of-optimization-pipeline HLO
// protobuf to (as in DebugOptions).
ExecutableBuildOptions& set_dump_optimized_hlo_proto_to(
tensorflow::StringPiece dirpath);
const tensorflow::gtl::optional<string>& dump_optimized_hlo_proto_to() const;

// If set, specifies a dirpath to dump the per-pass-in-pipeline HLO protobufs
// to (as in DebugOptions).
ExecutableBuildOptions& set_dump_per_pass_hlo_proto_to(
tensorflow::StringPiece dirpath);
const tensorflow::gtl::optional<string>& dump_per_pass_hlo_proto_to() const;

// If set, specifies that we should record an HLO profile during execution and
// log it after execution (as in DebugOptions).
ExecutableBuildOptions& set_hlo_profile(bool enabled);
Expand All @@ -72,6 +85,8 @@ class ExecutableBuildOptions {
Shape result_layout_;
bool result_layout_set_ = false;
tensorflow::gtl::optional<string> generate_hlo_graph_;
tensorflow::gtl::optional<string> dump_optimized_hlo_proto_to_;
tensorflow::gtl::optional<string> dump_per_pass_hlo_proto_to_;
DeviceMemoryAllocator* device_allocator_ = nullptr;
};

Expand Down
72 changes: 34 additions & 38 deletions tensorflow/compiler/xla/client/xla_client/xla_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -51,21 +51,16 @@ bool CanBeRoot(HloOpcode opcode) {
}
}

void SetOpcode(HloInstructionProto* instr, HloOpcode opcode) {
instr->set_opcode(HloOpcodeString(opcode));
}

} // namespace

StatusOr<std::unique_ptr<Shape>> XlaBuilder::GetShape(const XlaOp& op) const {
StatusOr<Shape> XlaBuilder::GetShape(const XlaOp& op) const {
TF_ASSIGN_OR_RETURN(auto instr, LookUpInstruction(op));
return MakeUnique<Shape>(instr->shape());
return instr->shape();
}

StatusOr<Shape> XlaOp::GetShape() const {
TF_RET_CHECK(builder_ != nullptr);
TF_ASSIGN_OR_RETURN(auto shape, builder_->GetShape(*this));
return *shape;
return builder_->GetShape(*this);
}

XlaBuilder::XlaBuilder(const string& computation_name)
Expand Down Expand Up @@ -99,16 +94,17 @@ StatusOr<XlaComputation> XlaBuilder::Build() {

// Not all instructions can be roots. Walk backwards from the last added
// instruction until a valid root is found.
entry.set_root_id(-1);
for (int64 i = instructions_.size() - 1; i >= 0; i--) {
TF_ASSIGN_OR_RETURN(HloOpcode opcode,
StringToHloOpcode(instructions_[i].opcode()));
if (CanBeRoot(opcode)) {
entry.set_root_name(instructions_[i].name());
entry.set_root_id(instructions_[i].id());
*program_shape->mutable_result() = instructions_[i].shape();
break;
}
}
if (entry.root_name().empty()) {
if (entry.root_id() == -1) {
return FailedPrecondition("no root instruction was found");
}

Expand Down Expand Up @@ -141,7 +137,9 @@ StatusOr<XlaComputation> XlaBuilder::Build() {
XlaComputation computation(id);
HloModuleProto* module = computation.mutable_proto();
module->set_name(entry.name());
module->set_id(entry.id());
module->set_entry_computation_name(entry.name());
module->set_entry_computation_id(entry.id());
*module->mutable_program_shape() = entry.program_shape();
for (auto& e : embedded_) {
module->add_computations()->Swap(&e.second);
Expand All @@ -155,56 +153,49 @@ XlaOp XlaBuilder::Add(const XlaOp& lhs, const XlaOp& rhs,
tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
auto op = [&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
SetOpcode(&instr, HloOpcode::kAdd);
TF_ASSIGN_OR_RETURN(const auto* lhs_instr, LookUpInstruction(lhs));
TF_ASSIGN_OR_RETURN(const auto* rhs_instr, LookUpInstruction(rhs));
TF_ASSIGN_OR_RETURN(*instr.mutable_shape(),
ShapeInference::InferBinaryOpShape(
HloOpcode::kAdd, lhs_instr->shape(),
rhs_instr->shape(), broadcast_dimensions));
instr.add_operand_names(lhs_instr->name());
instr.add_operand_names(rhs_instr->name());
return AddInstruction(std::move(instr));
TF_ASSIGN_OR_RETURN(const Shape& lhs_shape, lhs.GetShape());
TF_ASSIGN_OR_RETURN(const Shape& rhs_shape, rhs.GetShape());
TF_ASSIGN_OR_RETURN(
*instr.mutable_shape(),
ShapeInference::InferBinaryOpShape(HloOpcode::kAdd, lhs_shape,
rhs_shape, broadcast_dimensions));
return AddInstruction(std::move(instr), HloOpcode::kAdd, {lhs, rhs});
};
return NoteErrorOrReturn(op());
}

XlaOp XlaBuilder::ConstantLiteral(const Literal& literal) {
HloInstructionProto instr;
SetOpcode(&instr, HloOpcode::kConstant);
*instr.mutable_shape() = literal.shape();
*instr.mutable_literal() = literal.ToProto();
return AddInstruction(std::move(instr));
return AddInstruction(std::move(instr), HloOpcode::kConstant);
}

XlaOp XlaBuilder::Call(const XlaComputation& computation,
tensorflow::gtl::ArraySlice<XlaOp> operands) {
auto op = [&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
SetOpcode(&instr, HloOpcode::kCall);
std::vector<const Shape*> operand_shapes;
std::vector<const Shape*> operand_shape_ptrs;
std::vector<Shape> operand_shapes;
for (const auto& operand : operands) {
TF_ASSIGN_OR_RETURN(const auto* input, LookUpInstruction(operand));
operand_shapes.push_back(&input->shape());
TF_ASSIGN_OR_RETURN(const Shape& shape, operand.GetShape());
operand_shapes.push_back(shape);
}
c_transform(operand_shapes, std::back_inserter(operand_shape_ptrs),
[](const Shape& shape) { return &shape; });
TF_ASSIGN_OR_RETURN(*instr.mutable_shape(),
ShapeInference::InferCallShape(
operand_shapes,
operand_shape_ptrs,
/*to_apply=*/computation.GetProgramShape()));

// Add input operands.
for (const auto& operand : operands) {
TF_ASSIGN_OR_RETURN(auto operand_instr, LookUpInstruction(operand));
instr.add_operand_names(operand_instr->name());
}

// Add called computation.
*instr.add_called_computation_names() = computation.proto().name();
instr.add_called_computation_ids(
computation.proto().entry_computation_id());
for (const HloComputationProto& e : computation.proto().computations()) {
embedded_.insert({e.id(), e});
}

return AddInstruction(std::move(instr));
return AddInstruction(std::move(instr), HloOpcode::kCall, operands);
};
return NoteErrorOrReturn(op());
}
Expand All @@ -213,7 +204,6 @@ XlaOp XlaBuilder::Parameter(int64 parameter_number, const Shape& shape,
const string& name) {
auto op = [&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
SetOpcode(&instr, HloOpcode::kParameter);
if (parameter_numbers_.find(parameter_number) != parameter_numbers_.end()) {
return InvalidArgument("parameter %lld already registered",
parameter_number);
Expand All @@ -222,19 +212,25 @@ XlaOp XlaBuilder::Parameter(int64 parameter_number, const Shape& shape,
instr.set_parameter_number(parameter_number);
instr.set_name(name);
*instr.mutable_shape() = shape;
return AddInstruction(std::move(instr));
return AddInstruction(std::move(instr), HloOpcode::kParameter);
};
return NoteErrorOrReturn(op());
}

XlaOp XlaBuilder::AddInstruction(HloInstructionProto&& instr) {
XlaOp XlaBuilder::AddInstruction(HloInstructionProto&& instr, HloOpcode opcode,
tensorflow::gtl::ArraySlice<XlaOp> operands) {
const int64 handle = instructions_.size();
instr.set_id(handle);
instr.set_opcode(HloOpcodeString(opcode));
if (instr.name().empty()) {
instr.set_name(StrCat(instr.opcode(), ".", handle));
} else {
// Append the handle to make sure the name is unique.
instr.set_name(StrCat(instr.name(), ".", handle));
}
for (const auto& operand : operands) {
instr.add_operand_ids(operand.handle());
}
instructions_.push_back(instr);

XlaOp op(handle, this);
Expand Down
6 changes: 4 additions & 2 deletions tensorflow/compiler/xla/client/xla_client/xla_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ limitations under the License.

#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/service/hlo.pb.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/statusor.h"
Expand Down Expand Up @@ -157,14 +158,15 @@ class XlaBuilder {
XlaOp ConstantR0(NativeT value);

// Returns the shape of the given op.
StatusOr<std::unique_ptr<Shape>> GetShape(const XlaOp& op) const;
StatusOr<Shape> GetShape(const XlaOp& op) const;

// Builds the computation with the requested operations, or returns a non-ok
// status.
StatusOr<XlaComputation> Build();

private:
XlaOp AddInstruction(HloInstructionProto&& instr);
XlaOp AddInstruction(HloInstructionProto&& instr, HloOpcode opcode,
tensorflow::gtl::ArraySlice<XlaOp> operands = {});

// Notes that the error occurred by:
// * storing it internally and capturing a backtrace if it's the first error
Expand Down
52 changes: 41 additions & 11 deletions tensorflow/compiler/xla/python/local_computation_builder.i
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,33 @@ bool GetIntAttr(PyObject* o, const char* field, int64* result) {
return true;
}

// Returns "ok"; true if there is no error, false if there was an error.
bool HandleStringAttribute(PyObject* o,
const char* attr_name,
std::function<void(string s)> f) {
if (!PyObject_HasAttrString(o, attr_name)) {
return true; // It's ok for the object to not have the attribute.
}
PyObject* attr = PyObject_GetAttrString(o, attr_name);
if (attr == nullptr) {
return false; // An error occurred getting the attribute.
}
if (attr == Py_None) {
Py_DECREF(attr);
return true; // The attribute is None, which we consider ok.
}
if (!PyString_Check(attr)) {
string message = tensorflow::strings::Printf("%s must be a string or none; got %s",
attr_name, numpy::PyObjectCppRepr(attr).c_str());
PyErr_SetString(PyExc_TypeError, message.c_str());
Py_DECREF(attr);
return false; // Type error, not ok.
}
f(PyString_AsString(attr));
Py_DECREF(attr);
return true; // Handled string attribute, ok!
}

}
}
%}
Expand Down Expand Up @@ -820,20 +847,23 @@ tensorflow::ImportNumpy();
if ($input == Py_None) {
$1 = NULL;
} else {
PyObject* o = PyObject_GetAttrString($input, "generate_hlo_graph");
if (!o) {
return NULL;
if (!HandleStringAttribute($input, "generate_hlo_graph", [&](string s) {
build_options.set_generate_hlo_graph(std::move(s));
})) {
return nullptr;
}
if (o != Py_None) {
if (!PyString_Check(o)) {
PyErr_SetString(PyExc_TypeError, "ExecutableBuildOptions.generate_hlo_graph must be a string or None.");
return NULL;
}
build_options.set_generate_hlo_graph(PyString_AsString(o));
if (!HandleStringAttribute($input, "dump_optimized_hlo_proto_to", [&](string s) {
build_options.set_dump_optimized_hlo_proto_to(std::move(s));
})) {
return nullptr;
}
if (!HandleStringAttribute($input, "dump_per_pass_hlo_proto_to", [&](string s) {
build_options.set_dump_per_pass_hlo_proto_to(std::move(s));
})) {
return nullptr;
}
Py_DECREF(o);

o = PyObject_GetAttrString($input, "hlo_profile");
PyObject* o = PyObject_GetAttrString($input, "hlo_profile");
if (o == NULL) {
return NULL;
}
Expand Down
3 changes: 1 addition & 2 deletions tensorflow/compiler/xla/python/numpy_bridge.cc
Original file line number Diff line number Diff line change
Expand Up @@ -170,8 +170,7 @@ static string PyObjectCppStr(PyObject* o) {
return ExtractStringAndDecref(s);
}

// Safely returns a repr of the given Python object o as a C++ string.
static string PyObjectCppRepr(PyObject* o) {
string PyObjectCppRepr(PyObject* o) {
PyObject* r = PyObject_Repr(o);
return ExtractStringAndDecref(r);
}
Expand Down
3 changes: 3 additions & 0 deletions tensorflow/compiler/xla/python/numpy_bridge.h
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,9 @@ void CopyLiteralToNumpyArray(const Literal& literal, PyArrayObject* py_array) {
std::copy(source.begin(), source.end(), dest);
}

// Safely returns a repr of the given Python object o as a C++ string.
string PyObjectCppRepr(PyObject* o);

// Workarounds for Python 2 and 3 interop

PyObject* LongToPyIntOrPyLong(long x); // NOLINT
Expand Down
2 changes: 2 additions & 0 deletions tensorflow/compiler/xla/python/xla_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,8 @@ class CompileOptions(object):

def __init__(self):
self.generate_hlo_graph = None
self.dump_optimized_hlo_proto_to = None
self.dump_per_pass_hlo_proto_to = None
self.hlo_profile = False


Expand Down
4 changes: 4 additions & 0 deletions tensorflow/compiler/xla/service/gpu/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,7 @@ cc_library(
"gpu_executable.cc",
"infeed_thunk.cc",
"kernel_thunk.cc",
"memset_thunk.cc",
"sequential_thunk.cc",
"thunk_schedule.cc",
"tuple_thunk.cc",
Expand All @@ -257,6 +258,7 @@ cc_library(
"gpu_executable.h",
"infeed_thunk.h",
"kernel_thunk.h",
"memset_thunk.h",
"sequential_thunk.h",
"thunk.h",
"thunk_schedule.h",
Expand All @@ -273,6 +275,7 @@ cc_library(
"//tensorflow/compiler/xla:array2d",
"//tensorflow/compiler/xla:shape_tree",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:types",
Expand All @@ -293,6 +296,7 @@ cc_library(
"//tensorflow/core/platform/default/build_config:cudnn_plugin",
"//tensorflow/core/platform/default/build_config:cufft_plugin",
"//tensorflow/core/platform/default/build_config:stream_executor_cuda", # build_cleaner: keep
"//tensorflow/stream_executor",
],
)

Expand Down
Loading

0 comments on commit 00c90e6

Please sign in to comment.