Skip to content

Commit

Permalink
[CINN] Support passing value symbol info from forward to backward pro…
Browse files Browse the repository at this point in the history
…gram (PaddlePaddle#65088)

* bind shape_analysis and shape_or_data to python level

* 1.delete print_shape_or_data 2. modify get_shape_constraint_ir_analysis as a lambda func

* delete using in pir.cc

* delete note

* support register symbol info for shape_analysis

* support pass value symbol info from forward to backward program

* refine code

* refine print log

---------

Co-authored-by: gongshaotian <[email protected]>
  • Loading branch information
zyfncg and gongshaotian authored Jun 14, 2024
1 parent d8c997d commit 4afbfa4
Show file tree
Hide file tree
Showing 6 changed files with 153 additions and 51 deletions.
7 changes: 7 additions & 0 deletions paddle/cinn/hlir/dialect/operator/transforms/add_cinn_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
#include "paddle/fluid/pir/transforms/general/dead_code_elimination_pass.h"

COMMON_DECLARE_bool(print_ir);
COMMON_DECLARE_bool(pir_debug);
COMMON_DECLARE_bool(disable_dyshape_in_train);
COMMON_DECLARE_bool(enable_cinn_accuracy_check);
COMMON_DECLARE_bool(enable_fuse_parallel_matmul_pass);
Expand Down Expand Up @@ -246,6 +247,12 @@ void ApplyCinnPass(::pir::Program* program,
LOG(INFO) << "FusionOp count before lowering : *****[ "
<< GetOpCount<cinn::dialect::FusionOp>(program->module_op())
<< " ]*****";
if (FLAGS_pir_debug) {
auto& shape_analysis = pir::ShapeAnalysisManager::Instance().Get(program);
std::cout << "Program before lowering: \n"
<< pir::CustomPrintHelper(*program, shape_analysis.PrintHook())
<< std::endl;
}
ApplyCinnLowerPass(program, CreatePassManager);
}

Expand Down
34 changes: 22 additions & 12 deletions paddle/fluid/pybind/pir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -261,15 +261,22 @@ void SetValueName(Value value, const std::string name) {
}
}

bool IsUsedByShadowOutput(const pir::Value &value) {
for (auto iter = value.use_begin(); iter != value.use_end(); ++iter) {
if (iter->owner()->isa<::pir::ShadowOutputOp>()) {
return true;
}
}
return false;
}

