Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Don't infinitely recur #427

Merged
merged 2 commits into from
Mar 4, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
fix index div
  • Loading branch information
wsmoses committed Mar 4, 2025
commit 4526e638d8441c32e05aadd67297a854973b1a0f
112 changes: 73 additions & 39 deletions src/enzyme_ad/jax/Passes/AffineCFG.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2766,6 +2766,38 @@ optimizeExprFloorDiv(llvm::ArrayRef<AffineDimDescriptor> dims, AffineExpr lhs,
return mlir::getAffineConstantExpr(0, lhs.getContext());
}

if (auto add = dyn_cast<AffineBinaryOpExpr>(lhs)) {
if (add.getKind() == AffineExprKind::Add) {
for (int i = 0; i < 2; i++) {
auto lhs = i == 0 ? add.getLHS() : add.getRHS();
auto rhs = i == 0 ? add.getRHS() : add.getLHS();
auto lhse = dyn_cast<AffineDimExpr>(lhs);
if (!lhse)
continue;
auto rhse = dyn_cast<AffineBinaryOpExpr>(rhs);
if (!rhse)
continue;
if (rhse.getKind() != AffineExprKind::Mul)
continue;
auto mulconst = dyn_cast<AffineConstantExpr>(rhse.getRHS());
if (!mulconst)
continue;
auto dim = dims[lhse.getPosition()];
if (!dim.known)
continue;

if (dim.step < 0)
continue;
if (dim.lb != 0)
continue;
if (dim.ub != mulconst.getValue())
continue;
if (constRhs.getValue() % mulconst.getValue() == 0)
return rhse.getLHS().floorDiv(constRhs.floorDiv(mulconst));
}
}
}

return std::nullopt;
}

Expand Down Expand Up @@ -2802,43 +2834,44 @@ optimizeExprMod(llvm::ArrayRef<AffineDimDescriptor> dims, AffineExpr lhs,
return std::nullopt;
}

