Skip to content
This repository has been archived by the owner on Mar 23, 2023. It is now read-only.

Commit

Permalink
Enable passing Python functions to Go for invocation.
Browse files Browse the repository at this point in the history
  • Loading branch information
nairb774 committed Jan 24, 2017
1 parent c0a76bc commit 627e648
Show file tree
Hide file tree
Showing 2 changed files with 103 additions and 5 deletions.
5 changes: 5 additions & 0 deletions runtime/function.go
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,10 @@ func functionGet(_ *Frame, desc, instance *Object, owner *Type) (*Object, *BaseE
return NewMethod(toFunctionUnsafe(desc), instance, owner).ToObject(), nil
}

func functionNative(f *Frame, o *Object) (reflect.Value, *BaseException) {
return reflect.ValueOf(o.Call), nil
}

func functionRepr(_ *Frame, o *Object) (*Object, *BaseException) {
fun := toFunctionUnsafe(o)
return NewStr(fmt.Sprintf("<%s %s at %p>", fun.typ.Name(), fun.Name(), fun)).ToObject(), nil
Expand All @@ -134,6 +138,7 @@ func initFunctionType(map[string]*Object) {
FunctionType.flags &= ^(typeFlagInstantiable | typeFlagBasetype)
FunctionType.slots.Call = &callSlot{functionCall}
FunctionType.slots.Get = &getSlot{functionGet}
FunctionType.slots.Native = &nativeSlot{functionNative}
FunctionType.slots.Repr = &unaryOpSlot{functionRepr}
}

Expand Down
103 changes: 98 additions & 5 deletions runtime/native.go
Original file line number Diff line number Diff line change
Expand Up @@ -479,22 +479,115 @@ func maybeConvertValue(f *Frame, o *Object, expectedRType reflect.Type) (reflect
if raised != nil {
return reflect.Value{}, raised
}
rtype := val.Type()
for {
rtype := val.Type()
if rtype == expectedRType {
return val, nil
}
if rtype.ConvertibleTo(expectedRType) {
return val.Convert(expectedRType), nil
}
if rtype.Kind() == reflect.Ptr {
switch rtype.Kind() {
case reflect.Ptr:
val = val.Elem()
rtype = val.Type()
continue

case reflect.Func:
if fn, ok := val.Interface().(func(*Frame, Args, KWArgs) (*Object, *BaseException)); ok {
val = nativeToPyFuncBridge(fn, expectedRType)
continue
}
}
break
return val, f.RaiseType(TypeErrorType, fmt.Sprintf("cannot convert %s to %s", rtype, expectedRType))
}
}

var baseExceptionReflectType = reflect.TypeOf((*BaseException)(nil))

// pyToNativeRaised supports pushing a `raised` exception from python code to
// native calling code. If the raised exception can't be returned to native
// code, then the raised exception is panic-ed.
func pyToNativeRaised(outs []reflect.Type, raised *BaseException) []reflect.Value {
last := len(outs) - 1
if len(outs) == 0 || outs[last] != baseExceptionReflectType {
panic(raised)
}
ret := make([]reflect.Value, len(outs))
for i, out := range outs[:last] {
ret[i] = reflect.Zero(out)
}
return reflect.Value{}, f.RaiseType(TypeErrorType, fmt.Sprintf("cannot convert %s to %s", rtype, expectedRType))
ret[last] = reflect.ValueOf(raised)
return ret
}

var frameReflectType = reflect.TypeOf((*Frame)(nil))

func nativeToPyFuncBridge(fn func(*Frame, Args, KWArgs) (*Object, *BaseException), target reflect.Type) reflect.Value {
firstInIsFrame := target.NumIn() > 0 && target.In(0) == frameReflectType

outs := make([]reflect.Type, target.NumOut())
for i := range outs {
outs[i] = target.Out(i)
}

return reflect.MakeFunc(target, func(args []reflect.Value) []reflect.Value {
var f *Frame
if firstInIsFrame {
f, args = args[0].Interface().(*Frame), args[1:]
} else {
f = NewRootFrame()
}

pyArgs := f.MakeArgs(len(args))
for i, arg := range args {
var raised *BaseException
pyArgs[i], raised = WrapNative(f, arg)
if raised != nil {
return pyToNativeRaised(outs, raised)
}
}

ret, raised := fn(f, pyArgs, nil)
f.FreeArgs(pyArgs)
if raised != nil {
return pyToNativeRaised(outs, raised)
}

switch len(outs) {
case 0:
if ret != nil && ret != None {
return pyToNativeRaised(outs, f.RaiseType(TypeErrorType, fmt.Sprintf("unexpected return of %v when None expected", ret)))
}
return nil

case 1:
v, raised := maybeConvertValue(f, ret, outs[0])
if raised != nil {
return pyToNativeRaised(outs, raised)
}
return []reflect.Value{v}

default:
converted := make([]reflect.Value, 0, len(outs))
if raised := seqForEach(f, ret, func(o *Object) *BaseException {
i := len(converted)
if i >= len(outs) {
return f.RaiseType(TypeErrorType, fmt.Sprintf("return value too long, want %d items", len(outs)))
}
v, raised := maybeConvertValue(f, o, outs[i])
converted = append(converted, v)
return raised
}); raised != nil {
return pyToNativeRaised(outs, raised)
}

if len(converted) != len(outs) {
return pyToNativeRaised(outs, f.RaiseType(TypeErrorType, fmt.Sprintf("return value wrong size %d, want %d", len(converted), len(outs))))
}

return converted
}
})
}

func nativeFuncTypeName(rtype reflect.Type) string {
Expand Down

0 comments on commit 627e648

Please sign in to comment.