bool HasValueName(const Value &value) {
if (IsFakeValue(value)) {
return false;
}
if (value.defining_op()->isa<::pir::ParameterOp>() ||
value.defining_op()->isa<paddle::dialect::DataOp>() ||
value.isa<BlockArgument>() ||
(value.first_use() &&
(value.first_use().owner()->isa<::pir::ShadowOutputOp>()))) {
value.isa<BlockArgument>() || IsUsedByShadowOutput(value)) {
return true;
} else {
return false;
Expand All @@ -287,15 +294,15 @@ std::string GetValueName(Value value) {
} else {
return "arg_" + std::to_string(block_arg.index());
}
} else if (value.first_use()) {
auto nextOp = value.first_use().owner();
if (nextOp->isa<::pir::ShadowOutputOp>()) {
return nextOp->attribute<StrAttribute>("output_name").AsString();
} else {
PADDLE_THROW(phi::errors::InvalidArgument(
"Currently, we can only get name of Value which is "
"shadowoutput "));
} else if (IsUsedByShadowOutput(value)) {
for (auto iter = value.use_begin(); iter != value.use_end(); ++iter) {
if (iter->owner()->isa<::pir::ShadowOutputOp>()) {
return iter->owner()->attribute<StrAttribute>("output_name").AsString();
}
}
PADDLE_THROW(phi::errors::InvalidArgument(
"Currently, we can only get name of Value which is "
"shadowoutput "));
} else {
PADDLE_THROW(phi::errors::InvalidArgument(
"Currently, we can only get name of Value that "
Expand Down Expand Up @@ -2575,7 +2582,10 @@ void BindShapeConstraintIRAnalysis(pybind11::module *m) {
&pir::ShapeConstraintIRAnalysis::GetShapeOrDataForValue,
return_value_policy::reference)
.def("set_shape_or_data_for_var",
&pir::ShapeConstraintIRAnalysis::SetShapeOrDataForValue);
&pir::ShapeConstraintIRAnalysis::SetShapeOrDataForValue)
.def("register_symbol_cstr_from_shape_analysis",
&pir::ShapeConstraintIRAnalysis::
RegisterSymbolConstraintFromShapeAnalysis);
}

void BindPir(pybind11::module *module) {
Expand Down
8 changes: 8 additions & 0 deletions paddle/pir/include/dialect/shape/utils/shape_analysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,10 @@ class IR_API InferSymbolicShapeContext {
InferSymbolicShapeContext(InferSymbolicShapeContext&&) = delete;
void Init();

// Note: Only initialize the symbol info, the value info is not update.
void RegisterSymbolConstraintFromContext(
const InferSymbolicShapeContext& other);

const std::string GetNextSymName();

bool HasShapeOrDataForValue(Value val) const;
Expand Down Expand Up @@ -74,6 +78,7 @@ class IR_API InferSymbolicShapeContext {
void SubstituteDimExpr(const symbol::DimExpr& origin,
const symbol::DimExpr& substituted);

int64_t sym_idx_begin_ = 0;
int64_t next_sym_idx_ = 0;

std::unordered_map<uint64_t, symbol::ShapeOrDataDimExprs>
Expand All @@ -94,6 +99,9 @@ class IR_API ShapeConstraintIRAnalysis final
ShapeConstraintIRAnalysis(ShapeConstraintIRAnalysis&&) = delete;
void Init();

void RegisterSymbolConstraintFromShapeAnalysis(
const ShapeConstraintIRAnalysis& other);

const std::string GetNextSymName();

const symbol::ShapeOrDataDimExprs& GetShapeOrDataForValue(Value val);
Expand Down
57 changes: 22 additions & 35 deletions paddle/pir/src/dialect/shape/transforms/shape_optimization_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -134,33 +134,7 @@ void DebugPrintOpInfo(pir::Operation* op,
<< "ShapeOrData: {";

if (infer_context != nullptr) {
auto shape_data = infer_context->GetShapeOrDataForValue(res);
if (shape_data.isa<symbol::TensorListShapeOrDataDimExprs>()) continue;
print_stream << "shape: [";

for (size_t i = 0; i < shape_data.shape().size(); ++i) {
if (i != shape_data.shape().size() - 1) {
print_stream << symbol::ToString(shape_data.shape()[i]) << ",";
} else {
print_stream << symbol::ToString(shape_data.shape()[i]);
}
}

print_stream << "], data: [";
if (shape_data.data().has_value()) {
for (size_t i = 0; i < shape_data.data().value().size(); ++i) {
if (i != shape_data.data().value().size() - 1) {
print_stream << symbol::ToString(shape_data.data().value()[i])
<< ",";
} else {
print_stream << symbol::ToString(shape_data.data().value()[i]);
}
}
} else {
print_stream << "nullopt";
}

print_stream << "]";
print_stream << infer_context->GetShapeOrDataForValue(res);
}
print_stream << " }\n";
}
Expand Down Expand Up @@ -318,16 +292,29 @@ void InferSymExprForBlock(const Block& block,
void InferSymExprForAllValues(ModuleOp module_op) {
ShapeConstraintIRAnalysis& shape_analysis =
ShapeAnalysisManager::Instance().Get(module_op.program());
auto* infer_context = shape_analysis.MutInferSymbolicShapeContext();

// hold the kwargs symbol shape info to avoid be cleared when call init.
const std::unordered_map<pir::Value, symbol::ShapeOrDataDimExprs>
symbol_shape_map = [&] {
std::unordered_map<pir::Value, symbol::ShapeOrDataDimExprs>
symbol_shape_map;
for (const auto& [_, value] : module_op.block().kwargs()) {
if (infer_context->HasShapeOrDataForValue(value)) {
symbol_shape_map.emplace(
value, infer_context->GetShapeOrDataForValue(value));
}
}
return symbol_shape_map;
}();

shape_analysis.Init();
auto infer_context = shape_analysis.MutInferSymbolicShapeContext();
for (uint32_t i = 0; i < module_op->num_regions(); i++) {
for (auto& block : module_op->region(i)) {
for (auto& [_, value] : block.kwargs()) {
infer_context->SetSymbolForValueByStaticShape(value);
}
InferSymExprForBlock(block, infer_context);
}
// init the kwarg symbol shape info
for (const auto& kv : symbol_shape_map) {
infer_context->SetShapeOrDataForValue(kv.first, kv.second);
}

InferSymExprForBlock(module_op.block(), infer_context);
}

std::unique_ptr<Pass> CreateShapeOptimizationPass() {
Expand Down
40 changes: 39 additions & 1 deletion paddle/pir/src/dialect/shape/utils/shape_analysis.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,46 @@ static std::string GetValueId(Value val) {

void InferSymbolicShapeContext::Init() {
value_id_to_shape_or_data_.clear();
next_sym_idx_ = 0;
next_sym_idx_ = sym_idx_begin_;
constraints_manager_.SetEqualCallbackFunc(
[&](const symbol::DimExpr& lhs, const symbol::DimExpr& rhs) {
return SubstituteDimExpr(lhs, rhs);
});
}

void InferSymbolicShapeContext::RegisterSymbolConstraintFromContext(
const InferSymbolicShapeContext& other) {
PADDLE_ENFORCE_EQ(
next_sym_idx_,
0,
common::errors::PreconditionNotMet("next_sym_idx_ should be 0 when init "
"symbol constraint, but now get %d",
next_sym_idx_));
PADDLE_ENFORCE_EQ(value_id_to_shape_or_data_.size(),
0,
common::errors::PreconditionNotMet(
"value_id_to_shape_or_data_ should be empty when init "
"symbol constraint, but now get %d",
value_id_to_shape_or_data_.size()));
sym_idx_begin_ = other.next_sym_idx_;
next_sym_idx_ = sym_idx_begin_;
// init equal constraints
for (const auto& kv : other.constraints_manager_.equals().GetMap()) {
constraints_manager_.AddEqCstr(kv.first, kv.second);
}
// init broadcastable constraints
for (const auto& bc_item : other.constraints_manager_.broadcastables()) {
constraints_manager_.AddBroadcastableCstr(bc_item.data->lhs,
bc_item.data->rhs);
}
// init gtone constraints
for (const auto& gt_one : other.constraints_manager_.gtones()) {
constraints_manager_.AddGTOneCstr(gt_one);
}

substitution_pattern_ = other.substitution_pattern_;
}

const std::string InferSymbolicShapeContext::GetNextSymName() {
return "S" + std::to_string(next_sym_idx_++);
}
Expand Down Expand Up @@ -283,6 +316,11 @@ void InferSymbolicShapeContext::PrintShapeOrDatas() const {

void ShapeConstraintIRAnalysis::Init() { context_.Init(); }

void ShapeConstraintIRAnalysis::RegisterSymbolConstraintFromShapeAnalysis(
const ShapeConstraintIRAnalysis& other) {
context_.RegisterSymbolConstraintFromContext(other.context_);
}

const std::string ShapeConstraintIRAnalysis::GetNextSymName() {
return context_.GetNextSymName();
}
Expand Down
58 changes: 55 additions & 3 deletions python/paddle/jit/dy2static/pir_partial_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,7 @@ def pass_fn(forward_program, backward_program):
# NOTE(dev): Add this line to trigger program_name_attr logic
program_name_attr = self.program_name_attr
self.forward_program, self.backward_program = pass_fn(
origin_fwd, origin_bwd
origin_fwd, origin_bwd, program_name_attr
)
prog_logger.log(
1,
Expand Down Expand Up @@ -736,7 +736,7 @@ def _get_scope(self, program_id=None, use_scope_cache=False):
def _create_program(self, is_infer_mode=False):
if is_infer_mode:

def pass_fn(forward_program, backward_program):
def pass_fn(forward_program, backward_program, program_name_attr):
# common pass
pm = paddle.base.libpaddle.pir.PassManager()
paddle.base.libpaddle.pir.infer_symbolic_shape_pass(
Expand Down Expand Up @@ -770,12 +770,64 @@ def pass_fn(forward_program, backward_program):
# Note: Only set grad type once after initializing train program. So we put it here.
self._set_grad_type(self._params, train_program)

def pass_fn(forward_program, backward_program):
def pass_fn(forward_program, backward_program, program_name_attr):
def init_backward_program_shape_analysis(
forward_program, backward_program
):
forward_shape_analysis = paddle.base.libpaddle.pir.get_shape_constraint_ir_analysis(
forward_program
)
backward_shape_analysis = paddle.base.libpaddle.pir.get_shape_constraint_ir_analysis(
backward_program
)
backward_shape_analysis.register_symbol_cstr_from_shape_analysis(
forward_shape_analysis
)
forward_name_value_map = {
item.name: item
for item in forward_program.list_vars()
if item.has_name
}

def share_symbol_shape_from_forward_to_backward(
forward_value, backward_value
):
backward_shape_analysis.set_shape_or_data_for_var(
backward_value,
forward_shape_analysis.get_shape_or_data_for_var(
forward_value
),
)

def get_kwargs_forward_matched_value(kw_name, kw_value):
if kw_name in program_name_attr['bo_g']:
idx = program_name_attr['bo_g'].index(kw_name)
return forward_name_value_map[
program_name_attr['fo'][idx]
]
elif kw_name in forward_name_value_map:
return forward_name_value_map[kw_name]
else:
raise Exception(f"kw_args: {kw_name} not found")

for [kw_name, kw_value] in (
backward_program.global_block().kwargs().items()
):
forward_matched_value = (
get_kwargs_forward_matched_value(kw_name, kw_value)
)
share_symbol_shape_from_forward_to_backward(
forward_matched_value, kw_value
)

if cse_is_enabled():
paddle.base.libpaddle.pir.apply_cse_pass(forward_program)
paddle.base.libpaddle.pir.apply_cse_pass(backward_program)
if cinn_is_enabled(self._build_strategy, self._backend):
paddle.base.libpaddle.pir.apply_cinn_pass(forward_program)
init_backward_program_shape_analysis(
forward_program, backward_program
)
paddle.base.libpaddle.pir.apply_cinn_pass(backward_program)
else:
paddle.base.libpaddle.pir.check_infer_symbolic_if_need(
Expand Down

0 comments on commit 4afbfa4

Please sign in to comment.