Skip to content

Commit

Permalink
[SPARK-31519][SQL] Cast in having aggregate expressions returns the w…
Browse files Browse the repository at this point in the history
…rong result

### What changes were proposed in this pull request?
Add a new logical node AggregateWithHaving, and the parser should create this plan for HAVING. The analyzer resolves it to Filter(..., Aggregate(...)).

### Why are the changes needed?
The SQL parser in Spark creates Filter(..., Aggregate(...)) for the HAVING query, and Spark has a special analyzer rule ResolveAggregateFunctions to resolve the aggregate functions and grouping columns in the Filter operator.

It works for simple cases in a very tricky way as it relies on rule execution order:
1. Rule ResolveReferences hits the Aggregate operator and resolves attributes inside aggregate functions, but the function itself is still unresolved as it's an UnresolvedFunction. This stops resolving the Filter operator as the child Aggrege operator is still unresolved.
2. Rule ResolveFunctions resolves UnresolvedFunction. This makes the Aggrege operator resolved.
3. Rule ResolveAggregateFunctions resolves the Filter operator if its child is a resolved Aggregate. This rule can correctly resolve the grouping columns.

In the example query, I put a CAST, which needs to be resolved by rule ResolveTimeZone, which runs after ResolveAggregateFunctions. This breaks step 3 as the Aggregate operator is unresolved at that time. Then the analyzer starts next round and the Filter operator is resolved by ResolveReferences, which wrongly resolves the grouping columns.

See the demo below:
```
SELECT SUM(a) AS b, '2020-01-01' AS fake FROM VALUES (1, 10), (2, 20) AS T(a, b) GROUP BY b HAVING b > 10
```
The query's result is
```
+---+----------+
|  b|      fake|
+---+----------+
|  2|2020-01-01|
+---+----------+
```
But if we add CAST, it will return an empty result.
```
SELECT SUM(a) AS b, CAST('2020-01-01' AS DATE) AS fake FROM VALUES (1, 10), (2, 20) AS T(a, b) GROUP BY b HAVING b > 10
```

### Does this PR introduce any user-facing change?
Yes, bug fix for cast in having aggregate expressions.

### How was this patch tested?
New UT added.

Closes apache#28294 from xuanyuanking/SPARK-31519.

Authored-by: Yuanjian Li <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
  • Loading branch information