static std::optional<AffineExpr>
optimizeExprWithBounds(AffineExpr expr,
llvm::ArrayRef<AffineDimDescriptor> dims) {
AffineExpr optimizeExprWithBounds(AffineExpr expr,
llvm::ArrayRef<AffineDimDescriptor> dims) {
std::optional<AffineExpr> replacement;
auto binExpr = dyn_cast<AffineBinaryOpExpr>(expr);
if (!binExpr)
return std::nullopt;
return expr;

AffineExpr lhs = binExpr.getLHS(), rhs = binExpr.getRHS();
AffineExpr lhs = optimizeExprWithBounds(binExpr.getLHS(), dims);
AffineExpr rhs = optimizeExprWithBounds(binExpr.getRHS(), dims);

switch (expr.getKind()) {
case AffineExprKind::Add:
return lhs + rhs;
case AffineExprKind::Mul:
return lhs * rhs;
case AffineExprKind::Mod:
replacement = optimizeExprMod(dims, lhs, rhs);
break;
if (auto replacement = optimizeExprMod(dims, lhs, rhs))
return *replacement;
else
return lhs % rhs;
case AffineExprKind::FloorDiv:
replacement = optimizeExprFloorDiv(dims, lhs, rhs);
break;
default:
break;
if (auto replacement = optimizeExprFloorDiv(dims, lhs, rhs))
return *replacement;
else
return lhs.floorDiv(rhs);
}

return replacement;
return expr;
}

static AffineMap optimizeMap(AffineMap map,
llvm::ArrayRef<AffineDimDescriptor> dims) {
llvm::DenseMap<AffineExpr, AffineExpr> replacements;
SmallVector<AffineExpr> todo(map.getResults().begin(),
map.getResults().end());

map.walkExprs([&replacements, &dims](AffineExpr expr) {
auto val = optimizeExprWithBounds(expr, dims);
if (val.has_value())
replacements[expr] = *val;
});

return map.replace(replacements);
SmallVector<AffineExpr> todo;
for (auto expr : map.getResults())
todo.push_back(optimizeExprWithBounds(expr, dims));
return AffineMap::get(map.getNumDims(), map.getNumSymbols(), todo,
map.getContext());
}

// When all uses of an IV are of the form (%i % cst) or (%i // cst), replace
Expand Down Expand Up @@ -3106,6 +3139,10 @@ struct SplitParallelInductions
AffineExpr ubound0 =
op.getUpperBoundsMap().getResult(idx).floorDiv(baseExpr);

if (ubound0 * baseExpr != op.getUpperBoundsMap().getResult(idx)) {
continue;
}

if (ubound0 == mlir::getAffineConstantExpr(0, op.getContext())) {
continue;
}
Expand Down Expand Up @@ -3154,15 +3191,16 @@ struct SplitParallelInductions
SmallVector<Operation *> users(iv.getUsers().begin(),
iv.getUsers().end());

auto getDimExpr = [](Value iv, ValueRange operands) {
auto getDimExpr = [](Value iv, ValueRange operands) -> AffineDimExpr {
unsigned ivPos = 0;
for (unsigned i = 0; i < operands.size(); ++i) {
if (operands[i] == iv) {
ivPos = i;
break;
}
}
return mlir::getAffineDimExpr(ivPos, iv.getContext());
return cast<AffineDimExpr>(
mlir::getAffineDimExpr(ivPos, iv.getContext()));
};

auto getNewMap = [getDimExpr, ubound0, base](Value iv, AffineMap oldMap,
Expand Down Expand Up @@ -3210,10 +3248,9 @@ struct SplitParallelInductions
auto operands = AI.getOperands();
auto is = AI.getIntegerSet();

AffineExpr majorExpr = getDimExpr(iv, operands),
minorExpr = mlir::getAffineDimExpr(is.getNumDims(),
iv.getContext());

AffineDimExpr majorExpr = getDimExpr(iv, operands);
auto minorExpr =
mlir::getAffineDimExpr(is.getNumDims(), iv.getContext());
SmallVector<AffineDimDescriptor> dimDescriptors(
is.getNumDims() + 1, AffineDimDescriptor());

Expand All @@ -3225,18 +3262,15 @@ struct SplitParallelInductions

SmallVector<AffineExpr> newConstraints;
for (auto constraint : is.getConstraints()) {
if (!constraint.isFunctionOfDim(majorExpr.getPosition())) {
newConstraints.push_back(constraint);
continue;
}
auto E = constraint.replace(majorExpr,
majorExpr * baseExpr + minorExpr);
E = optimizeExprWithBounds(E, dimDescriptors);

DenseMap<AffineExpr, AffineExpr> replacements;
E.walk([&](AffineExpr subExpr) {
auto replacement =
optimizeExprWithBounds(subExpr, dimDescriptors);
if (replacement.has_value())
replacements[subExpr] = *replacement;
});

newConstraints.push_back(E.replace(replacements));
newConstraints.push_back(E);
}

auto newIntegerSet =
Expand Down Expand Up @@ -3692,8 +3726,8 @@ void AffineCFGPass::runOnOperation() {
AffineIfSimplification, CombineAffineIfs,
MergeNestedAffineParallelLoops, PrepMergeNestedAffineParallelLoops,
MergeNestedAffineParallelIf, MergeParallelInductions,
SplitParallelInductions, CanonicalieForBounds, AddAddCstEnd>(
getOperation()->getContext());
CanonicalieForBounds, AddAddCstEnd>(getOperation()->getContext(), 2);
rpl.add<SplitParallelInductions>(getOperation()->getContext(), 1);
GreedyRewriteConfig config;
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(rpl), config);
}
Expand Down
32 changes: 26 additions & 6 deletions src/enzyme_ad/jax/Passes/SimplifyAffineExprs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -336,12 +336,32 @@ AffineExpr mlir::enzyme::recreateExpr(AffineExpr expr) {
return sortSum(lhs) * sortSum(rhs);
case AffineExprKind::Mod:
return sortSum(lhs) % sortSum(rhs);
case AffineExprKind::FloorDiv:
return sortSum(lhs).floorDiv(sortSum(rhs));
case AffineExprKind::CeilDiv:
return sortSum(lhs).ceilDiv(sortSum(rhs));
default:
return expr;
case AffineExprKind::FloorDiv: {
rhs = sortSum(rhs);
SmallVector<AffineExpr> toDivide;
SmallVector<AffineExpr> alreadyDivided;
if (auto cst = dyn_cast<AffineConstantExpr>(rhs)) {
for (auto expr : getSumOperands(lhs)) {
if (expr.isMultipleOf(cst.getValue()))
alreadyDivided.push_back(expr.floorDiv(cst));
else
toDivide.push_back(expr);
}
} else {
toDivide.push_back(sortSum(lhs));
}
llvm::sort(toDivide, affineCmp);
AffineExpr out = getAffineConstantExpr(0, expr.getContext());
for (auto expr : toDivide)
out = out + expr;
out = out.floorDiv(rhs);
alreadyDivided.push_back(out);
out = getAffineConstantExpr(0, expr.getContext());
llvm::sort(alreadyDivided, affineCmp);
for (auto expr : alreadyDivided)
out = out + expr;
return out;
}
}
}
return expr;
Expand Down
Loading