Skip to content

Commit

Permalink
!601 code check: tiling_strategy_manager_npu
Browse files Browse the repository at this point in the history
Merge pull request !601 from yangsijia/code-check3
  • Loading branch information
it-is-a-robot authored and gitee-org committed Mar 17, 2022
2 parents e3fa523 + e031624 commit 1f0dad0
Show file tree
Hide file tree
Showing 11 changed files with 125 additions and 167 deletions.
10 changes: 5 additions & 5 deletions src/poly/tiling/schtree_analyzer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -262,13 +262,13 @@ bool ScheduleTreeAnalyzer::GetPosShiftedTileRange(const std::string &vname, cons
std::string post = sp_add[1];
if (pre.empty() || post.empty()) return false;
if (pre == actual_name && post != actual_name) {
auto add_range = static_cast<int>(std::strtol(post.c_str(), nullptr, 10));
auto add_range = StrToDecimalInt(post);
ranges.first += add_range;
ranges.second += add_range;
old_ranges = ranges;
return true;
} else if (post == actual_name && pre != actual_name) {
auto add_range = static_cast<int>(std::strtol(pre.c_str(), nullptr, 10));
auto add_range = StrToDecimalInt(pre);
ranges.first += add_range;
ranges.second += add_range;
old_ranges = ranges;
Expand All @@ -286,13 +286,13 @@ bool ScheduleTreeAnalyzer::GetNegShiftedTileRange(const std::string &vname, cons
std::string post = sp_sub[1];
if (pre.empty() || post.empty()) return false;
if (pre == actual_name && post != actual_name) {
auto sub_range = static_cast<int>(std::strtol(post.c_str(), nullptr, 10));
auto sub_range = StrToDecimalInt(post);
ranges.first -= sub_range;
ranges.second -= sub_range;
old_ranges = ranges;
return true;
} else if (post == actual_name && pre != actual_name) {
auto sub_range = static_cast<int>(std::strtol(pre.c_str(), nullptr, 10));
auto sub_range = StrToDecimalInt(pre);
std::pair<int, int> res;
res.second = sub_range - ranges.first;
res.first = sub_range - ranges.second;
Expand Down Expand Up @@ -564,7 +564,7 @@ int ScheduleTreeAnalyzer::GetLayerIndex(const std::string &var_name) {
layer_s += i;
}
}
return layer_s.empty() ? -1 : static_cast<int>(std::strtol(layer_s.c_str(), nullptr, 10));
return layer_s.empty() ? -1 : StrToDecimalInt(layer_s);
}

bool ScheduleTreeAnalyzer::MatchNodeWithDynamicLoop(std::unordered_set<const For *> &matched, TileNode &node,
Expand Down
2 changes: 1 addition & 1 deletion src/poly/tiling/space_analyzer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -342,7 +342,7 @@ void SpaceAnalyzer::MarkDmaAlign(const TensorEntry &dst_tensor, std::vector<Tens
// B((cc0 + 126794), (cc1 + 12), (cc2 + 1), 0) = input_1(cc0, cc1, cc2, 0)
// Or B((cc0 + 126794), (cc1 + 12), (cc2 + 1), 7) = input_1(cc0, cc1, cc2, 0)
VarNames last_names = dst_tensor.var_names.back();
if (last_names.size() == 1U && !last_names[0].empty() && StrToInt64(last_names[0]) < ALIGN_BYTES) {
if (last_names.size() == 1U && !last_names[0].empty() && StrToDecimalInt64(last_names[0]) < ALIGN_BYTES) {
analyzer_->RootAxis()->MarkWithAttr(AttrInfo{AT_TRANSFORM, dst_tensor.name});
}
}
Expand Down
8 changes: 5 additions & 3 deletions src/poly/tiling/tiling.cc
Original file line number Diff line number Diff line change
Expand Up @@ -312,19 +312,21 @@ void TilingGenerator::ConvertShiftBoundToDims() {
if (!bound_value.empty()) {
CHECK_EQ(bound_value.size(), 1U);
CHECK_NE(bound_value[0], "");
auto bound = static_cast<int>(std::strtol(bound_value[0].c_str(), nullptr, 10));
auto bound = StrToDecimalInt(bound_value[0]);
DimensionInfo bound_info = ConvertDefaultInfo(axis);
bound_info.c1_tiling_size = bound;
bound_info.c1_var = axis->range_extent;
for (const auto &d : this->dims_) {
if (d.dim_seq != bound_info.dim_seq) continue;
if (d.dim_seq != bound_info.dim_seq) {
continue;
}
bound_info.c0_tiling_size = d.c1_tiling_size;
bound_info.c0_var = d.c1_var;
}
std::vector<std::string> shift_value = axis->GetAttrValue(AT_DYNAMIC_SHIFT);
CHECK_EQ(shift_value.size(), 1U) << "Empty shift_time for dynamic bound " << bound;
CHECK_NE(shift_value[0], "");
auto shift = static_cast<int>(std::strtol(shift_value[0].c_str(), nullptr, 10));
auto shift = StrToDecimalInt(shift_value[0]);
bound_info.pragma = shift;
CHECK_NE(bound_info.c0_tiling_size, -1);
this->dims_.push_back(bound_info);
Expand Down
2 changes: 1 addition & 1 deletion src/poly/tiling/tiling_analyzer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1084,7 +1084,7 @@ class LinearAccessPatternBuilder : public IRVisitor {
std::vector<std::string> info = akg::common::Split(attr.attr_value, "->");
CHECK_EQ(info.size(), 2U);
std::string buffer = info[0];
auto times = static_cast<int>(std::strtol(info[1].c_str(), nullptr, 10));
auto times = StrToDecimalInt(info[1]);
expanded_buf_[buffer] = times;
}
}
Expand Down
10 changes: 3 additions & 7 deletions src/poly/tiling/tiling_analyzer.h
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ inline Expr CastToExpr(const std::string &value) {
return Expr(Var(value));
}
}
return Expr(static_cast<int>(std::strtol(value.c_str(), nullptr, 10)));
return Expr(StrToDecimalInt(value));
}

inline Expr CastInt64ToExpr(const int64_t value) { return air::ir::IntImm::make(Int(32), value); }
Expand Down Expand Up @@ -422,12 +422,8 @@ class TileCandidate {
void UpdateFixTileAxis(TileLevel level);

std::vector<TileAxis *> GetTileAxis() { return this->tile_axis_; }
void ResetTileAxis() {
this->tile_axis_.clear();
}
void ResetTileVal() {
this->tile_val_.clear();
}
void ResetTileAxis() { this->tile_axis_.clear(); }
void ResetTileVal() { this->tile_val_.clear(); }
void UpdateConstTile(const TileAxis *a, int64_t c1_val, int64_t c0_val = -1);
void UpdateC1Tile(const TileAxis *a, const Expr &c1_val);
void UpdateC0Tile(const TileAxis *a, const Expr &c0_val);
Expand Down
6 changes: 3 additions & 3 deletions src/poly/tiling/tiling_solver.cc
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ void GpuSolver::TotSpeedup() {
if (shared_memory_limit < total_middle_tensor_size / outermost_tile) return;
ss << "total middle tensor size: " << total_middle_tensor_size
<< ", exceeds shared_memory_limit: " << shared_memory_limit;

auto current_elem_num = 1;
for (unsigned int i = 0; i < tile_axes.size(); ++i) {
if (i != OUTERMOST_AXIS) {
Expand Down Expand Up @@ -661,7 +661,7 @@ Expr InequalitySolver::SolveByInferBound(const Array<Expr> &cons_on_var, const V
if (axis->HasAttr("DYN_SHAPE_LIMIT")) {
auto res = axis->GetAttrValue("DYN_SHAPE_LIMIT");
CHECK_EQ(res.size(), 1U);
auto range_limit = static_cast<int>(std::strtol(res[0].c_str(), nullptr, 10));
auto range_limit = StrToDecimalInt(res[0]);
new_constraints.push_back(axis->range_extent <= CastIntToExpr(range_limit));
}
});
Expand Down Expand Up @@ -833,7 +833,7 @@ Expr InequalitySolver::DetermineTileForDynamic(TileAxis *axis, const Expr &mem_c
if (axis->HasAttr("DYN_SHAPE_LIMIT")) {
auto shape_limit = axis->GetAttrValue("DYN_SHAPE_LIMIT");
CHECK_EQ(shape_limit.size(), 1U);
auto range_limit = static_cast<int>(std::strtol(shape_limit[0].c_str(), nullptr, 10));
auto range_limit = StrToDecimalInt(shape_limit[0]);
if (analyzer_.arith_ana_.CanProve(range_limit <= GetConstIntUpBound(max_final_factor))) {
final_factor = axis->range_extent;
} else {
Expand Down
2 changes: 1 addition & 1 deletion src/poly/tiling/tiling_solver.h
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ class GpuSolver : TilingSolver {
if (s.empty()) {
continue;
}
alloced_slot.emplace_back(static_cast<int>(std::strtol(s.c_str(), nullptr, 10)));
alloced_slot.emplace_back(StrToDecimalInt64(s));
}
}
for (size_t i = alloced_slot.size(); i < resource_limit.size(); ++i) {
Expand Down
90 changes: 78 additions & 12 deletions src/poly/tiling/tiling_strategy_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,6 @@ class TilingStrategy {
int64_t vector_size_ = 4;

const static int binary_factor_{2};
const static int decimal_factor_{10};
const static int double_warp_size_{64};
const static int quadruple_warp_size_{128};
};
Expand Down Expand Up @@ -123,7 +122,76 @@ class CustomTilingStrategy : public TilingStrategy {
explicit CustomTilingStrategy(const TilingAnalyzer *a) : TilingStrategy(a) { interested_attr_key = "CUSTOM"; }
void AddNpuConstraint() override;
void AddGpuConstraint() override;
void ApplyCustomConstraints(TileAxis *axis, std::string con, TileLevel lv);

private:
void ParseConstraintStr(const std::string &attr_key, const std::string &attr_value) {
std::vector<std::string> modes = akg::common::Split(attr_key, ":");
CHECK_EQ(modes.size(), 2U);
constraint_str_ = attr_value;
if (constraint_str_.find("->") != std::string::npos) {
std::vector<std::string> res = akg::common::Split(constraint_str_, "->");
related_buf_ = res[0];
constraint_str_ = res[1];
}
}

void ParseLevel() {
constraints_ = akg::common::Split(constraint_str_, "_");
CHECK_GE(constraints_.size(), 1U);
std::vector<std::string> level = akg::common::Split(constraints_[0], ":");
CHECK(level.size() == 2U && level[0] == "LEVEL");
CHECK(level[1] == kDsaC1 || level[1] == kDsaC0);
lv_ = level[1] == kDsaC1 ? CACHE1 : CACHE0;
void(constraints_.erase(constraints_.cbegin()));
}

void ApplyEachCustomConstraint(TileAxis *axis, const std::string &con) {
std::vector<std::string> items = akg::common::Split(con, ":");
CHECK_EQ(items.size(), 2U);
CHECK_NE(items[0], "");
CHECK_NE(items[1], "");
if (items[0] == "MIN") {
if (items[1] == "MIN") {
if (lv_ == CACHE1) {
axis->TileRestrainUpper(axis->c1_constraints.tile_min_, lv_);
} else if (lv_ == CACHE0) {
axis->TileRestrainUpper(axis->c0_constraints.tile_min_, lv_);
}
} else {
axis->TileRestrainLower(CastToExpr(items[1]), lv_);
}
} else if (items[0] == "FACTOR") {
axis->TileRestrainToSingleValue(CastToExpr(items[1]), lv_);
} else if (items[0] == "CANDIDATE") {
if (lv_ == CACHE1) {
axis->InsertC1CandFactor(CastToExpr(items[1]));
} else {
axis->InsertC0CandFactor(CastToExpr(items[1]));
}
} else if (items[0] == "MAX") {
if (items[1] == "FULL") {
axis->TileRestrainEntire(lv_);
} else {
axis->TileRestrainUpper(CastToExpr(items[1]), lv_);
}
} else if (items[0] == AT_MOD) {
axis->TileRestrainMod(CastToExpr(items[1]), lv_);
} else if (items[0] == "FORBIDISO") {
axis->forbid_iso = true;
} else if (items[0] == "PRIORITY") {
axis->priority = StrToDecimalInt(items[1]);
} else if (items[0] == "EXPANSION") {
std::string info = related_buf_ + "->" + items[1];
analyzer_->RootAxis()->MarkWithAttr(AttrInfo{"EXPANSION", info});
} else if (items[0] == "AXISINFO") {
axis->axis_type_ = items[1];
}
}

std::string constraint_str_;
std::string related_buf_;
TileLevel lv_{TileLevel::CACHE1};
std::vector<std::string> constraints_;
};

class ConflictTreeRangeStrategy : public TilingStrategy {
Expand Down Expand Up @@ -161,15 +229,13 @@ class CastStrategy : public TilingStrategy {
std::vector<std::string> src_info = akg::common::Split(src, ":");
CHECK_EQ(src_info.size(), 2U);
CHECK_NE(src_info[1], "");
axis->data_size[src_info[0]].emplace_back(
static_cast<int>(std::strtol(src_info[1].c_str(), nullptr, decimal_factor_)));
axis->data_size[src_info[0]].emplace_back(StrToDecimalInt(src_info[1]));
}

std::vector<std::string> dst_info = akg::common::Split(src_dst[1], ":");
CHECK_EQ(dst_info.size(), 2U);
CHECK_NE(dst_info[1], "");
axis->data_size[dst_info[0]].emplace_back(
static_cast<int>(std::strtol(dst_info[1].c_str(), nullptr, decimal_factor_)));
axis->data_size[dst_info[0]].emplace_back(StrToDecimalInt(dst_info[1]));
}
}
}
Expand Down Expand Up @@ -273,7 +339,7 @@ class ShiftAxisStrategy : public TilingStrategy {
shifted_axes_.insert(axis);
for (const auto &attr : it.second) {
CHECK_NE(attr.attr_value, "");
auto share_time = static_cast<int>(std::strtol(attr.attr_value.c_str(), nullptr, decimal_factor_));
auto share_time = StrToDecimalInt(attr.attr_value);
axis->TileRestrainToSingleValue(const_extent * (share_time + 1), CACHE1);
break;
}
Expand Down Expand Up @@ -316,8 +382,8 @@ class ConvStrategy : public TilingStrategy {
void SetFinalConfig(const MmaConv &macro_mma, const Mma &mma);

// Return a combination of total factor that can be divisible by shape_m and shape_n.
const std::pair<int64_t, int64_t> GetDivisibleFactorForMN(
int64_t shape_m, int64_t shape_n, int64_t total_factor, const Mma &mma);
const std::pair<int64_t, int64_t> GetDivisibleFactorForMN(int64_t shape_m, int64_t shape_n, int64_t total_factor,
const Mma &mma);

int w0_for_m_{1};
int w1_for_n_{1};
Expand Down Expand Up @@ -361,8 +427,8 @@ class GemmStrategy : public TilingStrategy {
int EstimateSharedSize(const Mma &alloc, int dtype);
int EstimateRegisterSize(const Mma &alloc, int dtype);
// Return a combination of total factor that can be divisible by shape_m and shape_n.
const std::pair<int64_t, int64_t> GetDivisibleFactorForMN(
int64_t shape_m, int64_t shape_n, int64_t total_factor, const Mma &mma);
const std::pair<int64_t, int64_t> GetDivisibleFactorForMN(int64_t shape_m, int64_t shape_n, int64_t total_factor,
const Mma &mma);

int w0_for_m_{1};
int w1_for_n_{1};
Expand Down Expand Up @@ -479,7 +545,7 @@ class GpuStrategy : public TilingStrategy {
int possible_threads_;
int coalesced_size_;
int total_injective_size_;
int64_t total_vectorized_bytes_ = 16; // The default total number of bytes for vectorization is 16.
int64_t total_vectorized_bytes_ = 16; // The default total number of bytes for vectorization is 16.
};

class CpuStrategy : public TilingStrategy {
Expand Down
Loading

0 comments on commit 1f0dad0

Please sign in to comment.