Skip to content

Commit

Permalink
Use upstream MLIR's getMappingId method
Browse files Browse the repository at this point in the history
  • Loading branch information
grypp committed Nov 22, 2022
1 parent 8941a06 commit 404e306
Showing 1 changed file with 6 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -599,16 +599,13 @@ static LogicalResult lowerWorkgroupCountComputingRegion(
}
workgroupCount.resize(3, rewriter.getIndexAttr(1));
permutedWorkgroupCount.resize(3, rewriter.getIndexAttr(1));
int dimId = 0;
int mappingId = 0;
for (auto map : mapping->getValue()) {
auto id = map.cast<mlir::gpu::GPUBlockMappingAttr>().getBlock();
if (id == mlir::gpu::Blocks::DimX)
permutedWorkgroupCount[0] = workgroupCount[dimId];
if (id == mlir::gpu::Blocks::DimY)
permutedWorkgroupCount[1] = workgroupCount[dimId];
if (id == mlir::gpu::Blocks::DimZ)
permutedWorkgroupCount[2] = workgroupCount[dimId];
dimId++;
int64_t dimId = map.cast<DeviceMappingAttrInterface>().getMappingId();
permutedWorkgroupCount[dimId] = workgroupCount[mappingId];
permutedWorkgroupCount[dimId] = workgroupCount[mappingId];
permutedWorkgroupCount[dimId] = workgroupCount[mappingId];
mappingId++;
}
rewriter.replaceOp(
workgroupCountOp,
Expand Down

0 comments on commit 404e306

Please sign in to comment.