Skip to content

Commit

Permalink
Initial JSON extensions for bson.
Browse files Browse the repository at this point in the history
  • Loading branch information
niemeyer committed May 6, 2016
1 parent 0308d06 commit 82635f8
Show file tree
Hide file tree
Showing 6 changed files with 366 additions and 3 deletions.
17 changes: 17 additions & 0 deletions bson/bson.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ import (
"crypto/rand"
"encoding/binary"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"io"
Expand Down Expand Up @@ -276,6 +277,22 @@ var nullBytes = []byte("null")

// UnmarshalJSON turns *bson.ObjectId into a json.Unmarshaller.
func (id *ObjectId) UnmarshalJSON(data []byte) error {
if len(data) > 0 && (data[0] == '{' || data[0] == 'O') {
var v struct {
Id json.RawMessage `json:"$oid"`
Func struct {
Id json.RawMessage
} `json:"$oidFunc"`
}
err := jdec(data, &v)
if err == nil {
if len(v.Id) > 0 {
data = []byte(v.Id)
} else {
data = []byte(v.Func.Id)
}
}
}
if len(data) == 2 && data[0] == '"' && data[1] == '"' || bytes.Equal(data, nullBytes) {
*id = ""
return nil
Expand Down
4 changes: 2 additions & 2 deletions bson/bson_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,10 +71,10 @@ func makeZeroDoc(value interface{}) (zero interface{}) {
case reflect.Ptr:
pv := reflect.New(v.Type().Elem())
zero = pv.Interface()
case reflect.Slice, reflect.Int:
case reflect.Slice, reflect.Int, reflect.Int64, reflect.Struct:
zero = reflect.New(t).Interface()
default:
panic("unsupported doc type")
panic("unsupported doc type: " + t.Name())
}
return zero
}
Expand Down
205 changes: 205 additions & 0 deletions bson/json.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,205 @@
package bson

import (
"bytes"
"fmt"
"gopkg.in/mgo.v2-unstable/internal/json"
)

func UnmarshalJSON(data []byte, value interface{}) error {
d := json.NewDecoder(bytes.NewBuffer(data))
d.Extend(&jsonExt)
return d.Decode(value)
}

func MarshalJSON(value interface{}) ([]byte, error) {
var buf bytes.Buffer
e := json.NewEncoder(&buf)
e.Extend(&jsonExt)
err := e.Encode(value)
if err != nil {
return nil, err
}
return buf.Bytes(), nil
}

// jdec is used internally by the JSON decoding functions
// so they may unmarshal functions without getting into endless
// recursion due to keyed objects.
func jdec(data []byte, value interface{}) error {
d := json.NewDecoder(bytes.NewBuffer(data))
d.Extend(&funcExt)
return d.Decode(value)
}

var jsonExt json.Extension
var funcExt json.Extension

func init() {
funcExt.DecodeFunc("ObjectId", "$oidFunc", "Id")
jsonExt.DecodeKeyed("$oid", jdecObjectId)
jsonExt.DecodeKeyed("$oidFunc", jdecObjectId)
jsonExt.EncodeType(ObjectId(""), jencObjectId)

funcExt.DecodeFunc("DBRef", "$dbrefFunc", "$ref", "$id")
jsonExt.DecodeKeyed("$dbrefFunc", jdecDBRef)

funcExt.DecodeFunc("NumberLong", "$numberLongFunc", "N")
jsonExt.DecodeKeyed("$numberLong", jdecNumberLong)
jsonExt.DecodeKeyed("$numberLongFunc", jdecNumberLong)
jsonExt.EncodeType(int64(0), jencNumberLong)
jsonExt.EncodeType(int(0), jencInt)

jsonExt.DecodeKeyed("$minKey", jdecMinKey)
jsonExt.DecodeKeyed("$maxKey", jdecMaxKey)
jsonExt.EncodeType(orderKey(0), jencMinMaxKey)

jsonExt.DecodeKeyed("$undefined", jdecUndefined)
jsonExt.EncodeType(Undefined, jencUndefined)

jsonExt.Extend(&funcExt)
}

func fbytes(format string, args ...interface{}) []byte {
var buf bytes.Buffer
fmt.Fprintf(&buf, format, args...)
return buf.Bytes()
}

func jdecObjectId(data []byte) (interface{}, error) {
println("Here!")
var v struct {
Id string `json:"$oid"`
Func struct {
Id string
} `json:"$oidFunc"`
}
err := jdec(data, &v)
if err != nil {
return nil, err
}
if v.Id == "" {
v.Id = v.Func.Id
}
return ObjectIdHex(v.Id), nil
}

func jencObjectId(v interface{}) ([]byte, error) {
return fbytes(`{"$oid":"%s"}`, v.(ObjectId).Hex()), nil
}

func jdecDBRef(data []byte) (interface{}, error) {
// TODO Support unmarshaling $ref and $id into the input value.
var v struct {
Obj map[string]interface{} `json:"$dbrefFunc"`
}
// TODO Fix this. Must not be required.
v.Obj = make(map[string]interface{})
err := jdec(data, &v)
if err != nil {
return nil, err
}
return v.Obj, nil
}

func jdecNumberLong(data []byte) (interface{}, error) {
var v struct {
N int64 `json:"$numberLong,string"`
Func struct {
N int64 `json:",string"`
} `json:"$numberLongFunc"`
}
var vn struct {
N int64 `json:"$numberLong"`
Func struct {
N int64
} `json:"$numberLongFunc"`
}
err := jdec(data, &v)
if err != nil {
err = jdec(data, &vn)
v.N = vn.N
v.Func.N = vn.Func.N
}
if err != nil {
return nil, err
}
if v.N != 0 {
return v.N, nil
}
return v.Func.N, nil
}

func jencNumberLong(v interface{}) ([]byte, error) {
n := v.(int64)
f := `{"$numberLong":"%d"}`
if n <= 1<<53 {
f = `{"$numberLong":%d}`
}
return fbytes(f, n), nil
}

func jencInt(v interface{}) ([]byte, error) {
n := v.(int)
f := `{"$numberLong":"%d"}`
if n <= 1<<53 {
f = `%d`
}
return fbytes(f, n), nil
}

func jdecMinKey(data []byte) (interface{}, error) {
var v struct {
N int64 `json:"$minKey"`
}
err := jdec(data, &v)
if err != nil {
return nil, err
}
if v.N != 1 {
return nil, fmt.Errorf("invalid $minKey object: %s", data)
}
return MinKey, nil
}

func jdecMaxKey(data []byte) (interface{}, error) {
var v struct {
N int64 `json:"$maxKey"`
}
err := jdec(data, &v)
if err != nil {
return nil, err
}
if v.N != 1 {
return nil, fmt.Errorf("invalid $maxKey object: %s", data)
}
return MaxKey, nil
}

func jencMinMaxKey(v interface{}) ([]byte, error) {
switch v.(orderKey) {
case MinKey:
return []byte(`{"$minKey":1}`), nil
case MaxKey:
return []byte(`{"$maxKey":1}`), nil
}
panic(fmt.Sprintf("invalid $minKey/$maxKey value: %d", v))
}

func jdecUndefined(data []byte) (interface{}, error) {
var v struct {
B bool `json:"$undefined"`
}
err := jdec(data, &v)
if err != nil {
return nil, err
}
if !v.B {
return nil, fmt.Errorf("invalid $undefined object: %s", data)
}
return Undefined, nil
}

func jencUndefined(v interface{}) ([]byte, error) {
return []byte(`{"$undefined":true}`), nil
}
121 changes: 121 additions & 0 deletions bson/json_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
package bson_test

import (
"gopkg.in/mgo.v2-unstable/bson"

. "gopkg.in/check.v1"
"reflect"
"strings"
)

type jsonTest struct {
a interface{}
b string
c interface{}
e string
}

var jsonTests = []jsonTest{
// $oid
{
a: bson.ObjectIdHex("0123456789abcdef01234567"),
b: `{"$oid":"0123456789abcdef01234567"}`,
}, {
b: `ObjectId("0123456789abcdef01234567")`,
c: bson.ObjectIdHex("0123456789abcdef01234567"),
},

// $ref (no special type)
{
b: `DBRef("name", "id")`,
c: map[string]interface{}{"$ref": "name", "$id": "id"},
},

// $numberLong
{
a: 123,
b: `123`,
}, {
a: int64(9007199254740992),
b: `{"$numberLong":9007199254740992}`,
}, {
a: int64(1<<53 + 1),
b: `{"$numberLong":"9007199254740993"}`,
}, {
a: 1<<53 + 1,
b: `{"$numberLong":"9007199254740993"}`,
c: int64(9007199254740993),
}, {
b: `NumberLong(9007199254740992)`,
c: int64(1 << 53),
}, {
b: `NumberLong("9007199254740993")`,
c: int64(1<<53 + 1),
},

// $minKey, $maxKey
{
a: bson.MinKey,
b: `{"$minKey":1}`,
}, {
a: bson.MaxKey,
b: `{"$maxKey":1}`,
}, {
b: `{"$minKey":0}`,
e: `invalid $minKey object: {"$minKey":0}`,
}, {
b: `{"$maxKey":0}`,
e: `invalid $maxKey object: {"$maxKey":0}`,
},

// $undefined
{
a: bson.Undefined,
b: `{"$undefined":true}`,
},
}

func (s *S) TestJSON(c *C) {
for _, item := range jsonTests {
c.Logf("------------")
c.Logf("A: %#v", item.a)
c.Logf("B: %#v", item.b)

if item.c == nil {
item.c = item.a
} else {
c.Logf("C: %#v", item.c)
}
if item.e != "" {
c.Logf("E: %s", item.e)
}

if item.a != nil {
data, err := bson.MarshalJSON(item.a)
c.Assert(err, IsNil)
c.Logf("Dumped: %#v", string(data))
c.Assert(strings.TrimSuffix(string(data), "\n"), Equals, item.b)
}

var zero interface{}
if item.c == nil {
zero = &struct{}{}
} else {
zero = reflect.New(reflect.TypeOf(item.c)).Interface()
}
err := bson.UnmarshalJSON([]byte(item.b), zero)
if item.e != "" {
c.Assert(err, NotNil)
c.Assert(err.Error(), Equals, item.e)
continue
}
c.Assert(err, IsNil)
zerov := reflect.ValueOf(zero)
value := zerov.Interface()
if zerov.Kind() == reflect.Ptr {
value = zerov.Elem().Interface()
}
c.Logf("Loaded: %#v", value)
c.Assert(value, DeepEquals, item.c)
}
}
6 changes: 5 additions & 1 deletion internal/json/decode.go
Original file line number Diff line number Diff line change
Expand Up @@ -1013,8 +1013,12 @@ func (d *decodeState) storeKeyed(v reflect.Value) bool {
return false
}
keyedv := reflect.ValueOf(keyed)
if keyedv.Type().AssignableTo(v.Type()) {
keyedt := keyedv.Type()
vt := v.Type()
if keyedt.AssignableTo(vt) {
v.Set(keyedv)
} else if keyedt.ConvertibleTo(vt) {
v.Set(keyedv.Convert(vt))
} else {
d.saveError(&UnmarshalTypeError{"object", v.Type(), int64(d.off)})
}
Expand Down
Loading

0 comments on commit 82635f8

Please sign in to comment.