Skip to content

Commit

Permalink
QL: Simplify arithmetic operations in binary comps (elastic#66022)
Browse files Browse the repository at this point in the history
* Simplify arithmetic operations in binary comps

This commit adds an optimizer rule to simplify the arithmetic operations
in binary comparison expressions, which in turn will allow for further
expression compounding by the optimiser.

Only the negation and plus, minus, multiplication and division are
currently considered and only when two of the operands are a literal.

For instance `((a + 1) / 2 - 3) * 4 >= 14` becomes `a >= 12`.
  • Loading branch information
bpintea authored Feb 1, 2021
1 parent 480795c commit f5c2982
Show file tree
Hide file tree
Showing 33 changed files with 881 additions and 58 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
import org.elasticsearch.xpack.ql.optimizer.OptimizerRules.OptimizerRule;
import org.elasticsearch.xpack.ql.optimizer.OptimizerRules.PropagateEquals;
import org.elasticsearch.xpack.ql.optimizer.OptimizerRules.PruneLiteralsInOrderBy;
import org.elasticsearch.xpack.ql.optimizer.OptimizerRules.SimplifyComparisonsArithmetics;
import org.elasticsearch.xpack.ql.optimizer.OptimizerRules.ReplaceRegexMatch;
import org.elasticsearch.xpack.ql.optimizer.OptimizerRules.ReplaceSurrogateFunction;
import org.elasticsearch.xpack.ql.optimizer.OptimizerRules.SetAsOptimized;
Expand Down Expand Up @@ -87,6 +88,7 @@ protected Iterable<RuleExecutor<LogicalPlan>.Batch> batches() {
new CombineBinaryComparisons(),
new CombineDisjunctionsToIn(),
new PushDownAndCombineFilters(),
new SimplifyComparisonsArithmetics(DataTypes::areCompatible),
// prune/elimination
new PruneFilters(),
new PruneLiteralsInOrderBy(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
import org.elasticsearch.xpack.ql.expression.predicate.nulls.IsNull;
import org.elasticsearch.xpack.ql.expression.predicate.operator.comparison.Equals;
import org.elasticsearch.xpack.ql.expression.predicate.operator.comparison.GreaterThan;
import org.elasticsearch.xpack.ql.expression.predicate.operator.comparison.GreaterThanOrEqual;
import org.elasticsearch.xpack.ql.expression.predicate.operator.comparison.LessThan;
import org.elasticsearch.xpack.ql.expression.predicate.regex.Like;
import org.elasticsearch.xpack.ql.index.EsIndex;
Expand Down Expand Up @@ -665,6 +666,28 @@ public void testDifferentKeyFromDisjunction() {
assertEquals(filter, filterCondition(child2.children().get(0)));
}

// ((a + 1) - 3) * 4 >= 16 -> a >= 6.
public void testReduceBinaryComparisons() {
LogicalPlan plan = accept("foo where ((pid + 1) - 3) * 4 >= 16");
assertNotNull(plan);
List<LogicalPlan> filters = plan.collectFirstChildren(x -> x instanceof Filter);
assertNotNull(filters);
assertEquals(1, filters.size());
assertTrue(filters.get(0) instanceof Filter);
Filter filter = (Filter) filters.get(0);

assertTrue(filter.condition() instanceof And);
And and = (And) filter.condition();
assertTrue(and.right() instanceof GreaterThanOrEqual);
GreaterThanOrEqual gte = (GreaterThanOrEqual) and.right();

assertTrue(gte.left() instanceof FieldAttribute);
assertEquals("pid", ((FieldAttribute) gte.left()).name());

assertTrue(gte.right() instanceof Literal);
assertEquals(6, ((Literal) gte.right()).value());
}

private static Attribute timestamp() {
return new FieldAttribute(EMPTY, "test", new EsField("field", INTEGER, emptyMap(), true));
}
Expand Down
32 changes: 16 additions & 16 deletions x-pack/plugin/eql/src/test/resources/queryfolder_tests.txt
Original file line number Diff line number Diff line change
Expand Up @@ -528,35 +528,35 @@ process where wildcard~(process_path, "*\\red_ttp\\wininit.*", "*\\abc\\*", "*de


addOperator
process where serial_event_id + 2 == 41
process where serial_event_id + 2 == -2147483647
;
"script":{"source":"InternalQlScriptUtils.nullSafeFilter(InternalQlScriptUtils.eq(
InternalQlScriptUtils.add(InternalQlScriptUtils.docValue(doc,params.v0),params.v1),params.v2))",
"params":{"v0":"serial_event_id","v1":2,"v2":41}
"params":{"v0":"serial_event_id","v1":2,"v2":-2147483647}
;

addOperatorReversed
process where 2 + serial_event_id == 41
process where 2 + serial_event_id == -2147483647
;
"script":{"source":"InternalQlScriptUtils.nullSafeFilter(InternalQlScriptUtils.eq(
InternalQlScriptUtils.add(InternalQlScriptUtils.docValue(doc,params.v0),params.v1),params.v2))",
"params":{"v0":"serial_event_id","v1":2,"v2":41}
"params":{"v0":"serial_event_id","v1":2,"v2":-2147483647}
;

addFunction
process where add(serial_event_id, 2) == 41
process where add(serial_event_id, 2) == -2147483647
;
"script":{"source":"InternalQlScriptUtils.nullSafeFilter(InternalQlScriptUtils.eq(
InternalQlScriptUtils.add(InternalQlScriptUtils.docValue(doc,params.v0),params.v1),params.v2))",
"params":{"v0":"serial_event_id","v1":2,"v2":41}
"params":{"v0":"serial_event_id","v1":2,"v2":-2147483647}
;

addFunctionReversed
process where add(2, serial_event_id) == 41
process where add(2, serial_event_id) == -2147483647
;
"script":{"source":"InternalQlScriptUtils.nullSafeFilter(InternalQlScriptUtils.eq(
InternalQlScriptUtils.add(InternalQlScriptUtils.docValue(doc,params.v0),params.v1),params.v2))",
"params":{"v0":"serial_event_id","v1":2,"v2":41}
"params":{"v0":"serial_event_id","v1":2,"v2":-2147483647}
;

