Skip to content

Commit

Permalink
fix(compiler): Add validation for GROUP BY clause column references (s…
Browse files Browse the repository at this point in the history
…qlc-dev#1285)

* fix(compiler): Add validation for GROUP BY clause column references

* Add MySQL test
  • Loading branch information
timstudd authored Nov 18, 2021
1 parent 466c3e1 commit 710cc21
Show file tree
Hide file tree
Showing 8 changed files with 125 additions and 3 deletions.
54 changes: 54 additions & 0 deletions internal/compiler/output_columns.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,20 @@ func outputColumns(qc *QueryCatalog, node ast.Node) ([]*Column, error) {
targets = n.ReturningList
case *ast.SelectStmt:
targets = n.TargetList

if n.GroupClause != nil {
for _, item := range n.GroupClause.Items {
ref, ok := item.(*ast.ColumnRef)
if !ok {
continue
}

if err := findColumnForRef(ref, tables); err != nil {
return nil, err
}
}
}

// For UNION queries, targets is empty and we need to look for the
// columns in Largs.
if len(targets.Items) == 0 && n.Larg != nil {
Expand Down Expand Up @@ -470,3 +484,43 @@ func outputColumnRefs(res *ast.ResTarget, tables []*Table, node *ast.ColumnRef)
}
return cols, nil
}

func findColumnForRef(ref *ast.ColumnRef, tables []*Table) error {
parts := stringSlice(ref.Fields)
var alias, name string
if len(parts) == 1 {
name = parts[0]
} else if len(parts) == 2 {
alias = parts[0]
name = parts[1]
}

var found int
for _, t := range tables {
if alias != "" && t.Rel.Name != alias {
continue
}
for _, c := range t.Columns {
if c.Name == name {
found++
}
}
}

if found == 0 {
return &sqlerr.Error{
Code: "42703",
Message: fmt.Sprintf("column reference \"%s\" not found", name),
Location: ref.Location,
}
}
if found > 1 {
return &sqlerr.Error{
Code: "42703",
Message: fmt.Sprintf("column reference \"%s\" is ambiguous", name),
Location: ref.Location,
}
}

return nil
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
CREATE TABLE authors (
id BIGINT NOT NULL AUTO_INCREMENT PRIMARY KEY,
name text NOT NULL,
bio text,
UNIQUE(name)
);

-- name: ListAuthors :many
SELECT *
FROM authors
GROUP BY invalid_reference;
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
{
"version": "1",
"packages": [
{
"path": "go",
"engine": "mysql",
"name": "querytest",
"schema": "query.sql",
"queries": "query.sql"
}
]
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# package querytest
query.sql:9:1: column reference "invalid_reference" not found
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
CREATE TABLE authors (
id BIGSERIAL PRIMARY KEY,
name text NOT NULL,
bio text
);

-- name: ListAuthors :many
SELECT *
FROM authors
GROUP BY invalid_reference;
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
{
"version": "1",
"packages": [
{
"path": "go",
"engine": "postgresql",
"name": "querytest",
"schema": "query.sql",
"queries": "query.sql"
}
]
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# package querytest
query.sql:10:10: column reference "invalid_reference" not found
25 changes: 22 additions & 3 deletions internal/engine/dolphin/convert.go
Original file line number Diff line number Diff line change
Expand Up @@ -474,6 +474,7 @@ func (c *cc) convertSelectStmt(n *pcast.SelectStmt) *ast.SelectStmt {
stmt := &ast.SelectStmt{
TargetList: c.convertFieldList(n.Fields),
FromClause: c.convertTableRefsClause(n.From),
GroupClause: c.convertGroupByClause(n.GroupBy),
WhereClause: c.convert(n.Where),
WithClause: c.convertWithClause(n.With),
WindowClause: windowClause,
Expand Down Expand Up @@ -677,7 +678,14 @@ func (c *cc) convertBinlogStmt(n *pcast.BinlogStmt) ast.Node {
}

func (c *cc) convertByItem(n *pcast.ByItem) ast.Node {
return todo(n)
switch n.Expr.(type) {
case *pcast.PositionExpr:
return c.convertPositionExpr(n.Expr.(*pcast.PositionExpr))
case *pcast.ColumnNameExpr:
return c.convertColumnNameExpr(n.Expr.(*pcast.ColumnNameExpr))
default:
return todo(n)
}
}

func (c *cc) convertCaseExpr(n *pcast.CaseExpr) ast.Node {
Expand Down Expand Up @@ -858,8 +866,19 @@ func (c *cc) convertGrantStmt(n *pcast.GrantStmt) ast.Node {
return todo(n)
}

func (c *cc) convertGroupByClause(n *pcast.GroupByClause) ast.Node {
return todo(n)
func (c *cc) convertGroupByClause(n *pcast.GroupByClause) *ast.List {
if n == nil {
return &ast.List{}
}

var items []ast.Node
for _, item := range n.Items {
items = append(items, c.convertByItem(item))
}

return &ast.List{
Items: items,
}
}

func (c *cc) convertHavingClause(n *pcast.HavingClause) ast.Node {
Expand Down

0 comments on commit 710cc21

Please sign in to comment.