Skip to content

Commit

Permalink
Merge hooks and tweak API
Browse files Browse the repository at this point in the history
  • Loading branch information
vmihailenco committed May 26, 2021
1 parent e104273 commit beff015
Show file tree
Hide file tree
Showing 17 changed files with 122 additions and 118 deletions.
49 changes: 24 additions & 25 deletions db.go
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,6 @@ func (db *DB) ExecContext(
ctx, event := db.beforeQuery(ctx, nil, query, args)
res, err := db.DB.ExecContext(ctx, db.format(query, args))
db.afterQuery(ctx, event, res, err)

return res, err
}

Expand Down Expand Up @@ -276,51 +275,51 @@ func (c Conn) QueryRowContext(ctx context.Context, query string, args ...interfa
}

func (c Conn) NewValues(model interface{}) *ValuesQuery {
return NewValuesQuery(c.db, model).DB(c)
return NewValuesQuery(c.db, model).Conn(c)
}

func (c Conn) NewSelect() *SelectQuery {
return NewSelectQuery(c.db).DB(c)
return NewSelectQuery(c.db).Conn(c)
}

func (c Conn) NewInsert() *InsertQuery {
return NewInsertQuery(c.db).DB(c)
return NewInsertQuery(c.db).Conn(c)
}

func (c Conn) NewUpdate() *UpdateQuery {
return NewUpdateQuery(c.db).DB(c)
return NewUpdateQuery(c.db).Conn(c)
}

func (c Conn) NewDelete() *DeleteQuery {
return NewDeleteQuery(c.db).DB(c)
return NewDeleteQuery(c.db).Conn(c)
}

func (c Conn) NewCreateTable() *CreateTableQuery {
return NewCreateTableQuery(c.db).DB(c)
return NewCreateTableQuery(c.db).Conn(c)
}

func (c Conn) NewDropTable() *DropTableQuery {
return NewDropTableQuery(c.db).DB(c)
return NewDropTableQuery(c.db).Conn(c)
}

func (c Conn) NewCreateIndex() *CreateIndexQuery {
return NewCreateIndexQuery(c.db).DB(c)
return NewCreateIndexQuery(c.db).Conn(c)
}

func (c Conn) NewDropIndex() *DropIndexQuery {
return NewDropIndexQuery(c.db).DB(c)
return NewDropIndexQuery(c.db).Conn(c)
}

func (c Conn) NewTruncateTable() *TruncateTableQuery {
return NewTruncateTableQuery(c.db).DB(c)
return NewTruncateTableQuery(c.db).Conn(c)
}

func (c Conn) NewAddColumn() *AddColumnQuery {
return NewAddColumnQuery(c.db).DB(c)
return NewAddColumnQuery(c.db).Conn(c)
}

func (c Conn) NewDropColumn() *DropColumnQuery {
return NewDropColumnQuery(c.db).DB(c)
return NewDropColumnQuery(c.db).Conn(c)
}

//------------------------------------------------------------------------------
Expand Down Expand Up @@ -399,51 +398,51 @@ func (tx Tx) QueryRowContext(ctx context.Context, query string, args ...interfac
}

func (tx Tx) NewValues(model interface{}) *ValuesQuery {
return NewValuesQuery(tx.db, model).DB(tx)
return NewValuesQuery(tx.db, model).Conn(tx)
}

func (tx Tx) NewSelect() *SelectQuery {
return NewSelectQuery(tx.db).DB(tx)
return NewSelectQuery(tx.db).Conn(tx)
}

func (tx Tx) NewInsert() *InsertQuery {
return NewInsertQuery(tx.db).DB(tx)
return NewInsertQuery(tx.db).Conn(tx)
}

func (tx Tx) NewUpdate() *UpdateQuery {
return NewUpdateQuery(tx.db).DB(tx)
return NewUpdateQuery(tx.db).Conn(tx)
}

func (tx Tx) NewDelete() *DeleteQuery {
return NewDeleteQuery(tx.db).DB(tx)
return NewDeleteQuery(tx.db).Conn(tx)
}

func (tx Tx) NewCreateTable() *CreateTableQuery {
return NewCreateTableQuery(tx.db).DB(tx)
return NewCreateTableQuery(tx.db).Conn(tx)
}

func (tx Tx) NewDropTable() *DropTableQuery {
return NewDropTableQuery(tx.db).DB(tx)
return NewDropTableQuery(tx.db).Conn(tx)
}

func (tx Tx) NewCreateIndex() *CreateIndexQuery {
return NewCreateIndexQuery(tx.db).DB(tx)
return NewCreateIndexQuery(tx.db).Conn(tx)
}

func (tx Tx) NewDropIndex() *DropIndexQuery {
return NewDropIndexQuery(tx.db).DB(tx)
return NewDropIndexQuery(tx.db).Conn(tx)
}

func (tx Tx) NewTruncateTable() *TruncateTableQuery {
return NewTruncateTableQuery(tx.db).DB(tx)
return NewTruncateTableQuery(tx.db).Conn(tx)
}

func (tx Tx) NewAddColumn() *AddColumnQuery {
return NewAddColumnQuery(tx.db).DB(tx)
return NewAddColumnQuery(tx.db).Conn(tx)
}

func (tx Tx) NewDropColumn() *DropColumnQuery {
return NewDropColumnQuery(tx.db).DB(tx)
return NewDropColumnQuery(tx.db).Conn(tx)
}

//------------------------------------------------------------------------------
Expand Down
83 changes: 83 additions & 0 deletions model_hook.go → hook.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,94 @@ package bun

