Skip to content

Commit

Permalink
chore: cleanup dialects and complete binary type serialization
Browse files Browse the repository at this point in the history
  • Loading branch information
vmihailenco committed Oct 22, 2021
1 parent e899f50 commit 5fabe44
Show file tree
Hide file tree
Showing 25 changed files with 244 additions and 488 deletions.
37 changes: 2 additions & 35 deletions dialect/append.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import (
"unicode/utf8"

"github.com/uptrace/bun/internal"
"github.com/uptrace/bun/internal/parser"
)

func AppendError(b []byte, err error) []byte {
Expand Down Expand Up @@ -77,12 +76,12 @@ func AppendString(b []byte, s string) []byte {
return b
}

func AppendBytes(b []byte, bs []byte) []byte {
func AppendBytes(b, bs []byte) []byte {
if bs == nil {
return AppendNull(b)
}

b = append(b, `X'`...)
b = append(b, `'\x`...)

s := len(b)
b = append(b, make([]byte, hex.EncodedLen(len(bs)))...)
Expand All @@ -93,38 +92,6 @@ func AppendBytes(b []byte, bs []byte) []byte {
return b
}

func AppendJSON(b, jsonb []byte) []byte {
b = append(b, '\'')

p := parser.New(jsonb)
for p.Valid() {
c := p.Read()
switch c {
case '"':
b = append(b, '"')
case '\'':
b = append(b, "''"...)
case '\000':
continue
case '\\':
if p.SkipBytes([]byte("u0000")) {
b = append(b, `\\u0000`...)
} else {
b = append(b, '\\')
if p.Valid() {
b = append(b, p.Read())
}
}
default:
b = append(b, c)
}
}

b = append(b, '\'')

return b
}

//------------------------------------------------------------------------------

func AppendIdent(b []byte, field string, quote byte) []byte {
Expand Down
63 changes: 0 additions & 63 deletions dialect/mysqldialect/append.go

This file was deleted.

98 changes: 29 additions & 69 deletions dialect/mysqldialect/dialect.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,9 @@ package mysqldialect

import (
"database/sql"
"encoding/hex"
"log"
"reflect"
"strconv"
"strings"
"sync"
"time"

"golang.org/x/mod/semver"
Expand All @@ -20,11 +18,10 @@ import (
const datetimeType = "DATETIME"

type Dialect struct {
schema.BaseDialect

tables *schema.Tables
features feature.Feature

appenderMap sync.Map
scannerMap sync.Map
}

func New() *Dialect {
Expand Down Expand Up @@ -86,82 +83,45 @@ func (d *Dialect) IdentQuote() byte {
}

func (d *Dialect) AppendTime(b []byte, tm time.Time) []byte {
return appendTime(b, tm)
b = append(b, '\'')
b = tm.AppendFormat(b, "2006-01-02 15:04:05.999999")
b = append(b, '\'')
return b
}

func (d *Dialect) Append(fmter schema.Formatter, b []byte, v interface{}) []byte {
switch v := v.(type) {
case nil:
func (d *Dialect) AppendBytes(b []byte, bs []byte) []byte {
if bs == nil {
return dialect.AppendNull(b)
case bool:
return dialect.AppendBool(b, v)
case int:
return strconv.AppendInt(b, int64(v), 10)
case int32:
return strconv.AppendInt(b, int64(v), 10)
case int64:
return strconv.AppendInt(b, v, 10)
case uint:
return strconv.AppendUint(b, uint64(v), 10)
case uint32:
return strconv.AppendUint(b, uint64(v), 10)
case uint64:
return strconv.AppendUint(b, v, 10)
case float32:
return dialect.AppendFloat32(b, v)
case float64:
return dialect.AppendFloat64(b, v)
case string:
return dialect.AppendString(b, v)
case time.Time:
return appendTime(b, v)
case []byte:
return dialect.AppendBytes(b, v)
case schema.QueryAppender:
return schema.AppendQueryAppender(fmter, b, v)
default:
vv := reflect.ValueOf(v)
if vv.Kind() == reflect.Ptr && vv.IsNil() {
return dialect.AppendNull(b)
}
appender := d.Appender(vv.Type())
return appender(fmter, b, vv)
}
}

func (d *Dialect) Appender(typ reflect.Type) schema.AppenderFunc {
if v, ok := d.appenderMap.Load(typ); ok {
return v.(schema.AppenderFunc)
}
b = append(b, `X'`...)

fn := schema.Appender(typ, customAppender)
s := len(b)
b = append(b, make([]byte, hex.EncodedLen(len(bs)))...)
hex.Encode(b[s:], bs)

if v, ok := d.appenderMap.LoadOrStore(typ, fn); ok {
return v.(schema.AppenderFunc)
}
return fn
}

func (d *Dialect) FieldAppender(field *schema.Field) schema.AppenderFunc {
switch strings.ToUpper(field.UserSQLType) {
case sqltype.JSON:
return appendJSONValue
}
b = append(b, '\'')

return schema.FieldAppender(d, field)
return b
}

func (d *Dialect) Scanner(typ reflect.Type) schema.ScannerFunc {
if v, ok := d.scannerMap.Load(typ); ok {
return v.(schema.ScannerFunc)
func (d *Dialect) AppendJSON(b, jsonb []byte) []byte {
b = append(b, '\'')

for _, c := range jsonb {
switch c {
case '\'':
b = append(b, "''"...)
case '\\':
b = append(b, `\\`...)
default:
b = append(b, c)
}
}

fn := scanner(typ)
b = append(b, '\'')

if v, ok := d.scannerMap.LoadOrStore(typ, fn); ok {
return v.(schema.ScannerFunc)
}
return fn
return b
}

func sqlType(field *schema.Field) string {
Expand Down
65 changes: 19 additions & 46 deletions dialect/pgdialect/append.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,33 +29,6 @@ var (
sliceFloat64Type = reflect.TypeOf([]float64(nil))
)

func customAppender(typ reflect.Type) schema.AppenderFunc {
switch typ.Kind() {
case reflect.Uint32:
return appendUint32ValueAsInt
case reflect.Uint, reflect.Uint64:
return appendUint64ValueAsInt
}
return nil
}

func appendTime(b []byte, tm time.Time) []byte {
b = append(b, '\'')
b = tm.UTC().AppendFormat(b, "2006-01-02 15:04:05.999999-07:00")
b = append(b, '\'')
return b
}

func appendUint32ValueAsInt(fmter schema.Formatter, b []byte, v reflect.Value) []byte {
return strconv.AppendInt(b, int64(int32(v.Uint())), 10)
}

func appendUint64ValueAsInt(fmter schema.Formatter, b []byte, v reflect.Value) []byte {
return strconv.AppendInt(b, int64(v.Uint()), 10)
}

//------------------------------------------------------------------------------

func arrayAppend(fmter schema.Formatter, b []byte, v interface{}) []byte {
switch v := v.(type) {
case int64:
Expand All @@ -69,28 +42,13 @@ func arrayAppend(fmter schema.Formatter, b []byte, v interface{}) []byte {
case string:
return arrayAppendString(b, v)
case time.Time:
return appendTime(b, v)
return fmter.Dialect().AppendTime(b, v)
default:
err := fmt.Errorf("pgdialect: can't append %T", v)
return dialect.AppendError(b, err)
}
}

func arrayElemAppender(typ reflect.Type) schema.AppenderFunc {
if typ.Implements(driverValuerType) {
return arrayAppendDriverValue
}
switch typ.Kind() {
case reflect.String:
return arrayAppendStringValue
case reflect.Slice:
if typ.Elem().Kind() == reflect.Uint8 {
return arrayAppendBytesValue
}
}
return schema.Appender(typ, customAppender)
}

func arrayAppendStringValue(fmter schema.Formatter, b []byte, v reflect.Value) []byte {
return arrayAppendString(b, v.String())
}
Expand All @@ -109,12 +67,12 @@ func arrayAppendDriverValue(fmter schema.Formatter, b []byte, v reflect.Value) [

//------------------------------------------------------------------------------

func arrayAppender(typ reflect.Type) schema.AppenderFunc {
func (d *Dialect) arrayAppender(typ reflect.Type) schema.AppenderFunc {
kind := typ.Kind()

switch kind {
case reflect.Ptr:
if fn := arrayAppender(typ.Elem()); fn != nil {
if fn := d.arrayAppender(typ.Elem()); fn != nil {
return schema.PtrAppender(fn)
}
case reflect.Slice, reflect.Array:
Expand All @@ -138,7 +96,7 @@ func arrayAppender(typ reflect.Type) schema.AppenderFunc {
}
}

appendElem := arrayElemAppender(elemType)
appendElem := d.arrayElemAppender(elemType)
if appendElem == nil {
panic(fmt.Errorf("pgdialect: %s is not supported", typ))
}
Expand Down Expand Up @@ -176,6 +134,21 @@ func arrayAppender(typ reflect.Type) schema.AppenderFunc {
}
}

func (d *Dialect) arrayElemAppender(typ reflect.Type) schema.AppenderFunc {
if typ.Implements(driverValuerType) {
return arrayAppendDriverValue
}
switch typ.Kind() {
case reflect.String:
return arrayAppendStringValue
case reflect.Slice:
if typ.Elem().Kind() == reflect.Uint8 {
return arrayAppendBytesValue
}
}
return schema.Appender(d, typ)
}

func appendStringSliceValue(fmter schema.Formatter, b []byte, v reflect.Value) []byte {
ss := v.Convert(sliceStringType).Interface().([]string)
return appendStringSlice(b, ss)
Expand Down
2 changes: 1 addition & 1 deletion dialect/pgdialect/array.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ func Array(vi interface{}) *ArrayValue {
return &ArrayValue{
v: v,

append: arrayAppender(v.Type()),
append: pgDialect.arrayAppender(v.Type()),
scan: arrayScanner(v.Type()),
}
}
Expand Down
Loading

0 comments on commit 5fabe44

Please sign in to comment.