Skip to content

Commit

Permalink
[SPARK-49565][SQL] Improve auto-generated expression aliases with pip…
Browse files Browse the repository at this point in the history
…e SQL operators

### What changes were proposed in this pull request?

This RP improves auto-generated expression aliases with pipe SQL operators.

For example, consider the pipe SQL syntax query:

```
table t
|> extend 1
```

Previously, the analyzed plan was:

```
Project [x#x, y#x, pipeexpression(1, false, EXTEND) AS pipeexpression(1)#x]
+- SubqueryAlias spark_catalog.default.t
   +- Relation spark_catalog.default.t[x#x,y#x] csv
```

After this PR, it is:

```
Project [x#x, y#x, pipeexpression(1, false, EXTEND) AS 1#x]
+- SubqueryAlias spark_catalog.default.t
   +- Relation spark_catalog.default.t[x#x,y#x] csv
```

Note that the output aliases visible in the resulting DataFrame for the query derive from the `AS <alias>` part of the analyzed plans shown.

### Why are the changes needed?

This improves the user experience with pipe SQL syntax.

### Does this PR introduce _any_ user-facing change?

Yes, see above.

### How was this patch tested?

Existing golden file tests update to show the improved aliases.

### Was this patch authored or co-authored using generative AI tooling?

No.

Closes apache#49245 from dtenedor/fix-pipe-output-aliases.

Authored-by: Daniel Tenedorio <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
  • Loading branch information
dtenedor authored and cloud-fan committed Jan 10, 2025
1 parent e638f6d commit a8bec11
Show file tree
Hide file tree
Showing 6 changed files with 161 additions and 138 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -373,6 +373,7 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
ResolveProcedures ::
BindProcedures ::
ResolveTableSpec ::
ValidateAndStripPipeExpressions ::
ResolveAliases ::
ResolveSubquery ::
ResolveSubqueryColumnAliases ::
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,9 @@ package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateFunction
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, UnaryNode}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.trees.TreePattern.{PIPE_OPERATOR, TreePattern}
import org.apache.spark.sql.catalyst.trees.TreePattern.{PIPE_EXPRESSION, PIPE_OPERATOR, TreePattern}
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.types.DataType

/**
* Represents an expression when used with a SQL pipe operator.
Expand All @@ -33,31 +34,12 @@ import org.apache.spark.sql.errors.QueryCompilationErrors
* @param clause The clause of the pipe operator. This is used to generate error messages.
*/
case class PipeExpression(child: Expression, isAggregate: Boolean, clause: String)
extends UnaryExpression with RuntimeReplaceable {
extends UnaryExpression with Unevaluable {
final override val nodePatterns = Seq(PIPE_EXPRESSION)
final override lazy val resolved = false
override def withNewChildInternal(newChild: Expression): Expression =
PipeExpression(newChild, isAggregate, clause)
override lazy val replacement: Expression = {
val firstAggregateFunction: Option[AggregateFunction] = findFirstAggregate(child)
if (isAggregate && firstAggregateFunction.isEmpty) {
throw QueryCompilationErrors.pipeOperatorAggregateExpressionContainsNoAggregateFunction(child)
} else if (!isAggregate) {
firstAggregateFunction.foreach { a =>
throw QueryCompilationErrors.pipeOperatorContainsAggregateFunction(a, clause)
}
}
child
}

/** Returns the first aggregate function in the given expression, or None if not found. */
private def findFirstAggregate(e: Expression): Option[AggregateFunction] = e match {
case a: AggregateFunction =>
Some(a)
case _: WindowExpression =>
// Window functions are allowed in these pipe operators, so do not traverse into children.
None
case _ =>
e.children.flatMap(findFirstAggregate).headOption
}
override def dataType: DataType = child.dataType
}

/**
Expand All @@ -79,6 +61,43 @@ object EliminatePipeOperators extends Rule[LogicalPlan] {
}
}

/**
* Validates and strips PipeExpression nodes from a logical plan once the child expressions are
* resolved.
*/
object ValidateAndStripPipeExpressions extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUpWithPruning(
_.containsPattern(PIPE_EXPRESSION), ruleId) {
case node: LogicalPlan =>
node.resolveExpressions {
case p: PipeExpression if p.child.resolved =>
// Once the child expression is resolved, we can perform the necessary invariant checks
// and then remove this expression, replacing it with the child expression instead.
val firstAggregateFunction: Option[AggregateFunction] = findFirstAggregate(p.child)
if (p.isAggregate && firstAggregateFunction.isEmpty) {
throw QueryCompilationErrors
.pipeOperatorAggregateExpressionContainsNoAggregateFunction(p.child)
} else if (!p.isAggregate) {
firstAggregateFunction.foreach { a =>
throw QueryCompilationErrors.pipeOperatorContainsAggregateFunction(a, p.clause)
}
}
p.child
}
}

/** Returns the first aggregate function in the given expression, or None if not found. */
private def findFirstAggregate(e: Expression): Option[AggregateFunction] = e match {
case a: AggregateFunction =>
Some(a)
case _: WindowExpression =>
// Window functions are allowed in these pipe operators, so do not traverse into children.
None
case _ =>
e.children.flatMap(findFirstAggregate).headOption
}
}

object PipeOperators {
// These are definitions of query result clauses that can be used with the pipe operator.
val aggregateClause = "AGGREGATE"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ object RuleIdCollection {
"org.apache.spark.sql.catalyst.analysis.UpdateAttributeNullability" ::
"org.apache.spark.sql.catalyst.analysis.ResolveUpdateEventTimeWatermarkColumn" ::
"org.apache.spark.sql.catalyst.expressions.EliminatePipeOperators" ::
"org.apache.spark.sql.catalyst.expressions.ValidateAndStripPipeExpressions" ::
// Catalyst Optimizer rules
"org.apache.spark.sql.catalyst.optimizer.BooleanSimplification" ::
"org.apache.spark.sql.catalyst.optimizer.CollapseProject" ::
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ object TreePattern extends Enumeration {
val OUTER_REFERENCE: Value = Value
val PARAMETER: Value = Value
val PARAMETERIZED_QUERY: Value = Value
val PIPE_EXPRESSION: Value = Value
val PIPE_OPERATOR: Value = Value
val PIVOT: Value = Value
val PLAN_EXPRESSION: Value = Value
Expand Down
Loading

0 comments on commit a8bec11

Please sign in to comment.