import (
"context"
"database/sql"
"reflect"
"sync/atomic"
"time"

"github.com/uptrace/bun/internal"
"github.com/uptrace/bun/schema"
)

type QueryEvent struct {
DB *DB

QueryAppender schema.QueryAppender
Query []byte
QueryArgs []interface{}

StartTime time.Time
Result sql.Result
Err error

Stash map[interface{}]interface{}
}

type QueryHook interface {
BeforeQuery(context.Context, *QueryEvent) context.Context
AfterQuery(context.Context, *QueryEvent)
}

func (db *DB) beforeQuery(
ctx context.Context,
queryApp schema.QueryAppender,
query string,
queryArgs []interface{},
) (context.Context, *QueryEvent) {
atomic.AddUint64(&db.stats.Queries, 1)

if len(db.queryHooks) == 0 {
return ctx, nil
}

event := &QueryEvent{
DB: db,

QueryAppender: queryApp,
Query: internal.Bytes(query),
QueryArgs: queryArgs,

StartTime: time.Now(),
}

for _, hook := range db.queryHooks {
ctx = hook.BeforeQuery(ctx, event)
}

return ctx, event
}

func (db *DB) afterQuery(
ctx context.Context,
event *QueryEvent,
res sql.Result,
err error,
) {
switch err {
case nil, sql.ErrNoRows:
// nothing
default:
atomic.AddUint64(&db.stats.Errors, 1)
}

if event == nil {
return
}

event.Result = res
event.Err = err

db.afterQueryFromIndex(ctx, event, len(db.queryHooks)-1)
}

func (db *DB) afterQueryFromIndex(ctx context.Context, event *QueryEvent, hookIndex int) {
for ; hookIndex >= 0; hookIndex-- {
db.queryHooks[hookIndex].AfterQuery(ctx, event)
}
}

//------------------------------------------------------------------------------

type hookStubs struct{}

var (
Expand Down
4 changes: 2 additions & 2 deletions internal/dbtest/pg_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -226,10 +226,10 @@ func TestPGTransaction(t *testing.T) {
tx, err := db.BeginTx(ctx, nil)
require.NoError(t, err)

_, err = db.NewCreateTable().DB(tx).Model((*Model)(nil)).Exec(ctx)
_, err = db.NewCreateTable().Conn(tx).Model((*Model)(nil)).Exec(ctx)
require.NoError(t, err)

n, err := db.NewSelect().DB(tx).Model((*Model)(nil)).Count(ctx)
n, err := db.NewSelect().Conn(tx).Model((*Model)(nil)).Count(ctx)
require.NoError(t, err)
require.Equal(t, 0, n)

Expand Down
2 changes: 1 addition & 1 deletion query_base.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ type baseQuery struct {
flags internal.Flag
}

func (q *baseQuery) GetDB() *DB {
func (q *baseQuery) DB() *DB {
return q.db
}

Expand Down
2 changes: 1 addition & 1 deletion query_column_add.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ func NewAddColumnQuery(db *DB) *AddColumnQuery {
return q
}

func (q *AddColumnQuery) DB(db DBI) *AddColumnQuery {
func (q *AddColumnQuery) Conn(db DBI) *AddColumnQuery {
q.setDBI(db)
return q
}
Expand Down
2 changes: 1 addition & 1 deletion query_column_drop.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ func NewDropColumnQuery(db *DB) *DropColumnQuery {
return q
}

func (q *DropColumnQuery) DB(db DBI) *DropColumnQuery {
func (q *DropColumnQuery) Conn(db DBI) *DropColumnQuery {
q.setDBI(db)
return q
}
Expand Down
2 changes: 1 addition & 1 deletion query_delete.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ func NewDeleteQuery(db *DB) *DeleteQuery {
return q
}

func (q *DeleteQuery) DB(db DBI) *DeleteQuery {
func (q *DeleteQuery) Conn(db DBI) *DeleteQuery {
q.setDBI(db)
return q
}
Expand Down
78 changes: 0 additions & 78 deletions query_hook.go

This file was deleted.

2 changes: 1 addition & 1 deletion query_index_create.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ func NewCreateIndexQuery(db *DB) *CreateIndexQuery {
return q
}

func (q *CreateIndexQuery) DB(db DBI) *CreateIndexQuery {
func (q *CreateIndexQuery) Conn(db DBI) *CreateIndexQuery {
q.setDBI(db)
return q
}
Expand Down
2 changes: 1 addition & 1 deletion query_index_drop.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ func NewDropIndexQuery(db *DB) *DropIndexQuery {
return q
}

func (q *DropIndexQuery) DB(db DBI) *DropIndexQuery {
func (q *DropIndexQuery) Conn(db DBI) *DropIndexQuery {
q.setDBI(db)
return q
}
Expand Down
2 changes: 1 addition & 1 deletion query_insert.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ func NewInsertQuery(db *DB) *InsertQuery {
return q
}

func (q *InsertQuery) DB(db DBI) *InsertQuery {
func (q *InsertQuery) Conn(db DBI) *InsertQuery {
q.setDBI(db)
return q
}
Expand Down
2 changes: 1 addition & 1 deletion query_select.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ func NewSelectQuery(db *DB) *SelectQuery {
}
}

func (q *SelectQuery) DB(db DBI) *SelectQuery {
func (q *SelectQuery) Conn(db DBI) *SelectQuery {
q.setDBI(db)
return q
}
Expand Down
Loading

0 comments on commit beff015

Please sign in to comment.