Skip to content

Commit

Permalink
Various fixes.
Browse files Browse the repository at this point in the history
* Change HloDomainIsolator to take a factory of DomainCreators. The pass was previously not safe to run on a HloModuleGroup containing multiple modules because the DomainCreator held stale data from the earlier modules.
* Move  entry ComputationLayout verification in HloVerifier into a virtual method on ShapeVerifier. This enables backend-specific checking.
* Make HloPassPipeline copyable.

PiperOrigin-RevId: 217962214
  • Loading branch information
meheffernan authored and tensorflower-gardener committed Oct 20, 2018
1 parent 938c08e commit 388ce5d
Show file tree
Hide file tree
Showing 7 changed files with 94 additions and 101 deletions.
34 changes: 13 additions & 21 deletions tensorflow/compiler/xla/service/hlo_domain_isolator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,23 +23,14 @@ limitations under the License.

namespace xla {

class HloDomainIsolator::RunContext {
public:
RunContext(HloModule* module, HloDomainIsolator* isolator)
: module_(module), isolator_(isolator) {}
namespace {

StatusOr<bool> Run();

private:
HloModule* module_;
HloDomainIsolator* isolator_;
};

StatusOr<bool> HloDomainIsolator::RunContext::Run() {
hlo_graph_dumper::MaybeDumpHloModule(*module_, "Before Domain Isolator");
StatusOr<bool> RunInternal(HloModule* module,
HloDomainIsolator::DomainCreator* creator) {
hlo_graph_dumper::MaybeDumpHloModule(*module, "Before Domain Isolator");

int64 added_domains = 0;
for (HloComputation* computation : module_->computations()) {
for (HloComputation* computation : module->computations()) {
// Walk in post order and place all the required kDomain instructions.
for (HloInstruction* instruction :
computation->MakeInstructionPostOrder()) {
Expand All @@ -55,8 +46,7 @@ StatusOr<bool> HloDomainIsolator::RunContext::Run() {
root = root->mutable_operand(0);
}
// Check whether a kDomain is necessary between instruction and operand.
HloInstruction* domain =
isolator_->creator_(instruction, root, operand);
HloInstruction* domain = (*creator)(instruction, root, operand);
if (domain != nullptr) {
VLOG(4) << "New domain: " << domain->ToString();
TF_RETURN_IF_ERROR(operand->ReplaceUseWith(instruction, domain));
Expand All @@ -67,17 +57,19 @@ StatusOr<bool> HloDomainIsolator::RunContext::Run() {
}
VLOG(3) << "Added " << added_domains << " kDomain instructions";
if (added_domains > 0) {
hlo_graph_dumper::MaybeDumpHloModule(*module_, "After Domain Isolator");
hlo_graph_dumper::MaybeDumpHloModule(*module, "After Domain Isolator");
}
return added_domains > 0;
}

HloDomainIsolator::HloDomainIsolator(DomainCreator creator)
: creator_(std::move(creator)) {}
} // namespace

HloDomainIsolator::HloDomainIsolator(DomainCreatorFactory creator_factory)
: creator_factory_(std::move(creator_factory)) {}

StatusOr<bool> HloDomainIsolator::Run(HloModule* module) {
RunContext run_context(module, this);
return run_context.Run();
DomainCreator creator = creator_factory_();
return RunInternal(module, &creator);
}

} // namespace xla
8 changes: 3 additions & 5 deletions tensorflow/compiler/xla/service/hlo_domain_isolator.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,17 +40,15 @@ class HloDomainIsolator : public HloModulePass {
// Returns nullptr in case no domain separation is necessary.
using DomainCreator = std::function<HloInstruction*(
HloInstruction*, HloInstruction*, HloInstruction*)>;

explicit HloDomainIsolator(DomainCreator creator);
using DomainCreatorFactory = std::function<DomainCreator()>;
explicit HloDomainIsolator(DomainCreatorFactory creator_factory_);

absl::string_view name() const override { return "domain_isolator"; }

StatusOr<bool> Run(HloModule* module) override;

private:
class RunContext;

DomainCreator creator_;
DomainCreatorFactory creator_factory_;
};

} // namespace xla
Expand Down
50 changes: 26 additions & 24 deletions tensorflow/compiler/xla/service/hlo_domain_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -106,20 +106,22 @@ class OpNameMetadata : public DomainMetadata {
};

// Creator function for OpNameMetadata domains.
HloInstruction* OpNameDomainCreator(HloInstruction* instruction,
HloInstruction* root,
HloInstruction* operand) {
if (instruction->metadata().op_name() == root->metadata().op_name()) {
return nullptr;
class OpNameDomainCreator {
public:
HloInstruction* operator()(HloInstruction* instruction, HloInstruction* root,
HloInstruction* operand) {
if (instruction->metadata().op_name() == root->metadata().op_name()) {
return nullptr;
}
std::unique_ptr<DomainMetadata> operand_side_metadata =
absl::make_unique<OpNameMetadata>(root->metadata().op_name());
std::unique_ptr<DomainMetadata> user_side_metadata =
absl::make_unique<OpNameMetadata>(instruction->metadata().op_name());
return operand->parent()->AddInstruction(HloInstruction::CreateDomain(
operand->shape(), operand, std::move(operand_side_metadata),
std::move(user_side_metadata)));
}
std::unique_ptr<DomainMetadata> operand_side_metadata =
absl::make_unique<OpNameMetadata>(root->metadata().op_name());
std::unique_ptr<DomainMetadata> user_side_metadata =
absl::make_unique<OpNameMetadata>(instruction->metadata().op_name());
return operand->parent()->AddInstruction(HloInstruction::CreateDomain(
operand->shape(), operand, std::move(operand_side_metadata),
std::move(user_side_metadata)));
}
};

Status OpNameDomainNormalizer(const DomainMetadata::Domain& domain,
const DomainMetadata* metadata) {
Expand All @@ -145,7 +147,7 @@ ENTRY entry {
TF_ASSERT_OK_AND_ASSIGN(HloModule * module, ParseModule(hlo_string));
LOG(INFO) << "Original module:\n" << module->ToString();

HloDomainIsolator isolator(ShardingDomainCreator{});
HloDomainIsolator isolator([]() { return ShardingDomainCreator{}; });
TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module));
EXPECT_TRUE(isolator_changed);

Expand Down Expand Up @@ -187,7 +189,7 @@ ENTRY entry {
TF_ASSERT_OK_AND_ASSIGN(HloModule * module, ParseModule(hlo_string));
LOG(INFO) << "Original module:\n" << module->ToString();

HloDomainIsolator isolator(ShardingDomainCreator{});
HloDomainIsolator isolator([]() { return ShardingDomainCreator{}; });
TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module));
EXPECT_TRUE(!isolator_changed);
}
Expand All @@ -214,7 +216,7 @@ ENTRY entry {
TF_ASSERT_OK_AND_ASSIGN(HloModule * module, ParseModule(hlo_string));
LOG(INFO) << "Original module:\n" << module->ToString();

HloDomainIsolator isolator(ShardingDomainCreator{});
HloDomainIsolator isolator([]() { return ShardingDomainCreator{}; });
TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module));
EXPECT_TRUE(isolator_changed);

Expand Down Expand Up @@ -251,7 +253,7 @@ ENTRY entry {
TF_ASSERT_OK_AND_ASSIGN(HloModule * module, ParseModule(hlo_string));
LOG(INFO) << "Original module:\n" << module->ToString();

HloDomainIsolator isolator(ShardingDomainCreator{});
HloDomainIsolator isolator([]() { return ShardingDomainCreator{}; });
TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module));
EXPECT_FALSE(isolator_changed);
}
Expand Down Expand Up @@ -305,12 +307,12 @@ ENTRY entry {
TF_ASSERT_OK_AND_ASSIGN(HloModule * module, ParseModule(hlo_string));
LOG(INFO) << "Original module:\n" << module->ToString();

HloDomainIsolator sharding_isolator(ShardingDomainCreator{});
HloDomainIsolator sharding_isolator([]() { return ShardingDomainCreator{}; });
TF_ASSERT_OK_AND_ASSIGN(bool sharding_isolator_changed,
sharding_isolator.Run(module));
EXPECT_TRUE(sharding_isolator_changed);

HloDomainIsolator opname_isolator(OpNameDomainCreator);
HloDomainIsolator opname_isolator([]() { return OpNameDomainCreator{}; });
TF_ASSERT_OK_AND_ASSIGN(bool opname_isolator_changed,
opname_isolator.Run(module));
EXPECT_TRUE(opname_isolator_changed);
Expand Down Expand Up @@ -360,7 +362,7 @@ ENTRY entry {
TF_ASSERT_OK_AND_ASSIGN(HloModule * module, ParseModule(hlo_string));
LOG(INFO) << "Original module:\n" << module->ToString();

HloDomainIsolator isolator(ShardingDomainCreator{});
HloDomainIsolator isolator([]() { return ShardingDomainCreator{}; });
TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module));
EXPECT_TRUE(isolator_changed);