xuanyuanking authored and cloud-fan committed Apr 28, 2020
1 parent 079b362 commit 6ed2dfb
Show file tree
Hide file tree
Showing 9 changed files with 135 additions and 80 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1239,13 +1239,13 @@ class Analyzer(
/**
* Resolves the attribute and extract value expressions(s) by traversing the
* input expression in top down manner. The traversal is done in top-down manner as
* we need to skip over unbound lamda function expression. The lamda expressions are
* we need to skip over unbound lambda function expression. The lambda expressions are
* resolved in a different rule [[ResolveLambdaVariables]]
*
* Example :
* SELECT transform(array(1, 2, 3), (x, i) -> x + i)"
*
* In the case above, x and i are resolved as lamda variables in [[ResolveLambdaVariables]]
* In the case above, x and i are resolved as lambda variables in [[ResolveLambdaVariables]]
*
* Note : In this routine, the unresolved attributes are resolved from the input plan's
* children attributes.
Expand Down Expand Up @@ -1400,6 +1400,9 @@ class Analyzer(
notMatchedActions = newNotMatchedActions)
}

// Skip the having clause here, this will be handled in ResolveAggregateFunctions.
case h: AggregateWithHaving => h

case q: LogicalPlan =>
logTrace(s"Attempting to resolve ${q.simpleString(SQLConf.get.maxToStringFields)}")
q.mapExpressions(resolveExpressionTopDown(_, q))
Expand Down Expand Up @@ -2040,62 +2043,14 @@ class Analyzer(
*/
object ResolveAggregateFunctions extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp {
case f @ Filter(cond, agg @ Aggregate(grouping, originalAggExprs, child)) if agg.resolved =>
// Resolve aggregate with having clause to Filter(..., Aggregate()). Note, to avoid wrongly
// resolve the having condition expression, here we skip resolving it in ResolveReferences
// and transform it to Filter after aggregate is resolved. See more details in SPARK-31519.
case AggregateWithHaving(cond, agg: Aggregate) if agg.resolved =>
resolveHaving(Filter(cond, agg), agg)

// Try resolving the condition of the filter as though it is in the aggregate clause
try {
val aggregatedCondition =
Aggregate(
grouping,
Alias(cond, "havingCondition")() :: Nil,
child)
val resolvedOperator = executeSameContext(aggregatedCondition)
def resolvedAggregateFilter =
resolvedOperator
.asInstanceOf[Aggregate]
.aggregateExpressions.head

// If resolution was successful and we see the filter has an aggregate in it, add it to
// the original aggregate operator.
if (resolvedOperator.resolved) {
// Try to replace all aggregate expressions in the filter by an alias.
val aggregateExpressions = ArrayBuffer.empty[NamedExpression]
val transformedAggregateFilter = resolvedAggregateFilter.transform {
case ae: AggregateExpression =>
val alias = Alias(ae, ae.toString)()
aggregateExpressions += alias
alias.toAttribute
// Grouping functions are handled in the rule [[ResolveGroupingAnalytics]].
case e: Expression if grouping.exists(_.semanticEquals(e)) &&
!ResolveGroupingAnalytics.hasGroupingFunction(e) &&
!agg.output.exists(_.semanticEquals(e)) =>
e match {
case ne: NamedExpression =>
aggregateExpressions += ne
ne.toAttribute
case _ =>
val alias = Alias(e, e.toString)()
aggregateExpressions += alias
alias.toAttribute
}
}

// Push the aggregate expressions into the aggregate (if any).
if (aggregateExpressions.nonEmpty) {
Project(agg.output,
Filter(transformedAggregateFilter,
agg.copy(aggregateExpressions = originalAggExprs ++ aggregateExpressions)))
} else {
f
}
} else {
f
}
} catch {
// Attempting to resolve in the aggregate can result in ambiguity. When this happens,
// just return the original plan.
case ae: AnalysisException => f
}
case f @ Filter(_, agg: Aggregate) if agg.resolved =>
resolveHaving(f, agg)

case sort @ Sort(sortOrder, global, aggregate: Aggregate) if aggregate.resolved =>

Expand Down Expand Up @@ -2166,6 +2121,63 @@ class Analyzer(
def containsAggregate(condition: Expression): Boolean = {
condition.find(_.isInstanceOf[AggregateExpression]).isDefined
}

def resolveHaving(filter: Filter, agg: Aggregate): LogicalPlan = {
// Try resolving the condition of the filter as though it is in the aggregate clause
try {
val aggregatedCondition =
Aggregate(
agg.groupingExpressions,
Alias(filter.condition, "havingCondition")() :: Nil,
agg.child)
val resolvedOperator = executeSameContext(aggregatedCondition)
def resolvedAggregateFilter =
resolvedOperator
.asInstanceOf[Aggregate]
.aggregateExpressions.head

// If resolution was successful and we see the filter has an aggregate in it, add it to
// the original aggregate operator.
if (resolvedOperator.resolved) {
// Try to replace all aggregate expressions in the filter by an alias.
val aggregateExpressions = ArrayBuffer.empty[NamedExpression]
val transformedAggregateFilter = resolvedAggregateFilter.transform {
case ae: AggregateExpression =>
val alias = Alias(ae, ae.toString)()
aggregateExpressions += alias
alias.toAttribute
// Grouping functions are handled in the rule [[ResolveGroupingAnalytics]].
case e: Expression if agg.groupingExpressions.exists(_.semanticEquals(e)) &&
!ResolveGroupingAnalytics.hasGroupingFunction(e) &&
!agg.output.exists(_.semanticEquals(e)) =>
e match {
case ne: NamedExpression =>
aggregateExpressions += ne
ne.toAttribute
case _ =>
val alias = Alias(e, e.toString)()
aggregateExpressions += alias
alias.toAttribute
}
}

// Push the aggregate expressions into the aggregate (if any).
if (aggregateExpressions.nonEmpty) {
Project(agg.output,
Filter(transformedAggregateFilter,
agg.copy(aggregateExpressions = agg.aggregateExpressions ++ aggregateExpressions)))
} else {
filter
}
} else {
filter
}
} catch {
// Attempting to resolve in the aggregate can result in ambiguity. When this happens,
// just return the original plan.
case ae: AnalysisException => filter
}
}
}

/**
Expand Down Expand Up @@ -2590,11 +2602,14 @@ class Analyzer(
def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperatorsDown {

case Filter(condition, _) if hasWindowFunction(condition) =>
failAnalysis("It is not allowed to use window functions inside WHERE and HAVING clauses")
failAnalysis("It is not allowed to use window functions inside WHERE clause")

case AggregateWithHaving(condition, _) if hasWindowFunction(condition) =>
failAnalysis("It is not allowed to use window functions inside HAVING clause")

// Aggregate with Having clause. This rule works with an unresolved Aggregate because
// a resolved Aggregate will not have Window Functions.
case f @ Filter(condition, a @ Aggregate(groupingExprs, aggregateExprs, child))
case f @ AggregateWithHaving(condition, a @ Aggregate(groupingExprs, aggregateExprs, child))
if child.resolved &&
hasWindowFunction(aggregateExprs) &&
a.expressions.forall(_.resolved) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.errors.TreeNodeException
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.catalyst.parser.ParserUtils
import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, UnaryNode}
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LeafNode, LogicalPlan, UnaryNode}
import org.apache.spark.sql.catalyst.trees.TreeNode
import org.apache.spark.sql.catalyst.util.quoteIdentifier
import org.apache.spark.sql.connector.catalog.{Identifier, TableCatalog}
Expand Down Expand Up @@ -538,3 +538,14 @@ case class UnresolvedOrdinal(ordinal: Int)
override def nullable: Boolean = throw new UnresolvedException(this, "nullable")
override lazy val resolved = false
}

/**
* Represents unresolved aggregate with having clause, it is turned by the analyzer into a Filter.
*/
case class AggregateWithHaving(
havingCondition: Expression,
child: Aggregate)
extends UnaryNode {
override lazy val resolved: Boolean = false
override def output: Seq[Attribute] = child.output
}
Original file line number Diff line number Diff line change
Expand Up @@ -364,6 +364,14 @@ package object dsl {
Aggregate(groupingExprs, aliasedExprs, logicalPlan)
}

def having(
groupingExprs: Expression*)(
aggregateExprs: Expression*)(
havingCondition: Expression): LogicalPlan = {
AggregateWithHaving(havingCondition,
groupBy(groupingExprs: _*)(aggregateExprs: _*).asInstanceOf[Aggregate])
}

def window(
windowExpressions: Seq[NamedExpression],
partitionSpec: Seq[Expression],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -629,7 +629,12 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging
case p: Predicate => p
case e => Cast(e, BooleanType)
}
Filter(predicate, plan)
plan match {
case aggregate: Aggregate =>
AggregateWithHaving(predicate, aggregate)
case _ =>
Filter(predicate, plan)
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ class PlanParserSuite extends AnalysisTest {
assertEqual("select a, b from db.c where x < 1", table("db", "c").where('x < 1).select('a, 'b))
assertEqual(
"select a, b from db.c having x < 1",
table("db", "c").groupBy()('a, 'b).where('x < 1))
table("db", "c").having()('a, 'b)('x < 1))
assertEqual("select distinct a, b from db.c", Distinct(table("db", "c").select('a, 'b)))
assertEqual("select all a, b from db.c", table("db", "c").select('a, 'b))
assertEqual("select from tbl", OneRowRelation().select('from.as("tbl")))
Expand Down Expand Up @@ -574,8 +574,7 @@ class PlanParserSuite extends AnalysisTest {
assertEqual(
"select g from t group by g having a > (select b from s)",
table("t")
.groupBy('g)('g)
.where('a > ScalarSubquery(table("s").select('b))))
.having('g)('g)('a > ScalarSubquery(table("s").select('b))))
}

test("table reference") {
Expand Down
3 changes: 3 additions & 0 deletions sql/core/src/test/resources/sql-tests/inputs/having.sql
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,6 @@ SELECT MIN(t.v) FROM (SELECT * FROM hav WHERE v > 0) t HAVING(COUNT(1) > 0);

-- SPARK-20329: make sure we handle timezones correctly
SELECT a + b FROM VALUES (1L, 2), (3L, 4) AS T(a, b) GROUP BY a + b HAVING a + b > 1;

-- SPARK-31519: Cast in having aggregate expressions returns the wrong result
SELECT SUM(a) AS b, CAST('2020-01-01' AS DATE) AS fake FROM VALUES (1, 10), (2, 20) AS T(a, b) GROUP BY b HAVING b > 10
10 changes: 9 additions & 1 deletion sql/core/src/test/resources/sql-tests/results/having.sql.out
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
-- Automatically generated by SQLQueryTestSuite
-- Number of queries: 5
-- Number of queries: 6


-- !query
Expand Down Expand Up @@ -47,3 +47,11 @@ struct<(a + CAST(b AS BIGINT)):bigint>
-- !query output
3
7


-- !query
SELECT SUM(a) AS b, CAST('2020-01-01' AS DATE) AS fake FROM VALUES (1, 10), (2, 20) AS T(a, b) GROUP BY b HAVING b > 10
-- !query schema
struct<b:bigint,fake:date>
-- !query output
2 2020-01-01
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,7 @@ SELECT * FROM empsalary WHERE row_number() OVER (ORDER BY salary) < 10
struct<>
-- !query output
org.apache.spark.sql.AnalysisException
It is not allowed to use window functions inside WHERE and HAVING clauses;
It is not allowed to use window functions inside WHERE clause;


-- !query
Expand Down Expand Up @@ -341,7 +341,7 @@ SELECT * FROM empsalary WHERE (rank() OVER (ORDER BY random())) > 10
struct<>
-- !query output
org.apache.spark.sql.AnalysisException
It is not allowed to use window functions inside WHERE and HAVING clauses;
It is not allowed to use window functions inside WHERE clause;


-- !query
Expand All @@ -350,7 +350,7 @@ SELECT * FROM empsalary WHERE rank() OVER (ORDER BY random())
struct<>
-- !query output
org.apache.spark.sql.AnalysisException
It is not allowed to use window functions inside WHERE and HAVING clauses;
It is not allowed to use window functions inside WHERE clause;


-- !query
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -665,40 +665,46 @@ class DataFrameWindowFunctionsSuite extends QueryTest
}

test("SPARK-24575: Window functions inside WHERE and HAVING clauses") {
def checkAnalysisError(df: => DataFrame): Unit = {
def checkAnalysisError(df: => DataFrame, clause: String): Unit = {
val thrownException = the[AnalysisException] thrownBy {
df.queryExecution.analyzed
}
assert(thrownException.message.contains("window functions inside WHERE and HAVING clauses"))
assert(thrownException.message.contains(s"window functions inside $clause clause"))
}

checkAnalysisError(testData2.select("a").where(rank().over(Window.orderBy($"b")) === 1))
checkAnalysisError(testData2.where($"b" === 2 && rank().over(Window.orderBy($"b")) === 1))
checkAnalysisError(
testData2.select("a").where(rank().over(Window.orderBy($"b")) === 1), "WHERE")
checkAnalysisError(
testData2.where($"b" === 2 && rank().over(Window.orderBy($"b")) === 1), "WHERE")
checkAnalysisError(
testData2.groupBy($"a")
.agg(avg($"b").as("avgb"))
.where($"a" > $"avgb" && rank().over(Window.orderBy($"a")) === 1))
.where($"a" > $"avgb" && rank().over(Window.orderBy($"a")) === 1), "WHERE")
checkAnalysisError(
testData2.groupBy($"a")
.agg(max($"b").as("maxb"), sum($"b").as("sumb"))
.where(rank().over(Window.orderBy($"a")) === 1))
.where(rank().over(Window.orderBy($"a")) === 1), "WHERE")
checkAnalysisError(
testData2.groupBy($"a")
.agg(max($"b").as("maxb"), sum($"b").as("sumb"))
.where($"sumb" === 5 && rank().over(Window.orderBy($"a")) === 1))
.where($"sumb" === 5 && rank().over(Window.orderBy($"a")) === 1), "WHERE")

checkAnalysisError(sql("SELECT a FROM testData2 WHERE RANK() OVER(ORDER BY b) = 1"))
checkAnalysisError(sql("SELECT * FROM testData2 WHERE b = 2 AND RANK() OVER(ORDER BY b) = 1"))
checkAnalysisError(sql("SELECT a FROM testData2 WHERE RANK() OVER(ORDER BY b) = 1"), "WHERE")
checkAnalysisError(
sql("SELECT * FROM testData2 WHERE b = 2 AND RANK() OVER(ORDER BY b) = 1"), "WHERE")
checkAnalysisError(
sql("SELECT * FROM testData2 GROUP BY a HAVING a > AVG(b) AND RANK() OVER(ORDER BY a) = 1"))
sql("SELECT * FROM testData2 GROUP BY a HAVING a > AVG(b) AND RANK() OVER(ORDER BY a) = 1"),
"HAVING")
checkAnalysisError(
sql("SELECT a, MAX(b), SUM(b) FROM testData2 GROUP BY a HAVING RANK() OVER(ORDER BY a) = 1"))
sql("SELECT a, MAX(b), SUM(b) FROM testData2 GROUP BY a HAVING RANK() OVER(ORDER BY a) = 1"),
"HAVING")
checkAnalysisError(
sql(
s"""SELECT a, MAX(b)
|FROM testData2
|GROUP BY a
|HAVING SUM(b) = 5 AND RANK() OVER(ORDER BY a) = 1""".stripMargin))
|HAVING SUM(b) = 5 AND RANK() OVER(ORDER BY a) = 1""".stripMargin),
"HAVING")
}

test("window functions in multiple selects") {
Expand Down

0 comments on commit 6ed2dfb

Please sign in to comment.