diff --git a/cmd/protoc-gen-go-grpc/testdata/go.mod b/cmd/protoc-gen-go-grpc/testdata/go.mod index 67f6a1cf6..2b2c71693 100644 --- a/cmd/protoc-gen-go-grpc/testdata/go.mod +++ b/cmd/protoc-gen-go-grpc/testdata/go.mod @@ -1,7 +1,7 @@ module google.golang.org/protobuf/cmd/protoc-gen-go-grpc/testdata require ( - github.com/golang/protobuf v1.2.1-0.20190717234224-b9f5089fb9d4 + github.com/golang/protobuf v1.2.1-0.20190806214225-7037721e6de0 google.golang.org/grpc v1.19.0 google.golang.org/protobuf v1.0.0 ) diff --git a/cmd/protoc-gen-go/testdata/go.mod b/cmd/protoc-gen-go/testdata/go.mod index fbf038f9a..2413632e3 100644 --- a/cmd/protoc-gen-go/testdata/go.mod +++ b/cmd/protoc-gen-go/testdata/go.mod @@ -1,7 +1,7 @@ module google.golang.org/protobuf/cmd/protoc-gen-go/testdata require ( - github.com/golang/protobuf v1.2.1-0.20190717234224-b9f5089fb9d4 + github.com/golang/protobuf v1.2.1-0.20190806214225-7037721e6de0 google.golang.org/protobuf v1.0.0 ) diff --git a/encoding/protojson/decode.go b/encoding/protojson/decode.go index 7c0da2384..1d22d3a40 100644 --- a/encoding/protojson/decode.go +++ b/encoding/protojson/decode.go @@ -182,7 +182,9 @@ func (o UnmarshalOptions) unmarshalFields(m pref.Message, skipTypeURL bool) erro if err != nil && err != protoregistry.NotFound { return errors.New("unable to resolve [%v]: %v", extName, err) } - fd = extType + if extType != nil { + fd = extType.Descriptor() + } } else { // The name can either be the JSON name or the proto field name. fd = fieldDescs.ByJSONName(name) diff --git a/encoding/protojson/decode_test.go b/encoding/protojson/decode_test.go index 2385dc4c2..9c4f93e61 100644 --- a/encoding/protojson/decode_test.go +++ b/encoding/protojson/decode_test.go @@ -1200,10 +1200,10 @@ func TestUnmarshal(t *testing.T) { OptBool: proto.Bool(true), OptInt32: proto.Int32(42), } - setExtension(m, pb2.E_OptExtBool, true) - setExtension(m, pb2.E_OptExtString, "extension field") - setExtension(m, pb2.E_OptExtEnum, pb2.Enum_TEN) - setExtension(m, pb2.E_OptExtNested, &pb2.Nested{ + proto.SetExtension(m, pb2.E_OptExtBool, true) + proto.SetExtension(m, pb2.E_OptExtString, "extension field") + proto.SetExtension(m, pb2.E_OptExtEnum, pb2.Enum_TEN) + proto.SetExtension(m, pb2.E_OptExtNested, &pb2.Nested{ OptString: proto.String("nested in an extension"), OptNested: &pb2.Nested{ OptString: proto.String("another nested in an extension"), @@ -1225,9 +1225,9 @@ func TestUnmarshal(t *testing.T) { }`, wantMessage: func() proto.Message { m := &pb2.Extensions{} - setExtension(m, pb2.E_RptExtEnum, &[]pb2.Enum{pb2.Enum_TEN, 101, pb2.Enum_ONE}) - setExtension(m, pb2.E_RptExtFixed32, &[]uint32{42, 47}) - setExtension(m, pb2.E_RptExtNested, &[]*pb2.Nested{ + proto.SetExtension(m, pb2.E_RptExtEnum, &[]pb2.Enum{pb2.Enum_TEN, 101, pb2.Enum_ONE}) + proto.SetExtension(m, pb2.E_RptExtFixed32, &[]uint32{42, 47}) + proto.SetExtension(m, pb2.E_RptExtNested, &[]*pb2.Nested{ &pb2.Nested{OptString: proto.String("one")}, &pb2.Nested{OptString: proto.String("two")}, &pb2.Nested{OptString: proto.String("three")}, @@ -1250,10 +1250,10 @@ func TestUnmarshal(t *testing.T) { }`, wantMessage: func() proto.Message { m := &pb2.Extensions{} - setExtension(m, pb2.E_ExtensionsContainer_OptExtBool, true) - setExtension(m, pb2.E_ExtensionsContainer_OptExtString, "extension field") - setExtension(m, pb2.E_ExtensionsContainer_OptExtEnum, pb2.Enum_TEN) - setExtension(m, pb2.E_ExtensionsContainer_OptExtNested, &pb2.Nested{ + proto.SetExtension(m, pb2.E_ExtensionsContainer_OptExtBool, true) + proto.SetExtension(m, pb2.E_ExtensionsContainer_OptExtString, "extension field") + proto.SetExtension(m, pb2.E_ExtensionsContainer_OptExtEnum, pb2.Enum_TEN) + proto.SetExtension(m, pb2.E_ExtensionsContainer_OptExtNested, &pb2.Nested{ OptString: proto.String("nested in an extension"), OptNested: &pb2.Nested{ OptString: proto.String("another nested in an extension"), @@ -1282,9 +1282,9 @@ func TestUnmarshal(t *testing.T) { OptBool: proto.Bool(true), OptInt32: proto.Int32(42), } - setExtension(m, pb2.E_ExtensionsContainer_RptExtEnum, &[]pb2.Enum{pb2.Enum_TEN, 101, pb2.Enum_ONE}) - setExtension(m, pb2.E_ExtensionsContainer_RptExtString, &[]string{"hello", "world"}) - setExtension(m, pb2.E_ExtensionsContainer_RptExtNested, &[]*pb2.Nested{ + proto.SetExtension(m, pb2.E_ExtensionsContainer_RptExtEnum, &[]pb2.Enum{pb2.Enum_TEN, 101, pb2.Enum_ONE}) + proto.SetExtension(m, pb2.E_ExtensionsContainer_RptExtString, &[]string{"hello", "world"}) + proto.SetExtension(m, pb2.E_ExtensionsContainer_RptExtNested, &[]*pb2.Nested{ &pb2.Nested{OptString: proto.String("one")}, &pb2.Nested{OptString: proto.String("two")}, &pb2.Nested{OptString: proto.String("three")}, @@ -1323,13 +1323,13 @@ func TestUnmarshal(t *testing.T) { }`, wantMessage: func() proto.Message { m := &pb2.MessageSet{} - setExtension(m, pb2.E_MessageSetExtension_MessageSetExtension, &pb2.MessageSetExtension{ + proto.SetExtension(m, pb2.E_MessageSetExtension_MessageSetExtension, &pb2.MessageSetExtension{ OptString: proto.String("a messageset extension"), }) - setExtension(m, pb2.E_MessageSetExtension_NotMessageSetExtension, &pb2.MessageSetExtension{ + proto.SetExtension(m, pb2.E_MessageSetExtension_NotMessageSetExtension, &pb2.MessageSetExtension{ OptString: proto.String("not a messageset extension"), }) - setExtension(m, pb2.E_MessageSetExtension_ExtNested, &pb2.Nested{ + proto.SetExtension(m, pb2.E_MessageSetExtension_ExtNested, &pb2.Nested{ OptString: proto.String("just a regular extension"), }) return m @@ -1345,7 +1345,7 @@ func TestUnmarshal(t *testing.T) { }`, wantMessage: func() proto.Message { m := &pb2.FakeMessageSet{} - setExtension(m, pb2.E_FakeMessageSetExtension_MessageSetExtension, &pb2.FakeMessageSetExtension{ + proto.SetExtension(m, pb2.E_FakeMessageSetExtension_MessageSetExtension, &pb2.FakeMessageSetExtension{ OptString: proto.String("not a messageset extension"), }) return m @@ -1371,7 +1371,7 @@ func TestUnmarshal(t *testing.T) { }`, wantMessage: func() proto.Message { m := &pb2.MessageSet{} - setExtension(m, pb2.E_MessageSetExtension, &pb2.FakeMessageSetExtension{ + proto.SetExtension(m, pb2.E_MessageSetExtension, &pb2.FakeMessageSetExtension{ OptString: proto.String("another not a messageset extension"), }) return m @@ -2402,7 +2402,7 @@ func TestUnmarshal(t *testing.T) { }`, wantMessage: func() proto.Message { m := &pb2.Extensions{} - setExtension(m, pb2.E_OptExtNested, &pb2.Nested{}) + proto.SetExtension(m, pb2.E_OptExtNested, &pb2.Nested{}) return m }(), }, { diff --git a/encoding/protojson/encode_test.go b/encoding/protojson/encode_test.go index 2aa61dd6f..5cacf9ecd 100644 --- a/encoding/protojson/encode_test.go +++ b/encoding/protojson/encode_test.go @@ -16,7 +16,6 @@ import ( pimpl "google.golang.org/protobuf/internal/impl" "google.golang.org/protobuf/proto" preg "google.golang.org/protobuf/reflect/protoregistry" - "google.golang.org/protobuf/runtime/protoiface" "google.golang.org/protobuf/encoding/testprotos/pb2" "google.golang.org/protobuf/encoding/testprotos/pb3" @@ -29,11 +28,6 @@ import ( "google.golang.org/protobuf/types/known/wrapperspb" ) -// TODO: Replace this with proto.SetExtension. -func setExtension(m proto.Message, xd *protoiface.ExtensionDescV1, val interface{}) { - m.ProtoReflect().Set(xd.Type, xd.Type.ValueOf(val)) -} - func TestMarshal(t *testing.T) { tests := []struct { desc string @@ -886,10 +880,10 @@ func TestMarshal(t *testing.T) { OptBool: proto.Bool(true), OptInt32: proto.Int32(42), } - setExtension(m, pb2.E_OptExtBool, true) - setExtension(m, pb2.E_OptExtString, "extension field") - setExtension(m, pb2.E_OptExtEnum, pb2.Enum_TEN) - setExtension(m, pb2.E_OptExtNested, &pb2.Nested{ + proto.SetExtension(m, pb2.E_OptExtBool, true) + proto.SetExtension(m, pb2.E_OptExtString, "extension field") + proto.SetExtension(m, pb2.E_OptExtEnum, pb2.Enum_TEN) + proto.SetExtension(m, pb2.E_OptExtNested, &pb2.Nested{ OptString: proto.String("nested in an extension"), OptNested: &pb2.Nested{ OptString: proto.String("another nested in an extension"), @@ -915,9 +909,9 @@ func TestMarshal(t *testing.T) { desc: "extensions of repeated fields", input: func() proto.Message { m := &pb2.Extensions{} - setExtension(m, pb2.E_RptExtEnum, &[]pb2.Enum{pb2.Enum_TEN, 101, pb2.Enum_ONE}) - setExtension(m, pb2.E_RptExtFixed32, &[]uint32{42, 47}) - setExtension(m, pb2.E_RptExtNested, &[]*pb2.Nested{ + proto.SetExtension(m, pb2.E_RptExtEnum, &[]pb2.Enum{pb2.Enum_TEN, 101, pb2.Enum_ONE}) + proto.SetExtension(m, pb2.E_RptExtFixed32, &[]uint32{42, 47}) + proto.SetExtension(m, pb2.E_RptExtNested, &[]*pb2.Nested{ &pb2.Nested{OptString: proto.String("one")}, &pb2.Nested{OptString: proto.String("two")}, &pb2.Nested{OptString: proto.String("three")}, @@ -950,10 +944,10 @@ func TestMarshal(t *testing.T) { desc: "extensions of non-repeated fields in another message", input: func() proto.Message { m := &pb2.Extensions{} - setExtension(m, pb2.E_ExtensionsContainer_OptExtBool, true) - setExtension(m, pb2.E_ExtensionsContainer_OptExtString, "extension field") - setExtension(m, pb2.E_ExtensionsContainer_OptExtEnum, pb2.Enum_TEN) - setExtension(m, pb2.E_ExtensionsContainer_OptExtNested, &pb2.Nested{ + proto.SetExtension(m, pb2.E_ExtensionsContainer_OptExtBool, true) + proto.SetExtension(m, pb2.E_ExtensionsContainer_OptExtString, "extension field") + proto.SetExtension(m, pb2.E_ExtensionsContainer_OptExtEnum, pb2.Enum_TEN) + proto.SetExtension(m, pb2.E_ExtensionsContainer_OptExtNested, &pb2.Nested{ OptString: proto.String("nested in an extension"), OptNested: &pb2.Nested{ OptString: proto.String("another nested in an extension"), @@ -980,9 +974,9 @@ func TestMarshal(t *testing.T) { OptBool: proto.Bool(true), OptInt32: proto.Int32(42), } - setExtension(m, pb2.E_ExtensionsContainer_RptExtEnum, &[]pb2.Enum{pb2.Enum_TEN, 101, pb2.Enum_ONE}) - setExtension(m, pb2.E_ExtensionsContainer_RptExtString, &[]string{"hello", "world"}) - setExtension(m, pb2.E_ExtensionsContainer_RptExtNested, &[]*pb2.Nested{ + proto.SetExtension(m, pb2.E_ExtensionsContainer_RptExtEnum, &[]pb2.Enum{pb2.Enum_TEN, 101, pb2.Enum_ONE}) + proto.SetExtension(m, pb2.E_ExtensionsContainer_RptExtString, &[]string{"hello", "world"}) + proto.SetExtension(m, pb2.E_ExtensionsContainer_RptExtNested, &[]*pb2.Nested{ &pb2.Nested{OptString: proto.String("one")}, &pb2.Nested{OptString: proto.String("two")}, &pb2.Nested{OptString: proto.String("three")}, @@ -1018,13 +1012,13 @@ func TestMarshal(t *testing.T) { desc: "MessageSet", input: func() proto.Message { m := &pb2.MessageSet{} - setExtension(m, pb2.E_MessageSetExtension_MessageSetExtension, &pb2.MessageSetExtension{ + proto.SetExtension(m, pb2.E_MessageSetExtension_MessageSetExtension, &pb2.MessageSetExtension{ OptString: proto.String("a messageset extension"), }) - setExtension(m, pb2.E_MessageSetExtension_NotMessageSetExtension, &pb2.MessageSetExtension{ + proto.SetExtension(m, pb2.E_MessageSetExtension_NotMessageSetExtension, &pb2.MessageSetExtension{ OptString: proto.String("not a messageset extension"), }) - setExtension(m, pb2.E_MessageSetExtension_ExtNested, &pb2.Nested{ + proto.SetExtension(m, pb2.E_MessageSetExtension_ExtNested, &pb2.Nested{ OptString: proto.String("just a regular extension"), }) return m @@ -1045,7 +1039,7 @@ func TestMarshal(t *testing.T) { desc: "not real MessageSet 1", input: func() proto.Message { m := &pb2.FakeMessageSet{} - setExtension(m, pb2.E_FakeMessageSetExtension_MessageSetExtension, &pb2.FakeMessageSetExtension{ + proto.SetExtension(m, pb2.E_FakeMessageSetExtension_MessageSetExtension, &pb2.FakeMessageSetExtension{ OptString: proto.String("not a messageset extension"), }) return m @@ -1060,7 +1054,7 @@ func TestMarshal(t *testing.T) { desc: "not real MessageSet 2", input: func() proto.Message { m := &pb2.MessageSet{} - setExtension(m, pb2.E_MessageSetExtension, &pb2.FakeMessageSetExtension{ + proto.SetExtension(m, pb2.E_MessageSetExtension, &pb2.FakeMessageSetExtension{ OptString: proto.String("another not a messageset extension"), }) return m diff --git a/encoding/protojson/well_known_types.go b/encoding/protojson/well_known_types.go index e744536d0..77833d84c 100644 --- a/encoding/protojson/well_known_types.go +++ b/encoding/protojson/well_known_types.go @@ -189,7 +189,7 @@ func (o MarshalOptions) marshalAny(m pref.Message) error { // If type of value has custom JSON encoding, marshal out a field "value" // with corresponding custom JSON encoding of the embedded message as a // field. - if isCustomType(emt.FullName()) { + if isCustomType(emt.Descriptor().FullName()) { o.encoder.WriteName("value") return o.marshalCustomType(em) } @@ -235,7 +235,7 @@ func (o UnmarshalOptions) unmarshalAny(m pref.Message) error { // Create new message for the embedded message type and unmarshal into it. em := emt.New() - if isCustomType(emt.FullName()) { + if isCustomType(emt.Descriptor().FullName()) { // If embedded message is a custom type, unmarshal the JSON "value" field // into it. if err := o.unmarshalAnyValue(em); err != nil { diff --git a/encoding/prototext/decode.go b/encoding/prototext/decode.go index 4f384a0ac..06388c81e 100644 --- a/encoding/prototext/decode.go +++ b/encoding/prototext/decode.go @@ -126,7 +126,9 @@ func (o UnmarshalOptions) unmarshalMessage(tmsg [][2]text.Value, m pref.Message) if err != nil && err != protoregistry.NotFound { return errors.New("unable to resolve [%v]: %v", extName, err) } - fd = xt + if xt != nil { + fd = xt.Descriptor() + } } if fd == nil { diff --git a/encoding/prototext/decode_test.go b/encoding/prototext/decode_test.go index 20ce1333f..588d9ee79 100644 --- a/encoding/prototext/decode_test.go +++ b/encoding/prototext/decode_test.go @@ -1171,10 +1171,10 @@ opt_int32: 42 OptBool: proto.Bool(true), OptInt32: proto.Int32(42), } - setExtension(m, pb2.E_OptExtBool, true) - setExtension(m, pb2.E_OptExtString, "extension field") - setExtension(m, pb2.E_OptExtEnum, pb2.Enum_TEN) - setExtension(m, pb2.E_OptExtNested, &pb2.Nested{ + proto.SetExtension(m, pb2.E_OptExtBool, true) + proto.SetExtension(m, pb2.E_OptExtString, "extension field") + proto.SetExtension(m, pb2.E_OptExtEnum, pb2.Enum_TEN) + proto.SetExtension(m, pb2.E_OptExtNested, &pb2.Nested{ OptString: proto.String("nested in an extension"), OptNested: &pb2.Nested{ OptString: proto.String("another nested in an extension"), @@ -1207,9 +1207,9 @@ opt_int32: 42 `, wantMessage: func() proto.Message { m := &pb2.Extensions{} - setExtension(m, pb2.E_RptExtEnum, &[]pb2.Enum{pb2.Enum_TEN, 101, pb2.Enum_ONE}) - setExtension(m, pb2.E_RptExtFixed32, &[]uint32{42, 47}) - setExtension(m, pb2.E_RptExtNested, &[]*pb2.Nested{ + proto.SetExtension(m, pb2.E_RptExtEnum, &[]pb2.Enum{pb2.Enum_TEN, 101, pb2.Enum_ONE}) + proto.SetExtension(m, pb2.E_RptExtFixed32, &[]uint32{42, 47}) + proto.SetExtension(m, pb2.E_RptExtNested, &[]*pb2.Nested{ &pb2.Nested{OptString: proto.String("one")}, &pb2.Nested{OptString: proto.String("two")}, &pb2.Nested{OptString: proto.String("three")}, @@ -1231,10 +1231,10 @@ opt_int32: 42 `, wantMessage: func() proto.Message { m := &pb2.Extensions{} - setExtension(m, pb2.E_ExtensionsContainer_OptExtBool, true) - setExtension(m, pb2.E_ExtensionsContainer_OptExtString, "extension field") - setExtension(m, pb2.E_ExtensionsContainer_OptExtEnum, pb2.Enum_TEN) - setExtension(m, pb2.E_ExtensionsContainer_OptExtNested, &pb2.Nested{ + proto.SetExtension(m, pb2.E_ExtensionsContainer_OptExtBool, true) + proto.SetExtension(m, pb2.E_ExtensionsContainer_OptExtString, "extension field") + proto.SetExtension(m, pb2.E_ExtensionsContainer_OptExtEnum, pb2.Enum_TEN) + proto.SetExtension(m, pb2.E_ExtensionsContainer_OptExtNested, &pb2.Nested{ OptString: proto.String("nested in an extension"), OptNested: &pb2.Nested{ OptString: proto.String("another nested in an extension"), @@ -1269,9 +1269,9 @@ opt_int32: 42 OptBool: proto.Bool(true), OptInt32: proto.Int32(42), } - setExtension(m, pb2.E_ExtensionsContainer_RptExtEnum, &[]pb2.Enum{pb2.Enum_TEN, 101, pb2.Enum_ONE}) - setExtension(m, pb2.E_ExtensionsContainer_RptExtString, &[]string{"hello", "world"}) - setExtension(m, pb2.E_ExtensionsContainer_RptExtNested, &[]*pb2.Nested{ + proto.SetExtension(m, pb2.E_ExtensionsContainer_RptExtEnum, &[]pb2.Enum{pb2.Enum_TEN, 101, pb2.Enum_ONE}) + proto.SetExtension(m, pb2.E_ExtensionsContainer_RptExtString, &[]string{"hello", "world"}) + proto.SetExtension(m, pb2.E_ExtensionsContainer_RptExtNested, &[]*pb2.Nested{ &pb2.Nested{OptString: proto.String("one")}, &pb2.Nested{OptString: proto.String("two")}, &pb2.Nested{OptString: proto.String("three")}, @@ -1299,13 +1299,13 @@ opt_int32: 42 `, wantMessage: func() proto.Message { m := &pb2.MessageSet{} - setExtension(m, pb2.E_MessageSetExtension_MessageSetExtension, &pb2.MessageSetExtension{ + proto.SetExtension(m, pb2.E_MessageSetExtension_MessageSetExtension, &pb2.MessageSetExtension{ OptString: proto.String("a messageset extension"), }) - setExtension(m, pb2.E_MessageSetExtension_NotMessageSetExtension, &pb2.MessageSetExtension{ + proto.SetExtension(m, pb2.E_MessageSetExtension_NotMessageSetExtension, &pb2.MessageSetExtension{ OptString: proto.String("not a messageset extension"), }) - setExtension(m, pb2.E_MessageSetExtension_ExtNested, &pb2.Nested{ + proto.SetExtension(m, pb2.E_MessageSetExtension_ExtNested, &pb2.Nested{ OptString: proto.String("just a regular extension"), }) return m @@ -1321,7 +1321,7 @@ opt_int32: 42 `, wantMessage: func() proto.Message { m := &pb2.FakeMessageSet{} - setExtension(m, pb2.E_FakeMessageSetExtension_MessageSetExtension, &pb2.FakeMessageSetExtension{ + proto.SetExtension(m, pb2.E_FakeMessageSetExtension_MessageSetExtension, &pb2.FakeMessageSetExtension{ OptString: proto.String("not a messageset extension"), }) return m @@ -1346,7 +1346,7 @@ opt_int32: 42 }`, wantMessage: func() proto.Message { m := &pb2.MessageSet{} - setExtension(m, pb2.E_MessageSetExtension, &pb2.FakeMessageSetExtension{ + proto.SetExtension(m, pb2.E_MessageSetExtension, &pb2.FakeMessageSetExtension{ OptString: proto.String("another not a messageset extension"), }) return m diff --git a/encoding/prototext/encode_test.go b/encoding/prototext/encode_test.go index c29169b40..493a22961 100644 --- a/encoding/prototext/encode_test.go +++ b/encoding/prototext/encode_test.go @@ -16,7 +16,6 @@ import ( pimpl "google.golang.org/protobuf/internal/impl" "google.golang.org/protobuf/proto" preg "google.golang.org/protobuf/reflect/protoregistry" - "google.golang.org/protobuf/runtime/protoiface" "google.golang.org/protobuf/encoding/testprotos/pb2" "google.golang.org/protobuf/encoding/testprotos/pb3" @@ -28,11 +27,6 @@ func init() { detrand.Disable() } -// TODO: Use proto.SetExtension when available. -func setExtension(m proto.Message, xd *protoiface.ExtensionDescV1, val interface{}) { - m.ProtoReflect().Set(xd.Type, xd.Type.ValueOf(val)) -} - func TestMarshal(t *testing.T) { tests := []struct { desc string @@ -905,10 +899,10 @@ req_nested: {} OptBool: proto.Bool(true), OptInt32: proto.Int32(42), } - setExtension(m, pb2.E_OptExtBool, true) - setExtension(m, pb2.E_OptExtString, "extension field") - setExtension(m, pb2.E_OptExtEnum, pb2.Enum_TEN) - setExtension(m, pb2.E_OptExtNested, &pb2.Nested{ + proto.SetExtension(m, pb2.E_OptExtBool, true) + proto.SetExtension(m, pb2.E_OptExtString, "extension field") + proto.SetExtension(m, pb2.E_OptExtEnum, pb2.Enum_TEN) + proto.SetExtension(m, pb2.E_OptExtNested, &pb2.Nested{ OptString: proto.String("nested in an extension"), OptNested: &pb2.Nested{ OptString: proto.String("another nested in an extension"), @@ -933,7 +927,7 @@ opt_int32: 42 desc: "extension field contains invalid UTF-8", input: func() proto.Message { m := &pb2.Extensions{} - setExtension(m, pb2.E_OptExtString, "abc\xff") + proto.SetExtension(m, pb2.E_OptExtString, "abc\xff") return m }(), wantErr: true, @@ -941,10 +935,10 @@ opt_int32: 42 desc: "extension partial returns error", input: func() proto.Message { m := &pb2.Extensions{} - setExtension(m, pb2.E_OptExtPartial, &pb2.PartialRequired{ + proto.SetExtension(m, pb2.E_OptExtPartial, &pb2.PartialRequired{ OptString: proto.String("partial1"), }) - setExtension(m, pb2.E_ExtensionsContainer_OptExtPartial, &pb2.PartialRequired{ + proto.SetExtension(m, pb2.E_ExtensionsContainer_OptExtPartial, &pb2.PartialRequired{ OptString: proto.String("partial2"), }) return m @@ -962,7 +956,7 @@ opt_int32: 42 mo: prototext.MarshalOptions{AllowPartial: true}, input: func() proto.Message { m := &pb2.Extensions{} - setExtension(m, pb2.E_OptExtPartial, &pb2.PartialRequired{ + proto.SetExtension(m, pb2.E_OptExtPartial, &pb2.PartialRequired{ OptString: proto.String("partial1"), }) return m @@ -975,9 +969,9 @@ opt_int32: 42 desc: "extensions of repeated fields", input: func() proto.Message { m := &pb2.Extensions{} - setExtension(m, pb2.E_RptExtEnum, &[]pb2.Enum{pb2.Enum_TEN, 101, pb2.Enum_ONE}) - setExtension(m, pb2.E_RptExtFixed32, &[]uint32{42, 47}) - setExtension(m, pb2.E_RptExtNested, &[]*pb2.Nested{ + proto.SetExtension(m, pb2.E_RptExtEnum, &[]pb2.Enum{pb2.Enum_TEN, 101, pb2.Enum_ONE}) + proto.SetExtension(m, pb2.E_RptExtFixed32, &[]uint32{42, 47}) + proto.SetExtension(m, pb2.E_RptExtNested, &[]*pb2.Nested{ &pb2.Nested{OptString: proto.String("one")}, &pb2.Nested{OptString: proto.String("two")}, &pb2.Nested{OptString: proto.String("three")}, @@ -1003,10 +997,10 @@ opt_int32: 42 desc: "extensions of non-repeated fields in another message", input: func() proto.Message { m := &pb2.Extensions{} - setExtension(m, pb2.E_ExtensionsContainer_OptExtBool, true) - setExtension(m, pb2.E_ExtensionsContainer_OptExtString, "extension field") - setExtension(m, pb2.E_ExtensionsContainer_OptExtEnum, pb2.Enum_TEN) - setExtension(m, pb2.E_ExtensionsContainer_OptExtNested, &pb2.Nested{ + proto.SetExtension(m, pb2.E_ExtensionsContainer_OptExtBool, true) + proto.SetExtension(m, pb2.E_ExtensionsContainer_OptExtString, "extension field") + proto.SetExtension(m, pb2.E_ExtensionsContainer_OptExtEnum, pb2.Enum_TEN) + proto.SetExtension(m, pb2.E_ExtensionsContainer_OptExtNested, &pb2.Nested{ OptString: proto.String("nested in an extension"), OptNested: &pb2.Nested{ OptString: proto.String("another nested in an extension"), @@ -1032,9 +1026,9 @@ opt_int32: 42 OptBool: proto.Bool(true), OptInt32: proto.Int32(42), } - setExtension(m, pb2.E_ExtensionsContainer_RptExtEnum, &[]pb2.Enum{pb2.Enum_TEN, 101, pb2.Enum_ONE}) - setExtension(m, pb2.E_ExtensionsContainer_RptExtString, &[]string{"hello", "world"}) - setExtension(m, pb2.E_ExtensionsContainer_RptExtNested, &[]*pb2.Nested{ + proto.SetExtension(m, pb2.E_ExtensionsContainer_RptExtEnum, &[]pb2.Enum{pb2.Enum_TEN, 101, pb2.Enum_ONE}) + proto.SetExtension(m, pb2.E_ExtensionsContainer_RptExtString, &[]string{"hello", "world"}) + proto.SetExtension(m, pb2.E_ExtensionsContainer_RptExtNested, &[]*pb2.Nested{ &pb2.Nested{OptString: proto.String("one")}, &pb2.Nested{OptString: proto.String("two")}, &pb2.Nested{OptString: proto.String("three")}, @@ -1063,13 +1057,13 @@ opt_int32: 42 desc: "MessageSet", input: func() proto.Message { m := &pb2.MessageSet{} - setExtension(m, pb2.E_MessageSetExtension_MessageSetExtension, &pb2.MessageSetExtension{ + proto.SetExtension(m, pb2.E_MessageSetExtension_MessageSetExtension, &pb2.MessageSetExtension{ OptString: proto.String("a messageset extension"), }) - setExtension(m, pb2.E_MessageSetExtension_NotMessageSetExtension, &pb2.MessageSetExtension{ + proto.SetExtension(m, pb2.E_MessageSetExtension_NotMessageSetExtension, &pb2.MessageSetExtension{ OptString: proto.String("not a messageset extension"), }) - setExtension(m, pb2.E_MessageSetExtension_ExtNested, &pb2.Nested{ + proto.SetExtension(m, pb2.E_MessageSetExtension_ExtNested, &pb2.Nested{ OptString: proto.String("just a regular extension"), }) return m @@ -1089,7 +1083,7 @@ opt_int32: 42 desc: "not real MessageSet 1", input: func() proto.Message { m := &pb2.FakeMessageSet{} - setExtension(m, pb2.E_FakeMessageSetExtension_MessageSetExtension, &pb2.FakeMessageSetExtension{ + proto.SetExtension(m, pb2.E_FakeMessageSetExtension_MessageSetExtension, &pb2.FakeMessageSetExtension{ OptString: proto.String("not a messageset extension"), }) return m @@ -1103,7 +1097,7 @@ opt_int32: 42 desc: "not real MessageSet 2", input: func() proto.Message { m := &pb2.MessageSet{} - setExtension(m, pb2.E_MessageSetExtension, &pb2.FakeMessageSetExtension{ + proto.SetExtension(m, pb2.E_MessageSetExtension, &pb2.FakeMessageSetExtension{ OptString: proto.String("another not a messageset extension"), }) return m diff --git a/go.mod b/go.mod index d9c103435..c29b07470 100644 --- a/go.mod +++ b/go.mod @@ -3,6 +3,6 @@ module google.golang.org/protobuf go 1.9 require ( - github.com/golang/protobuf v1.2.1-0.20190717234224-b9f5089fb9d4 + github.com/golang/protobuf v1.2.1-0.20190806214225-7037721e6de0 github.com/google/go-cmp v0.3.0 ) diff --git a/go.sum b/go.sum index 274f450f9..de16d8151 100644 --- a/go.sum +++ b/go.sum @@ -6,8 +6,8 @@ github.com/golang/protobuf v1.2.1-0.20190523175523-a1331f0b4ab4/go.mod h1:G+fNMo github.com/golang/protobuf v1.2.1-0.20190605195750-76c9e09470ba/go.mod h1:S1YIJXvYHGRCG2UmZsOcElkAYfvZLg2sDRr9+Xu8JXU= github.com/golang/protobuf v1.2.1-0.20190617175902-f94016f5239f/go.mod h1:G+HpKX7pYZAVkElkAWZkr08MToW6pTp/vs+E9osFfbg= github.com/golang/protobuf v1.2.1-0.20190620192300-1ee46dfd80dd/go.mod h1:+CMAsi9jpYf/wAltLUKlg++CWXqxCJyD8iLDbQONsJs= -github.com/golang/protobuf v1.2.1-0.20190717234224-b9f5089fb9d4 h1:Hj8cGYPgLw3MR0AGL0GFObh4pq8i31QOWWMCE0KY9z4= -github.com/golang/protobuf v1.2.1-0.20190717234224-b9f5089fb9d4/go.mod h1:tDQPRlaHYu9yt1wPgdx85inRiLvUCuJZXsYjC0mwc1c= +github.com/golang/protobuf v1.2.1-0.20190806214225-7037721e6de0 h1:a3hJDGxxWRbPxfOMiV6aG8pb0I+8RdgICRdXjXjiKzY= +github.com/golang/protobuf v1.2.1-0.20190806214225-7037721e6de0/go.mod h1:tDQPRlaHYu9yt1wPgdx85inRiLvUCuJZXsYjC0mwc1c= github.com/google/go-cmp v0.3.0 h1:crn/baboCvb5fXaQ0IJ1SGTsTVrWpDsCWC8EGETZijY= github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= google.golang.org/protobuf v0.0.0-20190514172829-e89e6244e0e8/go.mod h1:791zQGC15vDqjpmPRn1uGPu5oHy/Jzw/Q1n5JsgIIcY= diff --git a/internal/encoding/messageset/messageset.go b/internal/encoding/messageset/messageset.go index 7d05dee39..b1c6db59c 100644 --- a/internal/encoding/messageset/messageset.go +++ b/internal/encoding/messageset/messageset.go @@ -65,7 +65,7 @@ func FindMessageSetExtension(r preg.ExtensionTypeResolver, s pref.FullName) (pre if err != nil { return nil, err } - if !IsMessageSetExtension(xt) { + if !IsMessageSetExtension(xt.Descriptor()) { return nil, preg.NotFound } return xt, nil diff --git a/internal/filetype/build.go b/internal/filetype/build.go index f30f98c36..15c8f278e 100644 --- a/internal/filetype/build.go +++ b/internal/filetype/build.go @@ -7,11 +7,9 @@ package filetype import ( - "fmt" "reflect" "sync" - "google.golang.org/protobuf/internal/descfmt" "google.golang.org/protobuf/internal/descopts" fdesc "google.golang.org/protobuf/internal/filedesc" pimpl "google.golang.org/protobuf/internal/impl" @@ -358,8 +356,7 @@ func (t *Extension) GoType() reflect.Type { t.lazyInit() return t.goType } -func (t *Extension) Descriptor() pref.ExtensionDescriptor { return t.ExtensionDescriptor } -func (t *Extension) Format(s fmt.State, r rune) { descfmt.FormatDesc(s, r, t) } +func (t *Extension) Descriptor() pref.ExtensionTypeDescriptor { return (*extDesc)(t) } // ProtoLegacyExtensionDesc is a pseudo-internal API for allowing the v1 code // to be able to retrieve a v1 ExtensionDesc. @@ -379,3 +376,8 @@ func (t *Extension) lazyInit() pimpl.Converter { }) return t.conv } + +type extDesc Extension + +func (t *extDesc) Type() pref.ExtensionType { return (*Extension)(t) } +func (t *extDesc) Descriptor() pref.ExtensionDescriptor { return t.ExtensionDescriptor } diff --git a/internal/impl/codec_extension.go b/internal/impl/codec_extension.go index 3a23bb0de..7e430ee29 100644 --- a/internal/impl/codec_extension.go +++ b/internal/impl/codec_extension.go @@ -29,25 +29,26 @@ func (mi *MessageInfo) extensionFieldInfo(xt pref.ExtensionType) *extensionField return e } + xd := xt.Descriptor() var wiretag uint64 - if !xt.IsPacked() { - wiretag = wire.EncodeTag(xt.Number(), wireTypes[xt.Kind()]) + if !xd.IsPacked() { + wiretag = wire.EncodeTag(xd.Number(), wireTypes[xd.Kind()]) } else { - wiretag = wire.EncodeTag(xt.Number(), wire.BytesType) + wiretag = wire.EncodeTag(xd.Number(), wire.BytesType) } e = &extensionFieldInfo{ wiretag: wiretag, tagsize: wire.SizeVarint(wiretag), - funcs: encoderFuncsForValue(xt, xt.GoType()), + funcs: encoderFuncsForValue(xd, xt.GoType()), } // Does the unmarshal function need a value passed to it? // This is true for composite types, where we pass in a message, list, or map to fill in, // and for enums, where we pass in a prototype value to specify the concrete enum type. - switch xt.Kind() { + switch xd.Kind() { case pref.MessageKind, pref.GroupKind, pref.EnumKind: e.unmarshalNeedsValue = true default: - if xt.Cardinality() == pref.Repeated { + if xd.Cardinality() == pref.Repeated { e.unmarshalNeedsValue = true } } diff --git a/internal/impl/codec_message.go b/internal/impl/codec_message.go index e43c81293..70148617e 100644 --- a/internal/impl/codec_message.go +++ b/internal/impl/codec_message.go @@ -44,8 +44,9 @@ func (mi *MessageInfo) makeMethods(t reflect.Type, si structInfo) { mi.extensionOffset = si.extensionOffset mi.coderFields = make(map[wire.Number]*coderFieldInfo) - for i := 0; i < mi.PBType.Fields().Len(); i++ { - fd := mi.PBType.Fields().Get(i) + fields := mi.PBType.Descriptor().Fields() + for i := 0; i < fields.Len(); i++ { + fd := fields.Get(i) fs := si.fieldsByNumber[fd.Number()] if fd.ContainingOneof() != nil { @@ -81,7 +82,7 @@ func (mi *MessageInfo) makeMethods(t reflect.Type, si structInfo) { } if messageset.IsMessageSet(mi.PBType.Descriptor()) { if !mi.extensionOffset.IsValid() { - panic(fmt.Sprintf("%v: MessageSet with no extensions field", mi.PBType.FullName())) + panic(fmt.Sprintf("%v: MessageSet with no extensions field", mi.PBType.Descriptor().FullName())) } cf := &coderFieldInfo{ num: messageset.FieldItem, @@ -113,7 +114,7 @@ func (mi *MessageInfo) makeMethods(t reflect.Type, si structInfo) { mi.denseCoderFields[cf.num] = cf } - mi.needsInitCheck = needsInitCheck(mi.PBType) + mi.needsInitCheck = needsInitCheck(mi.PBType.Descriptor()) mi.methods = piface.Methods{ Flags: piface.SupportMarshalDeterministic | piface.SupportUnmarshalDiscardUnknown, MarshalAppend: mi.marshalAppend, diff --git a/internal/impl/decode.go b/internal/impl/decode.go index 4e4c7f337..48218529c 100644 --- a/internal/impl/decode.go +++ b/internal/impl/decode.go @@ -138,7 +138,7 @@ func (mi *MessageInfo) unmarshalExtension(b []byte, num wire.Number, wtyp wire.T xt := x.GetType() if xt == nil { var err error - xt, err = opts.Resolver().FindExtensionByNumber(mi.PBType.FullName(), num) + xt, err = opts.Resolver().FindExtensionByNumber(mi.PBType.Descriptor().FullName(), num) if err != nil { if err == preg.NotFound { return 0, errUnknown diff --git a/internal/impl/isinit.go b/internal/impl/isinit.go index ca000122e..079afe075 100644 --- a/internal/impl/isinit.go +++ b/internal/impl/isinit.go @@ -29,7 +29,7 @@ func (mi *MessageInfo) isInitializedPointer(p pointer) error { if p.IsNil() { for _, f := range mi.orderedCoderFields { if f.isRequired { - return errors.RequiredNotSet(string(mi.PBType.Fields().ByNumber(f.num).FullName())) + return errors.RequiredNotSet(string(mi.PBType.Descriptor().Fields().ByNumber(f.num).FullName())) } } return nil @@ -47,7 +47,7 @@ func (mi *MessageInfo) isInitializedPointer(p pointer) error { fptr := p.Apply(f.offset) if f.isPointer && fptr.Elem().IsNil() { if f.isRequired { - return errors.RequiredNotSet(string(mi.PBType.Fields().ByNumber(f.num).FullName())) + return errors.RequiredNotSet(string(mi.PBType.Descriptor().Fields().ByNumber(f.num).FullName())) } continue } diff --git a/internal/impl/legacy_extension.go b/internal/impl/legacy_extension.go index aaf8fcce0..2da4d71b5 100644 --- a/internal/impl/legacy_extension.go +++ b/internal/impl/legacy_extension.go @@ -5,11 +5,9 @@ package impl import ( - "fmt" "reflect" "sync" - "google.golang.org/protobuf/internal/descfmt" ptag "google.golang.org/protobuf/internal/encoding/tag" "google.golang.org/protobuf/internal/filedesc" pref "google.golang.org/protobuf/reflect/protoreflect" @@ -62,8 +60,9 @@ func legacyExtensionDescFromType(xt pref.ExtensionType) *piface.ExtensionDescV1 } // Determine the parent type if possible. + xd := xt.Descriptor() var parent piface.MessageV1 - messageName := xt.ContainingMessage().FullName() + messageName := xd.ContainingMessage().FullName() if mt, _ := preg.GlobalTypes.FindMessageByName(messageName); mt != nil { // Create a new parent message and unwrap it if possible. mv := mt.New().Interface() @@ -94,7 +93,7 @@ func legacyExtensionDescFromType(xt pref.ExtensionType) *piface.ExtensionDescV1 // Reconstruct the legacy enum full name, which is an odd mixture of the // proto package name with the Go type name. var enumName string - if xt.Kind() == pref.EnumKind { + if xd.Kind() == pref.EnumKind { // Derive Go type name. t := extType if t.Kind() == reflect.Ptr || t.Kind() == reflect.Slice { @@ -105,7 +104,7 @@ func legacyExtensionDescFromType(xt pref.ExtensionType) *piface.ExtensionDescV1 // Derive the proto package name. // For legacy enums, obtain the proto package from the raw descriptor. var protoPkg string - if fd := xt.Enum().ParentFile(); fd != nil { + if fd := xd.Enum().ParentFile(); fd != nil { protoPkg = string(fd.Package()) } if ed, ok := reflect.Zero(t).Interface().(enumV1); ok && protoPkg == "" { @@ -120,7 +119,7 @@ func legacyExtensionDescFromType(xt pref.ExtensionType) *piface.ExtensionDescV1 // Derive the proto file that the extension was declared within. var filename string - if fd := xt.ParentFile(); fd != nil { + if fd := xd.ParentFile(); fd != nil { filename = fd.Path() } @@ -129,9 +128,9 @@ func legacyExtensionDescFromType(xt pref.ExtensionType) *piface.ExtensionDescV1 Type: xt, ExtendedType: parent, ExtensionType: reflect.Zero(extType).Interface(), - Field: int32(xt.Number()), - Name: string(xt.FullName()), - Tag: ptag.Marshal(xt, enumName), + Field: int32(xd.Number()), + Name: string(xd.FullName()), + Tag: ptag.Marshal(xd, enumName), Filename: filename, } if d, ok := legacyExtensionDescCache.LoadOrStore(xt, d); ok { @@ -217,15 +216,16 @@ func legacyExtensionTypeFromDesc(d *piface.ExtensionDescV1) pref.ExtensionType { // // This is exported for testing purposes. func LegacyExtensionTypeOf(xd pref.ExtensionDescriptor, t reflect.Type) pref.ExtensionType { - return &legacyExtensionType{ - ExtensionDescriptor: xd, - typ: t, - conv: NewConverter(t, xd), + xt := &legacyExtensionType{ + typ: t, + conv: NewConverter(t, xd), } + xt.desc = &extDesc{xd, xt} + return xt } type legacyExtensionType struct { - pref.ExtensionDescriptor + desc pref.ExtensionTypeDescriptor typ reflect.Type conv Converter } @@ -239,5 +239,12 @@ func (x *legacyExtensionType) ValueOf(v interface{}) pref.Value { func (x *legacyExtensionType) InterfaceOf(v pref.Value) interface{} { return x.conv.GoValueOf(v).Interface() } -func (x *legacyExtensionType) Descriptor() pref.ExtensionDescriptor { return x.ExtensionDescriptor } -func (x *legacyExtensionType) Format(s fmt.State, r rune) { descfmt.FormatDesc(s, r, x) } +func (x *legacyExtensionType) Descriptor() pref.ExtensionTypeDescriptor { return x.desc } + +type extDesc struct { + pref.ExtensionDescriptor + xt *legacyExtensionType +} + +func (t *extDesc) Type() pref.ExtensionType { return t.xt } +func (t *extDesc) Descriptor() pref.ExtensionDescriptor { return t.ExtensionDescriptor } diff --git a/internal/impl/legacy_test.go b/internal/impl/legacy_test.go index 70c56038c..89cd0bc7a 100644 --- a/internal/impl/legacy_test.go +++ b/internal/impl/legacy_test.go @@ -52,7 +52,7 @@ var legacyFD = func() []byte { func init() { mt := pimpl.Export{}.MessageTypeOf((*LegacyTestMessage)(nil)) - preg.GlobalFiles.Register(mt.ParentFile()) + preg.GlobalFiles.Register(mt.Descriptor().ParentFile()) preg.GlobalTypes.Register(mt) } @@ -357,19 +357,21 @@ func TestLegacyExtensions(t *testing.T) { } for i, xt := range extensionTypes { var got interface{} - if !(xt.IsList() || xt.IsMap() || xt.Message() != nil) { - got = xt.InterfaceOf(m.Get(xt)) + xd := xt.Descriptor() + if !(xd.IsList() || xd.IsMap() || xd.Message() != nil) { + got = xt.InterfaceOf(m.Get(xd)) } want := defaultValues[i] if diff := cmp.Diff(want, got, opts); diff != "" { - t.Errorf("Message.Get(%d) mismatch (-want +got):\n%v", xt.Number(), diff) + t.Errorf("Message.Get(%d) mismatch (-want +got):\n%v", xd.Number(), diff) } } // All fields should be unpopulated. for _, xt := range extensionTypes { - if m.Has(xt) { - t.Errorf("Message.Has(%d) = true, want false", xt.Number()) + xd := xt.Descriptor() + if m.Has(xd) { + t.Errorf("Message.Has(%d) = true, want false", xd.Number()) } } @@ -401,11 +403,11 @@ func TestLegacyExtensions(t *testing.T) { 19: &[]*EnumMessages{m2b}, } for i, xt := range extensionTypes { - m.Set(xt, xt.ValueOf(setValues[i])) + m.Set(xt.Descriptor(), xt.ValueOf(setValues[i])) } for i, xt := range extensionTypes[len(extensionTypes)/2:] { v := extensionTypes[i].ValueOf(setValues[i]) - m.Get(xt).List().Append(v) + m.Get(xt.Descriptor()).List().Append(v) } // Get the values and check for equality. @@ -432,24 +434,25 @@ func TestLegacyExtensions(t *testing.T) { 19: &[]*EnumMessages{m2b, m2a}, } for i, xt := range extensionTypes { - got := xt.InterfaceOf(m.Get(xt)) + xd := xt.Descriptor() + got := xt.InterfaceOf(m.Get(xd)) want := getValues[i] if diff := cmp.Diff(want, got, opts); diff != "" { - t.Errorf("Message.Get(%d) mismatch (-want +got):\n%v", xt.Number(), diff) + t.Errorf("Message.Get(%d) mismatch (-want +got):\n%v", xd.Number(), diff) } } // Clear all singular fields and truncate all repeated fields. for _, xt := range extensionTypes[:len(extensionTypes)/2] { - m.Clear(xt) + m.Clear(xt.Descriptor()) } for _, xt := range extensionTypes[len(extensionTypes)/2:] { - m.Get(xt).List().Truncate(0) + m.Get(xt.Descriptor()).List().Truncate(0) } // Clear all repeated fields. for _, xt := range extensionTypes[len(extensionTypes)/2:] { - m.Clear(xt) + m.Clear(xt.Descriptor()) } } @@ -491,8 +494,6 @@ func TestExtensionConvert(t *testing.T) { switch name { case "ParentFile", "Parent": // Ignore parents to avoid recursive cycle. - case "New", "Zero": - // Ignore constructors. case "Options": // Ignore descriptor options since protos are not cmperable. case "ContainingOneof", "ContainingMessage", "Enum", "Message": @@ -504,6 +505,8 @@ func TestExtensionConvert(t *testing.T) { if !v.IsNil() { out[name] = v.Interface().(pref.Descriptor).FullName() } + case "Type": + // Ignore ExtensionTypeDescriptor.Type method to avoid cycle. default: out[name] = m.Call(nil)[0].Interface() } @@ -511,6 +514,12 @@ func TestExtensionConvert(t *testing.T) { } return out }), + cmp.Transformer("", func(xt pref.ExtensionType) map[string]interface{} { + return map[string]interface{}{ + "Descriptor": xt.Descriptor(), + "GoType": xt.GoType(), + } + }), cmp.Transformer("", func(v pref.Value) interface{} { return v.Interface() }), @@ -605,23 +614,23 @@ func TestConcurrentInit(t *testing.T) { var ( wantMTA = messageATypes[0] - wantMDA = messageATypes[0].Fields().ByNumber(1).Message() + wantMDA = messageATypes[0].Descriptor().Fields().ByNumber(1).Message() wantMTB = messageBTypes[0] - wantMDB = messageBTypes[0].Fields().ByNumber(2).Message() - wantED = messageATypes[0].Fields().ByNumber(3).Enum() + wantMDB = messageBTypes[0].Descriptor().Fields().ByNumber(2).Message() + wantED = messageATypes[0].Descriptor().Fields().ByNumber(3).Enum() ) for _, gotMT := range messageATypes[1:] { if gotMT != wantMTA { t.Error("MessageType(MessageA) mismatch") } - if gotMDA := gotMT.Fields().ByNumber(1).Message(); gotMDA != wantMDA { + if gotMDA := gotMT.Descriptor().Fields().ByNumber(1).Message(); gotMDA != wantMDA { t.Error("MessageDescriptor(MessageA) mismatch") } - if gotMDB := gotMT.Fields().ByNumber(2).Message(); gotMDB != wantMDB { + if gotMDB := gotMT.Descriptor().Fields().ByNumber(2).Message(); gotMDB != wantMDB { t.Error("MessageDescriptor(MessageB) mismatch") } - if gotED := gotMT.Fields().ByNumber(3).Enum(); gotED != wantED { + if gotED := gotMT.Descriptor().Fields().ByNumber(3).Enum(); gotED != wantED { t.Error("EnumDescriptor(Enum) mismatch") } } @@ -629,13 +638,13 @@ func TestConcurrentInit(t *testing.T) { if gotMT != wantMTB { t.Error("MessageType(MessageB) mismatch") } - if gotMDA := gotMT.Fields().ByNumber(1).Message(); gotMDA != wantMDA { + if gotMDA := gotMT.Descriptor().Fields().ByNumber(1).Message(); gotMDA != wantMDA { t.Error("MessageDescriptor(MessageA) mismatch") } - if gotMDB := gotMT.Fields().ByNumber(2).Message(); gotMDB != wantMDB { + if gotMDB := gotMT.Descriptor().Fields().ByNumber(2).Message(); gotMDB != wantMDB { t.Error("MessageDescriptor(MessageB) mismatch") } - if gotED := gotMT.Fields().ByNumber(3).Enum(); gotED != wantED { + if gotED := gotMT.Descriptor().Fields().ByNumber(3).Enum(); gotED != wantED { t.Error("EnumDescriptor(Enum) mismatch") } } diff --git a/internal/impl/message.go b/internal/impl/message.go index 305e17df5..6100663ec 100644 --- a/internal/impl/message.go +++ b/internal/impl/message.go @@ -222,8 +222,9 @@ fieldLoop: // any discrepancies. func (mi *MessageInfo) makeKnownFieldsFunc(si structInfo) { mi.fields = map[pref.FieldNumber]*fieldInfo{} - for i := 0; i < mi.PBType.Fields().Len(); i++ { - fd := mi.PBType.Fields().Get(i) + md := mi.PBType.Descriptor() + for i := 0; i < md.Fields().Len(); i++ { + fd := md.Fields().Get(i) fs := si.fieldsByNumber[fd.Number()] var fi fieldInfo switch { @@ -244,8 +245,8 @@ func (mi *MessageInfo) makeKnownFieldsFunc(si structInfo) { } mi.oneofs = map[pref.Name]*oneofInfo{} - for i := 0; i < mi.PBType.Oneofs().Len(); i++ { - od := mi.PBType.Oneofs().Get(i) + for i := 0; i < md.Oneofs().Len(); i++ { + od := md.Oneofs().Get(i) mi.oneofs[od.Name()] = makeOneofInfo(od, si.oneofsByName[od.Name()], mi.Exporter, si.oneofWrappersByType) } } diff --git a/internal/impl/message_reflect.go b/internal/impl/message_reflect.go index 699ac2c65..b0f1778b9 100644 --- a/internal/impl/message_reflect.go +++ b/internal/impl/message_reflect.go @@ -121,7 +121,7 @@ func (m *extensionMap) Range(f func(pref.FieldDescriptor, pref.Value) bool) { if m != nil { for _, x := range *m { xt := x.GetType() - if !f(xt, xt.ValueOf(x.GetValue())) { + if !f(xt.Descriptor(), xt.ValueOf(x.GetValue())) { return } } @@ -129,16 +129,17 @@ func (m *extensionMap) Range(f func(pref.FieldDescriptor, pref.Value) bool) { } func (m *extensionMap) Has(xt pref.ExtensionType) (ok bool) { if m != nil { - _, ok = (*m)[int32(xt.Number())] + _, ok = (*m)[int32(xt.Descriptor().Number())] } return ok } func (m *extensionMap) Clear(xt pref.ExtensionType) { - delete(*m, int32(xt.Number())) + delete(*m, int32(xt.Descriptor().Number())) } func (m *extensionMap) Get(xt pref.ExtensionType) pref.Value { + xd := xt.Descriptor() if m != nil { - if x, ok := (*m)[int32(xt.Number())]; ok { + if x, ok := (*m)[int32(xd.Number())]; ok { return xt.ValueOf(x.GetValue()) } } @@ -151,13 +152,14 @@ func (m *extensionMap) Set(xt pref.ExtensionType, v pref.Value) { var x ExtensionField x.SetType(xt) x.SetEagerValue(xt.InterfaceOf(v)) - (*m)[int32(xt.Number())] = x + (*m)[int32(xt.Descriptor().Number())] = x } func (m *extensionMap) Mutable(xt pref.ExtensionType) pref.Value { - if !isComposite(xt) { + xd := xt.Descriptor() + if !isComposite(xd) { panic("invalid Mutable on field with non-composite type") } - if x, ok := (*m)[int32(xt.Number())]; ok { + if x, ok := (*m)[int32(xd.Number())]; ok { return xt.ValueOf(x.GetValue()) } v := xt.New() @@ -179,14 +181,18 @@ func (mi *MessageInfo) checkField(fd pref.FieldDescriptor) (*fieldInfo, pref.Ext return fi, nil } if fd.IsExtension() { - if fd.ContainingMessage().FullName() != mi.PBType.FullName() { + if fd.ContainingMessage().FullName() != mi.PBType.Descriptor().FullName() { // TODO: Should this be exact containing message descriptor match? panic("mismatching containing message") } - if !mi.PBType.ExtensionRanges().Has(fd.Number()) { + if !mi.PBType.Descriptor().ExtensionRanges().Has(fd.Number()) { panic("invalid extension field") } - return nil, fd.(pref.ExtensionType) + xtd, ok := fd.(pref.ExtensionTypeDescriptor) + if !ok { + panic("extension descriptor does not implement ExtensionTypeDescriptor") + } + return nil, xtd.Type() } panic("invalid field descriptor") } diff --git a/proto/decode.go b/proto/decode.go index f147e68f6..e39424370 100644 --- a/proto/decode.go +++ b/proto/decode.go @@ -88,7 +88,9 @@ func (o UnmarshalOptions) unmarshalMessageSlow(b []byte, m protoreflect.Message) if err != nil && err != protoregistry.NotFound { return err } - fd = extType + if extType != nil { + fd = extType.Descriptor() + } } var err error var valLen int diff --git a/proto/decode_test.go b/proto/decode_test.go index 6088eb558..89fc30320 100644 --- a/proto/decode_test.go +++ b/proto/decode_test.go @@ -1680,10 +1680,8 @@ func extend(desc *protoiface.ExtensionDescV1, value interface{}) buildOpt { v.Elem().Set(reflect.ValueOf(value)) value = v.Interface() } - return func(m proto.Message) { - xt := desc.Type - m.ProtoReflect().Set(xt, xt.ValueOf(value)) + proto.SetExtension(m, desc, value) } } diff --git a/proto/extension.go b/proto/extension.go new file mode 100644 index 000000000..2e1c78f02 --- /dev/null +++ b/proto/extension.go @@ -0,0 +1,33 @@ +// Copyright 2019 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style. +// license that can be found in the LICENSE file. + +package proto + +import ( + "google.golang.org/protobuf/reflect/protoreflect" +) + +// HasExtension reports whether an extension field is populated. +func HasExtension(m Message, ext protoreflect.ExtensionType) bool { + return m.ProtoReflect().Has(ext.Descriptor()) +} + +// ClearExtension clears an extension field such that subsequent +// HasExtension calls return false. +func ClearExtension(m Message, ext protoreflect.ExtensionType) { + m.ProtoReflect().Clear(ext.Descriptor()) +} + +// GetExtension retrieves the value for an extension field. +// +// If the field is unpopulated, it returns the default value for +// scalars and an immutable, empty value for lists, maps, or messages. +func GetExtension(m Message, ext protoreflect.ExtensionType) interface{} { + return ext.InterfaceOf(m.ProtoReflect().Get(ext.Descriptor())) +} + +// SetExtension stores the value of an extension field. +func SetExtension(m Message, ext protoreflect.ExtensionType, value interface{}) { + m.ProtoReflect().Set(ext.Descriptor(), ext.ValueOf(value)) +} diff --git a/proto/extension_test.go b/proto/extension_test.go new file mode 100644 index 000000000..ce3a14266 --- /dev/null +++ b/proto/extension_test.go @@ -0,0 +1,70 @@ +// Copyright 2019 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style. +// license that can be found in the LICENSE file. + +package proto_test + +import ( + "fmt" + "testing" + + "github.com/google/go-cmp/cmp" + "google.golang.org/protobuf/proto" + pref "google.golang.org/protobuf/reflect/protoreflect" + "google.golang.org/protobuf/runtime/protoimpl" + + legacy1pb "google.golang.org/protobuf/internal/testprotos/legacy/proto2.v0.0.0-20160225-2fc053c5" + testpb "google.golang.org/protobuf/internal/testprotos/test" +) + +func TestExtensionFuncs(t *testing.T) { + for _, test := range []struct { + message proto.Message + ext pref.ExtensionType + wantDefault interface{} + value interface{} + }{ + { + message: &testpb.TestAllExtensions{}, + ext: testpb.E_OptionalInt32Extension, + wantDefault: int32(0), + value: int32(1), + }, + { + message: &testpb.TestAllExtensions{}, + ext: testpb.E_RepeatedStringExtension, + // TODO: Represent repeated extension fields as []T. + // https://github.com/golang/protobuf/issues/901 + wantDefault: (*[]string)(nil), + value: &[]string{"a", "b", "c"}, + }, + { + message: protoimpl.X.MessageOf(&legacy1pb.Message{}).Interface(), + ext: legacy1pb.E_Message_ExtensionOptionalBool, + wantDefault: false, + value: true, + }, + } { + desc := fmt.Sprintf("Extension %v, value %v", test.ext.Descriptor().FullName(), test.value) + if proto.HasExtension(test.message, test.ext) { + t.Errorf("%v:\nbefore setting extension HasExtension(...) = true, want false", desc) + } + got := proto.GetExtension(test.message, test.ext) + if d := cmp.Diff(test.wantDefault, got); d != "" { + t.Errorf("%v:\nbefore setting extension GetExtension(...) returns unexpected value (-want,+got):\n%v", desc, d) + } + proto.SetExtension(test.message, test.ext, test.value) + if !proto.HasExtension(test.message, test.ext) { + t.Errorf("%v:\nafter setting extension HasExtension(...) = false, want true", desc) + } + got = proto.GetExtension(test.message, test.ext) + if d := cmp.Diff(test.value, got); d != "" { + t.Errorf("%v:\nafter setting extension GetExtension(...) returns unexpected value (-want,+got):\n%v", desc, d) + } + proto.ClearExtension(test.message, test.ext) + if proto.HasExtension(test.message, test.ext) { + t.Errorf("%v:\nafter clearing extension HasExtension(...) = true, want false", desc) + } + + } +} diff --git a/proto/merge_test.go b/proto/merge_test.go index 4fa161ef2..a0ec57127 100644 --- a/proto/merge_test.go +++ b/proto/merge_test.go @@ -274,65 +274,41 @@ func TestMerge(t *testing.T) { desc: "merge extension fields", dst: func() proto.Message { m := new(testpb.TestAllExtensions) - m.ProtoReflect().Set( - testpb.E_OptionalInt32Extension.Type, - testpb.E_OptionalInt32Extension.Type.ValueOf(int32(32)), - ) - m.ProtoReflect().Set( - testpb.E_OptionalNestedMessageExtension.Type, - testpb.E_OptionalNestedMessageExtension.Type.ValueOf(&testpb.TestAllTypes_NestedMessage{ + proto.SetExtension(m, testpb.E_OptionalInt32Extension.Type, int32(32)) + proto.SetExtension(m, testpb.E_OptionalNestedMessageExtension.Type, + &testpb.TestAllTypes_NestedMessage{ A: proto.Int32(50), - }), - ) - m.ProtoReflect().Set( - testpb.E_RepeatedFixed32Extension.Type, - testpb.E_RepeatedFixed32Extension.Type.ValueOf(&[]uint32{1, 2, 3}), + }, ) + proto.SetExtension(m, testpb.E_RepeatedFixed32Extension.Type, &[]uint32{1, 2, 3}) return m }(), src: func() proto.Message { m := new(testpb.TestAllExtensions) - m.ProtoReflect().Set( - testpb.E_OptionalInt64Extension.Type, - testpb.E_OptionalInt64Extension.Type.ValueOf(int64(64)), - ) - m.ProtoReflect().Set( - testpb.E_OptionalNestedMessageExtension.Type, - testpb.E_OptionalNestedMessageExtension.Type.ValueOf(&testpb.TestAllTypes_NestedMessage{ + proto.SetExtension(m, testpb.E_OptionalInt64Extension.Type, int64(64)) + proto.SetExtension(m, testpb.E_OptionalNestedMessageExtension.Type, + &testpb.TestAllTypes_NestedMessage{ Corecursive: &testpb.TestAllTypes{ OptionalInt64: proto.Int64(1000), }, - }), - ) - m.ProtoReflect().Set( - testpb.E_RepeatedFixed32Extension.Type, - testpb.E_RepeatedFixed32Extension.Type.ValueOf(&[]uint32{4, 5, 6}), + }, ) + proto.SetExtension(m, testpb.E_RepeatedFixed32Extension.Type, &[]uint32{4, 5, 6}) return m }(), want: func() proto.Message { m := new(testpb.TestAllExtensions) - m.ProtoReflect().Set( - testpb.E_OptionalInt32Extension.Type, - testpb.E_OptionalInt32Extension.Type.ValueOf(int32(32)), - ) - m.ProtoReflect().Set( - testpb.E_OptionalInt64Extension.Type, - testpb.E_OptionalInt64Extension.Type.ValueOf(int64(64)), - ) - m.ProtoReflect().Set( - testpb.E_OptionalNestedMessageExtension.Type, - testpb.E_OptionalNestedMessageExtension.Type.ValueOf(&testpb.TestAllTypes_NestedMessage{ + proto.SetExtension(m, testpb.E_OptionalInt32Extension.Type, int32(32)) + proto.SetExtension(m, testpb.E_OptionalInt64Extension.Type, int64(64)) + proto.SetExtension(m, testpb.E_OptionalNestedMessageExtension.Type, + &testpb.TestAllTypes_NestedMessage{ A: proto.Int32(50), Corecursive: &testpb.TestAllTypes{ OptionalInt64: proto.Int64(1000), }, - }), - ) - m.ProtoReflect().Set( - testpb.E_RepeatedFixed32Extension.Type, - testpb.E_RepeatedFixed32Extension.Type.ValueOf(&[]uint32{1, 2, 3, 4, 5, 6}), + }, ) + proto.SetExtension(m, testpb.E_RepeatedFixed32Extension.Type, &[]uint32{1, 2, 3, 4, 5, 6}) return m }(), }, { diff --git a/proto/messageset.go b/proto/messageset.go index 1c6ac299b..e5d4bd5d1 100644 --- a/proto/messageset.go +++ b/proto/messageset.go @@ -71,14 +71,15 @@ func unmarshalMessageSet(b []byte, m protoreflect.Message, o UnmarshalOptions) e if !md.ExtensionRanges().Has(num) { return errUnknown } - fd, err := o.Resolver.FindExtensionByNumber(md.FullName(), num) + xt, err := o.Resolver.FindExtensionByNumber(md.FullName(), num) if err == protoregistry.NotFound { return errUnknown } if err != nil { return err } - if err := o.unmarshalMessage(v, m.Mutable(fd).Message()); err != nil { + xd := xt.Descriptor() + if err := o.unmarshalMessage(v, m.Mutable(xd).Message()); err != nil { // Contents cannot be unmarshaled. return err } diff --git a/reflect/protoreflect/type.go b/reflect/protoreflect/type.go index afd4cffce..92b5750a3 100644 --- a/reflect/protoreflect/type.go +++ b/reflect/protoreflect/type.go @@ -229,8 +229,6 @@ type isMessageDescriptor interface{ ProtoType(MessageDescriptor) } // MessageType encapsulates a MessageDescriptor with a concrete Go implementation. type MessageType interface { - MessageDescriptor - // New returns a newly allocated empty message. New() Message @@ -401,6 +399,18 @@ type OneofDescriptors interface { // ExtensionDescriptor is an alias of FieldDescriptor for documentation. type ExtensionDescriptor = FieldDescriptor +// ExtensionTypeDescriptor is an ExtensionDescriptor with an associated ExtensionType. +type ExtensionTypeDescriptor interface { + ExtensionDescriptor + + // Type returns the associated ExtensionType. + Type() ExtensionType + + // Descriptor returns the plain ExtensionDescriptor without the + // associated ExtensionType. + Descriptor() ExtensionDescriptor +} + // ExtensionDescriptors is a list of field declarations. type ExtensionDescriptors interface { // Len reports the number of fields. @@ -436,8 +446,6 @@ type ExtensionDescriptors interface { // Field "bar_field" is an extension of FooMessage, but its full name is // "example.BarMessage.bar_field" instead of "example.FooMessage.bar_field". type ExtensionType interface { - ExtensionDescriptor - // New returns a new value for the field. // For scalars, this returns the default value in native Go form. New() Value @@ -454,7 +462,7 @@ type ExtensionType interface { GoType() reflect.Type // Descriptor returns the extension descriptor. - Descriptor() ExtensionDescriptor + Descriptor() ExtensionTypeDescriptor // TODO: What to do with nil? // Should ValueOf(nil) return Value{}? @@ -500,8 +508,6 @@ type isEnumDescriptor interface{ ProtoType(EnumDescriptor) } // EnumType encapsulates an EnumDescriptor with a concrete Go implementation. type EnumType interface { - EnumDescriptor - // New returns an instance of this enum type with its value set to n. New(n EnumNumber) Enum diff --git a/reflect/protoreflect/value.go b/reflect/protoreflect/value.go index 3c9229bc2..ec1009948 100644 --- a/reflect/protoreflect/value.go +++ b/reflect/protoreflect/value.go @@ -30,7 +30,7 @@ type Enum interface { // Accessor/mutators for individual fields are keyed by FieldDescriptor. // For non-extension fields, the descriptor must exactly match the // field known by the parent message. -// For extension fields, the descriptor must implement ExtensionType, +// For extension fields, the descriptor must implement ExtensionTypeDescriptor, // extend the parent message (i.e., have the same message FullName), and // be within the parent's extension range. // diff --git a/reflect/protoregistry/registry.go b/reflect/protoregistry/registry.go index 0b60c64f6..22b5d8c84 100644 --- a/reflect/protoregistry/registry.go +++ b/reflect/protoregistry/registry.go @@ -317,7 +317,6 @@ func rangeTopLevelDescriptors(fd protoreflect.FileDescriptor, f func(protoreflec // Type is an interface satisfied by protoreflect.EnumType, // protoreflect.MessageType, or protoreflect.ExtensionType. type Type interface { - protoreflect.Descriptor GoType() reflect.Type } @@ -428,21 +427,22 @@ typeLoop: switch typ.(type) { case protoreflect.EnumType, protoreflect.MessageType, protoreflect.ExtensionType: // Check for conflicts in typesByName. - var name protoreflect.FullName + var desc protoreflect.Descriptor switch t := typ.(type) { case protoreflect.EnumType: - name = t.FullName() + desc = t.Descriptor() case protoreflect.MessageType: - name = t.FullName() + desc = t.Descriptor() case protoreflect.ExtensionType: - name = t.FullName() + desc = t.Descriptor() default: panic(fmt.Sprintf("invalid type: %T", t)) } + name := desc.FullName() if prev := r.typesByName[name]; prev != nil { err := errors.New("%v %v is already registered", typeName(typ), name) err = amendErrorWithCaller(err, prev, typ) - if r == GlobalTypes && ignoreConflict(typ, err) { + if r == GlobalTypes && ignoreConflict(desc, err) { err = nil } if firstErr == nil { @@ -453,12 +453,13 @@ typeLoop: // Check for conflicts in extensionsByMessage. if xt, _ := typ.(protoreflect.ExtensionType); xt != nil { - field := xt.Number() - message := xt.ContainingMessage().FullName() + xd := xt.Descriptor() + field := xd.Number() + message := xd.ContainingMessage().FullName() if prev := r.extensionsByMessage[message][field]; prev != nil { err := errors.New("extension number %d is already registered on message %v", field, message) err = amendErrorWithCaller(err, prev, typ) - if r == GlobalTypes && ignoreConflict(typ, err) { + if r == GlobalTypes && ignoreConflict(xd, err) { err = nil } if firstErr == nil { diff --git a/reflect/protoregistry/registry_test.go b/reflect/protoregistry/registry_test.go index 63dd1baab..1572b6310 100644 --- a/reflect/protoregistry/registry_test.go +++ b/reflect/protoregistry/registry_test.go @@ -536,11 +536,11 @@ func TestTypes(t *testing.T) { fullName := func(t preg.Type) pref.FullName { switch t := t.(type) { case pref.EnumType: - return t.FullName() + return t.Descriptor().FullName() case pref.MessageType: - return t.FullName() + return t.Descriptor().FullName() case pref.ExtensionType: - return t.FullName() + return t.Descriptor().FullName() default: panic("invalid type") } diff --git a/runtime/protoiface/legacy.go b/runtime/protoiface/legacy.go index 4f8d71f12..d7acb337e 100644 --- a/runtime/protoiface/legacy.go +++ b/runtime/protoiface/legacy.go @@ -5,6 +5,8 @@ package protoiface import ( + "reflect" + "google.golang.org/protobuf/reflect/protoreflect" ) @@ -64,3 +66,31 @@ type ExtensionDescV1 struct { // protoreflect.FileDescriptor.Path. Filename string } + +func (e ExtensionDescV1) getType() protoreflect.ExtensionType { + if e.Type != nil { + return e.Type + } + // All ExtensionDescV1 instances in generated code should have + // an Type field initialized at init time, so this case only + // occurs for non-standard generated code and hand-written + // ExtensionDescs. + panic(`proto: ExtensionDesc.Type is not set. + +This error probably indicates that you are trying to use a non-standard +"github.com/golang/protobuf/proto".ExtensionDesc with the +"google.golang.org/golang/protobuf" API. Use a protoreflect.ExtensionType +instead. +`) +} + +func (e ExtensionDescV1) New() protoreflect.Value { return e.getType().New() } +func (e ExtensionDescV1) Zero() protoreflect.Value { return e.getType().Zero() } +func (e ExtensionDescV1) GoType() reflect.Type { return e.getType().GoType() } +func (e ExtensionDescV1) Descriptor() protoreflect.ExtensionTypeDescriptor { + return e.getType().Descriptor() +} +func (e ExtensionDescV1) ValueOf(x interface{}) protoreflect.Value { return e.getType().ValueOf(x) } +func (e ExtensionDescV1) InterfaceOf(x protoreflect.Value) interface{} { + return e.getType().InterfaceOf(x) +} diff --git a/testing/prototest/prototest.go b/testing/prototest/prototest.go index fded0ff69..9ff58f990 100644 --- a/testing/prototest/prototest.go +++ b/testing/prototest/prototest.go @@ -47,7 +47,7 @@ func TestMessage(t testing.TB, m proto.Message, opts MessageOptions) { }) } for _, xt := range opts.ExtensionTypes { - testField(t, m1, xt) + testField(t, m1, xt.Descriptor()) } for i := 0; i < md.Oneofs().Len(); i++ { testOneof(t, m1, md.Oneofs().Get(i)) @@ -57,12 +57,12 @@ func TestMessage(t testing.TB, m proto.Message, opts MessageOptions) { // Test round-trip marshal/unmarshal. m2 := m.ProtoReflect().New().Interface() populateMessage(m2.ProtoReflect(), 1, nil) - b, err := proto.Marshal(m2) + b, err := (proto.MarshalOptions{AllowPartial: true}).Marshal(m2) if err != nil { t.Errorf("Marshal() = %v, want nil\n%v", err, marshalText(m2)) } m3 := m.ProtoReflect().New().Interface() - if err := proto.Unmarshal(b, m3); err != nil { + if err := (proto.UnmarshalOptions{AllowPartial: true}).Unmarshal(b, m3); err != nil { t.Errorf("Unmarshal() = %v, want nil\n%v", err, marshalText(m2)) } if !proto.Equal(m2, m3) { diff --git a/testing/prototest/prototest_test.go b/testing/prototest/prototest_test.go index a95ac0b7e..00df3a61c 100644 --- a/testing/prototest/prototest_test.go +++ b/testing/prototest/prototest_test.go @@ -10,9 +10,12 @@ import ( "google.golang.org/protobuf/internal/flags" "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/runtime/protoimpl" "google.golang.org/protobuf/testing/prototest" irregularpb "google.golang.org/protobuf/internal/testprotos/irregular" + legacypb "google.golang.org/protobuf/internal/testprotos/legacy" + legacy1pb "google.golang.org/protobuf/internal/testprotos/legacy/proto2.v0.0.0-20160225-2fc053c5" testpb "google.golang.org/protobuf/internal/testprotos/test" _ "google.golang.org/protobuf/internal/testprotos/test/weak1" _ "google.golang.org/protobuf/internal/testprotos/test/weak2" @@ -26,6 +29,8 @@ func Test(t *testing.T) { (*testpb.TestRequired)(nil), (*irregularpb.Message)(nil), (*testpb.TestAllExtensions)(nil), + (*legacypb.Legacy)(nil), + protoimpl.X.MessageOf((*legacy1pb.Message)(nil)).Interface(), } if flags.Proto1Legacy { ms = append(ms, (*testpb.TestWeak)(nil)) diff --git a/types/dynamicpb/dynamic.go b/types/dynamicpb/dynamic.go index 8794167c3..fa9155b2a 100644 --- a/types/dynamicpb/dynamic.go +++ b/types/dynamicpb/dynamic.go @@ -170,7 +170,7 @@ func (m *Message) Set(fd pref.FieldDescriptor, v pref.Value) { switch { case fd.IsExtension(): // Call InterfaceOf just to let the extension typecheck the value. - _ = fd.(pref.ExtensionType).InterfaceOf(v) + _ = fd.(pref.ExtensionTypeDescriptor).Type().InterfaceOf(v) m.ext[fd.Number()] = fd case fd.IsMap(): if mapv, ok := v.Interface().(*dynamicMap); !ok || mapv.desc != fd { @@ -217,7 +217,7 @@ func (m *Message) NewField(fd pref.FieldDescriptor) pref.Value { m.checkField(fd) switch { case fd.IsExtension(): - return fd.(pref.ExtensionType).New() + return fd.(pref.ExtensionTypeDescriptor).Type().New() case fd.IsMap(): return pref.ValueOf(&dynamicMap{ desc: fd, @@ -258,8 +258,8 @@ func (m *Message) SetUnknown(r pref.RawFields) { func (m *Message) checkField(fd pref.FieldDescriptor) { if fd.IsExtension() && fd.ContainingMessage().FullName() == m.Descriptor().FullName() { - if _, ok := fd.(pref.ExtensionType); !ok { - panic(errors.New("%v: extension field descriptor does not implement ExtensionType", fd.FullName())) + if _, ok := fd.(pref.ExtensionTypeDescriptor); !ok { + panic(errors.New("%v: extension field descriptor does not implement ExtensionTypeDescriptor", fd.FullName())) } return }