Skip to content

Commit

Permalink
Revert predicates simplification to SimplifyExpressions
Browse files Browse the repository at this point in the history
This reverts commit b2ef00c.

We found a correctness issue introduced in this pull request. The query below
produces two rows when it should only produce one.

      SELECT
        *
      FROM (
        values
          ('1234', 'a', '2016-01-01'),
          ('1235', 'b', '2016-01-01'),
      ) t(app, s, ds)
      WHERE
        ds > '2015-01-01' AND
        (app like '%234' OR app like '%235' OR
        app like '234%' OR app like '235%') AND
        (app like '%235' OR app like '235%')
  • Loading branch information
haozhun committed Feb 13, 2016
1 parent 444d254 commit 431a91e
Show file tree
Hide file tree
Showing 7 changed files with 72 additions and 376 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,13 @@

import static com.facebook.presto.sql.tree.BooleanLiteral.FALSE_LITERAL;
import static com.facebook.presto.sql.tree.BooleanLiteral.TRUE_LITERAL;
import static com.facebook.presto.sql.tree.ComparisonExpression.Type.EQUAL;
import static com.facebook.presto.sql.tree.ComparisonExpression.Type.GREATER_THAN;
import static com.facebook.presto.sql.tree.ComparisonExpression.Type.GREATER_THAN_OR_EQUAL;
import static com.facebook.presto.sql.tree.ComparisonExpression.Type.IS_DISTINCT_FROM;
import static com.facebook.presto.sql.tree.ComparisonExpression.Type.LESS_THAN;
import static com.facebook.presto.sql.tree.ComparisonExpression.Type.LESS_THAN_OR_EQUAL;
import static com.facebook.presto.sql.tree.ComparisonExpression.Type.NOT_EQUAL;
import static com.facebook.presto.util.ImmutableCollectors.toImmutableList;
import static com.google.common.base.Predicates.not;
import static com.google.common.collect.Iterables.filter;
Expand All @@ -49,26 +56,24 @@ private ExpressionUtils() {}

public static List<Expression> extractConjuncts(Expression expression)
{
return extractPredicates(LogicalBinaryExpression.Type.AND, expression);
}

public static List<Expression> extractDisjuncts(Expression expression)
{
return extractPredicates(LogicalBinaryExpression.Type.OR, expression);
}
if (expression instanceof LogicalBinaryExpression && ((LogicalBinaryExpression) expression).getType() == LogicalBinaryExpression.Type.AND) {
LogicalBinaryExpression and = (LogicalBinaryExpression) expression;
return ImmutableList.<Expression>builder()
.addAll(extractConjuncts(and.getLeft()))
.addAll(extractConjuncts(and.getRight()))
.build();
}

public static List<Expression> extractPredicates(LogicalBinaryExpression expression)
{
return extractPredicates(expression.getType(), expression);
return ImmutableList.of(expression);
}