divideOperator
Expand Down Expand Up @@ -656,35 +656,35 @@ InternalQlScriptUtils.mul(InternalQlScriptUtils.docValue(doc,params.v0),params.v
;

subtractOperator
process where serial_event_id - 2 == 41
process where serial_event_id - 2 == 2147483647
;
"script":{"source":"InternalQlScriptUtils.nullSafeFilter(InternalQlScriptUtils.eq(
InternalQlScriptUtils.sub(InternalQlScriptUtils.docValue(doc,params.v0),params.v1),params.v2))",
"params":{"v0":"serial_event_id","v1":2,"v2":41}
"params":{"v0":"serial_event_id","v1":2,"v2":2147483647}
;

subtractOperatorReversed
process where 43 - serial_event_id == 41
process where 43 - serial_event_id == -2147483647
;
"script":{"source":"InternalQlScriptUtils.nullSafeFilter(InternalQlScriptUtils.eq(
InternalQlScriptUtils.sub(params.v0,InternalQlScriptUtils.docValue(doc,params.v1)),params.v2))",
"params":{"v0":43,"v1":"serial_event_id","v2":41}
"params":{"v0":43,"v1":"serial_event_id","v2":-2147483647}
;

subtractFunction
process where subtract(serial_event_id, 2) == 41
process where subtract(serial_event_id, 2) == 2147483647
;
"script":{"source":"InternalQlScriptUtils.nullSafeFilter(InternalQlScriptUtils.eq(
InternalQlScriptUtils.sub(InternalQlScriptUtils.docValue(doc,params.v0),params.v1),params.v2))",
"params":{"v0":"serial_event_id","v1":2,"v2":41}
"params":{"v0":"serial_event_id","v1":2,"v2":2147483647}
;

subtractFunctionReversed
process where subtract(43, serial_event_id) == 41
process where subtract(43, serial_event_id) == -2147483647
;
"script":{"source":"InternalQlScriptUtils.nullSafeFilter(InternalQlScriptUtils.eq(
InternalQlScriptUtils.sub(params.v0,InternalQlScriptUtils.docValue(doc,params.v1)),params.v2))",
"params":{"v0":43,"v1":"serial_event_id","v2":41}
"params":{"v0":43,"v1":"serial_event_id","v2":-2147483647}
;

