Skip to content

Commit

Permalink
plan: keep error message consistent with MySQL (pingcap#3962) (pingca…
Browse files Browse the repository at this point in the history
  • Loading branch information
jackysp authored and zz-jason committed Aug 9, 2017
1 parent 06c4f6e commit 171aea2
Show file tree
Hide file tree
Showing 4 changed files with 101 additions and 67 deletions.
19 changes: 11 additions & 8 deletions plan/planbuilder.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ import (
var (
ErrUnsupportedType = terror.ClassOptimizerPlan.New(CodeUnsupportedType, "Unsupported type")
SystemInternalErrorType = terror.ClassOptimizerPlan.New(SystemInternalError, "System internal error")
ErrUnknownColumn = terror.ClassOptimizerPlan.New(CodeUnknownColumn, "Unknown column '%s' in '%s'")
ErrUnknownColumn = terror.ClassOptimizerPlan.New(CodeUnknownColumn, mysql.MySQLErrName[mysql.ErrBadField])
ErrUnknownTable = terror.ClassOptimizerPlan.New(CodeUnknownColumn, mysql.MySQLErrName[mysql.ErrBadTable])
ErrWrongArguments = terror.ClassOptimizerPlan.New(CodeWrongArguments, "Incorrect arguments to EXECUTE")
ErrAmbiguous = terror.ClassOptimizerPlan.New(CodeAmbiguous, "Column '%s' in field list is ambiguous")
ErrAnalyzeMissIndex = terror.ClassOptimizerPlan.New(CodeAnalyzeMissIndex, "Index '%s' in field list does not exist in table '%s'")
Expand All @@ -44,18 +45,20 @@ var (
// Error codes.
const (
CodeUnsupportedType terror.ErrCode = 1
SystemInternalError terror.ErrCode = 2
CodeAlterAutoID terror.ErrCode = 3
CodeAnalyzeMissIndex terror.ErrCode = 4
CodeAmbiguous terror.ErrCode = 1052
CodeUnknownColumn terror.ErrCode = 1054
CodeWrongArguments terror.ErrCode = 1210
CodeBadGeneratedColumn terror.ErrCode = mysql.ErrBadGeneratedColumn
SystemInternalError = 2
CodeAlterAutoID = 3
CodeAnalyzeMissIndex = 4
CodeAmbiguous = 1052
CodeUnknownColumn = mysql.ErrBadField
CodeUnknownTable = mysql.ErrBadTable
CodeWrongArguments = 1210
CodeBadGeneratedColumn = mysql.ErrBadGeneratedColumn
)

func init() {
tableMySQLErrCodes := map[terror.ErrCode]uint16{
CodeUnknownColumn: mysql.ErrBadField,
CodeUnknownTable: mysql.ErrBadTable,
CodeAmbiguous: mysql.ErrNonUniq,
CodeWrongArguments: mysql.ErrWrongArguments,
CodeBadGeneratedColumn: mysql.ErrBadGeneratedColumn,
Expand Down
70 changes: 47 additions & 23 deletions plan/resolver.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,17 @@ import (
"github.com/pingcap/tidb/util/types"
)

const (
unknownClause = ""
fieldList = "field list"
havingClause = "having clause"
onClause = "on clause"
orderByClause = "order clause"
whereClause = "where clause"
groupByStatement = "group statement"
showStatement = "show statement"
)

// ResolveName resolves table name and column name.
// It generates ResultFields for ResultSetNode and resolves ColumnNameExpr to a ResultField.
func ResolveName(node ast.Node, info infoschema.InfoSchema, ctx context.Context) error {
Expand Down Expand Up @@ -502,8 +513,11 @@ func (nr *nameResolver) handleColumnName(cn *ast.ColumnNameExpr) {
}

// Try to resolve the column name from top to bottom in the context stack.
var where string
var ok bool
for i := len(nr.contextStack) - 1; i >= 0; i-- {
if nr.resolveColumnNameInContext(nr.contextStack[i], cn) {
where, ok = nr.resolveColumnNameInContext(nr.contextStack[i], cn)
if ok {
// Column is already resolved or encountered an error.
if i < len(nr.contextStack)-1 {
// If in subselect, the query use outer query.
Expand All @@ -512,18 +526,23 @@ func (nr *nameResolver) handleColumnName(cn *ast.ColumnNameExpr) {
return
}
}
nr.Err = errors.Errorf("unknown column %s", cn.Name.Name.L)
fieldName := cn.Name.Name.String()
if len(cn.Name.Table.String()) != 0 {
fieldName = fmt.Sprintf("%s.%s", cn.Name.Table.String(), fieldName)

}
nr.Err = ErrUnknownColumn.GenByArgs(fieldName, where)
}

// resolveColumnNameInContext looks up and sets ResultField for a column with the ctx.
func (nr *nameResolver) resolveColumnNameInContext(ctx *resolverContext, cn *ast.ColumnNameExpr) bool {
func (nr *nameResolver) resolveColumnNameInContext(ctx *resolverContext, cn *ast.ColumnNameExpr) (string, bool) {
if ctx.inTableRefs {
// In TableRefsClause, column reference only in join on condition which is handled before.
return false
return unknownClause, false
}
if ctx.inFieldList {
// only resolve column using tables.
return nr.resolveColumnInTableSources(cn, ctx.tables)
return fieldList, nr.resolveColumnInTableSources(cn, ctx.tables)
}
if ctx.inGroupBy {
// From tables first, then field list.
Expand All @@ -532,7 +551,7 @@ func (nr *nameResolver) resolveColumnNameInContext(ctx *resolverContext, cn *ast
if ctx.inByItemExpression {
// From table first, then field list.
if nr.resolveColumnInTableSources(cn, ctx.tables) {
return true
return groupByStatement, true
}
found := nr.resolveColumnInResultFields(ctx, cn, ctx.fieldList)
if nr.Err == nil && found {
Expand All @@ -541,12 +560,12 @@ func (nr *nameResolver) resolveColumnNameInContext(ctx *resolverContext, cn *ast
nr.Err = ErrIllegalReference.Gen("Reference '%s' not supported (reference to group function)", cn.Name.Name.O)
}
}
return found
return groupByStatement, found
}
// Resolve from table first, then from select list.
found := nr.resolveColumnInTableSources(cn, ctx.tables)
if nr.Err != nil {
return found
return groupByStatement, found
}
// We should copy the refer here.
// Because if the ByItem is an identifier, we should check if it
Expand All @@ -555,7 +574,7 @@ func (nr *nameResolver) resolveColumnNameInContext(ctx *resolverContext, cn *ast
r := cn.Refer
if nr.resolveColumnInResultFields(ctx, cn, ctx.fieldList) {
if nr.Err != nil {
return true
return groupByStatement, true
}
if r != nil {
// It is not ambiguous and already resolved from table source.
Expand All @@ -564,45 +583,45 @@ func (nr *nameResolver) resolveColumnNameInContext(ctx *resolverContext, cn *ast
} else if _, ok := cn.Refer.Expr.(*ast.AggregateFuncExpr); ok {
nr.Err = ErrIllegalReference.Gen("Reference '%s' not supported (reference to group function)", cn.Name.Name.O)
}
return true
return groupByStatement, true
}
return found
return groupByStatement, found
}
if ctx.inHaving {
// First group by, then field list.
if nr.resolveColumnInResultFields(ctx, cn, ctx.groupBy) {
return true
return havingClause, true
}
if ctx.inHavingAgg {
// If cn is in an aggregate function in having clause, check tablesource first.
if nr.resolveColumnInTableSources(cn, ctx.tables) {
return true
return havingClause, true
}
}
return nr.resolveColumnInResultFields(ctx, cn, ctx.fieldList)
return havingClause, nr.resolveColumnInResultFields(ctx, cn, ctx.fieldList)
}
if ctx.inOrderBy {
if nr.resolveColumnInResultFields(ctx, cn, ctx.groupBy) {
return true
return orderByClause, true
}
if ctx.inByItemExpression {
// From table first, then field list.
if nr.resolveColumnInTableSources(cn, ctx.tables) {
return true
return orderByClause, true
}
return nr.resolveColumnInResultFields(ctx, cn, ctx.fieldList)
return orderByClause, nr.resolveColumnInResultFields(ctx, cn, ctx.fieldList)
}
// Field list first, then from table.
if nr.resolveColumnInResultFields(ctx, cn, ctx.fieldList) {
return true
return orderByClause, true
}
return nr.resolveColumnInTableSources(cn, ctx.tables)
return orderByClause, nr.resolveColumnInTableSources(cn, ctx.tables)
}
if ctx.inShow {
return nr.resolveColumnInResultFields(ctx, cn, ctx.fieldList)
return showStatement, nr.resolveColumnInResultFields(ctx, cn, ctx.fieldList)
}
// In where clause.
return nr.resolveColumnInTableSources(cn, ctx.tables)
return whereClause, nr.resolveColumnInTableSources(cn, ctx.tables)
}

// resolveColumnNameInOnCondition resolves the column name in current join.
Expand All @@ -611,7 +630,12 @@ func (nr *nameResolver) resolveColumnNameInOnCondition(cn *ast.ColumnNameExpr) {
join := ctx.joinNodeStack[len(ctx.joinNodeStack)-1]
tableSources := appendTableSources(nil, join)
if !nr.resolveColumnInTableSources(cn, tableSources) {
nr.Err = errors.Errorf("unknown column name %s", cn.Name.Name.O)
fieldName := cn.Name.Name.String()
if len(cn.Name.Table.String()) != 0 {
fieldName = fmt.Sprintf("%s.%s", cn.Name.Table.String(), fieldName)

}
nr.Err = ErrUnknownColumn.GenByArgs(fieldName, onClause)
}
}

Expand Down Expand Up @@ -755,7 +779,7 @@ func (nr *nameResolver) createResultFields(field *ast.SelectField) (rfs []*ast.R
tableIdx, ok1 := ctx.tableMap[name]
derivedTableIdx, ok2 := ctx.derivedTableMap[name]
if !ok1 && !ok2 {
nr.Err = errors.Errorf("unknown table %s.", field.WildCard.Table.O)
nr.Err = ErrUnknownTable.GenByArgs(field.WildCard.Table.String())
}
if ok1 {
tableRfs = ctx.tables[tableIdx].GetResultFields()
Expand Down
38 changes: 22 additions & 16 deletions plan/resolver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,25 +55,30 @@ func (rv *resolverVerifier) Leave(in ast.Node) (out ast.Node, ok bool) {
type resolverTestCase struct {
src string
valid bool
err string
}

var resolverTests = []resolverTestCase{
{"select c1 from t1", true},
{"select c3 from t1", false},
{"select c1 from t4", false},
{"select * from t1", true},
{"select t1.* from t1", true},
{"select t2.* from t1", false},
{"select c1 as a, c1 as a from t1 group by a", true},
{"select 1 as a, c1 as a, c2 as a from t1 group by a", true},
{"select c1, c2 as c1 from t1 group by c1+1", true},
{"select c1, c2 as c1 from t1 order by c1+1", true},
{"select * from t1, t2 join t3 on t1.c1 = t2.c1", false},
{"select * from t1, t2 join t3 on t2.c1 = t3.c1", true},
{"select c1 from t1 group by c1 having c1 = 3", true},
{"select c1 from t1 group by c1 having c2 = 3", false},
{"select c1 from t1 where exists (select c2)", true},
{"select cnt from (select count(c2) as cnt from t1 group by c1) t2 group by cnt", true},
{"select c1 from t1", true, ""},
{"select c3 from t1", false, "[plan:1054]Unknown column 'c3' in 'field list'"},
{"select c1 from t4", false, "[schema:1146]Table 'test.t4' doesn't exist"},
{"select * from t1", true, ""},
{"select t1.* from t1", true, ""},
{"select t2.* from t1", false, "[plan:1054]Unknown table 't2'"},
{"select c1 as a, c1 as a from t1 group by a", true, ""},
{"select 1 as a, c1 as a, c2 as a from t1 group by a", true, ""},
{"select c1, c2 as c1 from t1 group by c1+1", true, ""},
{"select c1, c2 as c1 from t1 order by c1+1", true, ""},
{"select * from t1, t2 join t3 on t1.c1 = t2.c1", false, "[plan:1054]Unknown column 't1.c1' in 'on clause'"},
{"select * from t1, t2 join t3 on t2.c1 = t3.c1", true, ""},
{"select c1 from t1 group by c1 having c1 = 3", true, ""},
{"select c1 from t1 group by c1 having c2 = 3", false, "[plan:1054]Unknown column 'c2' in 'having clause'"},
{"select c1 from t1 where exists (select c2)", true, ""},
{"select cnt from (select count(c2) as cnt from t1 group by c1) t2 group by cnt", true, ""},
{"select c1 from t2 where t11.c1 < t2.c1", false, "[plan:1054]Unknown column 't11.c1' in 'where clause'"},
{"select c1 from t2 having t11.c1 < t2.c1", false, "[plan:1054]Unknown column 't11.c1' in 'having clause'"},
{"select c1 from t2 where t2.c1 < t2.c1 order by t11.c1", false, "[plan:1054]Unknown column 't11.c1' in 'order clause'"},
{"select c1 from t2 group by t11.c1", false, "[plan:1054]Unknown column 't11.c1' in 'group statement'"},
}

func (ts *testNameResolverSuite) TestNameResolver(c *C) {
Expand All @@ -98,6 +103,7 @@ func (ts *testNameResolverSuite) TestNameResolver(c *C) {
node.Accept(verifier)
} else {
c.Assert(resolveErr, NotNil, Commentf("%s", tt.src))
c.Assert(resolveErr.Error(), Equals, tt.err, Commentf("%s", resolveErr.Error()))
}
}
}
41 changes: 21 additions & 20 deletions terror/terror.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,26 +86,27 @@ const (
)

var errClz2Str = map[ErrClass]string{
ClassAutoid: "autoid",
ClassDDL: "ddl",
ClassDomain: "domain",
ClassExecutor: "executor",
ClassExpression: "expression",
ClassInspectkv: "inspectkv",
ClassMeta: "meta",
ClassKV: "kv",
ClassOptimizer: "optimizer",
ClassParser: "parser",
ClassPerfSchema: "perfschema",
ClassPrivilege: "privilege",
ClassSchema: "schema",
ClassServer: "server",
ClassStructure: "structure",
ClassVariable: "variable",
ClassTable: "table",
ClassTypes: "types",
ClassGlobal: "global",
ClassMockTikv: "mocktikv",
ClassAutoid: "autoid",
ClassDDL: "ddl",
ClassDomain: "domain",
ClassExecutor: "executor",
ClassExpression: "expression",
ClassInspectkv: "inspectkv",
ClassMeta: "meta",
ClassKV: "kv",
ClassOptimizer: "optimizer",
ClassOptimizerPlan: "plan",
ClassParser: "parser",
ClassPerfSchema: "perfschema",
ClassPrivilege: "privilege",
ClassSchema: "schema",
ClassServer: "server",
ClassStructure: "structure",
ClassVariable: "variable",
ClassTable: "table",
ClassTypes: "types",
ClassGlobal: "global",
ClassMockTikv: "mocktikv",
}

// String implements fmt.Stringer interface.
Expand Down

0 comments on commit 171aea2

Please sign in to comment.