Expand Down Expand Up @@ -446,7 +448,7 @@ ENTRY entry {

TF_ASSERT_OK_AND_ASSIGN(HloModule * module, ParseModule(hlo_string));

HloDomainIsolator isolator(ShardingDomainCreator{});
HloDomainIsolator isolator([]() { return ShardingDomainCreator{}; });
TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module));
EXPECT_TRUE(isolator_changed);

Expand Down Expand Up @@ -507,7 +509,7 @@ ENTRY entry {

TF_ASSERT_OK_AND_ASSIGN(HloModule * module, ParseModule(hlo_string));

HloDomainIsolator isolator(ShardingDomainCreator{});
HloDomainIsolator isolator([]() { return ShardingDomainCreator{}; });
TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module));
EXPECT_TRUE(isolator_changed);

Expand Down Expand Up @@ -556,7 +558,7 @@ ENTRY %entry (p0: (f32[4], f32[4])) -> (f32[4], f32[4], f32[4]) {
TF_ASSERT_OK_AND_ASSIGN(HloModule * module, ParseModule(hlo_string));
LOG(INFO) << "Original module:\n" << module->ToString();

HloDomainIsolator opname_isolator(OpNameDomainCreator);
HloDomainIsolator opname_isolator([]() { return OpNameDomainCreator{}; });
TF_ASSERT_OK_AND_ASSIGN(bool opname_isolator_changed,
opname_isolator.Run(module));
EXPECT_TRUE(opname_isolator_changed);
Expand Down Expand Up @@ -603,7 +605,7 @@ ENTRY entry {

TF_ASSERT_OK_AND_ASSIGN(HloModule * module, ParseModule(hlo_string));

HloDomainIsolator isolator(ShardingDomainCreator{});
HloDomainIsolator isolator([]() { return ShardingDomainCreator{}; });
TF_ASSERT_OK_AND_ASSIGN(bool isolator_changed, isolator.Run(module));
EXPECT_TRUE(isolator_changed);

Expand Down
1 change: 1 addition & 0 deletions tensorflow/compiler/xla/service/hlo_pass_pipeline.cc
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ void HloPassPipeline::MaybeDumpHlo(const HloModule& module,
StrCat("after ", after_pass_name, ", before ", before_pass_name);
hlo_graph_dumper::MaybeDumpHloModule(module, message);
VLOG(3) << "HLO " << message << ":";
VLOG(3) << module.entry_computation_layout().ToString();
XLA_VLOG_LINES(3, module.ToString());
}

Expand Down
2 changes: 0 additions & 2 deletions tensorflow/compiler/xla/service/hlo_pass_pipeline.h
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,6 @@ class HloPassPipeline : public HloPassInterface {
std::vector<std::unique_ptr<HloPassInterface>> passes_;
std::vector<std::unique_ptr<HloPassInterface>> invariant_checkers_;
bool run_called_ = false;

TF_DISALLOW_COPY_AND_ASSIGN(HloPassPipeline);
};

} // namespace xla
Expand Down
94 changes: 46 additions & 48 deletions tensorflow/compiler/xla/service/hlo_verifier.cc
Original file line number Diff line number Diff line change
Expand Up @@ -828,6 +828,50 @@ Status ShapeVerifier::CheckVariadicShape(const HloInstruction* instruction) {
instruction->opcode(), instruction->operands()));
}

