Skip to content

Commit

Permalink
Disable AGG(IF()) rewrite for AGG applied on NULLs
Browse files Browse the repository at this point in the history
This rewrite will filter out the null values. It could change the
behavior if the aggregation is also applied on NULLs.

Also change SetAggregationFunction::isCalledOnNullInput and
SetUnionFunction::isCalledOnNullInput to return true,
as they are actually applied on NULL inputs.

Add more tests for AGG(IF()) rewrites.
  • Loading branch information
yuanzhanhku authored and highker committed Aug 6, 2021
1 parent ce9298e commit 72644f6
Show file tree
Hide file tree
Showing 6 changed files with 88 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,12 @@ public InternalAggregationFunction specialize(BoundVariables boundVariables, int
return generateAggregation(type, outputType);
}

@Override
public boolean isCalledOnNullInput()
{
return true;
}

private static InternalAggregationFunction generateAggregation(Type type, ArrayType outputType)
{
DynamicClassLoader classLoader = new DynamicClassLoader(SetAggregationFunction.class.getClassLoader());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,12 @@ else if (state.get() == null) {
}
}

@Override
public boolean isCalledOnNullInput()
{
return true;
}

public static void output(SetAggregationState state, BlockBuilder out)
{
SetOfValues set = state.get();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -404,7 +404,7 @@ public PlanOptimizers(
ruleStats,
statsCalculator,
estimatedExchangesCostCalculator,
ImmutableSet.of(new RewriteAggregationIfToFilter())),
ImmutableSet.of(new RewriteAggregationIfToFilter(metadata.getFunctionAndTypeManager()))),
predicatePushDown,
new IterativeOptimizer(
ruleStats,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import com.facebook.presto.matching.Capture;
import com.facebook.presto.matching.Captures;
import com.facebook.presto.matching.Pattern;
import com.facebook.presto.metadata.FunctionAndTypeManager;
import com.facebook.presto.spi.plan.AggregationNode;
import com.facebook.presto.spi.plan.AggregationNode.Aggregation;
import com.facebook.presto.spi.plan.Assignments;
Expand Down Expand Up @@ -47,6 +48,7 @@
import static com.facebook.presto.sql.planner.plan.Patterns.source;
import static com.google.common.collect.ImmutableSet.toImmutableSet;
import static com.google.common.collect.ImmutableSortedMap.toImmutableSortedMap;
import static java.util.Objects.requireNonNull;
import static java.util.function.Function.identity;

/**
Expand All @@ -68,6 +70,13 @@ public class RewriteAggregationIfToFilter
private static final Pattern<AggregationNode> PATTERN = aggregation()
.with(source().matching(project().capturedAs(CHILD)));

private final FunctionAndTypeManager functionAndTypeManager;

public RewriteAggregationIfToFilter(FunctionAndTypeManager functionAndTypeManager)
{
this.functionAndTypeManager = requireNonNull(functionAndTypeManager, "functionManager is null");
}

@Override
public boolean isEnabled(Session session)
{
Expand Down Expand Up @@ -178,6 +187,10 @@ public Result apply(AggregationNode aggregationNode, Captures captures, Context

private boolean shouldRewriteAggregation(Aggregation aggregation, ProjectNode sourceProject)
{
if (functionAndTypeManager.getFunctionMetadata(aggregation.getFunctionHandle()).isCalledOnNullInput()) {
// This rewrite will filter out the null values. It could change the behavior if the aggregation is also applied on NULLs.
return false;
}
if (!(aggregation.getArguments().size() == 1 && aggregation.getArguments().get(0) instanceof VariableReferenceExpression)) {
// Currently we only handle aggregation with a single VariableReferenceExpression. The detailed expressions are in a project node below this aggregation.
return false;
Expand All @@ -191,7 +204,7 @@ private boolean shouldRewriteAggregation(Aggregation aggregation, ProjectNode so
return false;
}
SpecialFormExpression expression = (SpecialFormExpression) sourceExpression;
// Only rewrite the aggregation if the else branch is not present.
// Only rewrite the aggregation if the else branch is not present or the else result is NULL.
return expression.getForm() == IF && Expressions.isNull(expression.getArguments().get(2));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ public class TestRewriteAggregationIfToFilter
public void testDoesNotFireForNonIf()
{
// The aggregation expression is not an if expression.
tester().assertThat(new RewriteAggregationIfToFilter())
tester().assertThat(new RewriteAggregationIfToFilter(getFunctionManager()))
.on(p -> {
VariableReferenceExpression a = p.variable("a", BooleanType.BOOLEAN);
VariableReferenceExpression ds = p.variable("ds", VARCHAR);
Expand All @@ -60,7 +60,7 @@ public void testDoesNotFireForNonIf()
public void testDoesNotFireForIfWithElse()
{
// The if expression has an else branch. We cannot rewrite it.
tester().assertThat(new RewriteAggregationIfToFilter())
tester().assertThat(new RewriteAggregationIfToFilter(getFunctionManager()))
.on(p -> {
VariableReferenceExpression a = p.variable("a");
VariableReferenceExpression ds = p.variable("ds", VARCHAR);
Expand All @@ -75,7 +75,7 @@ public void testDoesNotFireForIfWithElse()
@Test
public void testFireOneAggregation()
{
tester().assertThat(new RewriteAggregationIfToFilter())
tester().assertThat(new RewriteAggregationIfToFilter(getFunctionManager()))
.on(p -> {
VariableReferenceExpression a = p.variable("a");
VariableReferenceExpression ds = p.variable("ds", VARCHAR);
Expand Down Expand Up @@ -104,7 +104,7 @@ public void testFireOneAggregation()
@Test
public void testFireTwoAggregations()
{
tester().assertThat(new RewriteAggregationIfToFilter())
tester().assertThat(new RewriteAggregationIfToFilter(getFunctionManager()))
.on(p -> {
VariableReferenceExpression a = p.variable("a");
VariableReferenceExpression b = p.variable("b");
Expand Down Expand Up @@ -145,7 +145,7 @@ public void testFireTwoAggregations()
@Test
public void testFireTwoAggregationsWithSharedInput()
{
tester().assertThat(new RewriteAggregationIfToFilter())
tester().assertThat(new RewriteAggregationIfToFilter(getFunctionManager()))
.on(p -> {
VariableReferenceExpression a = p.variable("a");
VariableReferenceExpression ds = p.variable("ds", VARCHAR);
Expand Down Expand Up @@ -181,7 +181,7 @@ public void testFireTwoAggregationsWithSharedInput()
@Test
public void testFireForOneOfTwoAggregations()
{
tester().assertThat(new RewriteAggregationIfToFilter())
tester().assertThat(new RewriteAggregationIfToFilter(getFunctionManager()))
.on(p -> {
VariableReferenceExpression a = p.variable("a");
VariableReferenceExpression b = p.variable("b");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,61 @@ public void testGroupAll()
"VALUES (BIGINT '3', BIGINT '6')");
}

@Test
public void testSetAggWithNulls()
{
assertions.assertQuery(
"SELECT y, set_agg(y) FILTER (WHERE x = 1) FROM (SELECT 1 x, 2 y UNION ALL SELECT NULL x, 20 y UNION ALL SELECT 1 x, NULL y) GROUP BY y ORDER BY y",
"VALUES (INTEGER '2', ARRAY[INTEGER '2']), (INTEGER '20', CAST(NULL AS ARRAY<INTEGER>)), (CAST(NULL AS INTEGER), ARRAY[CAST(NULL AS INTEGER)])");
assertions.assertQuery(
"SELECT y, set_agg(IF(x = 1,y)) FROM (SELECT 1 x, 2 y UNION ALL SELECT NULL x, 20 y UNION ALL SELECT 1 x, NULL y) GROUP BY y ORDER BY y",
"VALUES (INTEGER '2', ARRAY[INTEGER '2']), (INTEGER '20', ARRAY[CAST(NULL AS INTEGER)]), (CAST(NULL AS INTEGER), ARRAY[CAST(NULL AS INTEGER)])");
}

@Test
public void testApproxSet()
{
assertions.assertQuery(
"SELECT y, approx_set(y) FILTER (WHERE x = 1) FROM (SELECT NULL x, 20 y UNION ALL SELECT 1 x, NULL y) GROUP BY y ORDER BY y",
"VALUES (INTEGER '20', CAST(NULL AS HyperLogLog)), (CAST(NULL AS INTEGER), CAST(NULL AS HyperLogLog))");
assertions.assertQuery(
"SELECT y, approx_set(IF(x = 1,y)) FROM (SELECT NULL x, 20 y UNION ALL SELECT 1 x, NULL y) GROUP BY y ORDER BY y",
"VALUES (INTEGER '20', CAST(NULL AS HyperLogLog)), (CAST(NULL AS INTEGER), CAST(NULL AS HyperLogLog))");
}

@Test
public void testSetUnion()
{
assertions.assertQuery(
"SELECT set_union(x) FILTER (WHERE y > 1) FROM (SELECT ARRAY[1] x, 1 y UNION ALL SELECT NULL x, 1 y)",
"VALUES (CAST (NULL AS ARRAY<INTEGER>))");
assertions.assertQuery(
"SELECT set_union(IF(y > 1, x)) FROM (SELECT ARRAY[1] x, 1 y UNION ALL SELECT NULL x, 1 y)",
"VALUES (CAST (ARRAY[] AS ARRAY<INTEGER>))");
}

@Test
public void testMapUnion()
{
assertions.assertQuery(
"SELECT map_union(x) FILTER (WHERE y > 1) FROM (SELECT MAP(ARRAY[1], ARRAY[1]) x, 1 y UNION ALL SELECT NULL x, 1 y)",
"VALUES (CAST (NULL AS MAP<INTEGER, INTEGER>))");
assertions.assertQuery(
"SELECT map_union(IF(y > 1, x)) FROM (SELECT MAP(ARRAY[1], ARRAY[1]) x, 1 y UNION ALL SELECT NULL x, 1 y)",
"VALUES (CAST (NULL AS MAP<INTEGER, INTEGER>))");
}

@Test
public void testMapUnionSum()
{
assertions.assertQuery(
"SELECT map_union_sum(x) FILTER (WHERE y > 1) FROM (SELECT MAP(ARRAY[1], ARRAY[1]) x, 1 y UNION ALL SELECT NULL x, 1 y)",
"VALUES (CAST (NULL AS MAP<INTEGER, INTEGER>))");
assertions.assertQuery(
"SELECT map_union_sum(IF(y > 1, x)) FROM (SELECT MAP(ARRAY[1], ARRAY[1]) x, 1 y UNION ALL SELECT NULL x, 1 y)",
"VALUES (CAST (NULL AS MAP<INTEGER, INTEGER>))");
}

@Test
public void testGroupingSets()
{
Expand Down

0 comments on commit 72644f6

Please sign in to comment.