public static List<Expression> extractPredicates(LogicalBinaryExpression.Type type, Expression expression)
public static List<Expression> extractDisjuncts(Expression expression)
{
if (expression instanceof LogicalBinaryExpression && ((LogicalBinaryExpression) expression).getType() == type) {
LogicalBinaryExpression logicalBinaryExpression = (LogicalBinaryExpression) expression;
if (expression instanceof LogicalBinaryExpression && ((LogicalBinaryExpression) expression).getType() == LogicalBinaryExpression.Type.OR) {
LogicalBinaryExpression or = (LogicalBinaryExpression) expression;
return ImmutableList.<Expression>builder()
.addAll(extractPredicates(type, logicalBinaryExpression.getLeft()))
.addAll(extractPredicates(type, logicalBinaryExpression.getRight()))
.addAll(extractDisjuncts(or.getLeft()))
.addAll(extractDisjuncts(or.getRight()))
.build();
}

Expand Down Expand Up @@ -109,20 +114,6 @@ public static Expression binaryExpression(LogicalBinaryExpression.Type type, Ite
return queue.remove();
}

public static Expression combinePredicates(LogicalBinaryExpression.Type type, Expression... expressions)
{
return combinePredicates(type, Arrays.asList(expressions));
}

public static Expression combinePredicates(LogicalBinaryExpression.Type type, Iterable<Expression> expressions)
{
if (type == LogicalBinaryExpression.Type.AND) {
return combineConjuncts(expressions);
}

return combineDisjuncts(expressions);
}

public static Expression combineConjuncts(Expression... expressions)
{
return combineConjuncts(Arrays.asList(expressions));
Expand Down Expand Up @@ -182,6 +173,28 @@ public static Expression stripDeterministicConjuncts(Expression expression)
.collect(toImmutableList()));
}

public static ComparisonExpression.Type flipComparison(ComparisonExpression.Type type)
{
switch (type) {
case EQUAL:
return EQUAL;
case NOT_EQUAL:
return NOT_EQUAL;
case LESS_THAN:
return GREATER_THAN;
case LESS_THAN_OR_EQUAL:
return GREATER_THAN_OR_EQUAL;
case GREATER_THAN:
return LESS_THAN;
case GREATER_THAN_OR_EQUAL:
return LESS_THAN_OR_EQUAL;
case IS_DISTINCT_FROM:
return IS_DISTINCT_FROM;
default:
throw new IllegalArgumentException("Unsupported comparison: " + type);
}
}

public static Function<Expression, Expression> expressionOrNullSymbols(final Predicate<Symbol>... nullSymbolScopes)
{
return expression -> {
Expand Down Expand Up @@ -217,13 +230,33 @@ private static Iterable<Expression> removeDuplicates(Iterable<Expression> expres
return Iterables.concat(nonDeterministicDisjuncts, deterministicDisjuncts);
}

private static ComparisonExpression.Type negate(ComparisonExpression.Type type)
{
switch (type) {
case EQUAL:
return NOT_EQUAL;
case NOT_EQUAL:
return EQUAL;
case LESS_THAN:
return GREATER_THAN_OR_EQUAL;
case LESS_THAN_OR_EQUAL:
return GREATER_THAN;
case GREATER_THAN:
return LESS_THAN_OR_EQUAL;
case GREATER_THAN_OR_EQUAL:
return LESS_THAN;
default:
throw new IllegalArgumentException("Unsupported comparison: " + type);
}
}

public static Expression normalize(Expression expression)
{
if (expression instanceof NotExpression) {
NotExpression not = (NotExpression) expression;
if (not.getValue() instanceof ComparisonExpression) {
ComparisonExpression comparison = (ComparisonExpression) not.getValue();
return new ComparisonExpression(comparison.getType().negate(), comparison.getLeft(), comparison.getRight());
return new ComparisonExpression(negate(comparison.getType()), comparison.getLeft(), comparison.getRight());
}
}
return expression;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
import static com.facebook.presto.sql.ExpressionUtils.and;
import static com.facebook.presto.sql.ExpressionUtils.combineConjuncts;
import static com.facebook.presto.sql.ExpressionUtils.combineDisjunctsWithDefault;
import static com.facebook.presto.sql.ExpressionUtils.flipComparison;
import static com.facebook.presto.sql.ExpressionUtils.or;
import static com.facebook.presto.sql.planner.LiteralInterpreter.toExpression;
import static com.facebook.presto.sql.tree.BooleanLiteral.FALSE_LITERAL;
Expand Down Expand Up @@ -533,7 +534,7 @@ private static Optional<NormalizedSimpleComparison> toNormalizedSimpleComparison
return Optional.of(new NormalizedSimpleComparison((QualifiedNameReference) left, comparison.getType(), new NullableValue(expressionTypes.get(comparison.getRight()), right)));
}
if (right instanceof QualifiedNameReference && !(left instanceof Expression)) {
return Optional.of(new NormalizedSimpleComparison((QualifiedNameReference) right, comparison.getType().flip(), new NullableValue(expressionTypes.get(comparison.getLeft()), left)));
return Optional.of(new NormalizedSimpleComparison((QualifiedNameReference) right, flipComparison(comparison.getType()), new NullableValue(expressionTypes.get(comparison.getLeft()), left)));
}
return Optional.empty();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@

import static com.facebook.presto.spi.type.BigintType.BIGINT;
import static com.facebook.presto.spi.type.BooleanType.BOOLEAN;
import static com.facebook.presto.sql.ExpressionUtils.flipComparison;
import static com.facebook.presto.sql.analyzer.SemanticErrorCode.NOT_SUPPORTED;
import static com.facebook.presto.sql.planner.ExpressionInterpreter.evaluateConstantExpression;
import static com.facebook.presto.sql.tree.ComparisonExpression.Type.EQUAL;
Expand Down Expand Up @@ -257,7 +258,7 @@ protected RelationPlan visitJoin(Join node, Void context)
else if (firstDependencies.stream().allMatch(right.canResolvePredicate()) && secondDependencies.stream().allMatch(left.canResolvePredicate())) {
leftExpression = comparison.getRight();
rightExpression = comparison.getLeft();
comparisonType = comparisonType.flip();
comparisonType = flipComparison(comparisonType);
}
else {
// must have a complex expression that involves both tuples on one side of the comparison expression (e.g., coalesce(left.x, right.x) = 1)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
import com.facebook.presto.metadata.Metadata;
import com.facebook.presto.spi.type.Type;
import com.facebook.presto.sql.parser.SqlParser;
import com.facebook.presto.sql.planner.DeterminismEvaluator;
import com.facebook.presto.sql.planner.ExpressionInterpreter;
import com.facebook.presto.sql.planner.LiteralInterpreter;
import com.facebook.presto.sql.planner.NoOpSymbolResolver;
Expand All @@ -30,36 +29,19 @@
import com.facebook.presto.sql.planner.plan.SimplePlanRewriter;
import com.facebook.presto.sql.planner.plan.TableScanNode;
import com.facebook.presto.sql.planner.plan.ValuesNode;
import com.facebook.presto.sql.tree.ComparisonExpression;
import com.facebook.presto.sql.tree.Expression;
import com.facebook.presto.sql.tree.ExpressionRewriter;
import com.facebook.presto.sql.tree.ExpressionTreeRewriter;
import com.facebook.presto.sql.tree.LogicalBinaryExpression;
import com.facebook.presto.sql.tree.NotExpression;
import com.facebook.presto.sql.tree.NullLiteral;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Maps;
import com.google.common.collect.Sets;

import java.util.Collection;
import java.util.HashSet;
import java.util.IdentityHashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;

import static com.facebook.presto.sql.ExpressionUtils.combinePredicates;
import static com.facebook.presto.sql.ExpressionUtils.extractPredicates;
import static com.facebook.presto.sql.analyzer.ExpressionAnalyzer.getExpressionTypes;
import static com.facebook.presto.sql.tree.BooleanLiteral.FALSE_LITERAL;
import static com.facebook.presto.sql.tree.BooleanLiteral.TRUE_LITERAL;
import static com.facebook.presto.sql.tree.ComparisonExpression.Type.IS_DISTINCT_FROM;
import static com.facebook.presto.sql.tree.LogicalBinaryExpression.Type.AND;
import static java.util.Collections.emptySet;
import static java.util.Objects.requireNonNull;
import static java.util.stream.Collectors.toList;
import static java.util.stream.Collectors.toSet;

public class SimplifyExpressions
extends PlanOptimizer
Expand Down Expand Up @@ -144,147 +126,11 @@ public PlanNode visitTableScan(TableScanNode node, RewriteContext<Void> context)
originalConstraint);
}

private Expression simplifyExpression(Expression expression)
private Expression simplifyExpression(Expression input)
{
expression = ExpressionTreeRewriter.rewriteWith(new PushDownNegationsExpressionRewriter(), expression);
expression = ExpressionTreeRewriter.rewriteWith(new ExtractCommonPredicatesExpressionRewriter(), expression, NodeContext.ROOT_NODE);
IdentityHashMap<Expression, Type> expressionTypes = getExpressionTypes(session, metadata, sqlParser, types, expression);
ExpressionInterpreter interpreter = ExpressionInterpreter.expressionOptimizer(expression, metadata, session, expressionTypes);
return LiteralInterpreter.toExpression(interpreter.optimize(NoOpSymbolResolver.INSTANCE), expressionTypes.get(expression));
}
}

private static class PushDownNegationsExpressionRewriter
extends ExpressionRewriter<Void>
{
@Override
public Expression rewriteNotExpression(NotExpression node, Void context, ExpressionTreeRewriter<Void> treeRewriter)
{
if (node.getValue() instanceof LogicalBinaryExpression) {
LogicalBinaryExpression child = (LogicalBinaryExpression) node.getValue();
List<Expression> predicates = extractPredicates(child);
List<Expression> negatedPredicates = predicates.stream()
.map(predicate -> treeRewriter.rewrite((Expression) new NotExpression(predicate), context))
.collect(toList());
return combinePredicates(child.getType().flip(), negatedPredicates);
}
else if (node.getValue() instanceof ComparisonExpression && ((ComparisonExpression) node.getValue()).getType() != IS_DISTINCT_FROM) {
ComparisonExpression child = (ComparisonExpression) node.getValue();
return new ComparisonExpression(
child.getType().negate(),
treeRewriter.rewrite(child.getLeft(), context),
treeRewriter.rewrite(child.getRight(), context));
}
else if (node.getValue() instanceof NotExpression) {
NotExpression child = (NotExpression) node.getValue();
return treeRewriter.rewrite(child.getValue(), context);
}

return new NotExpression(treeRewriter.rewrite(node.getValue(), context));
}
}

private enum NodeContext
{
ROOT_NODE,
NOT_ROOT_NODE;

boolean isRootNode()
{
return this == ROOT_NODE;
}
}

private static class ExtractCommonPredicatesExpressionRewriter
extends ExpressionRewriter<NodeContext>
{
@Override
public Expression rewriteExpression(Expression node, NodeContext context, ExpressionTreeRewriter<NodeContext> treeRewriter)
{
if (context.isRootNode()) {
return treeRewriter.rewrite(node, NodeContext.NOT_ROOT_NODE);
}

return null;
}

@Override
public Expression rewriteLogicalBinaryExpression(LogicalBinaryExpression node, NodeContext context, ExpressionTreeRewriter<NodeContext> treeRewriter)
{
List<Expression> predicates = extractPredicates(node.getType(), node).stream()
.map(expression -> treeRewriter.rewrite(expression, NodeContext.NOT_ROOT_NODE))
.collect(toList());

List<List<Expression>> subPredicates = getSubPredicates(predicates);
List<Expression> leafPredicates = getLeafPredicates(predicates);

List<Set<Expression>> deterministicSubPredicates = subPredicates.stream()
.map(this::filterDeterministicPredicates)
.collect(toList());

Set<Expression> commonPredicates = new HashSet<>(deterministicSubPredicates.stream()
.reduce(Sets::intersection)
.orElse(emptySet()));

List<List<Expression>> uncorrelatedSubPredicates = subPredicates.stream()
.map(predicateList -> removeAll(predicateList, commonPredicates))
.filter(predicateList -> !predicateList.isEmpty())
.collect(toList());

if (commonPredicates.isEmpty() || uncorrelatedSubPredicates.isEmpty()) {
return combinePredicates(node.getType(), predicates);
}

// Do not simplify top level conjuncts if it would result in top level disjuncts.
// Conjuncts are easier to process when pushing down predicates.
if (context.isRootNode() && node.getType() == AND && leafPredicates.isEmpty()) {
return combinePredicates(node.getType(), predicates);
}

LogicalBinaryExpression.Type flippedNodeType = node.getType().flip();

List<Expression> uncorrelatedPredicates = uncorrelatedSubPredicates.stream()
.map(predicate -> combinePredicates(flippedNodeType, predicate))
.collect(toList());

Expression result = combinePredicates(flippedNodeType,
combinePredicates(flippedNodeType, commonPredicates),
combinePredicates(node.getType(), uncorrelatedPredicates));

if (!leafPredicates.isEmpty()) {
result = combinePredicates(node.getType(), result, combinePredicates(node.getType(), leafPredicates));
}

return treeRewriter.rewrite(result, context);
}

private List<List<Expression>> getSubPredicates(List<Expression> predicates)
{
return predicates.stream()
.filter(predicate -> predicate instanceof LogicalBinaryExpression)
.map(predicate -> extractPredicates((LogicalBinaryExpression) predicate))
.collect(toList());
}

private List<Expression> getLeafPredicates(List<Expression> predicates)
{
return predicates.stream()
.filter(predicate -> !(predicate instanceof LogicalBinaryExpression))
.collect(toList());
}

private Set<Expression> filterDeterministicPredicates(List<Expression> predicates)
{
return predicates.stream()
.filter(DeterminismEvaluator::isDeterministic)
.collect(toSet());
}

private static <T> List<T> removeAll(Collection<T> collection, Collection<T> elementsToRemove)
{
return collection.stream()
.filter(element -> !elementsToRemove.contains(element))
.collect(toList());
IdentityHashMap<Expression, Type> expressionTypes = getExpressionTypes(session, metadata, sqlParser, types, input);
ExpressionInterpreter interpreter = ExpressionInterpreter.expressionOptimizer(input, metadata, session, expressionTypes);
return LiteralInterpreter.toExpression(interpreter.optimize(NoOpSymbolResolver.INSTANCE), expressionTypes.get(input));
}
}
}
Loading

0 comments on commit 431a91e

Please sign in to comment.