Status ShapeVerifier::VerifyEntryComputationLayout(const HloModule& module) {
const HloComputation* computation = module.entry_computation();
const auto& layout = module.entry_computation_layout();
const ShapeLayout& result_layout = layout.result_layout();

TF_RETURN_IF_ERROR(
ShapeUtil::ValidateShapeWithOptionalLayout(result_layout.shape()));

TF_RETURN_IF_ERROR(VerifyNotSparse(result_layout.shape()));

if (!ShapeUtil::Compatible(computation->root_instruction()->shape(),
result_layout.shape())) {
return InternalError(
"Shape of the root instruction of entry computation (%s) should be "
"compatible to one specified in module's entry computation layout (%s)",
ShapeUtil::HumanString(computation->root_instruction()->shape()),
ShapeUtil::HumanString(result_layout.shape()));
}

if (computation->num_parameters() != layout.parameter_count()) {
return InternalError(
"Number of parameters in entry computation layout (%d) must be same "
"as number of parameters of entry computation computation (%d)",
layout.parameter_count(), computation->num_parameters());
}

for (int i = 0; i < computation->num_parameters(); ++i) {
const HloInstruction* parameter = computation->parameter_instruction(i);
TF_RETURN_IF_ERROR(
ShapeUtil::ValidateShapeWithOptionalLayout(layout.parameter_shape(i)));
TF_RETURN_IF_ERROR(VerifyNotSparse(layout.parameter_shape(i)));
if (!ShapeUtil::Compatible(parameter->shape(), layout.parameter_shape(i))) {
return InternalError(
"Shape of the entry computation parameter %d is %s should be "
"compatible to the one specified in module's entry computation "
"layout %s",
i, ShapeUtil::HumanString(parameter->shape()),
ShapeUtil::HumanString(layout.parameter_shape(i)));
}
}

return Status::OK();
}

