Skip to content

Commit

Permalink
Set Event.Result for queries built by Bun
Browse files Browse the repository at this point in the history
  • Loading branch information
vmihailenco committed Aug 30, 2021
1 parent 61a5918 commit f6610a2
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 39 deletions.
24 changes: 1 addition & 23 deletions db.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package bun
import (
"context"
"database/sql"
"errors"
"fmt"
"reflect"
"strings"
Expand Down Expand Up @@ -473,30 +472,9 @@ func (tx Tx) NewDropColumn() *DropColumnQuery {
return NewDropColumnQuery(tx.db).Conn(tx)
}

//------------------------------------------------------------------------------0
//------------------------------------------------------------------------------

func (db *DB) makeQueryBytes() []byte {
// TODO: make this configurable?
return make([]byte, 0, 4096)
}

//------------------------------------------------------------------------------

type result struct {
r sql.Result
n int
}

func (r result) RowsAffected() (int64, error) {
if r.r != nil {
return r.r.RowsAffected()
}
return int64(r.n), nil
}

func (r result) LastInsertId() (int64, error) {
if r.r != nil {
return r.r.LastInsertId()
}
return 0, errors.New("LastInsertId is not available")
}
13 changes: 10 additions & 3 deletions extra/bunotel/otel.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,14 +54,21 @@ func (h *QueryHook) AfterQuery(ctx context.Context, event *bun.QueryEvent) {

attrs := make([]attribute.KeyValue, 0, 10)
attrs = append(attrs,
attribute.String("db.system", dbSystem(event.DB)),
attribute.String("db.statement", query),

attribute.String("code.function", fn),
attribute.String("code.filepath", file),
attribute.Int("code.lineno", line),
)

if s := dbSystem(event.DB); s != "" {
attrs = append(attrs, attribute.String("db.system", s))
}
if event.Result != nil {
if n, _ := event.Result.RowsAffected(); n > 0 {
attrs = append(attrs, attribute.Int64("db.rows_affected", n))
}
}

if event.Err != nil {
switch event.Err {
case sql.ErrNoRows:
Expand Down Expand Up @@ -163,6 +170,6 @@ func dbSystem(db *bun.DB) string {
case dialect.SQLite:
return "sqlite"
default:
return "unknown"
return ""
}
}
23 changes: 11 additions & 12 deletions query_base.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package bun
import (
"context"
"database/sql"
"database/sql/driver"
"errors"
"fmt"

Expand Down Expand Up @@ -430,28 +431,28 @@ func (q *baseQuery) scan(
query string,
model model,
hasDest bool,
) (res result, _ error) {
) (sql.Result, error) {
ctx, event := q.db.beforeQuery(ctx, queryApp, query, nil)

rows, err := q.conn.QueryContext(ctx, query)
if err != nil {
q.db.afterQuery(ctx, event, nil, err)
return res, err
return nil, err
}
defer rows.Close()

n, err := model.ScanRows(ctx, rows)
numRow, err := model.ScanRows(ctx, rows)
if err != nil {
q.db.afterQuery(ctx, event, nil, err)
return res, err
return nil, err
}

res.n = n
if n == 0 && hasDest && isSingleRowModel(model) {
if numRow == 0 && hasDest && isSingleRowModel(model) {
err = sql.ErrNoRows
}

q.db.afterQuery(ctx, event, nil, err)
res := driver.RowsAffected(numRow)
q.db.afterQuery(ctx, event, res, err)

return res, err
}
Expand All @@ -460,18 +461,16 @@ func (q *baseQuery) exec(
ctx context.Context,
queryApp schema.QueryAppender,
query string,
) (res result, _ error) {
) (sql.Result, error) {
ctx, event := q.db.beforeQuery(ctx, queryApp, query, nil)

r, err := q.conn.ExecContext(ctx, query)
res, err := q.conn.ExecContext(ctx, query)
if err != nil {
q.db.afterQuery(ctx, event, nil, err)
return res, err
}

res.r = r

q.db.afterQuery(ctx, event, nil, err)
q.db.afterQuery(ctx, event, res, err)
return res, nil
}

Expand Down
2 changes: 1 addition & 1 deletion query_select.go
Original file line number Diff line number Diff line change
Expand Up @@ -695,7 +695,7 @@ func (q *SelectQuery) Scan(ctx context.Context, dest ...interface{}) error {
return err
}

if res.n > 0 {
if n, _ := res.RowsAffected(); n > 0 {
if tableModel, ok := model.(tableModel); ok {
if err := q.selectJoins(ctx, tableModel.GetJoins()); err != nil {
return err
Expand Down

0 comments on commit f6610a2

Please sign in to comment.