Skip to content

Commit

Permalink
planner/cascades: implement ImplementationRule for HashAggregation (p…
Browse files Browse the repository at this point in the history
  • Loading branch information
francis0407 authored and sre-bot committed Oct 29, 2019
1 parent 17379cd commit 536a9c6
Show file tree
Hide file tree
Showing 8 changed files with 239 additions and 17 deletions.
27 changes: 27 additions & 0 deletions planner/cascades/implementation_rules.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,9 @@ var defaultImplementationMap = map[memo.Operand][]ImplementationRule{
memo.OperandSort: {
&ImplSort{},
},
memo.OperandAggregation: {
&ImplHashAgg{},
},
}

// ImplTableDual implements LogicalTableDual as PhysicalTableDual.
Expand Down Expand Up @@ -221,3 +224,27 @@ func (r *ImplSort) OnImplement(expr *memo.GroupExpr, reqProp *property.PhysicalP
)
return impl.NewSortImpl(ps), nil
}

// ImplHashAgg is the implementation rule which implements LogicalAggregation
// to PhysicalHashAgg.
type ImplHashAgg struct {
}

// Match implements ImplementationRule Match interface.
func (r *ImplHashAgg) Match(expr *memo.GroupExpr, prop *property.PhysicalProperty) (matched bool) {
// TODO: deal with the hints when we have implemented StreamAgg.
return prop.IsEmpty()
}

// OnImplement implements ImplementationRule OnImplement interface.
func (r *ImplHashAgg) OnImplement(expr *memo.GroupExpr, reqProp *property.PhysicalProperty) (memo.Implementation, error) {
la := expr.ExprNode.(*plannercore.LogicalAggregation)
hashAgg := plannercore.NewPhysicalHashAgg(
la,
expr.Group.Prop.Stats.ScaleByExpectCnt(reqProp.ExpectedCnt),
&property.PhysicalProperty{ExpectedCnt: math.MaxFloat64},
)
hashAgg.SetSchema(expr.Group.Prop.Schema.Clone())
// TODO: Implement TiKVHashAgg
return impl.NewTiDBHashAggImpl(hashAgg), nil
}
21 changes: 21 additions & 0 deletions planner/cascades/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,3 +116,24 @@ func (s *testIntegrationSuite) TestSort(c *C) {
tk.MustQuery(sql).Check(testkit.Rows(output[i].Result...))
}
}