string ComputationsToString(absl::Span<HloComputation* const> computations) {
return absl::StrJoin(computations, ",",
[](string* s, const HloComputation* computation) {
Expand Down Expand Up @@ -923,52 +967,6 @@ Status VerifyEntryAndExitShapes(const HloModule& module) {
return Status::OK();
}

// Verifies that entry computation layout matches characteristics of
// entry computation.
Status CheckEntryComputationLayout(const HloModule& module) {
const HloComputation* computation = module.entry_computation();
const auto& layout = module.entry_computation_layout();
const ShapeLayout& result_layout = layout.result_layout();

TF_RETURN_IF_ERROR(
ShapeUtil::ValidateShapeWithOptionalLayout(result_layout.shape()));

TF_RETURN_IF_ERROR(VerifyNotSparse(result_layout.shape()));

if (!ShapeUtil::Compatible(computation->root_instruction()->shape(),
result_layout.shape())) {
return InternalError(
"Shape of the root instruction of entry computation (%s) should be "
"compatible to one specified in module's entry computation layout (%s)",
ShapeUtil::HumanString(computation->root_instruction()->shape()),
ShapeUtil::HumanString(result_layout.shape()));
}

if (computation->num_parameters() != layout.parameter_count()) {
return InternalError(
"Number of parameters in entry computation layout (%d) must be same "
"as number of parameters of entry computation computation (%d)",
layout.parameter_count(), computation->num_parameters());
}

for (int i = 0; i < computation->num_parameters(); ++i) {
const HloInstruction* parameter = computation->parameter_instruction(i);
TF_RETURN_IF_ERROR(
ShapeUtil::ValidateShapeWithOptionalLayout(layout.parameter_shape(i)));
TF_RETURN_IF_ERROR(VerifyNotSparse(layout.parameter_shape(i)));
if (!ShapeUtil::Compatible(parameter->shape(), layout.parameter_shape(i))) {
return InternalError(
"Shape of the entry computation parameter %d is %s should be "
"compatible to the one specified in module's entry computation "
"layout %s",
i, ShapeUtil::HumanString(parameter->shape()),
ShapeUtil::HumanString(layout.parameter_shape(i)));
}
}

return Status::OK();
}

// Checks if the given two instructions share the same channel id.
Status CheckSameChannel(const HloInstruction* instr1,
const HloInstruction* instr2) {
Expand Down Expand Up @@ -1363,16 +1361,16 @@ StatusOr<bool> HloVerifier::Run(HloModule* module) {
TF_RETURN_IF_ERROR(VerifyHloStructure(module));
TF_RETURN_IF_ERROR(VerifySendsAndRecvs(*module));

std::unique_ptr<ShapeVerifier> shape_verifier = shape_verifier_factory_();
for (auto* computation : module->computations()) {
std::unique_ptr<ShapeVerifier> shape_verifier = shape_verifier_factory_();
TF_RETURN_IF_ERROR(computation->Accept(shape_verifier.get()));

InstructionVerifier instruction_verifier(
instruction_can_change_layout_func_);
TF_RETURN_IF_ERROR(computation->Accept(&instruction_verifier));
}

TF_RETURN_IF_ERROR(CheckEntryComputationLayout(*module));
TF_RETURN_IF_ERROR(shape_verifier->VerifyEntryComputationLayout(*module));
TF_RETURN_IF_ERROR(VerifyEntryAndExitShapes(*module));

// If the module has a schedule, it must be valid.
Expand Down
6 changes: 5 additions & 1 deletion tensorflow/compiler/xla/service/hlo_verifier.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,14 @@ namespace xla {
// TODO(b/26024837): Check output shape for all instruction types.
class ShapeVerifier : public DfsHloVisitor {
public:
explicit ShapeVerifier(bool layout_sensitive, bool allow_mixed_precision)
ShapeVerifier(bool layout_sensitive, bool allow_mixed_precision)
: layout_sensitive_(layout_sensitive),
allow_mixed_precision_(allow_mixed_precision) {}

// Verifies that entry computation layout matches parameters and root shape of
// the module's entry computation.
virtual Status VerifyEntryComputationLayout(const HloModule& module);

Status Preprocess(HloInstruction* hlo) override;

Status HandleElementwiseUnary(HloInstruction* hlo) override;
Expand Down

0 comments on commit 388ce5d

Please sign in to comment.