diff --git a/common/types.go b/common/types.go index 70b7e7aae86c..8a456e965e94 100644 --- a/common/types.go +++ b/common/types.go @@ -17,14 +17,12 @@ package common import ( - "encoding/hex" - "encoding/json" - "errors" "fmt" "math/big" "math/rand" "reflect" - "strings" + + "github.com/ethereum/go-ethereum/common/hexutil" ) const ( @@ -32,8 +30,6 @@ const ( AddressLength = 20 ) -var hashJsonLengthErr = errors.New("common: unmarshalJSON failed: hash must be exactly 32 bytes") - type ( // Hash represents the 32 byte Keccak256 hash of arbitrary data. Hash [HashLength]byte @@ -57,30 +53,16 @@ func HexToHash(s string) Hash { return BytesToHash(FromHex(s)) } func (h Hash) Str() string { return string(h[:]) } func (h Hash) Bytes() []byte { return h[:] } func (h Hash) Big() *big.Int { return Bytes2Big(h[:]) } -func (h Hash) Hex() string { return "0x" + Bytes2Hex(h[:]) } +func (h Hash) Hex() string { return hexutil.Encode(h[:]) } // UnmarshalJSON parses a hash in its hex from to a hash. func (h *Hash) UnmarshalJSON(input []byte) error { - length := len(input) - if length >= 2 && input[0] == '"' && input[length-1] == '"' { - input = input[1 : length-1] - } - // strip "0x" for length check - if len(input) > 1 && strings.ToLower(string(input[:2])) == "0x" { - input = input[2:] - } - - // validate the length of the input hash - if len(input) != HashLength*2 { - return hashJsonLengthErr - } - h.SetBytes(FromHex(string(input))) - return nil + return hexutil.UnmarshalJSON("Hash", input, h[:]) } // Serialize given hash to JSON func (h Hash) MarshalJSON() ([]byte, error) { - return json.Marshal(h.Hex()) + return hexutil.Bytes(h[:]).MarshalJSON() } // Sets the hash to the value of b. If b is larger than len(h) it will panic @@ -142,7 +124,7 @@ func (a Address) Str() string { return string(a[:]) } func (a Address) Bytes() []byte { return a[:] } func (a Address) Big() *big.Int { return Bytes2Big(a[:]) } func (a Address) Hash() Hash { return BytesToHash(a[:]) } -func (a Address) Hex() string { return "0x" + Bytes2Hex(a[:]) } +func (a Address) Hex() string { return hexutil.Encode(a[:]) } // Sets the address to the value of b. If b is larger than len(a) it will panic func (a *Address) SetBytes(b []byte) { @@ -164,34 +146,12 @@ func (a *Address) Set(other Address) { // Serialize given address to JSON func (a Address) MarshalJSON() ([]byte, error) { - return json.Marshal(a.Hex()) + return hexutil.Bytes(a[:]).MarshalJSON() } // Parse address from raw json data -func (a *Address) UnmarshalJSON(data []byte) error { - if len(data) > 2 && data[0] == '"' && data[len(data)-1] == '"' { - data = data[1 : len(data)-1] - } - - if len(data) > 2 && data[0] == '0' && data[1] == 'x' { - data = data[2:] - } - - if len(data) != 2*AddressLength { - return fmt.Errorf("Invalid address length, expected %d got %d bytes", 2*AddressLength, len(data)) - } - - n, err := hex.Decode(a[:], data) - if err != nil { - return err - } - - if n != AddressLength { - return fmt.Errorf("Invalid address") - } - - a.Set(HexToAddress(string(data))) - return nil +func (a *Address) UnmarshalJSON(input []byte) error { + return hexutil.UnmarshalJSON("Address", input, a[:]) } // PP Pretty Prints a byte slice in the following format: diff --git a/common/types_test.go b/common/types_test.go index de67cfcb5f7f..e84780f43dee 100644 --- a/common/types_test.go +++ b/common/types_test.go @@ -18,7 +18,10 @@ package common import ( "math/big" + "strings" "testing" + + "github.com/ethereum/go-ethereum/common/hexutil" ) func TestBytesConversion(t *testing.T) { @@ -38,19 +41,26 @@ func TestHashJsonValidation(t *testing.T) { var tests = []struct { Prefix string Size int - Error error + Error string }{ - {"", 2, hashJsonLengthErr}, - {"", 62, hashJsonLengthErr}, - {"", 66, hashJsonLengthErr}, - {"", 65, hashJsonLengthErr}, - {"0X", 64, nil}, - {"0x", 64, nil}, - {"0x", 62, hashJsonLengthErr}, + {"", 62, hexutil.ErrMissingPrefix.Error()}, + {"0x", 66, "hex string has length 66, want 64 for Hash"}, + {"0x", 63, hexutil.ErrOddLength.Error()}, + {"0x", 0, "hex string has length 0, want 64 for Hash"}, + {"0x", 64, ""}, + {"0X", 64, ""}, } - for i, test := range tests { - if err := h.UnmarshalJSON(append([]byte(test.Prefix), make([]byte, test.Size)...)); err != test.Error { - t.Errorf("test #%d: error mismatch: have %v, want %v", i, err, test.Error) + for _, test := range tests { + input := `"` + test.Prefix + strings.Repeat("0", test.Size) + `"` + err := h.UnmarshalJSON([]byte(input)) + if err == nil { + if test.Error != "" { + t.Errorf("%s: error mismatch: have nil, want %q", input, test.Error) + } + } else { + if err.Error() != test.Error { + t.Errorf("%s: error mismatch: have %q, want %q", input, err, test.Error) + } } } }