Skip to content

Commit

Permalink
[CINN] Dump example input tensor meta (PaddlePaddle#64689)
Browse files Browse the repository at this point in the history
* fix bugs about dumping pir block kwargs into py_code

* fix typo: self.t_null -> self.t_null()

* dump example input tensor meta to FLAGS_logging_pir_py_code_dir

* dump types of block args/kwargs into pir py code

* Update paddle/fluid/inference/api/analysis_predictor.cc

Co-authored-by: Yuanle Liu <[email protected]>

* Update paddle/fluid/inference/api/analysis_predictor.cc

Co-authored-by: Yuanle Liu <[email protected]>

* format code

* rename Program::random_logging_id to Program::id

* Program::id_ should not be cloned

* fix CI/CE complaints

---------

Co-authored-by: jiahy0825 <[email protected]>
Co-authored-by: Yuanle Liu <[email protected]>
  • Loading branch information
3 people authored Jun 4, 2024
1 parent e736b24 commit 452612c
Show file tree
Hide file tree
Showing 9 changed files with 302 additions and 29 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -866,27 +866,52 @@ struct PirToPyCodeConverterHelper {
}

std::string ConvertInputTypes(const pir::Operation* op) {
std::stringstream ss;
ss << "[";
for (int i = 0; i < op->num_operands(); ++i) {
if (i > 0) {
ss << ", ";
const auto& VisitValue = [&](const auto& DoEachValue) {
for (int i = 0; i < op->num_operands(); ++i) {
DoEachValue(op->operand_source(i));
}
ss << ConvertType(op->operand_source(i).type());
}
ss << "]";
return ss.str();
};
return ConvertValueTypes(VisitValue);
}

std::string ConvertBlockArgTypes(const pir::Block& block) {
const auto& VisitValue = [&](const auto& DoEachValue) {
for (const auto& arg : block.args()) {
DoEachValue(arg);
}
};
return ConvertValueTypes(VisitValue);
}

std::string ConvertBlockKwArgTypes(const pir::Block& block) {
const auto& VisitValue = [&](const auto& DoEachValue) {
for (const auto& [_, arg] : block.kwargs()) {
DoEachValue(arg);
}
};
return ConvertValueTypes(VisitValue);
}

std::string ConvertOutputTypes(const pir::Operation* op) {
const auto& VisitValue = [&](const auto& DoEachValue) {
for (int i = 0; i < op->num_results(); ++i) {
DoEachValue(op->result(i));
}
};
return ConvertValueTypes(VisitValue);
}

template <typename VisitValueT>
std::string ConvertValueTypes(const VisitValueT& VisitValue) {
std::stringstream ss;
ss << "[";
for (int i = 0; i < op->num_results(); ++i) {
if (i > 0) {
int i = 0;
VisitValue([&](pir::Value value) {
if (i++ > 0) {
ss << ", ";
}
ss << ConvertType(op->result(i).type());
}
ss << ConvertType(value.type());
});
ss << "]";
return ss.str();
}
Expand Down Expand Up @@ -1098,7 +1123,45 @@ struct PirToPyCodeConverterHelper {
}
ss << "]";
}
ss << "]";
ss << "], ";
}
{
int i = 0;
ss << "block_positional_arg_types=[";
for (const auto& region : *op) {
if (i++ > 0) {
ss << ",";
}
int j = 0;
ss << "[";
for (const auto& block : region) {
if (j++ > 0) {
ss << ",";
}
ss << ConvertBlockArgTypes(block);
}
ss << "]";
}
ss << "], ";
}
{
int i = 0;
ss << "block_keyword_arg_types=[";
for (const auto& region : *op) {
if (i++ > 0) {
ss << ",";
}
int j = 0;
ss << "[";
for (const auto& block : region) {
if (j++ > 0) {
ss << ",";
}
ss << ConvertBlockKwArgTypes(block);
}
ss << "]";
}
ss << "], ";
}
return ss.str();
}
Expand Down Expand Up @@ -1138,18 +1201,10 @@ struct PirToPyCodeConverterHelper {

std::string GetPyClassName() {
std::ostringstream ss;
ss << "PirProgram_" << RandomInt();
ss << "PirProgram_" << program_->id();
return ss.str();
}

int64_t RandomInt() {
std::random_device rd{};
std::mt19937_64 gen(rd());
std::uniform_int_distribution<int64_t> dis(
0, std::numeric_limits<int64_t>::max());
return dis(gen);
}

std::string ConvertIStringsToString(const IStrings& istrings) {
std::stringstream ss;
for (const auto& istring : istrings) {
Expand Down
4 changes: 3 additions & 1 deletion paddle/fluid/eager/to_static/run_program_op_node.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include "paddle/fluid/eager/grad_node_info.h"
#include "paddle/fluid/eager/tensor_wrapper.h"
#include "paddle/fluid/framework/executor_cache.h"
#include "paddle/fluid/framework/feed_hook.h"
#include "paddle/fluid/framework/new_executor/interpretercore.h"
#include "paddle/fluid/framework/tensor_ref_array.h"
#include "paddle/fluid/framework/variable_helper.h"
Expand Down Expand Up @@ -583,6 +584,7 @@ inline void PirRunProgramAPI(
//}
}

paddle::framework::RunFeedHooks(*forward_program, *global_inner_scope);
// interpretercore run
if (!forward_program->block()->empty()) {
paddle::platform::RecordEvent record_event(
Expand Down Expand Up @@ -869,7 +871,6 @@ inline void RunProgramGradAPI(
auto *backward_global_block = PADDLE_GET_CONST(
paddle::framework::BlockDesc *, attrs.at("backward_global_block"));
auto *backward_program = backward_global_block->Program();

details::Trans2ContiguousTensorsInplace(out_grad);

auto out_grad_names = details::GetTensorsName(out_grad);
Expand Down Expand Up @@ -1155,6 +1156,7 @@ inline void PirRunProgramGradAPI(
}
}

paddle::framework::RunFeedHooks(*backward_program, *global_inner_scope);
if (!backward_program->block()->empty()) {
paddle::platform::RecordEvent record_event(
"interpreter_core_run",
Expand Down
14 changes: 13 additions & 1 deletion paddle/fluid/framework/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -515,6 +515,12 @@ cc_library(
feed_fetch_method
SRCS feed_fetch_method.cc
DEPS lod_tensor scope glog)

cc_library(
feed_hook
SRCS feed_hook.cc
DEPS lod_tensor scope glog pir)

cc_library(
variable_helper
SRCS variable_helper.cc
Expand All @@ -529,6 +535,7 @@ set(NAIVE_EXECUTOR_DEPS
glog
lod_rank_table
feed_fetch_method
feed_hook
graph_to_program_pass
standalone_executor
variable_helper)
Expand Down Expand Up @@ -598,6 +605,7 @@ if(WITH_DISTRIBUTE)
lodtensor_printer
lod_rank_table
feed_fetch_method
feed_hook
collective_helper
${GLOB_DISTRIBUTE_DEPS}
graph_to_program_pass
Expand Down Expand Up @@ -628,7 +636,7 @@ if(WITH_DISTRIBUTE)
# pull_dense_worker.cc section_worker.cc heter_section_worker.cc device_worker_factory.cc data_set.cc DEPS op_registry
# device_context scope framework_proto data_feed_proto heter_service_proto trainer_desc_proto glog
# index_sampler index_wrapper sampler index_dataset_proto
# lod_rank_table framework_io fleet_wrapper heter_wrapper box_wrapper metrics lodtensor_printer feed_fetch_method
# lod_rank_table framework_io fleet_wrapper heter_wrapper box_wrapper metrics lodtensor_printer feed_fetch_method feed_hook
# graph_to_program_pass variable_helper timer monitor
# heter_service_proto fleet heter_server brpc fleet_executor
# graph_gpu_wrapper)
Expand Down Expand Up @@ -677,6 +685,7 @@ if(WITH_DISTRIBUTE)
metrics
lodtensor_printer
feed_fetch_method
feed_hook
graph_to_program_pass
variable_helper
timer
Expand Down Expand Up @@ -750,6 +759,7 @@ if(WITH_DISTRIBUTE)
metrics
lodtensor_printer
feed_fetch_method
feed_hook
graph_to_program_pass
variable_helper
timer
Expand Down Expand Up @@ -808,6 +818,7 @@ elseif(WITH_PSLIB)
box_wrapper
lodtensor_printer
feed_fetch_method
feed_hook
graph_to_program_pass
variable_helper
timer
Expand Down Expand Up @@ -854,6 +865,7 @@ else()
box_wrapper
lodtensor_printer
feed_fetch_method
feed_hook
graph_to_program_pass
variable_helper
timer
Expand Down
130 changes: 130 additions & 0 deletions paddle/fluid/framework/feed_hook.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/framework/feed_hook.h"
#include <fstream>
#include <sstream>
#include "paddle/common/flags.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h"
#include "paddle/pir/include/core/program.h"

COMMON_DECLARE_string(logging_pir_py_code_dir);
COMMON_DECLARE_bool(logging_trunc_pir_py_code);

namespace paddle::framework {

namespace {

std::optional<std::string> GetLoggingFilePath() {
if (FLAGS_logging_pir_py_code_dir.empty()) return std::nullopt;
const std::string file_path =
FLAGS_logging_pir_py_code_dir + "/programs_example_input_tensor_meta.py";
return file_path;
}

void TryTruncateLoggingFile() {
if (!FLAGS_logging_trunc_pir_py_code) return;
std::optional<std::string> file_path = GetLoggingFilePath();
if (!file_path.has_value()) return;
static std::once_flag once_flag;
std::call_once(once_flag, [&] {
std::ofstream ofs;
ofs.open(file_path.value().c_str(), std::ios::out | std::ios::trunc);
ofs.close();
});
}

template <typename DoEachFeadNameT>
void VisitFeedName(const pir::Program& program,
const DoEachFeadNameT& DoEachFeadName) {
auto module_op = program.module_op();
const auto& block = module_op.block();
const auto& IsDataOp = [](const pir::Operation& op) -> bool {
return op.isa<paddle::dialect::DataOp>();
};
const auto& GetDataOpName = [](const pir::Operation& op) -> std::string {
return op.attributes().at("name").dyn_cast<pir::StrAttribute>().AsString();
};
for (const auto& op : block) {
if (IsDataOp(op)) {
DoEachFeadName(GetDataOpName(op));
}
}
for (const auto& [name, _] : block.kwargs()) {
DoEachFeadName(name);
}
}

std::string GetLoggingShapeOrDataForName(int64_t program_id,
const std::string& name,
const phi::DenseTensor& tensor) {
std::ostringstream ss;
ss << "class PirProgram_example_input_tensor_meta_" << program_id << ":";
ss << "\n\tprogram_id = " << program_id;
ss << "\n\tinput_name = " << std::quoted(name);
ss << "\n\tshape = [";
int i = 0;
for (int dim : ::common::vectorize<int64_t>(tensor.dims())) {
if (i++ > 0) {
ss << ", ";
}
ss << dim;
}
ss << "]";
ss << "\n\n";
return ss.str();
}

void AppendToLoggingFile(const std::string& logging_str) {
std::optional<std::string> file_path = GetLoggingFilePath();
if (!file_path.has_value()) return;
std::ofstream ofs;
ofs.open(file_path.value().c_str(), std::ios::out | std::ios::app);
if (!ofs.is_open()) return;
ofs << logging_str << std::endl;
ofs.close();
}

void AppendLoggingShapeOrDataForName(int64_t uid,
const std::string& name,
const phi::DenseTensor& tensor) {
static std::mutex mutex;
std::unique_lock<std::mutex> lock(mutex);
using Name2OnceFlag = std::unordered_map<std::string, std::once_flag>;
static std::unordered_map<int64_t, Name2OnceFlag> once_flags;
std::call_once(once_flags[uid][name], [&] {
AppendToLoggingFile(GetLoggingShapeOrDataForName(uid, name, tensor));
});
}

void SaveLoggingShapeOrData(const pir::Program& program, const Scope& scope) {
if (FLAGS_logging_pir_py_code_dir.empty()) return;
TryTruncateLoggingFile();
VisitFeedName(program, [&](const std::string& name) {
Variable* variable = scope.FindVar(name);
if (variable == nullptr) return;
if (!variable->IsType<phi::DenseTensor>()) return;
const phi::DenseTensor& tensor = variable->Get<phi::DenseTensor>();
AppendLoggingShapeOrDataForName(program.id(), name, tensor);
});
}

} // namespace

void RunFeedHooks(const pir::Program& program, const Scope& scope) {
SaveLoggingShapeOrData(program, scope);
}

} // namespace paddle::framework
29 changes: 29 additions & 0 deletions paddle/fluid/framework/feed_hook.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

namespace pir {

class Program;

}

namespace paddle::framework {

class Scope;

void RunFeedHooks(const pir::Program& program, const Scope& scope);

} // namespace paddle::framework
Loading

0 comments on commit 452612c

Please sign in to comment.