eventQueryDefaultLimit
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
/**
* Addition function ({@code a + b}).
*/
public class Add extends DateTimeArithmeticOperation {
public class Add extends DateTimeArithmeticOperation implements BinaryComparisonInversible {
public Add(Source source, Expression left, Expression right) {
super(source, left, right, DefaultBinaryArithmeticOperation.ADD);
}
Expand All @@ -30,4 +30,9 @@ protected Add replaceChildren(Expression left, Expression right) {
public Add swapLeftAndRight() {
return new Add(source(), right(), left());
}

@Override
public ArithmeticOperationFactory binaryComparisonInverse() {
return Sub::new;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License;
* you may not use this file except in compliance with the Elastic License.
*/

package org.elasticsearch.xpack.ql.expression.predicate.operator.arithmetic;

import org.elasticsearch.xpack.ql.expression.Expression;
import org.elasticsearch.xpack.ql.tree.Source;

/*
* Factory interface for arithmetic operations that have an inverse in reference to a binary comparison.
* For instance the division is multiplication's inverse, substitution addition's, log exponentiation's a.s.o.
* Not all operations - like modulo - are invertible.
*/
public interface BinaryComparisonInversible {

interface ArithmeticOperationFactory {
ArithmeticOperation create(Source source, Expression left, Expression right);
}

ArithmeticOperationFactory binaryComparisonInverse();
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
/**
* Division function ({@code a / b}).
*/
public class Div extends ArithmeticOperation {
public class Div extends ArithmeticOperation implements BinaryComparisonInversible {

public Div(Source source, Expression left, Expression right) {
super(source, left, right, DefaultBinaryArithmeticOperation.DIV);
Expand All @@ -34,4 +34,9 @@ protected Div replaceChildren(Expression newLeft, Expression newRight) {
public DataType dataType() {
return DataTypeConverter.commonType(left().dataType(), right().dataType());
}

@Override
public ArithmeticOperationFactory binaryComparisonInverse() {
return Mul::new;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
/**
* Multiplication function ({@code a * b}).
*/
public class Mul extends ArithmeticOperation {
public class Mul extends ArithmeticOperation implements BinaryComparisonInversible {

public Mul(Source source, Expression left, Expression right) {
super(source, left, right, DefaultBinaryArithmeticOperation.MUL);
Expand Down Expand Up @@ -52,4 +52,9 @@ protected Mul replaceChildren(Expression newLeft, Expression newRight) {
public Mul swapLeftAndRight() {
return new Mul(source(), right(), left());
}

@Override
public ArithmeticOperationFactory binaryComparisonInverse() {
return Div::new;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
/**
* Subtraction function ({@code a - b}).
*/
public class Sub extends DateTimeArithmeticOperation {
public class Sub extends DateTimeArithmeticOperation implements BinaryComparisonInversible {

public Sub(Source source, Expression left, Expression right) {
super(source, left, right, DefaultBinaryArithmeticOperation.SUB);
Expand All @@ -27,4 +27,9 @@ protected NodeInfo<Sub> info() {
protected Sub replaceChildren(Expression newLeft, Expression newRight) {
return new Sub(source(), newLeft, newRight);
}

@Override
public ArithmeticOperationFactory binaryComparisonInverse() {
return Add::new;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -49,4 +49,11 @@ protected Pipe makePipe() {
public static Integer compare(Object left, Object right) {
return Comparisons.compare(left, right);
}

/**
* Reverses the direction of this comparison on the comparison axis.
* Some operations like Greater/LessThan/OrEqual will behave as if the operands of a numerical comparison get multiplied with a
* negative number. Others like Not/Equal can be immutable to this operation.
*/
public abstract BinaryComparison reverse();
}
Original file line number Diff line number Diff line change
Expand Up @@ -42,4 +42,9 @@ public Equals swapLeftAndRight() {
public BinaryComparison negate() {
return new NotEquals(source(), left(), right(), zoneId());
}

@Override
public BinaryComparison reverse() {
return this;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -38,4 +38,9 @@ public LessThan swapLeftAndRight() {
public LessThanOrEqual negate() {
return new LessThanOrEqual(source(), left(), right(), zoneId());
}

@Override
public BinaryComparison reverse() {
return new LessThan(source(), left(), right(), zoneId());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -38,4 +38,9 @@ public LessThanOrEqual swapLeftAndRight() {
public LessThan negate() {
return new LessThan(source(), left(), right(), zoneId());
}

@Override
public BinaryComparison reverse() {
return new LessThanOrEqual(source(), left(), right(), zoneId());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -38,4 +38,9 @@ public GreaterThan swapLeftAndRight() {
public GreaterThanOrEqual negate() {
return new GreaterThanOrEqual(source(), left(), right(), zoneId());
}

@Override
public BinaryComparison reverse() {
return new GreaterThan(source(), left(), right(), zoneId());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -38,4 +38,9 @@ public GreaterThanOrEqual swapLeftAndRight() {
public GreaterThan negate() {
return new GreaterThan(source(), left(), right(), zoneId());
}

@Override
public BinaryComparison reverse() {
return new GreaterThanOrEqual(source(), left(), right(), zoneId());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -38,4 +38,9 @@ public NotEquals swapLeftAndRight() {
public BinaryComparison negate() {
return new Equals(source(), left(), right(), zoneId());
}

@Override
public BinaryComparison reverse() {
return this;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -41,4 +41,9 @@ public NullEquals swapLeftAndRight() {
public Nullability nullable() {
return Nullability.FALSE;
}

@Override
public BinaryComparison reverse() {
return this;
}
}
Loading

0 comments on commit f5c2982

Please sign in to comment.