diff --git a/src/enzyme_ad/jax/Passes/AffineCFG.cpp b/src/enzyme_ad/jax/Passes/AffineCFG.cpp index b131b1471..2f6c27f97 100644 --- a/src/enzyme_ad/jax/Passes/AffineCFG.cpp +++ b/src/enzyme_ad/jax/Passes/AffineCFG.cpp @@ -2766,6 +2766,38 @@ optimizeExprFloorDiv(llvm::ArrayRef dims, AffineExpr lhs, return mlir::getAffineConstantExpr(0, lhs.getContext()); } + if (auto add = dyn_cast(lhs)) { + if (add.getKind() == AffineExprKind::Add) { + for (int i = 0; i < 2; i++) { + auto lhs = i == 0 ? add.getLHS() : add.getRHS(); + auto rhs = i == 0 ? add.getRHS() : add.getLHS(); + auto lhse = dyn_cast(lhs); + if (!lhse) + continue; + auto rhse = dyn_cast(rhs); + if (!rhse) + continue; + if (rhse.getKind() != AffineExprKind::Mul) + continue; + auto mulconst = dyn_cast(rhse.getRHS()); + if (!mulconst) + continue; + auto dim = dims[lhse.getPosition()]; + if (!dim.known) + continue; + + if (dim.step < 0) + continue; + if (dim.lb != 0) + continue; + if (dim.ub != mulconst.getValue()) + continue; + if (constRhs.getValue() % mulconst.getValue() == 0) + return rhse.getLHS().floorDiv(constRhs.floorDiv(mulconst)); + } + } + } + return std::nullopt; } @@ -2802,43 +2834,44 @@ optimizeExprMod(llvm::ArrayRef dims, AffineExpr lhs, return std::nullopt; } -static std::optional -optimizeExprWithBounds(AffineExpr expr, - llvm::ArrayRef dims) { +AffineExpr optimizeExprWithBounds(AffineExpr expr, + llvm::ArrayRef dims) { std::optional replacement; auto binExpr = dyn_cast(expr); if (!binExpr) - return std::nullopt; + return expr; - AffineExpr lhs = binExpr.getLHS(), rhs = binExpr.getRHS(); + AffineExpr lhs = optimizeExprWithBounds(binExpr.getLHS(), dims); + AffineExpr rhs = optimizeExprWithBounds(binExpr.getRHS(), dims); switch (expr.getKind()) { + case AffineExprKind::Add: + return lhs + rhs; + case AffineExprKind::Mul: + return lhs * rhs; case AffineExprKind::Mod: - replacement = optimizeExprMod(dims, lhs, rhs); - break; + if (auto replacement = optimizeExprMod(dims, lhs, rhs)) + return *replacement; + else + return lhs % rhs; case AffineExprKind::FloorDiv: - replacement = optimizeExprFloorDiv(dims, lhs, rhs); - break; - default: - break; + if (auto replacement = optimizeExprFloorDiv(dims, lhs, rhs)) + return *replacement; + else + return lhs.floorDiv(rhs); } - return replacement; + return expr; } static AffineMap optimizeMap(AffineMap map, llvm::ArrayRef dims) { llvm::DenseMap replacements; - SmallVector todo(map.getResults().begin(), - map.getResults().end()); - - map.walkExprs([&replacements, &dims](AffineExpr expr) { - auto val = optimizeExprWithBounds(expr, dims); - if (val.has_value()) - replacements[expr] = *val; - }); - - return map.replace(replacements); + SmallVector todo; + for (auto expr : map.getResults()) + todo.push_back(optimizeExprWithBounds(expr, dims)); + return AffineMap::get(map.getNumDims(), map.getNumSymbols(), todo, + map.getContext()); } // When all uses of an IV are of the form (%i % cst) or (%i // cst), replace @@ -2893,6 +2926,7 @@ struct SplitParallelInductions for (auto U : iv.getUsers()) { users.push_back(U); } + bool hasRemainder = false; while (!users.empty()) { auto U = users.pop_back_val(); SmallVector exprs; @@ -2963,6 +2997,9 @@ struct SplitParallelInductions continue; } else if (isa( U)) { + if (isa(U)) { + hasRemainder |= isa(U); + } Value newBase = U->getOperand(1); if (base.isValue && !base.v_val) @@ -2993,10 +3030,9 @@ struct SplitParallelInductions legal = false; break; } - auto findBasePattern = [](Value iv, AffineExpr root, ValueRange operands, ValueOrInt &base, - bool &legal) { + bool &legal, bool &hasRemainder) { SmallVector todo = {root}; while (!todo.empty()) { auto subExpr = todo.back(); @@ -3026,6 +3062,10 @@ struct SplitParallelInductions return; } + if (kind == AffineExprKind::Mod) { + hasRemainder = true; + } + if (base.isValue && base.v_val == nullptr) { base = newBase; } else if (base.isValue && newBase.isValue && @@ -3053,7 +3093,7 @@ struct SplitParallelInductions }; for (auto expr : exprs) { - findBasePattern(iv, expr, operands, base, legal); + findBasePattern(iv, expr, operands, base, legal, hasRemainder); if (!legal) break; } @@ -3065,6 +3105,8 @@ struct SplitParallelInductions if (base.isValue && !base.v_val) { legal = false; } + if (!hasRemainder) + legal = false; // We can add an extra iv if (legal) { @@ -3096,6 +3138,15 @@ struct SplitParallelInductions AffineExpr ubound0 = op.getUpperBoundsMap().getResult(idx).floorDiv(baseExpr); + + if (ubound0 * baseExpr != op.getUpperBoundsMap().getResult(idx)) { + continue; + } + + if (ubound0 == mlir::getAffineConstantExpr(0, op.getContext())) { + continue; + } + AffineExpr ubound1 = op.getUpperBoundsMap().getResult(idx).floorDiv(ubound0); @@ -3140,7 +3191,7 @@ struct SplitParallelInductions SmallVector users(iv.getUsers().begin(), iv.getUsers().end()); - auto getDimExpr = [](Value iv, ValueRange operands) { + auto getDimExpr = [](Value iv, ValueRange operands) -> AffineDimExpr { unsigned ivPos = 0; for (unsigned i = 0; i < operands.size(); ++i) { if (operands[i] == iv) { @@ -3148,7 +3199,8 @@ struct SplitParallelInductions break; } } - return mlir::getAffineDimExpr(ivPos, iv.getContext()); + return cast( + mlir::getAffineDimExpr(ivPos, iv.getContext())); }; auto getNewMap = [getDimExpr, ubound0, base](Value iv, AffineMap oldMap, @@ -3196,10 +3248,9 @@ struct SplitParallelInductions auto operands = AI.getOperands(); auto is = AI.getIntegerSet(); - AffineExpr majorExpr = getDimExpr(iv, operands), - minorExpr = mlir::getAffineDimExpr(is.getNumDims(), - iv.getContext()); - + AffineDimExpr majorExpr = getDimExpr(iv, operands); + auto minorExpr = + mlir::getAffineDimExpr(is.getNumDims(), iv.getContext()); SmallVector dimDescriptors( is.getNumDims() + 1, AffineDimDescriptor()); @@ -3211,18 +3262,15 @@ struct SplitParallelInductions SmallVector newConstraints; for (auto constraint : is.getConstraints()) { + if (!constraint.isFunctionOfDim(majorExpr.getPosition())) { + newConstraints.push_back(constraint); + continue; + } auto E = constraint.replace(majorExpr, majorExpr * baseExpr + minorExpr); + E = optimizeExprWithBounds(E, dimDescriptors); - DenseMap replacements; - E.walk([&](AffineExpr subExpr) { - auto replacement = - optimizeExprWithBounds(subExpr, dimDescriptors); - if (replacement.has_value()) - replacements[subExpr] = *replacement; - }); - - newConstraints.push_back(E.replace(replacements)); + newConstraints.push_back(E); } auto newIntegerSet = @@ -3678,8 +3726,8 @@ void AffineCFGPass::runOnOperation() { AffineIfSimplification, CombineAffineIfs, MergeNestedAffineParallelLoops, PrepMergeNestedAffineParallelLoops, MergeNestedAffineParallelIf, MergeParallelInductions, - SplitParallelInductions, CanonicalieForBounds, AddAddCstEnd>( - getOperation()->getContext()); + CanonicalieForBounds, AddAddCstEnd>(getOperation()->getContext(), 2); + rpl.add(getOperation()->getContext(), 1); GreedyRewriteConfig config; (void)applyPatternsAndFoldGreedily(getOperation(), std::move(rpl), config); } diff --git a/src/enzyme_ad/jax/Passes/SimplifyAffineExprs.cpp b/src/enzyme_ad/jax/Passes/SimplifyAffineExprs.cpp index 24c9544e8..d8584d026 100644 --- a/src/enzyme_ad/jax/Passes/SimplifyAffineExprs.cpp +++ b/src/enzyme_ad/jax/Passes/SimplifyAffineExprs.cpp @@ -336,12 +336,32 @@ AffineExpr mlir::enzyme::recreateExpr(AffineExpr expr) { return sortSum(lhs) * sortSum(rhs); case AffineExprKind::Mod: return sortSum(lhs) % sortSum(rhs); - case AffineExprKind::FloorDiv: - return sortSum(lhs).floorDiv(sortSum(rhs)); - case AffineExprKind::CeilDiv: - return sortSum(lhs).ceilDiv(sortSum(rhs)); - default: - return expr; + case AffineExprKind::FloorDiv: { + rhs = sortSum(rhs); + SmallVector toDivide; + SmallVector alreadyDivided; + if (auto cst = dyn_cast(rhs)) { + for (auto expr : getSumOperands(lhs)) { + if (expr.isMultipleOf(cst.getValue())) + alreadyDivided.push_back(expr.floorDiv(cst)); + else + toDivide.push_back(expr); + } + } else { + toDivide.push_back(sortSum(lhs)); + } + llvm::sort(toDivide, affineCmp); + AffineExpr out = getAffineConstantExpr(0, expr.getContext()); + for (auto expr : toDivide) + out = out + expr; + out = out.floorDiv(rhs); + alreadyDivided.push_back(out); + out = getAffineConstantExpr(0, expr.getContext()); + llvm::sort(alreadyDivided, affineCmp); + for (auto expr : alreadyDivided) + out = out + expr; + return out; + } } } return expr;