func (s *testIntegrationSuite) TestAggregation(c *C) {
tk := testkit.NewTestKitWithInit(c, s.store)
tk.MustExec("drop table if exists t")
tk.MustExec("create table t(a int primary key, b int)")
tk.MustExec("insert into t values (1, 11), (4, 44), (2, 22), (3, 33)")
tk.MustExec("set session tidb_enable_cascades_planner = 1")
var input []string
var output []struct {
SQL string
Result []string
}
s.testData.GetTestCases(c, &input, &output)
for i, sql := range input {
s.testData.OnRecord(func() {
output[i].SQL = sql
output[i].Result = s.testData.ConvertRowsToStrings(tk.MustQuery(sql).Rows())
})
tk.MustQuery(sql).Check(testkit.Rows(output[i].Result...))
}
}
17 changes: 17 additions & 0 deletions planner/cascades/testdata/integration_suite_in.json
Original file line number Diff line number Diff line change
Expand Up @@ -28,5 +28,22 @@
"explain select b from t order by b, a+b, a",
"select b from t order by b, a+b, a"
]
},
{
"name": "TestAggregation",
"cases": [
"explain select sum(a) from t",
"select a from t",
"explain select max(a), min(b) from t",
"select max(a), min(b) from t",
"explain select b, avg(a) from t group by b order by b",
"select b, avg(a) from t group by b order by b",
"explain select b, avg(a) from t group by b having sum(a) > 1 order by b",
"select b, avg(a) from t group by b having sum(a) > 1 order by b",
"explain select max(a+b) from t",
"select max(a+b) from t",
"explain select sum(a) from t group by a, a+b order by a",
"select sum(a) from t group by a, a+b order by a"
]
}
]
117 changes: 117 additions & 0 deletions planner/cascades/testdata/integration_suite_out.json
Original file line number Diff line number Diff line change
Expand Up @@ -158,5 +158,122 @@
]
}
]
},
{
"Name": "TestAggregation",
"Cases": [
{
"SQL": "explain select sum(a) from t",
"Result": [
"Projection_6 1.00 root Column#3",
"└─HashAgg_7 1.00 root funcs:sum(Column#5)",
" └─Projection_10 10000.00 root cast(Column#1)",
" └─TableReader_8 10000.00 root data:TableScan_9",
" └─TableScan_9 10000.00 cop[tikv] table:t, range:[-inf,+inf], keep order:false, stats:pseudo"
]
},
{
"SQL": "select a from t",
"Result": [
"1",
"2",
"3",
"4"
]
},
{
"SQL": "explain select max(a), min(b) from t",
"Result": [
"Projection_6 1.00 root Column#3, Column#4",
"└─HashAgg_7 1.00 root funcs:max(Column#1), min(Column#2)",
" └─TableReader_8 10000.00 root data:TableScan_9",
" └─TableScan_9 10000.00 cop[tikv] table:t, range:[-inf,+inf], keep order:false, stats:pseudo"
]
},
{
"SQL": "select max(a), min(b) from t",
"Result": [
"4 11"
]
},
{
"SQL": "explain select b, avg(a) from t group by b order by b",
"Result": [
"Projection_8 8000.00 root Column#2, Column#3",
"└─Sort_14 8000.00 root Column#2:asc",
" └─HashAgg_10 8000.00 root group by:Column#8, funcs:avg(Column#6), firstrow(Column#7)",
" └─Projection_13 10000.00 root cast(Column#1), Column#2, Column#2",
" └─TableReader_11 10000.00 root data:TableScan_12",
" └─TableScan_12 10000.00 cop[tikv] table:t, range:[-inf,+inf], keep order:false, stats:pseudo"
]
},
{
"SQL": "select b, avg(a) from t group by b order by b",
"Result": [
"11 1.0000",
"22 2.0000",
"33 3.0000",
"44 4.0000"
]
},
{
"SQL": "explain select b, avg(a) from t group by b having sum(a) > 1 order by b",
"Result": [
"Projection_9 6400.00 root Column#5, Column#6",
"└─Sort_24 6400.00 root Column#5:asc",
" └─Selection_23 6400.00 root gt(Column#8, 1)",
" └─Projection_20 8000.00 root Column#2, Column#3, Column#4",
" └─HashAgg_14 8000.00 root group by:Column#14, funcs:avg(Column#11), sum(Column#12), firstrow(Column#13)",
" └─Projection_17 10000.00 root cast(Column#1), cast(Column#1), Column#2, Column#2",
" └─TableReader_15 10000.00 root data:TableScan_16",
" └─TableScan_16 10000.00 cop[tikv] table:t, range:[-inf,+inf], keep order:false, stats:pseudo"
]
},
{
"SQL": "select b, avg(a) from t group by b having sum(a) > 1 order by b",
"Result": [
"22 2.0000",
"33 3.0000",
"44 4.0000"
]
},
{
"SQL": "explain select max(a+b) from t",
"Result": [
"Projection_6 1.00 root Column#3",
"└─HashAgg_7 1.00 root funcs:max(Column#5)",
" └─Projection_10 10000.00 root plus(Column#1, Column#2)",
" └─TableReader_8 10000.00 root data:TableScan_9",
" └─TableScan_9 10000.00 cop[tikv] table:t, range:[-inf,+inf], keep order:false, stats:pseudo"
]
},
{
"SQL": "select max(a+b) from t",
"Result": [
"48"
]
},
{
"SQL": "explain select sum(a) from t group by a, a+b order by a",
"Result": [
"Projection_8 8000.00 root Column#4",
"└─Projection_10 8000.00 root Column#3, Column#1",
" └─Sort_16 8000.00 root Column#1:asc",
" └─HashAgg_12 8000.00 root group by:Column#10, Column#9, funcs:sum(Column#7), firstrow(Column#8)",
" └─Projection_15 10000.00 root cast(Column#1), Column#1, Column#1, plus(Column#1, Column#2)",
" └─TableReader_13 10000.00 root data:TableScan_14",
" └─TableScan_14 10000.00 cop[tikv] table:t, range:[-inf,+inf], keep order:false, stats:pseudo"
]
},
{
"SQL": "select sum(a) from t group by a, a+b order by a",
"Result": [
"1",
"2",
"3",
"4"
]
}
]
}
]
30 changes: 17 additions & 13 deletions planner/core/exhaust_physical_plans.go
Original file line number Diff line number Diff line change
Expand Up @@ -1467,16 +1467,28 @@ func (la *LogicalAggregation) getHashAggs(prop *property.PhysicalProperty) []Phy
taskTypes = append(taskTypes, property.RootTaskType)
}
for _, taskTp := range taskTypes {
agg := basePhysicalAgg{
GroupByItems: la.GroupByItems,
AggFuncs: la.AggFuncs,
}.initForHash(la.ctx, la.stats.ScaleByExpectCnt(prop.ExpectedCnt), la.blockOffset, &property.PhysicalProperty{ExpectedCnt: math.MaxFloat64, TaskTp: taskTp})
agg := NewPhysicalHashAgg(la, la.stats.ScaleByExpectCnt(prop.ExpectedCnt), &property.PhysicalProperty{ExpectedCnt: math.MaxFloat64, TaskTp: taskTp})
agg.SetSchema(la.schema.Clone())
hashAggs = append(hashAggs, agg)
}
return hashAggs
}

