From 671fa2aab90e8f63575716bfb1da47d95489707a Mon Sep 17 00:00:00 2001 From: Torin Sandall Date: Wed, 22 Sep 2021 14:33:04 -0700 Subject: [PATCH] types: Add initial support for varargs This commit adds initial support for varargs on function declarations. Varargs are only supported on void functions for now. In the future, varargs can be supported on any function, but care will have to be take in the compiler to handle calls that make use of the legacy calling convention for functions where the output is captured in the argument list (e.g., plus(1,2,3)). Signed-off-by: Torin Sandall --- types/decode.go | 28 ++++++++++---- types/types.go | 93 ++++++++++++++++++++++++++++++++++++--------- types/types_test.go | 24 ++++++++++++ 3 files changed, 119 insertions(+), 26 deletions(-) diff --git a/types/decode.go b/types/decode.go index c05fc56bef..4e123384a0 100644 --- a/types/decode.go +++ b/types/decode.go @@ -87,12 +87,25 @@ func Unmarshal(bs []byte) (result Type, err error) { case typeFunction: var decl rawdecl if err = util.UnmarshalJSON(bs, &decl); err == nil { - var args []Type - if args, err = unmarshalSlice(decl.Args); err == nil { - var ret Type - if ret, err = Unmarshal(decl.Result); err == nil { - result = NewFunction(args, ret) + args, err := unmarshalSlice(decl.Args) + if err != nil { + return nil, err + } + var ret Type + if len(decl.Result) > 0 { + ret, err = Unmarshal(decl.Result) + if err != nil { + return nil, err + } + } + if len(decl.Variadic) > 0 { + varargs, err := Unmarshal(decl.Variadic) + if err != nil { + return nil, err } + result = NewVariadicFunction(args, varargs, ret) + } else { + result = NewFunction(args, ret) } } default: @@ -136,8 +149,9 @@ type rawunion struct { } type rawdecl struct { - Args []json.RawMessage `json:"args"` - Result json.RawMessage `json:"result"` + Args []json.RawMessage `json:"args"` + Result json.RawMessage `json:"result"` + Variadic json.RawMessage `json:"variadic"` } func unmarshalSlice(elems []json.RawMessage) (result []Type, err error) { diff --git a/types/types.go b/types/types.go index cf102040d1..fdfc1f879c 100644 --- a/types/types.go +++ b/types/types.go @@ -457,8 +457,9 @@ func (t Any) String() string { // Function represents a function type. type Function struct { - args []Type - result Type + args []Type + result Type + variadic Type } // Args returns an argument list. @@ -475,9 +476,31 @@ func NewFunction(args []Type, result Type) *Function { } } -// Args returns the function's argument types. +// NewVariadicFunction returns a new Function object. This function sets the +// variadic bit on the signature. Non-void variadic functions are not currently +// supported. +func NewVariadicFunction(args []Type, varargs Type, result Type) *Function { + if result != nil { + panic("illegal value: non-void variadic functions not supported") + } + return &Function{ + args: args, + variadic: varargs, + result: nil, + } +} + +// FuncArgs returns the function's arguments. +func (t *Function) FuncArgs() FuncArgs { + return FuncArgs{Args: t.Args(), Variadic: t.variadic} +} + +// Args returns the function's arguments as a slice, ignoring variadic arguments. +// Deprecated: Use FuncArgs instead. func (t *Function) Args() []Type { - return t.args + cpy := make([]Type, len(t.args)) + copy(cpy, t.args) + return cpy } // Result returns the function's result type. @@ -486,19 +509,7 @@ func (t *Function) Result() Type { } func (t *Function) String() string { - var args string - if len(t.args) != 1 { - args = "(" - } - buf := []string{} - for _, a := range t.Args() { - buf = append(buf, Sprint(a)) - } - args += strings.Join(buf, ", ") - if len(t.args) != 1 { - args += ")" - } - return fmt.Sprintf("%v => %v", args, Sprint(t.Result())) + return fmt.Sprintf("%v => %v", t.FuncArgs(), Sprint(t.Result())) } // MarshalJSON returns the JSON encoding of t. @@ -512,6 +523,9 @@ func (t *Function) MarshalJSON() ([]byte, error) { if t.result != nil { repr["result"] = t.result } + if t.variadic != nil { + repr["variadic"] = t.variadic + } return json.Marshal(repr) } @@ -540,17 +554,55 @@ func (t *Function) Union(other *Function) *Function { } else if t == nil { return other } + a := t.Args() b := other.Args() if len(a) != len(b) { return nil } + + aIsVariadic := t.FuncArgs().Variadic != nil + bIsVariadic := other.FuncArgs().Variadic != nil + + if aIsVariadic && !bIsVariadic { + return nil + } else if bIsVariadic && !aIsVariadic { + return nil + } + args := make([]Type, len(a)) for i := range a { args[i] = Or(a[i], b[i]) } - return NewFunction(args, Or(t.Result(), other.Result())) + result := NewFunction(args, Or(t.Result(), other.Result())) + result.variadic = Or(t.FuncArgs().Variadic, other.FuncArgs().Variadic) + + return result +} + +// FuncArgs represents the arguments that can be passed to a function. +type FuncArgs struct { + Args []Type `json:"args,omitempty"` + Variadic Type `json:"variadic,omitempty"` +} + +func (a FuncArgs) String() string { + var buf []string + for i := range a.Args { + buf = append(buf, Sprint(a.Args[i])) + } + if a.Variadic != nil { + buf = append(buf, Sprint(a.Variadic)+"...") + } + return "(" + strings.Join(buf, ", ") + ")" +} + +func (a FuncArgs) Arg(x int) Type { + if x < len(a.Args) { + return a.Args[x] + } + return a.Variadic } // Compare returns -1, 0, 1 based on comparison between a and b. @@ -648,7 +700,10 @@ func Compare(a, b Type) int { return cmp } } - return Compare(fA.result, fB.result) + if cmp := Compare(fA.result, fB.result); cmp != 0 { + return cmp + } + return Compare(fA.variadic, fB.variadic) default: panic("unreachable") } diff --git a/types/types_test.go b/types/types_test.go index e2201a34fe..9d0f4827e2 100644 --- a/types/types_test.go +++ b/types/types_test.go @@ -81,6 +81,13 @@ func TestStrings(t *testing.T) { if ftpe.String() != expected { t.Fatalf("Expected %v but got: %v", expected, ftpe) } + + ftpe = NewVariadicFunction([]Type{N}, S, nil) + expected = "(number, string...) => ???" + + if ftpe.String() != expected { + t.Fatal("expected", expected, "but got:", ftpe) + } } func TestCompare(t *testing.T) { @@ -466,3 +473,20 @@ func TestRoundtripJSON(t *testing.T) { t.Fatalf("Got: %v\n\nExpected: %v", result, tpe) } } + +func TestRoundtripJSONVariadicFunction(t *testing.T) { + tpe := NewVariadicFunction([]Type{S}, N, nil) + bs, err := json.Marshal(tpe) + if err != nil { + t.Fatal(err) + } + + result, err := Unmarshal(bs) + if err != nil { + t.Fatal(err) + } + + if Compare(result, tpe) != 0 { + t.Fatalf("Got: %v\n\nExpected: %v", result, tpe) + } +}