Skip to content

Commit

Permalink
executor: migrate test-infra to testify for executor_required_rows_te…
Browse files Browse the repository at this point in the history
…st.go (pingcap#32680)

close pingcap#28576
  • Loading branch information
feitian124 authored Mar 1, 2022
1 parent d7d6afc commit 886c8a7
Showing 1 changed file with 72 additions and 80 deletions.
152 changes: 72 additions & 80 deletions executor/executor_required_rows_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@ import (
"fmt"
"math"
"math/rand"
"testing"
"time"

"github.com/cznic/mathutil"
. "github.com/pingcap/check"
"github.com/pingcap/tidb/expression"
"github.com/pingcap/tidb/expression/aggregation"
"github.com/pingcap/tidb/parser/ast"
Expand All @@ -37,18 +37,10 @@ import (
"github.com/pingcap/tidb/util/disk"
"github.com/pingcap/tidb/util/memory"
"github.com/pingcap/tidb/util/mock"
"github.com/stretchr/testify/require"
"github.com/tikv/client-go/v2/oracle"
)

var _ = SerialSuites(&testExecSuite{})
var _ = SerialSuites(&testExecSerialSuite{})

type testExecSuite struct {
}

type testExecSerialSuite struct {
}

type requiredRowsDataSource struct {
baseExecutor
totalRows int
Expand Down Expand Up @@ -133,7 +125,7 @@ func (r *requiredRowsDataSource) checkNumNextCalled() error {
return nil
}

func (s *testExecSuite) TestLimitRequiredRows(c *C) {
func TestLimitRequiredRows(t *testing.T) {
maxChunkSize := defaultCtx().GetSessionVars().MaxChunkSize
testCases := []struct {
totalRows int
Expand Down Expand Up @@ -190,15 +182,15 @@ func (s *testExecSuite) TestLimitRequiredRows(c *C) {
ctx := context.Background()
ds := newRequiredRowsDataSource(sctx, testCase.totalRows, testCase.expectedRowsDS)
exec := buildLimitExec(sctx, ds, testCase.limitOffset, testCase.limitCount)
c.Assert(exec.Open(ctx), IsNil)
require.NoError(t, exec.Open(ctx))
chk := newFirstChunk(exec)
for i := range testCase.requiredRows {
chk.SetRequiredRows(testCase.requiredRows[i], sctx.GetSessionVars().MaxChunkSize)
c.Assert(exec.Next(ctx, chk), IsNil)
c.Assert(chk.NumRows(), Equals, testCase.expectedRows[i])
require.NoError(t, exec.Next(ctx, chk))
require.Equal(t, testCase.expectedRows[i], chk.NumRows())
}
c.Assert(exec.Close(), IsNil)
c.Assert(ds.checkNumNextCalled(), IsNil)
require.NoError(t, exec.Close())
require.NoError(t, ds.checkNumNextCalled())
}
}

Expand All @@ -224,7 +216,7 @@ func defaultCtx() sessionctx.Context {
return ctx
}

func (s *testExecSuite) TestSortRequiredRows(c *C) {
func TestSortRequiredRows(t *testing.T) {
maxChunkSize := defaultCtx().GetSessionVars().MaxChunkSize
testCases := []struct {
totalRows int
Expand Down Expand Up @@ -273,15 +265,15 @@ func (s *testExecSuite) TestSortRequiredRows(c *C) {
byItems = append(byItems, &util.ByItems{Expr: col})
}
exec := buildSortExec(sctx, byItems, ds)
c.Assert(exec.Open(ctx), IsNil)
require.NoError(t, exec.Open(ctx))
chk := newFirstChunk(exec)
for i := range testCase.requiredRows {
chk.SetRequiredRows(testCase.requiredRows[i], maxChunkSize)
c.Assert(exec.Next(ctx, chk), IsNil)
c.Assert(chk.NumRows(), Equals, testCase.expectedRows[i])
require.NoError(t, exec.Next(ctx, chk))
require.Equal(t, testCase.expectedRows[i], chk.NumRows())
}
c.Assert(exec.Close(), IsNil)
c.Assert(ds.checkNumNextCalled(), IsNil)
require.NoError(t, exec.Close())
require.NoError(t, ds.checkNumNextCalled())
}
}

Expand All @@ -294,7 +286,7 @@ func buildSortExec(sctx sessionctx.Context, byItems []*util.ByItems, src Executo
return &sortExec
}

func (s *testExecSuite) TestTopNRequiredRows(c *C) {
func TestTopNRequiredRows(t *testing.T) {
maxChunkSize := defaultCtx().GetSessionVars().MaxChunkSize
testCases := []struct {
totalRows int
Expand Down Expand Up @@ -380,15 +372,15 @@ func (s *testExecSuite) TestTopNRequiredRows(c *C) {
byItems = append(byItems, &util.ByItems{Expr: col})
}
exec := buildTopNExec(sctx, testCase.topNOffset, testCase.topNCount, byItems, ds)
c.Assert(exec.Open(ctx), IsNil)
require.NoError(t, exec.Open(ctx))
chk := newFirstChunk(exec)
for i := range testCase.requiredRows {
chk.SetRequiredRows(testCase.requiredRows[i], maxChunkSize)
c.Assert(exec.Next(ctx, chk), IsNil)
c.Assert(chk.NumRows(), Equals, testCase.expectedRows[i])
require.NoError(t, exec.Next(ctx, chk))
require.Equal(t, testCase.expectedRows[i], chk.NumRows())
}
c.Assert(exec.Close(), IsNil)
c.Assert(ds.checkNumNextCalled(), IsNil)
require.NoError(t, exec.Close())
require.NoError(t, ds.checkNumNextCalled())
}
}

Expand All @@ -404,7 +396,7 @@ func buildTopNExec(ctx sessionctx.Context, offset, count int, byItems []*util.By
}
}

func (s *testExecSuite) TestSelectionRequiredRows(c *C) {
func TestSelectionRequiredRows(t *testing.T) {
gen01 := func() func(valType *types.FieldType) interface{} {
closureCount := 0
return func(valType *types.FieldType) interface{} {
Expand Down Expand Up @@ -469,19 +461,19 @@ func (s *testExecSuite) TestSelectionRequiredRows(c *C) {
Value: types.NewDatum(testCase.filtersOfCol1),
RetType: types.NewFieldType(mysql.TypeTiny),
})
c.Assert(err, IsNil)
require.NoError(t, err)
filters = append(filters, f)
}
exec := buildSelectionExec(sctx, filters, ds)
c.Assert(exec.Open(ctx), IsNil)
require.NoError(t, exec.Open(ctx))
chk := newFirstChunk(exec)
for i := range testCase.requiredRows {
chk.SetRequiredRows(testCase.requiredRows[i], maxChunkSize)
c.Assert(exec.Next(ctx, chk), IsNil)
c.Assert(chk.NumRows(), Equals, testCase.expectedRows[i])
require.NoError(t, exec.Next(ctx, chk))
require.Equal(t, testCase.expectedRows[i], chk.NumRows())
}
c.Assert(exec.Close(), IsNil)
c.Assert(ds.checkNumNextCalled(), IsNil)
require.NoError(t, exec.Close())
require.NoError(t, ds.checkNumNextCalled())
}
}

Expand All @@ -492,7 +484,7 @@ func buildSelectionExec(ctx sessionctx.Context, filters []expression.Expression,
}
}

func (s *testExecSuite) TestProjectionUnparallelRequiredRows(c *C) {
func TestProjectionUnparallelRequiredRows(t *testing.T) {
maxChunkSize := defaultCtx().GetSessionVars().MaxChunkSize
testCases := []struct {
totalRows int
Expand Down Expand Up @@ -531,20 +523,20 @@ func (s *testExecSuite) TestProjectionUnparallelRequiredRows(c *C) {
}
}
exec := buildProjectionExec(sctx, exprs, ds, 0)
c.Assert(exec.Open(ctx), IsNil)
require.NoError(t, exec.Open(ctx))
chk := newFirstChunk(exec)
for i := range testCase.requiredRows {
chk.SetRequiredRows(testCase.requiredRows[i], maxChunkSize)
c.Assert(exec.Next(ctx, chk), IsNil)
c.Assert(chk.NumRows(), Equals, testCase.expectedRows[i])
require.NoError(t, exec.Next(ctx, chk))
require.Equal(t, testCase.expectedRows[i], chk.NumRows())
}
c.Assert(exec.Close(), IsNil)
c.Assert(ds.checkNumNextCalled(), IsNil)
require.NoError(t, exec.Close())
require.NoError(t, ds.checkNumNextCalled())
}
}

func (s *testExecSuite) TestProjectionParallelRequiredRows(c *C) {
c.Skip("not stable because of goroutine schedule")
func TestProjectionParallelRequiredRows(t *testing.T) {
t.Skip("not stable because of goroutine schedule")
maxChunkSize := defaultCtx().GetSessionVars().MaxChunkSize
testCases := []struct {
totalRows int
Expand Down Expand Up @@ -587,19 +579,19 @@ func (s *testExecSuite) TestProjectionParallelRequiredRows(c *C) {
}
}
exec := buildProjectionExec(sctx, exprs, ds, testCase.numWorkers)
c.Assert(exec.Open(ctx), IsNil)
require.NoError(t, exec.Open(ctx))
chk := newFirstChunk(exec)
for i := range testCase.requiredRows {
chk.SetRequiredRows(testCase.requiredRows[i], maxChunkSize)
c.Assert(exec.Next(ctx, chk), IsNil)
c.Assert(chk.NumRows(), Equals, testCase.expectedRows[i])
require.NoError(t, exec.Next(ctx, chk))
require.Equal(t, testCase.expectedRows[i], chk.NumRows())

// wait projectionInputFetcher blocked on fetching data
// from child in the background.
time.Sleep(time.Millisecond * 25)
}
c.Assert(exec.Close(), IsNil)
c.Assert(ds.checkNumNextCalled(), IsNil)
require.NoError(t, exec.Close())
require.NoError(t, ds.checkNumNextCalled())
}
}

Expand Down Expand Up @@ -630,7 +622,7 @@ func divGenerator(factor int) func(valType *types.FieldType) interface{} {
}
}

func (s *testExecSuite) TestStreamAggRequiredRows(c *C) {
func TestStreamAggRequiredRows(t *testing.T) {
maxChunkSize := defaultCtx().GetSessionVars().MaxChunkSize
testCases := []struct {
totalRows int
Expand Down Expand Up @@ -674,22 +666,22 @@ func (s *testExecSuite) TestStreamAggRequiredRows(c *C) {
schema := expression.NewSchema(childCols...)
groupBy := []expression.Expression{childCols[1]}
aggFunc, err := aggregation.NewAggFuncDesc(sctx, testCase.aggFunc, []expression.Expression{childCols[0]}, true)
c.Assert(err, IsNil)
require.NoError(t, err)
aggFuncs := []*aggregation.AggFuncDesc{aggFunc}
exec := buildStreamAggExecutor(sctx, ds, schema, aggFuncs, groupBy, 1, true)
c.Assert(exec.Open(ctx), IsNil)
require.NoError(t, exec.Open(ctx))
chk := newFirstChunk(exec)
for i := range testCase.requiredRows {
chk.SetRequiredRows(testCase.requiredRows[i], maxChunkSize)
c.Assert(exec.Next(ctx, chk), IsNil)
c.Assert(chk.NumRows(), Equals, testCase.expectedRows[i])
require.NoError(t, exec.Next(ctx, chk))
require.Equal(t, testCase.expectedRows[i], chk.NumRows())
}
c.Assert(exec.Close(), IsNil)
c.Assert(ds.checkNumNextCalled(), IsNil)
require.NoError(t, exec.Close())
require.NoError(t, ds.checkNumNextCalled())
}
}

func (s *testExecSuite) TestMergeJoinRequiredRows(c *C) {
func TestMergeJoinRequiredRows(t *testing.T) {
justReturn1 := func(valType *types.FieldType) interface{} {
switch valType.Tp {
case mysql.TypeLong, mysql.TypeLonglong:
Expand All @@ -711,15 +703,15 @@ func (s *testExecSuite) TestMergeJoinRequiredRows(c *C) {
innerSrc := newRequiredRowsDataSourceWithGenerator(ctx, 1, nil, justReturn1) // just return one row: (1, 1)
outerSrc := newRequiredRowsDataSourceWithGenerator(ctx, 10000000, required, justReturn1) // always return (1, 1)
exec := buildMergeJoinExec(ctx, joinType, innerSrc, outerSrc)
c.Assert(exec.Open(context.Background()), IsNil)
require.NoError(t, exec.Open(context.Background()))

chk := newFirstChunk(exec)
for i := range required {
chk.SetRequiredRows(required[i], ctx.GetSessionVars().MaxChunkSize)
c.Assert(exec.Next(context.Background(), chk), IsNil)
require.NoError(t, exec.Next(context.Background(), chk))
}
c.Assert(exec.Close(), IsNil)
c.Assert(outerSrc.checkNumNextCalled(), IsNil)
require.NoError(t, exec.Close())
require.NoError(t, outerSrc.checkNumNextCalled())
}
}

Expand Down Expand Up @@ -771,7 +763,7 @@ func genTestChunk4VecGroupChecker(chkRows []int, sameNum int) (expr []expression
return
}

func (s *testExecSuite) TestVecGroupChecker(c *C) {
func TestVecGroupChecker4GroupCount(t *testing.T) {
testCases := []struct {
chunkRows []int
expectedGroups int
Expand Down Expand Up @@ -823,15 +815,15 @@ func (s *testExecSuite) TestVecGroupChecker(c *C) {
groupNum := 0
for i, inputChk := range inputChks {
flag, err := groupChecker.splitIntoGroups(inputChk)
c.Assert(err, IsNil)
c.Assert(flag, Equals, testCase.expectedFlag[i])
require.NoError(t, err)
require.Equal(t, testCase.expectedFlag[i], flag)
if flag {
groupNum += groupChecker.groupCount - 1
} else {
groupNum += groupChecker.groupCount
}
}
c.Assert(groupNum, Equals, testCase.expectedGroups)
require.Equal(t, testCase.expectedGroups, groupNum)
}
}

Expand Down Expand Up @@ -871,7 +863,7 @@ func (mp *mockPlan) Schema() *expression.Schema {
return mp.exec.Schema()
}

func (s *testExecSuite) TestVecGroupCheckerDATARACE(c *C) {
func TestVecGroupCheckerDATARACE(t *testing.T) {
ctx := mock.NewContext()

mTypes := []byte{mysql.TypeVarString, mysql.TypeNewDecimal, mysql.TypeJSON}
Expand Down Expand Up @@ -900,37 +892,37 @@ func (s *testExecSuite) TestVecGroupCheckerDATARACE(c *C) {
case mysql.TypeJSON:
chk.Column(0).ReserveJSON(1)
j := new(json.BinaryJSON)
c.Assert(j.UnmarshalJSON([]byte(fmt.Sprintf(`{"%v":%v}`, 123, 123))), IsNil)
require.NoError(t, j.UnmarshalJSON([]byte(fmt.Sprintf(`{"%v":%v}`, 123, 123))))
chk.Column(0).AppendJSON(*j)
}

_, err := vgc.splitIntoGroups(chk)
c.Assert(err, IsNil)
require.NoError(t, err)

switch mType {
case mysql.TypeVarString:
c.Assert(vgc.firstRowDatums[0].GetString(), Equals, "abc")
c.Assert(vgc.lastRowDatums[0].GetString(), Equals, "abc")
require.Equal(t, "abc", vgc.firstRowDatums[0].GetString())
require.Equal(t, "abc", vgc.lastRowDatums[0].GetString())
chk.Column(0).ReserveString(1)
chk.Column(0).AppendString("edf")
c.Assert(vgc.firstRowDatums[0].GetString(), Equals, "abc")
c.Assert(vgc.lastRowDatums[0].GetString(), Equals, "abc")
require.Equal(t, "abc", vgc.firstRowDatums[0].GetString())
require.Equal(t, "abc", vgc.lastRowDatums[0].GetString())
case mysql.TypeNewDecimal:
c.Assert(vgc.firstRowDatums[0].GetMysqlDecimal().String(), Equals, "123")
c.Assert(vgc.lastRowDatums[0].GetMysqlDecimal().String(), Equals, "123")
require.Equal(t, "123", vgc.firstRowDatums[0].GetMysqlDecimal().String())
require.Equal(t, "123", vgc.lastRowDatums[0].GetMysqlDecimal().String())
chk.Column(0).ResizeDecimal(1, false)
chk.Column(0).Decimals()[0] = *types.NewDecFromInt(456)
c.Assert(vgc.firstRowDatums[0].GetMysqlDecimal().String(), Equals, "123")
c.Assert(vgc.lastRowDatums[0].GetMysqlDecimal().String(), Equals, "123")
require.Equal(t, "123", vgc.firstRowDatums[0].GetMysqlDecimal().String())
require.Equal(t, "123", vgc.lastRowDatums[0].GetMysqlDecimal().String())
case mysql.TypeJSON:
c.Assert(vgc.firstRowDatums[0].GetMysqlJSON().String(), Equals, `{"123": 123}`)
c.Assert(vgc.lastRowDatums[0].GetMysqlJSON().String(), Equals, `{"123": 123}`)
require.Equal(t, `{"123": 123}`, vgc.firstRowDatums[0].GetMysqlJSON().String())
require.Equal(t, `{"123": 123}`, vgc.lastRowDatums[0].GetMysqlJSON().String())
chk.Column(0).ReserveJSON(1)
j := new(json.BinaryJSON)
c.Assert(j.UnmarshalJSON([]byte(fmt.Sprintf(`{"%v":%v}`, 456, 456))), IsNil)
require.NoError(t, j.UnmarshalJSON([]byte(fmt.Sprintf(`{"%v":%v}`, 456, 456))))
chk.Column(0).AppendJSON(*j)
c.Assert(vgc.firstRowDatums[0].GetMysqlJSON().String(), Equals, `{"123": 123}`)
c.Assert(vgc.lastRowDatums[0].GetMysqlJSON().String(), Equals, `{"123": 123}`)
require.Equal(t, `{"123": 123}`, vgc.firstRowDatums[0].GetMysqlJSON().String())
require.Equal(t, `{"123": 123}`, vgc.lastRowDatums[0].GetMysqlJSON().String())
}
}
}

0 comments on commit 886c8a7

Please sign in to comment.