Skip to content

Commit

Permalink
Compute case when operator and if function column statistics (StarRoc…
Browse files Browse the repository at this point in the history
  • Loading branch information
Youngwb authored Nov 23, 2021
1 parent 9ab0132 commit 45f89a2
Show file tree
Hide file tree
Showing 7 changed files with 113 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ private enum StatisticType {
ESTIMATE
}

// Used for the column statistics which we could not get from the statistics storage or
// can not compute the actual column statistics for now
private static final ColumnStatistic
UNKNOWN = new ColumnStatistic(NEGATIVE_INFINITY, POSITIVE_INFINITY, 0, 1, 1, StatisticType.UNKNOWN);

Expand Down Expand Up @@ -76,6 +78,11 @@ public boolean isUnknown() {
return this.type == StatisticType.UNKNOWN;
}

public boolean isInfiniteRange() {
return this.minValue == NEGATIVE_INFINITY || this.maxValue == POSITIVE_INFINITY;
}

// TODO(ywb): remove this after user can dump statistics with type
public boolean isUnknownValue() {
return this.minValue == NEGATIVE_INFINITY && this.maxValue == POSITIVE_INFINITY && this.nullsFraction == 0 &&
this.averageRowSize == 1 && this.distinctValuesCount == 1;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@
package com.starrocks.sql.optimizer.statistics;

import com.google.common.base.Preconditions;
import com.google.common.collect.Lists;
import com.starrocks.catalog.FunctionSet;
import com.starrocks.catalog.Type;
import com.starrocks.sql.optimizer.operator.scalar.CallOperator;
import com.starrocks.sql.optimizer.operator.scalar.CaseWhenOperator;
import com.starrocks.sql.optimizer.operator.scalar.ColumnRefOperator;
import com.starrocks.sql.optimizer.operator.scalar.ConstantOperator;
import com.starrocks.sql.optimizer.operator.scalar.ScalarOperator;
Expand Down Expand Up @@ -58,11 +60,31 @@ public ColumnStatistic visitConstant(ConstantOperator operator, Void context) {
if (value.isPresent()) {
return new ColumnStatistic(value.getAsDouble(), value.getAsDouble(), 0,
operator.getType().getSlotSize(), 1);
} else if (operator.getType().isStringType()) {
return new ColumnStatistic(Double.NEGATIVE_INFINITY, Double.POSITIVE_INFINITY, 0, 1, 1);
} else {
return ColumnStatistic.unknown();
}
}

@Override
public ColumnStatistic visitCaseWhenOperator(CaseWhenOperator caseWhenOperator, Void context) {
// 1. compute children column statistics
int whenClauseSize = caseWhenOperator.getWhenClauseSize();
List<ColumnStatistic> childrenColumnStatistics = Lists.newArrayList();
for (int i = 0; i < whenClauseSize; ++i) {
childrenColumnStatistics.add(caseWhenOperator.getThenClause(i).accept(this, context));
}
if (caseWhenOperator.hasElse()) {
childrenColumnStatistics.add(caseWhenOperator.getElseClause().accept(this, context));
}
// 2. use sum of then clause and else clause's distinct values as column distinctValues
double distinctValues = childrenColumnStatistics.stream().mapToDouble(
ColumnStatistic::getDistinctValuesCount).sum();
return new ColumnStatistic(Double.NEGATIVE_INFINITY, Double.POSITIVE_INFINITY, 0,
caseWhenOperator.getType().getSlotSize(), distinctValues);
}

@Override
public ColumnStatistic visitCall(CallOperator call, Void context) {
List<ColumnStatistic> childrenColumnStatistics =
Expand All @@ -81,8 +103,7 @@ public ColumnStatistic visitCall(CallOperator call, Void context) {
return binaryExpressionCalculate(call, childrenColumnStatistics.get(0),
childrenColumnStatistics.get(1));
} else {
// TODO: Multiple Arithmetic calculations support later
return childrenColumnStatistics.get(0);
return multiaryExpressionCalculate(call, childrenColumnStatistics);
}
}

Expand Down Expand Up @@ -189,6 +210,19 @@ private ColumnStatistic binaryExpressionCalculate(CallOperator callOperator, Col
}
}

private ColumnStatistic multiaryExpressionCalculate(CallOperator callOperator,
List<ColumnStatistic> childColumnStatisticList) {
switch (callOperator.getFnName().toLowerCase()) {
case FunctionSet.IF:
double distinctValues = childColumnStatisticList.get(1).getDistinctValuesCount() +
childColumnStatisticList.get(2).getDistinctValuesCount();
return new ColumnStatistic(Double.NEGATIVE_INFINITY, Double.POSITIVE_INFINITY, 0,
callOperator.getType().getSlotSize(), distinctValues);
default:
return childColumnStatisticList.get(0);
}
}

private double divisorNotZero(double value) {
return value == 0 ? 1.0 : value;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ public Statistics visitBinaryPredicate(BinaryPredicateOperator predicate, Void c
leftChildOpt = leftChild.isColumnRef() ? Optional.of((ColumnRefOperator) leftChild) : Optional.empty();

if (rightChild.isConstant()) {
OptionalDouble constant = rightColumnStatistic.isUnknown() ? OptionalDouble.empty() :
OptionalDouble constant = (rightColumnStatistic.isInfiniteRange()) ? OptionalDouble.empty() :
OptionalDouble.of(rightColumnStatistic.getMaxValue());
return BinaryPredicateStatisticCalculator.estimateColumnToConstantComparison(leftChildOpt,
leftColumnStatistic, predicate, constant, statistics);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,7 @@ private boolean canGenerateOneStageAgg(GroupExpression childBestExpr) {
childBestExpr.getOp() instanceof PhysicalDistributionOperator) {
// 2. check default column statistics or child output row may not be accurate
if (groupExpression.getGroup().getStatistics().getColumnStatistics().values().stream()
.allMatch(ColumnStatistic::isUnknown) ||
.anyMatch(ColumnStatistic::isUnknown) ||
childBestExpr.getGroup().getStatistics().isTableRowCountMayInaccurate()) {
// 3. check child expr distribution, if it is shuffle or gather without limit, could disable this plan
PhysicalDistributionOperator distributionOperator =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,14 @@

package com.starrocks.sql.optimizer.statistics;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.Lists;

import com.google.common.collect.Maps;
import com.starrocks.catalog.Type;
import com.starrocks.sql.optimizer.operator.scalar.BinaryPredicateOperator;
import com.starrocks.sql.optimizer.operator.scalar.CallOperator;
import com.starrocks.sql.optimizer.operator.scalar.CaseWhenOperator;
import com.starrocks.sql.optimizer.operator.scalar.CastOperator;
import com.starrocks.sql.optimizer.operator.scalar.ColumnRefOperator;
import com.starrocks.sql.optimizer.operator.scalar.ConstantOperator;
Expand Down Expand Up @@ -48,7 +52,8 @@ public void testConstant() {

ConstantOperator constantOperator2 = ConstantOperator.createChar("123");
ColumnStatistic columnStatistic2 = ExpressionStatisticCalculator.calculate(constantOperator2, null);
Assert.assertTrue(columnStatistic2.isUnknown());
Assert.assertTrue(columnStatistic2.isInfiniteRange());
Assert.assertEquals(columnStatistic2.getDistinctValuesCount(), 1, 0.001);
}

@Test
Expand Down Expand Up @@ -139,4 +144,24 @@ public void testCastOperator() {
Assert.assertEquals(-100, columnStatistic.getMinValue(), 0.001);
Assert.assertEquals(100, columnStatistic.getMaxValue(), 0.001);
}

@Test
public void testCaseWhenOperator() {
ColumnRefOperator columnRefOperator = new ColumnRefOperator(1, Type.INT, "", true);
BinaryPredicateOperator whenOperator1 =
new BinaryPredicateOperator(BinaryPredicateOperator.BinaryType.EQ, columnRefOperator,
ConstantOperator.createInt(1));
ConstantOperator constantOperator1 = ConstantOperator.createChar("1");
BinaryPredicateOperator whenOperator2 =
new BinaryPredicateOperator(BinaryPredicateOperator.BinaryType.EQ, columnRefOperator,
ConstantOperator.createInt(2));
ConstantOperator constantOperator2 = ConstantOperator.createChar("2");

CaseWhenOperator caseWhenOperator =
new CaseWhenOperator(Type.VARCHAR, null, ConstantOperator.createChar("others", Type.VARCHAR),
ImmutableList.of(whenOperator1, constantOperator1, whenOperator2, constantOperator2));
ColumnStatistic columnStatistic = ExpressionStatisticCalculator.calculate(caseWhenOperator, new Statistics(100,
Maps.newHashMap()));
Assert.assertEquals(columnStatistic.getDistinctValuesCount(), 3, 0.001);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -423,7 +423,7 @@ public void testJoinOnExpression() throws Exception {
sql = "SELECT COUNT(*) FROM lineitem JOIN orders ON l_orderkey * 2 = o_orderkey + 1 GROUP BY l_shipmode, l_shipinstruct, o_orderdate, o_orderstatus;";
plan = getCostExplain(sql);
Assert.assertTrue(plan.contains("equal join conjunct: [29: multiply, BIGINT, true] = [30: add, BIGINT, true]\n" +
" | cardinality: 600000000"));
" | cardinality: 600000000"));
}

@Test
Expand Down Expand Up @@ -622,6 +622,7 @@ public void testCountFunctionColumnStatistics() throws Exception {
Assert.assertTrue(plan.contains("count-->[0.0, 1000000.0, 0.0, 8.0, 1.0] ESTIMATE"));
}

@Test
public void testGenRuntimeFilterWhenRightJoin() throws Exception {
String sql = "select * from lineitem right anti join [shuffle] part on lineitem.l_partkey = part.p_partkey";
String plan = getVerboseExplain(sql);
Expand All @@ -645,4 +646,43 @@ public void testGenRuntimeFilterWhenRightJoin() throws Exception {
" | build runtime filters:\n" +
" | - filter_id = 0, build_expr = (18: P_PARTKEY), remote = true"));
}

@Test
public void testCaseWhenCardinalityEstimate() throws Exception {
String sql = "select (case when `O_ORDERKEY` = 0 then 'ALGERIA' when `O_ORDERKEY` = 1 then 'ARGENTINA' " +
"else 'others' end) a from orders group by 1";
String plan = getCostExplain(sql);
Assert.assertTrue(plan.contains("cardinality: 3"));
Assert.assertTrue(plan.contains("* case-->[-Infinity, Infinity, 0.0, 16.0, 3.0]"));

sql = "select (case when `O_ORDERKEY` = 0 then 'ALGERIA' when `O_ORDERKEY` = 1 then 'ARGENTINA' end) a " +
"from orders group by 1";
plan = getCostExplain(sql);
Assert.assertTrue(plan.contains("cardinality: 2"));
Assert.assertTrue(plan.contains("* case-->[-Infinity, Infinity, 0.0, 16.0, 2.0]"));

sql = "select (case when `O_ORDERKEY` = 0 then O_ORDERSTATUS when `O_ORDERKEY` = 1 then 'ARGENTINA' " +
"else 'other' end) a from orders group by 1";
plan = getCostExplain(sql);
Assert.assertTrue(plan.contains("* case-->[-Infinity, Infinity, 0.0, 16.0, 5.0] ESTIMATE"));
}

@Test
public void testIFFunctionCardinalityEstimate() throws Exception {
String sql = "select (case when `O_ORDERKEY` = 0 then 'ALGERIA' else 'others' end) a from orders group by 1";
String plan = getCostExplain(sql);
Assert.assertTrue(plan.contains("* case-->[-Infinity, Infinity, 0.0, 16.0, 2.0] ESTIMATE"));

sql = "select if(`O_ORDERKEY` = 0, 'ALGERIA', 'others') a from orders group by 1";
plan = getCostExplain(sql);
Assert.assertTrue(plan.contains("* if-->[-Infinity, Infinity, 0.0, 16.0, 2.0] ESTIMATE"));

sql = "select if(`O_ORDERKEY` = 0, 'ALGERIA', if (`O_ORDERKEY` = 1, 'ARGENTINA', 'others')) a from orders group by 1";
plan = getCostExplain(sql);
Assert.assertTrue(plan.contains("* if-->[-Infinity, Infinity, 0.0, 16.0, 3.0] ESTIMATE"));

sql = "select if(`O_ORDERKEY` = 0, 'ALGERIA', if (`O_ORDERKEY` = 1, 'ARGENTINA', if(`O_ORDERKEY` = 2, 'BRAZIL', 'Others'))) a from orders group by 1";
plan = getCostExplain(sql);
Assert.assertTrue(plan.contains("* if-->[-Infinity, Infinity, 0.0, 16.0, 4.0] ESTIMATE"));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ public void testTPCHQ8EnumPlan() throws Exception {
"order by\n" +
" o_year ;";
int planCount = getPlanCount(sql);
Assert.assertEquals(96, planCount);
Assert.assertEquals(48, planCount);
}

@Test
Expand Down

0 comments on commit 45f89a2

Please sign in to comment.