Skip to content

Commit

Permalink
proto: add MarshalState, UnmarshalState
Browse files Browse the repository at this point in the history
Add functions to the proto package which plumb through the fast-path state.

As a sample use case: A followup CL adds an Initialized field to
protoiface.UnmarshalOutput, permitting the unmarshaller to report back
when it can confirm that a message is fully initialized. We want to
preserve that information when an unmarshal operation threads through
the proto package (such as when unmarshaling extensions).

To allow these functions to be added as methods of MarshalOptions and
UnmarshalOptions rather than top-level functions, separate the options
from the input structs.

Also update options passed to fast-path methods to set AllowPartial and
Merge to reflect the expected behavior of those methods. (Always allow
partial, never merge.)

Change-Id: I482477b0c9340793be533e75a86d0bb88708716a
Reviewed-on: https://go-review.googlesource.com/c/protobuf/+/215877
Reviewed-by: Joe Tsai <[email protected]>
  • Loading branch information
neild committed Jan 22, 2020
1 parent f0831e8 commit d30e561
Show file tree
Hide file tree
Showing 8 changed files with 90 additions and 60 deletions.
4 changes: 2 additions & 2 deletions internal/impl/decode.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,14 +62,14 @@ type unmarshalOutput struct {
}

// unmarshal is protoreflect.Methods.Unmarshal.
func (mi *MessageInfo) unmarshal(m pref.Message, in piface.UnmarshalInput) (piface.UnmarshalOutput, error) {
func (mi *MessageInfo) unmarshal(m pref.Message, in piface.UnmarshalInput, opts piface.UnmarshalOptions) (piface.UnmarshalOutput, error) {
var p pointer
if ms, ok := m.(*messageState); ok {
p = ms.pointer()
} else {
p = m.(*messageReflectWrapper).pointer()
}
_, err := mi.unmarshalPointer(in.Buf, p, 0, newUnmarshalOptions(in.Options))
_, err := mi.unmarshalPointer(in.Buf, p, 0, newUnmarshalOptions(opts))
return piface.UnmarshalOutput{}, err
}

Expand Down
4 changes: 2 additions & 2 deletions internal/impl/encode.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,14 +102,14 @@ func (mi *MessageInfo) sizePointerSlow(p pointer, opts marshalOptions) (size int
}

// marshal is protoreflect.Methods.Marshal.
func (mi *MessageInfo) marshal(m pref.Message, in piface.MarshalInput) (piface.MarshalOutput, error) {
func (mi *MessageInfo) marshal(m pref.Message, in piface.MarshalInput, opts piface.MarshalOptions) (piface.MarshalOutput, error) {
var p pointer
if ms, ok := m.(*messageState); ok {
p = ms.pointer()
} else {
p = m.(*messageReflectWrapper).pointer()
}
b, err := mi.marshalAppendPointer(in.Buf, p, newMarshalOptions(in.Options))
b, err := mi.marshalAppendPointer(in.Buf, p, newMarshalOptions(opts))
return piface.MarshalOutput{Buf: b}, err
}

Expand Down
4 changes: 2 additions & 2 deletions internal/impl/legacy_message.go
Original file line number Diff line number Diff line change
Expand Up @@ -373,7 +373,7 @@ var legacyProtoMethods = &piface.Methods{
Flags: piface.SupportMarshalDeterministic,
}

func legacyMarshal(m protoreflect.Message, in piface.MarshalInput) (piface.MarshalOutput, error) {
func legacyMarshal(m protoreflect.Message, in piface.MarshalInput, opts piface.MarshalOptions) (piface.MarshalOutput, error) {
v := m.(unwrapper).protoUnwrap()
marshaler, ok := v.(legacyMarshaler)
if !ok {
Expand All @@ -388,7 +388,7 @@ func legacyMarshal(m protoreflect.Message, in piface.MarshalInput) (piface.Marsh
}, err
}

func legacyUnmarshal(m protoreflect.Message, in piface.UnmarshalInput) (piface.UnmarshalOutput, error) {
func legacyUnmarshal(m protoreflect.Message, in piface.UnmarshalInput, opts piface.UnmarshalOptions) (piface.UnmarshalOutput, error) {
v := m.(unwrapper).protoUnwrap()
unmarshaler, ok := v.(legacyUnmarshaler)
if !ok {
Expand Down
52 changes: 35 additions & 17 deletions proto/decode.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,38 +47,56 @@ var _ = protoiface.UnmarshalOptions(UnmarshalOptions{})

// Unmarshal parses the wire-format message in b and places the result in m.
func Unmarshal(b []byte, m Message) error {
return UnmarshalOptions{}.Unmarshal(b, m)
_, err := UnmarshalOptions{}.unmarshal(b, m)
return err
}

// Unmarshal parses the wire-format message in b and places the result in m.
func (o UnmarshalOptions) Unmarshal(b []byte, m Message) error {
_, err := o.unmarshal(b, m)
return err
}

// UnmarshalState parses a wire-format message and places the result in m.
//
// This method permits fine-grained control over the unmarshaler.
// Most users should use Unmarshal instead.
func (o UnmarshalOptions) UnmarshalState(m Message, in protoiface.UnmarshalInput) (protoiface.UnmarshalOutput, error) {
return o.unmarshal(in.Buf, m)
}

func (o UnmarshalOptions) unmarshal(b []byte, message Message) (out protoiface.UnmarshalOutput, err error) {
if o.Resolver == nil {
o.Resolver = protoregistry.GlobalTypes
}

if !o.Merge {
Reset(m)
Reset(message)
}
allowPartial := o.AllowPartial
o.Merge = true
o.AllowPartial = true
m := message.ProtoReflect()
methods := protoMethods(m)
if methods != nil && methods.Unmarshal != nil &&
!(o.DiscardUnknown && methods.Flags&protoiface.SupportUnmarshalDiscardUnknown == 0) {
out, err = methods.Unmarshal(m, protoiface.UnmarshalInput{
Buf: b,
}, protoiface.UnmarshalOptions(o))
} else {
err = o.unmarshalMessageSlow(b, m)
}
err := o.unmarshalMessage(b, m.ProtoReflect())
if err != nil {
return err
return out, err
}
if o.AllowPartial {
return nil
if allowPartial {
return out, nil
}
return IsInitialized(m)
return out, isInitialized(m)
}

func (o UnmarshalOptions) unmarshalMessage(b []byte, m protoreflect.Message) error {
if methods := protoMethods(m); methods != nil && methods.Unmarshal != nil &&
!(o.DiscardUnknown && methods.Flags&protoiface.SupportUnmarshalDiscardUnknown == 0) {
_, err := methods.Unmarshal(m, protoiface.UnmarshalInput{
Buf: b,
Options: protoiface.UnmarshalOptions(o),
})
return err
}
return o.unmarshalMessageSlow(b, m)
_, err := o.unmarshal(b, m.Interface())
return err
}

func (o UnmarshalOptions) unmarshalMessageSlow(b []byte, m protoreflect.Message) error {
Expand Down
52 changes: 35 additions & 17 deletions proto/encode.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,28 +76,35 @@ var _ = protoiface.MarshalOptions(MarshalOptions{})

// Marshal returns the wire-format encoding of m.
func Marshal(m Message) ([]byte, error) {
return MarshalOptions{}.MarshalAppend(nil, m)
out, err := MarshalOptions{}.marshal(nil, m)
return out.Buf, err
}

// Marshal returns the wire-format encoding of m.
func (o MarshalOptions) Marshal(m Message) ([]byte, error) {
return o.MarshalAppend(nil, m)
out, err := o.marshal(nil, m)
return out.Buf, err
}

// MarshalAppend appends the wire-format encoding of m to b,
// returning the result.
func (o MarshalOptions) MarshalAppend(b []byte, m Message) ([]byte, error) {
out, err := o.marshalMessage(b, m.ProtoReflect())
if err != nil {
return out, err
}
if o.AllowPartial {
return out, nil
}
return out, IsInitialized(m)
out, err := o.marshal(b, m)
return out.Buf, err
}

func (o MarshalOptions) marshalMessage(b []byte, m protoreflect.Message) ([]byte, error) {
// MarshalState returns the wire-format encoding of m.
//
// This method permits fine-grained control over the marshaler.
// Most users should use Marshal instead.
func (o MarshalOptions) MarshalState(m Message, in protoiface.MarshalInput) (protoiface.MarshalOutput, error) {
return o.marshal(in.Buf, m)
}

func (o MarshalOptions) marshal(b []byte, message Message) (out protoiface.MarshalOutput, err error) {
allowPartial := o.AllowPartial
o.AllowPartial = true
m := message.ProtoReflect()
if methods := protoMethods(m); methods != nil && methods.Marshal != nil &&
!(o.Deterministic && methods.Flags&protoiface.SupportMarshalDeterministic == 0) {
if methods.Size != nil {
Expand All @@ -109,13 +116,24 @@ func (o MarshalOptions) marshalMessage(b []byte, m protoreflect.Message) ([]byte
}
o.UseCachedSize = true
}
out, err := methods.Marshal(m, protoiface.MarshalInput{
Buf: b,
Options: protoiface.MarshalOptions(o),
})
return out.Buf, err
out, err = methods.Marshal(m, protoiface.MarshalInput{
Buf: b,
}, protoiface.MarshalOptions(o))
} else {
out.Buf, err = o.marshalMessageSlow(b, m)
}
return o.marshalMessageSlow(b, m)
if err != nil {
return out, err
}
if allowPartial {
return out, nil
}
return out, isInitialized(m)
}

func (o MarshalOptions) marshalMessage(b []byte, m protoreflect.Message) ([]byte, error) {
out, err := o.marshal(b, m.Interface())
return out.Buf, err
}

// growcap scales up the capacity of a slice.
Expand Down
2 changes: 1 addition & 1 deletion proto/size.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ func sizeMessage(m protoreflect.Message) (size int) {
if methods != nil && methods.Marshal != nil {
// This is not efficient, but we don't have any choice.
// This case is mainly used for legacy types with a Marshal method.
out, _ := methods.Marshal(m, protoiface.MarshalInput{})
out, _ := methods.Marshal(m, protoiface.MarshalInput{}, protoiface.MarshalOptions{})
return len(out.Buf)
}
return sizeMessageSlow(m)
Expand Down
14 changes: 6 additions & 8 deletions reflect/protoreflect/methods.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,15 @@ type (
methods = struct {
pragma.NoUnkeyedLiterals
Flags supportFlags
Size func(m Message, opts marshalOptions) int
Marshal func(m Message, in marshalInput) (marshalOutput, error)
Unmarshal func(m Message, in unmarshalInput) (unmarshalOutput, error)
IsInitialized func(m Message) error
Size func(Message, marshalOptions) int
Marshal func(Message, marshalInput, marshalOptions) (marshalOutput, error)
Unmarshal func(Message, unmarshalInput, unmarshalOptions) (unmarshalOutput, error)
IsInitialized func(Message) error
}
supportFlags = uint64
marshalInput = struct {
pragma.NoUnkeyedLiterals
Buf []byte
Options marshalOptions
Buf []byte
}
marshalOutput = struct {
pragma.NoUnkeyedLiterals
Expand All @@ -41,8 +40,7 @@ type (
}
unmarshalInput = struct {
pragma.NoUnkeyedLiterals
Buf []byte
Options unmarshalOptions
Buf []byte
}
unmarshalOutput = struct {
pragma.NoUnkeyedLiterals
Expand Down
18 changes: 7 additions & 11 deletions runtime/protoiface/methods.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,10 @@ type Methods = struct {

// Marshal writes the wire-format encoding of m to the provided buffer.
// Size should be provided if a custom MarshalAppend is provided.
// It must not perform required field checks.
Marshal func(m protoreflect.Message, in MarshalInput) (MarshalOutput, error)
Marshal func(m protoreflect.Message, in MarshalInput, opts MarshalOptions) (MarshalOutput, error)

// Unmarshal parses the wire-format encoding of a message and merges the result to m.
// It must not reset m or perform required field checks.
Unmarshal func(m protoreflect.Message, in UnmarshalInput) (UnmarshalOutput, error)
Unmarshal func(m protoreflect.Message, in UnmarshalInput, opts UnmarshalOptions) (UnmarshalOutput, error)

// IsInitialized returns an error if any required fields in m are not set.
IsInitialized func(m protoreflect.Message) error
Expand All @@ -52,8 +50,7 @@ const (
type MarshalInput = struct {
pragma.NoUnkeyedLiterals

Buf []byte // output is appended to this buffer
Options MarshalOptions
Buf []byte // output is appended to this buffer
}

// MarshalOutput is output from the marshaler.
Expand All @@ -69,7 +66,7 @@ type MarshalOutput = struct {
type MarshalOptions = struct {
pragma.NoUnkeyedLiterals

AllowPartial bool // must be treated as true by method implementations
AllowPartial bool // may be treated as true by method implementations
Deterministic bool
UseCachedSize bool
}
Expand All @@ -78,8 +75,7 @@ type MarshalOptions = struct {
type UnmarshalInput = struct {
pragma.NoUnkeyedLiterals

Buf []byte // input buffer
Options UnmarshalOptions
Buf []byte // input buffer
}

// UnmarshalOutput is output from the unmarshaler.
Expand All @@ -95,8 +91,8 @@ type UnmarshalOutput = struct {
type UnmarshalOptions = struct {
pragma.NoUnkeyedLiterals

Merge bool // must be treated as true by method implementations
AllowPartial bool // must be treated as true by method implementations
Merge bool // may be treated as true by method implementations
AllowPartial bool // may be treated as true by method implementations
DiscardUnknown bool
Resolver interface {
FindExtensionByName(field protoreflect.FullName) (protoreflect.ExtensionType, error)
Expand Down

0 comments on commit d30e561

Please sign in to comment.