Skip to content

Commit

Permalink
Merge pull request segmentio#3 from segmentio/marshal
Browse files Browse the repository at this point in the history
add marshaling methods
  • Loading branch information
achille-roussel authored May 16, 2017
2 parents 763d7d3 + 6539a48 commit 260b89c
Show file tree
Hide file tree
Showing 2 changed files with 188 additions and 0 deletions.
75 changes: 75 additions & 0 deletions ksuid.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package ksuid
import (
"bytes"
"crypto/rand"
"database/sql/driver"
"encoding/binary"
"fmt"
"io"
Expand Down Expand Up @@ -90,6 +91,80 @@ func (i KSUID) IsNil() bool {
return i == Nil
}

// Get satisfies the flag.Getter interface, making it possible to use KSUIDs as
// part of of the command line options of a program.
func (i KSUID) Get() interface{} {
return i
}

// Set satisfies the flag.Value interface, making it possible to use KSUIDs as
// part of of the command line options of a program.
func (i *KSUID) Set(s string) error {
return i.UnmarshalText([]byte(s))
}

func (i KSUID) MarshalText() ([]byte, error) {
return []byte(i.String()), nil
}

func (i KSUID) MarshalBinary() ([]byte, error) {
return i.Bytes(), nil
}

func (i *KSUID) UnmarshalText(b []byte) error {
id, err := Parse(string(b))
if err != nil {
return err
}
*i = id
return nil
}

func (i *KSUID) UnmarshalBinary(b []byte) error {
id, err := FromBytes(b)
if err != nil {
return err
}
*i = id
return nil
}

// Value converts the KSUID into a SQL driver value which can be used to
// directly use the KSUID as parameter to a SQL query.
func (i KSUID) Value() (driver.Value, error) {
return i.String(), nil
}

// Scan implements the sql.Scanner interface. It supports converting from
// string, []byte, or nil into a KSUID value. Attempting to convert from
// another type will return an error.
func (i *KSUID) Scan(src interface{}) error {
switch v := src.(type) {
case nil:
return i.scan(nil)
case []byte:
return i.scan(v)
case string:
return i.scan([]byte(v))
default:
return fmt.Errorf("Scan: unable to scan type %T into KSUID", v)
}
}

func (i *KSUID) scan(b []byte) error {
switch len(b) {
case 0:
*i = Nil
return nil
case byteLength:
return i.UnmarshalBinary(b)
case stringEncodedLength:
return i.UnmarshalText(b)
default:
return errSize
}
}

// Decodes a string-encoded representation of a KSUID object
func Parse(s string) (KSUID, error) {
if len(s) != stringEncodedLength {
Expand Down
113 changes: 113 additions & 0 deletions ksuid_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
package ksuid

import (
"bytes"
"encoding/json"
"flag"
"fmt"
"strings"
"testing"
"time"
Expand Down Expand Up @@ -103,3 +107,112 @@ func TestEncodeAndDecode(t *testing.T) {
t.Fatal("Parse(X).String() != X")
}
}

func TestMarshalText(t *testing.T) {
var id1 = New()
var id2 KSUID

if err := id2.UnmarshalText([]byte(id1.String())); err != nil {
t.Fatal(err)
}

if id1 != id2 {
t.Fatal(id1, "!=", id2)
}

if b, err := id2.MarshalText(); err != nil {
t.Fatal(err)
} else if s := string(b); s != id1.String() {
t.Fatal(s)
}
}

func TestMarshalBinary(t *testing.T) {
var id1 = New()
var id2 KSUID

if err := id2.UnmarshalBinary(id1.Bytes()); err != nil {
t.Fatal(err)
}

if id1 != id2 {
t.Fatal(id1, "!=", id2)
}

if b, err := id2.MarshalBinary(); err != nil {
t.Fatal(err)
} else if bytes.Compare(b, id1.Bytes()) != 0 {
t.Fatal("bad binary form:", id2)
}
}

func TestMashalJSON(t *testing.T) {
var id1 = New()
var id2 KSUID

if b, err := json.Marshal(id1); err != nil {
t.Fatal(err)
} else if err := json.Unmarshal(b, &id2); err != nil {
t.Fatal(err)
} else if id1 != id2 {
t.Error(id1, "!=", id2)
}
}

func TestFlag(t *testing.T) {
var id1 = New()
var id2 KSUID

fset := flag.NewFlagSet("test", flag.ContinueOnError)
fset.Var(&id2, "id", "the KSUID")

if err := fset.Parse([]string{"-id", id1.String()}); err != nil {
t.Fatal(err)
}

if id1 != id2 {
t.Error(id1, "!=", id2)
}
}

func TestSqlValuer(t *testing.T) {
id, _ := Parse(maxStringEncoded)

if v, err := id.Value(); err != nil {
t.Error(err)
} else if s, ok := v.(string); !ok {
t.Error("not a string value")
} else if s != maxStringEncoded {
t.Error("bad string value::", s)
}
}

func TestSqlScanner(t *testing.T) {
id1 := New()
id2 := New()

tests := []struct {
ksuid KSUID
value interface{}
}{
{Nil, nil},
{id1, id1.String()},
{id2, id2.Bytes()},
}

for _, test := range tests {
t.Run(fmt.Sprintf("%T", test.value), func(t *testing.T) {
var id KSUID

if err := id.Scan(test.value); err != nil {
t.Error(err)
}

if id != test.ksuid {
t.Error("bad KSUID:")
t.Logf("expected %v", test.ksuid)
t.Logf("found %v", id)
}
})
}
}

0 comments on commit 260b89c

Please sign in to comment.