// ResetHintIfConflicted resets the aggHints.preferAggType if they are conflicted,
// and returns the two preferAggType hints.
func (la *LogicalAggregation) ResetHintIfConflicted() (preferHash bool, preferStream bool) {
preferHash = (la.aggHints.preferAggType & preferHashAgg) > 0
preferStream = (la.aggHints.preferAggType & preferStreamAgg) > 0
if preferHash && preferStream {
errMsg := "Optimizer aggregation hints are conflicted"
warning := ErrInternal.GenWithStack(errMsg)
la.ctx.GetSessionVars().StmtCtx.AppendWarning(warning)
la.aggHints.preferAggType = 0
preferHash, preferStream = false, false
}
return
}

func (la *LogicalAggregation) exhaustPhysicalPlans(prop *property.PhysicalProperty) []PhysicalPlan {
if la.aggHints.preferAggToCop {
if !la.canPushToCop() {
Expand All @@ -1487,15 +1499,7 @@ func (la *LogicalAggregation) exhaustPhysicalPlans(prop *property.PhysicalProper
}
}

preferHash := (la.aggHints.preferAggType & preferHashAgg) > 0
preferStream := (la.aggHints.preferAggType & preferStreamAgg) > 0
if preferHash && preferStream {
errMsg := "Optimizer aggregation hints are conflicted"
warning := ErrInternal.GenWithStack(errMsg)
la.ctx.GetSessionVars().StmtCtx.AppendWarning(warning)
la.aggHints.preferAggType = 0
preferHash, preferStream = false, false
}
preferHash, preferStream := la.ResetHintIfConflicted()

hashAggs := la.getHashAggs(prop)
if hashAggs != nil && preferHash {
Expand Down
9 changes: 9 additions & 0 deletions planner/core/physical_plans.go
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,15 @@ type PhysicalHashAgg struct {
basePhysicalAgg
}

// NewPhysicalHashAgg creates a new PhysicalHashAgg from a LogicalAggregation.
func NewPhysicalHashAgg(la *LogicalAggregation, newStats *property.StatsInfo, prop *property.PhysicalProperty) *PhysicalHashAgg {
agg := basePhysicalAgg{
GroupByItems: la.GroupByItems,
AggFuncs: la.AggFuncs,
}.initForHash(la.ctx, newStats, la.blockOffset, prop)
return agg
}

// PhysicalStreamAgg is stream operator of aggregate.
type PhysicalStreamAgg struct {
basePhysicalAgg
Expand Down
8 changes: 4 additions & 4 deletions planner/core/rule_inject_extra_projection.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,9 @@ func (pe *projInjector) inject(plan PhysicalPlan) PhysicalPlan {

switch p := plan.(type) {
case *PhysicalHashAgg:
plan = injectProjBelowAgg(plan, p.AggFuncs, p.GroupByItems)
plan = InjectProjBelowAgg(plan, p.AggFuncs, p.GroupByItems)
case *PhysicalStreamAgg:
plan = injectProjBelowAgg(plan, p.AggFuncs, p.GroupByItems)
plan = InjectProjBelowAgg(plan, p.AggFuncs, p.GroupByItems)
case *PhysicalSort:
plan = InjectProjBelowSort(p, p.ByItems)
case *PhysicalTopN:
Expand All @@ -67,10 +67,10 @@ func wrapCastForAggFuncs(sctx sessionctx.Context, aggFuncs []*aggregation.AggFun
}
}

// injectProjBelowAgg injects a ProjOperator below AggOperator. If all the args
// InjectProjBelowAgg injects a ProjOperator below AggOperator. If all the args
// of `aggFuncs`, and all the item of `groupByItems` are columns or constants,
// we do not need to build the `proj`.
func injectProjBelowAgg(aggPlan PhysicalPlan, aggFuncs []*aggregation.AggFuncDesc, groupByItems []expression.Expression) PhysicalPlan {
func InjectProjBelowAgg(aggPlan PhysicalPlan, aggFuncs []*aggregation.AggFuncDesc, groupByItems []expression.Expression) PhysicalPlan {
hasScalarFunc := false

wrapCastForAggFuncs(aggPlan.SCtx(), aggFuncs)
Expand Down
27 changes: 27 additions & 0 deletions planner/implementation/simple_plans.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,3 +69,30 @@ func (sel *TiKVSelectionImpl) CalcCost(outCount float64, children ...memo.Implem
func NewTiKVSelectionImpl(sel *plannercore.PhysicalSelection) *TiKVSelectionImpl {
return &TiKVSelectionImpl{baseImpl{plan: sel}}
}

// TiDBHashAggImpl is the implementation of PhysicalHashAgg in TiDB layer.
type TiDBHashAggImpl struct {
baseImpl
}

// CalcCost implements Implementation CalcCost interface.
func (agg *TiDBHashAggImpl) CalcCost(outCount float64, children ...memo.Implementation) float64 {
hashAgg := agg.plan.(*plannercore.PhysicalHashAgg)
selfCost := hashAgg.GetCost(children[0].GetPlan().Stats().RowCount, true)
agg.cost = selfCost + children[0].GetCost()
return agg.cost
}

// AttachChildren implements Implementation AttachChildren interface.
func (agg *TiDBHashAggImpl) AttachChildren(children ...memo.Implementation) memo.Implementation {
hashAgg := agg.plan.(*plannercore.PhysicalHashAgg)
hashAgg.SetChildren(children[0].GetPlan())
// Inject extraProjection if the AggFuncs or GroupByItems contain ScalarFunction.
plannercore.InjectProjBelowAgg(hashAgg, hashAgg.AggFuncs, hashAgg.GroupByItems)
return agg
}

// NewTiDBHashAggImpl creates a new TiDBHashAggImpl.
func NewTiDBHashAggImpl(agg *plannercore.PhysicalHashAgg) *TiDBHashAggImpl {
return &TiDBHashAggImpl{baseImpl{plan: agg}}
}

0 comments on commit 536a9c6

Please sign in to comment.