Skip to content

Commit

Permalink
Support non-constant param input of AvgPoolGrad and Sum.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 179962212
  • Loading branch information
Yao Zhang authored and tensorflower-gardener committed Dec 22, 2017
1 parent 3e27b27 commit bb45399
Showing 1 changed file with 10 additions and 45 deletions.
55 changes: 10 additions & 45 deletions tensorflow/core/grappler/optimizers/layout_optimizer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -926,7 +926,7 @@ class AvgPoolGradProcessor : public NodeProcessor {
protected:
std::vector<int> GetInputPos() const override { return {1}; }
Status CustomizedProcessing() override {
return UpdateAttrValueOfInput(0, true);
return UpdateOrTransformParamInput(0, "DataFormatVecPermute", DT_INT32);
}
};

Expand Down Expand Up @@ -1062,9 +1062,7 @@ class Conv2DBackpropInputProcessor : public Conv2DProcessor {
std::vector<int> GetInputPos() const override { return {2}; }

Status CustomizedProcessing() override {
TF_RETURN_IF_ERROR(
UpdateOrTransformParamInput(0, "DataFormatVecPermute", DT_INT32));
return Status::OK();
return UpdateOrTransformParamInput(0, "DataFormatVecPermute", DT_INT32);
}
};

Expand Down Expand Up @@ -1371,9 +1369,7 @@ class FillProcessor : public AgnosticNodeProcessor {

Status CustomizedProcessing() override {
DataType dtype = node_->attr().at("index_type").type();
TF_RETURN_IF_ERROR(
UpdateOrTransformParamInput(0, "DataFormatVecPermute", dtype));
return Status::OK();
return UpdateOrTransformParamInput(0, "DataFormatVecPermute", dtype);
}
};

Expand Down Expand Up @@ -1470,9 +1466,7 @@ class PadProcessor : public AgnosticNodeProcessor {
protected:
Status CustomizedProcessing() override {
DataType dtype = node_->attr().at("Tpaddings").type();
TF_RETURN_IF_ERROR(
UpdateOrTransformParamInput(1, "DataFormatVecPermute", dtype));
return Status::OK();
return UpdateOrTransformParamInput(1, "DataFormatVecPermute", dtype);
}
};

Expand All @@ -1484,9 +1478,7 @@ class ReverseProcessor : public AgnosticNodeProcessor {
protected:
Status CustomizedProcessing() override {
DataType dtype = node_->attr().at("Tidx").type();
TF_RETURN_IF_ERROR(
UpdateOrTransformParamInput(1, "DataFormatDimMap", dtype));
return Status::OK();
return UpdateOrTransformParamInput(1, "DataFormatDimMap", dtype);
}
};

Expand All @@ -1511,9 +1503,8 @@ class SplitProcessor : public AgnosticNodeProcessor {
}

Status CustomizedProcessing() override {
TF_RETURN_IF_ERROR(UpdateOrTransformParamInput(
axis_node_pos_, "DataFormatDimMap", DT_INT32));
return Status::OK();
return UpdateOrTransformParamInput(axis_node_pos_, "DataFormatDimMap",
DT_INT32);
}

int axis_node_pos_;
Expand Down Expand Up @@ -1629,40 +1620,14 @@ class SumProcessor : public AgnosticNodeProcessor {
int port;
ParseNodeName(node_->input(0), &port);
return !MustPreserve() && HasOutputs() && IsNodeAfterNCHWToNHWC() &&
IsPortDimsFour(*input0, port) && IsAlongDimNHW() && IsOnGPU();
IsPortDimsFour(*input0, port) && IsOnGPU();
}

Status AddLayoutTransposeToOutputs() override { return Status::OK(); }

Status CustomizedProcessing() override {
return UpdateAttrValueOfInput(1, false);
}

private:
bool IsAlongDimNHW() const {
NodeDef* reduction_indices = node_map_->GetNode(node_->input(1));
if (!IsConstant(*reduction_indices)) {
return false;
}
Tensor tensor;
if (reduction_indices->attr().find({"value"}) ==
reduction_indices->attr().end()) {
return false;
}
auto success =
tensor.FromProto(reduction_indices->attr().at({"value"}).tensor());
if (!success) {
LOG(ERROR) << "Failed to parse TensorProto.";
return false;
}
if (tensor.flat<int>().size() != 3) {
return false;
}
if (tensor.flat<int>()(0) == 0 && tensor.flat<int>()(1) == 1 &&
tensor.flat<int>()(2) == 2) {
return true;
}
return false;
DataType dtype = node_->attr().at("Tidx").type();
return UpdateOrTransformParamInput(1, "DataFormatDimMap", dtype);
}
};

Expand Down

0 comments on commit bb45399

Please sign in to comment.