Skip to content

Commit

Permalink
[SPARK-42815][SQL] Subexpression elimination support shortcut expression
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?

Add a new config to shortcut subexpression elimination for expression `and`, `or`.

The subexpression may not need to eval even if it appears more than once.
e.g., `if(or(a, and(b, b)))`, the expression `b` would be skipped if `a` is true.

### Why are the changes needed?

avoid eval unnecessary expression.

### Does this PR introduce _any_ user-facing change?

no

### How was this patch tested?

add test

Closes apache#40446 from ulysses-you/shortcut.

Authored-by: ulysses-you <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
  • Loading branch information
ulysses-you authored and cloud-fan committed Mar 22, 2023
1 parent d679dab commit 6f7403b
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,17 @@ import scala.collection.mutable

import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
import org.apache.spark.sql.catalyst.expressions.objects.LambdaVariable
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.util.Utils

/**
* This class is used to compute equality of (sub)expression trees. Expressions can be added
* to this class and they subsequently query for expression equality. Expression trees are
* considered equal if for the same input(s), the same result is produced.
*/
class EquivalentExpressions {
class EquivalentExpressions(
skipForShortcutEnable: Boolean = SQLConf.get.subexpressionEliminationSkipForShotcutExpr) {

// For each expression, the set of equivalent expressions.
private val equivalenceMap = mutable.HashMap.empty[ExpressionEquals, ExpressionStats]

Expand Down Expand Up @@ -129,13 +132,27 @@ class EquivalentExpressions {
}
}

private def skipForShortcut(expr: Expression): Expression = {
if (skipForShortcutEnable) {
// The subexpression may not need to eval even if it appears more than once.
// e.g., `if(or(a, and(b, b)))`, the expression `b` would be skipped if `a` is true.
expr match {
case and: And => and.left
case or: Or => or.left
case other => other
}
} else {
expr
}
}

// There are some special expressions that we should not recurse into all of its children.
// 1. CodegenFallback: it's children will not be used to generate code (call eval() instead)
// 2. ConditionalExpression: use its children that will always be evaluated.
private def childrenToRecurse(expr: Expression): Seq[Expression] = expr match {
case _: CodegenFallback => Nil
case c: ConditionalExpression => c.alwaysEvaluatedInputs
case other => other.children
case c: ConditionalExpression => c.alwaysEvaluatedInputs.map(skipForShortcut)
case other => skipForShortcut(other).children
}

// For some special expressions we cannot just recurse into all of its children, but we can
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -864,6 +864,16 @@ object SQLConf {
.checkValue(_ >= 0, "The maximum must not be negative")
.createWithDefault(100)

val SUBEXPRESSION_ELIMINATION_SKIP_FOR_SHORTCUT_EXPR =
buildConf("spark.sql.subexpressionElimination.skipForShortcutExpr")
.internal()
.doc("When true, shortcut eliminate subexpression with `AND`, `OR`. " +
"The subexpression may not need to eval even if it appears more than once. " +
"e.g., `if(or(a, and(b, b)))`, the expression `b` would be skipped if `a` is true.")
.version("3.5.0")
.booleanConf
.createWithDefault(false)

val CASE_SENSITIVE = buildConf("spark.sql.caseSensitive")
.internal()
.doc("Whether the query analyzer should be case sensitive or not. " +
Expand Down Expand Up @@ -4610,6 +4620,9 @@ class SQLConf extends Serializable with Logging {
def subexpressionEliminationCacheMaxEntries: Int =
getConf(SUBEXPRESSION_ELIMINATION_CACHE_MAX_ENTRIES)

def subexpressionEliminationSkipForShotcutExpr: Boolean =
getConf(SUBEXPRESSION_ELIMINATION_SKIP_FOR_SHORTCUT_EXPR)

def autoBroadcastJoinThreshold: Long = getConf(AUTO_BROADCASTJOIN_THRESHOLD)

def limitInitialNumPartitions: Int = getConf(LIMIT_INITIAL_NUM_PARTITIONS)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
*/
package org.apache.spark.sql.catalyst.expressions

import java.util.Properties

import org.apache.spark.{SparkFunSuite, TaskContext, TaskContextImpl}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen._
Expand Down Expand Up @@ -424,7 +426,7 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel
test("SPARK-38333: PlanExpression expression should skip addExprTree function in Executor") {
try {
// suppose we are in executor
val context1 = new TaskContextImpl(0, 0, 0, 0, 0, 1, null, null, null, cpus = 0)
val context1 = new TaskContextImpl(0, 0, 0, 0, 0, 1, null, new Properties, null, cpus = 0)
TaskContext.setTaskContext(context1)

val equivalence = new EquivalentExpressions
Expand Down Expand Up @@ -465,6 +467,33 @@ class SubexpressionEliminationSuite extends SparkFunSuite with ExpressionEvalHel
val cseState = equivalence.getExprState(expr)
assert(hasMatching == cseState.isDefined)
}

test("SPARK-42815: Subexpression elimination support shortcut conditional expression") {
val add = Add(Literal(1), Literal(0))
val equal = EqualTo(add, add)

def checkShortcut(expr: Expression, numCommonExpr: Int): Unit = {
val e1 = If(expr, Literal(1), Literal(2))
val ee1 = new EquivalentExpressions(true)
ee1.addExprTree(e1)
assert(ee1.getCommonSubexpressions.size == numCommonExpr)

val e2 = expr
val ee2 = new EquivalentExpressions(true)
ee2.addExprTree(e2)
assert(ee2.getCommonSubexpressions.size == numCommonExpr)
}

// shortcut right child
checkShortcut(And(Literal(false), equal), 0)
checkShortcut(Or(Literal(true), equal), 0)
checkShortcut(Not(And(Literal(true), equal)), 0)

// always eliminate subexpression for left child
checkShortcut((And(equal, Literal(false))), 1)
checkShortcut(Or(equal, Literal(true)), 1)
checkShortcut(Not(And(equal, Literal(false))), 1)
}
}

case class CodegenFallbackExpression(child: Expression)
Expand Down

0 comments on commit 6f7403b

Please sign in to comment.