Skip to content

Commit

Permalink
use Expression in comparison builder methods
Browse files Browse the repository at this point in the history
  • Loading branch information
stephenafamo committed Mar 12, 2023
1 parent b50b313 commit 8560322
Show file tree
Hide file tree
Showing 54 changed files with 372 additions and 442 deletions.
4 changes: 2 additions & 2 deletions clause/values.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,13 @@ type Values struct {
Vals []value
}

type value []any
type value []bob.Expression

func (v value) WriteSQL(w io.Writer, d bob.Dialect, start int) ([]any, error) {
return bob.ExpressSlice(w, d, start, v, "(", ", ", ")")
}

func (v *Values) AppendValues(vals ...any) {
func (v *Values) AppendValues(vals ...bob.Expression) {
if len(vals) == 0 {
return
}
Expand Down
2 changes: 1 addition & 1 deletion dialect/mssql/raw.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,6 @@ import (
"github.com/stephenafamo/bob/expr"
)

func RawQuery(q string, args ...any) bob.BaseQuery[expr.Raw] {
func RawQuery(q string, args ...any) bob.BaseQuery[expr.Clause] {
return expr.RawQuery(dialect, q, args...)
}
20 changes: 9 additions & 11 deletions dialect/mysql/delete_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,40 +13,38 @@ func TestDelete(t *testing.T) {
"simple": {
Query: mysql.Delete(
dm.From("films"),
dm.Where(mysql.X("kind").EQ(mysql.Arg("Drama"))),
dm.Where(mysql.Quote("kind").EQ(mysql.Arg("Drama"))),
),
ExpectedSQL: `DELETE FROM films WHERE (kind = ?)`,
ExpectedSQL: "DELETE FROM films WHERE (`kind` = ?)",
ExpectedArgs: []any{"Drama"},
},
"multiple tables": {
Query: mysql.Delete(
dm.From("films"),
dm.From("actors"),
dm.Where(mysql.X("kind").EQ(mysql.Arg("Drama"))),
dm.Where(mysql.Quote("kind").EQ(mysql.Arg("Drama"))),
),
ExpectedSQL: `DELETE FROM films, actors WHERE (kind = ?)`,
ExpectedSQL: "DELETE FROM films, actors WHERE (`kind` = ?)",
ExpectedArgs: []any{"Drama"},
},
"with limit and offest": {
Query: mysql.Delete(
dm.From("films"),
dm.Where(mysql.X("kind").EQ(mysql.Arg("Drama"))),
dm.Where(mysql.Quote("kind").EQ(mysql.Arg("Drama"))),
dm.Limit(10),
dm.OrderBy("producer").Desc(),
),
ExpectedSQL: `DELETE FROM films WHERE (kind = ?) ORDER BY producer DESC LIMIT 10`,
ExpectedSQL: "DELETE FROM films WHERE (`kind` = ?) ORDER BY producer DESC LIMIT 10",
ExpectedArgs: []any{"Drama"},
},
"with using": {
Query: mysql.Delete(
dm.From("employees"),
dm.Using("accounts"),
dm.Where(mysql.X("accounts.name").EQ(mysql.Arg("Acme Corporation"))),
dm.Where(mysql.X("employees.id").EQ("accounts.sales_person")),
dm.Where(mysql.Quote("accounts", "name").EQ(mysql.Arg("Acme Corporation"))),
dm.Where(mysql.Quote("employees", "id").EQ(mysql.Quote("accounts", "sales_person"))),
),
ExpectedSQL: `DELETE FROM employees USING accounts
WHERE (accounts.name = ?)
AND (employees.id = accounts.sales_person)`,
ExpectedSQL: "DELETE FROM employees USING accounts WHERE (`accounts`.`name` = ?) AND (`employees`.`id` = `accounts`.`sales_person`)",
ExpectedArgs: []any{"Acme Corporation"},
},
}
Expand Down
3 changes: 2 additions & 1 deletion dialect/mysql/dialect/builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,15 @@ package dialect
import (
"strings"

"github.com/stephenafamo/bob"
"github.com/stephenafamo/bob/expr"
)

type Expression struct {
expr.Chain[Expression, Expression]
}

func (Expression) New(exp any) Expression {
func (Expression) New(exp bob.Expression) Expression {
var b Expression
b.Base = exp
return b
Expand Down
7 changes: 2 additions & 5 deletions dialect/mysql/dialect/mods.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,6 @@ import (
"github.com/stephenafamo/bob/mods"
)

//nolint:gochecknoglobals
var bmod = expr.Builder[Expression, Expression]{}

func With[Q interface{ AppendWith(clause.CTE) }](name string, columns ...string) CTEChain[Q] {
return CTEChain[Q](func() clause.CTE {
return clause.CTE{
Expand Down Expand Up @@ -198,9 +195,9 @@ func (j JoinChain[Q]) On(on ...any) bob.Mod[Q] {
return mods.Join[Q](jo)
}

func (j JoinChain[Q]) OnEQ(a, b any) bob.Mod[Q] {
func (j JoinChain[Q]) OnEQ(a, b bob.Expression) bob.Mod[Q] {
jo := j()
jo.On = append(jo.On, bmod.X(a).EQ(b))
jo.On = append(jo.On, expr.X[Expression, Expression](a).EQ(b))

return mods.Join[Q](jo)
}
Expand Down
4 changes: 2 additions & 2 deletions dialect/mysql/im/qm.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,11 @@ func Partition(partitions ...string) bob.Mod[*dialect.InsertQuery] {
return dialect.Partition[*dialect.InsertQuery](partitions...)
}

func Values(clauses ...any) bob.Mod[*dialect.InsertQuery] {
func Values(clauses ...bob.Expression) bob.Mod[*dialect.InsertQuery] {
return mods.Values[*dialect.InsertQuery](clauses)
}

func Rows(rows ...[]any) bob.Mod[*dialect.InsertQuery] {
func Rows(rows ...[]bob.Expression) bob.Mod[*dialect.InsertQuery] {
return mods.Rows[*dialect.InsertQuery](rows)
}

Expand Down
5 changes: 3 additions & 2 deletions dialect/mysql/insert_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,15 +65,16 @@ func TestInsert(t *testing.T) {
im.OnDuplicateKeyUpdate().
Set("new", "did").
SetCol("dbname", mysql.Concat(
"new.dname", mysql.S(" (formerly "), "d.dname", mysql.S(")"),
mysql.Quote("new", "dname"), mysql.S(" (formerly "),
mysql.Quote("d", "dname"), mysql.S(")"),
)),
),
ExpectedSQL: `INSERT INTO distributors (` + "`did`" + `, ` + "`dname`" + `)
VALUES (?, ?), (?, ?)
AS new
ON DUPLICATE KEY UPDATE
` + "`did` = `new`.`did`," + `
` + "`dbname`" + ` = (new.dname || ' (formerly ' || d.dname || ')')`,
` + "`dbname` = (`new`.`dname` || ' (formerly ' || `d`.`dname` || ')')",
ExpectedArgs: []any{8, "Anvil Distribution", 9, "Sentry Distribution"},
},
}
Expand Down
18 changes: 3 additions & 15 deletions dialect/mysql/load.go
Original file line number Diff line number Diff line change
Expand Up @@ -146,25 +146,13 @@ func Preload[T any, Ts ~[]T](rel orm.Relationship, cols []string, opts ...Preloa
on := make([]any, 0, len(side.FromColumns)+len(side.FromWhere)+len(side.ToWhere))
for i, fromCol := range side.FromColumns {
toCol := side.ToColumns[i]
on = append(on, X(
Quote(parent, fromCol),
"=",
Quote(alias, toCol),
))
on = append(on, Quote(parent, fromCol).EQ(Quote(alias, toCol)))
}
for _, from := range side.FromWhere {
on = append(on, X(
Quote(parent, from.Column),
"=",
from.Value,
))
on = append(on, Quote(parent, from.Column).EQ(Raw(from.Value)))
}
for _, to := range side.ToWhere {
on = append(on, X(
Quote(alias, to.Column),
"=",
to.Value,
))
on = append(on, Quote(alias, to.Column).EQ(Raw(to.Value)))
}

queryMods = append(queryMods, sm.
Expand Down
2 changes: 1 addition & 1 deletion dialect/mysql/raw.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,6 @@ import (
"github.com/stephenafamo/bob/expr"
)

func RawQuery(q string, args ...any) bob.BaseQuery[expr.Raw] {
func RawQuery(q string, args ...any) bob.BaseQuery[expr.Clause] {
return expr.RawQuery(dialect.Dialect, q, args...)
}
24 changes: 8 additions & 16 deletions dialect/mysql/select_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,6 @@ package mysql_test
import (
"testing"

// "github.com/pingcap/tidb/parser"
// "github.com/pingcap/tidb/parser/format"
// _ "github.com/pingcap/tidb/types/parser_driver"
"github.com/stephenafamo/bob/dialect/mysql"
"github.com/stephenafamo/bob/dialect/mysql/sm"
testutils "github.com/stephenafamo/bob/test_utils"
Expand All @@ -14,22 +11,22 @@ import (
func TestSelect(t *testing.T) {
examples := testutils.Testcases{
"simple select": {
ExpectedSQL: "SELECT id, name FROM users WHERE (id IN (?, ?, ?))",
ExpectedSQL: "SELECT id, name FROM users WHERE (`id` IN (?, ?, ?))",
ExpectedArgs: []any{100, 200, 300},
Query: mysql.Select(
sm.Columns("id", "name"),
sm.From("users"),
sm.Where(mysql.X("id").In(mysql.Arg(100, 200, 300))),
sm.Where(mysql.Quote("id").In(mysql.Arg(100, 200, 300))),
),
},
"select distinct": {
ExpectedSQL: "SELECT DISTINCT id, name FROM users WHERE (id IN (?, ?, ?))",
ExpectedSQL: "SELECT DISTINCT id, name FROM users WHERE (`id` IN (?, ?, ?))",
ExpectedArgs: []any{100, 200, 300},
Query: mysql.Select(
sm.Columns("id", "name"),
sm.Distinct(),
sm.From("users"),
sm.Where(mysql.X("id").In(mysql.Arg(100, 200, 300))),
sm.Where(mysql.Quote("id").In(mysql.Arg(100, 200, 300))),
),
},
"with sub-select": {
Expand All @@ -42,7 +39,7 @@ func TestSelect(t *testing.T) {
- created_date) AS ` + "`difference`" + `
FROM presales_presalestatus
) AS ` + "`differnce_by_status`" + `
WHERE (status IN ('A', 'B', 'C'))
WHERE (` + "`status`" + ` IN ('A', 'B', 'C'))
GROUP BY status`,
Query: mysql.Select(
sm.Columns("status", mysql.F("avg", "difference")),
Expand All @@ -57,17 +54,17 @@ func TestSelect(t *testing.T) {
As("difference")),
sm.From("presales_presalestatus")),
).As("differnce_by_status"),
sm.Where(mysql.X("status").In(mysql.S("A"), mysql.S("B"), mysql.S("C"))),
sm.Where(mysql.Quote("status").In(mysql.S("A"), mysql.S("B"), mysql.S("C"))),
sm.GroupBy("status"),
),
},
"select with grouped IN": {
Query: mysql.Select(
sm.Columns("id", "name"),
sm.From("users"),
sm.Where(mysql.Group("id", "employee_id").In(mysql.ArgGroup(100, 200), mysql.ArgGroup(300, 400))),
sm.Where(mysql.Group(mysql.Quote("id"), mysql.Quote("employee_id")).In(mysql.ArgGroup(100, 200), mysql.ArgGroup(300, 400))),
),
ExpectedSQL: "SELECT id, name FROM users WHERE ((id, employee_id) IN ((?, ?), (?, ?)))",
ExpectedSQL: "SELECT id, name FROM users WHERE ((`id`, `employee_id`) IN ((?, ?), (?, ?)))",
ExpectedArgs: []any{100, 200, 300, 400},
},
}
Expand All @@ -80,11 +77,6 @@ func TestSelect(t *testing.T) {
// 2. Does not understand aliases in upsert
// In general, TIDB's parser is not updated for MySQL 8.0

// require (
// github.com/pingcap/tidb v1.1.0-beta.0.20221227032819-706c3fa3c526
// github.com/pingcap/tidb/parser v0.0.0-20221227032819-706c3fa3c526
// )

// var p = parser.New()

// func formatter(s string) (string, error) {
Expand Down
24 changes: 7 additions & 17 deletions dialect/mysql/starters.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package mysql

import (
"github.com/stephenafamo/bob"
"github.com/stephenafamo/bob/dialect/mysql/dialect"
"github.com/stephenafamo/bob/expr"
)
Expand All @@ -10,11 +11,6 @@ type Expression = dialect.Expression
//nolint:gochecknoglobals
var bmod = expr.Builder[Expression, Expression]{}

// X is a flexible starter that joins the given expressions with a space
func X(exp any, others ...any) Expression {
return bmod.X(exp, others...)
}

// F creates a function expression with the given name and args
//
// SQL: generate_series(1, 3)
Expand All @@ -38,26 +34,26 @@ func S(s string) Expression {

// SQL: NOT true
// Go: psql.Not("true")
func Not(exp any) Expression {
func Not(exp bob.Expression) Expression {
return bmod.Not(exp)
}

// SQL: a OR b OR c
// Go: psql.Or("a", "b", "c")
func Or(args ...any) Expression {
func Or(args ...bob.Expression) Expression {
return bmod.Or(args...)
}

// SQL: a AND b AND c
// Go: psql.And("a", "b", "c")
func And(args ...any) Expression {
func And(args ...bob.Expression) Expression {
return bmod.And(args...)
}

// SQL: a || b || c
// Go: psql.Concat("a", "b", "c")
func Concat(args ...any) Expression {
return bmod.X(expr.Join{Exprs: args, Sep: " || "})
func Concat(args ...bob.Expression) Expression {
return expr.X[Expression, Expression](expr.Join{Exprs: args, Sep: " || "})
}

// SQL: $1, $2, $3
Expand All @@ -78,15 +74,9 @@ func Placeholder(n uint) Expression {
return bmod.Placeholder(n)
}

// SQL: (a and b)
// Go: psql.P("a and b")
func P(exp any) Expression {
return bmod.P(exp)
}

// SQL: (a, b)
// Go: psql.Group("a", "b")
func Group(exps ...any) Expression {
func Group(exps ...bob.Expression) Expression {
return bmod.Group(exps...)
}

Expand Down
Loading

0 comments on commit 8560322

Please sign in to comment.