Skip to content

Commit

Permalink
generate plugin clients via template
Browse files Browse the repository at this point in the history
Signed-off-by: Brian Goff <[email protected]>
  • Loading branch information
cpuguy83 committed Jun 10, 2015
1 parent 55bdb51 commit 4c81c9d
Show file tree
Hide file tree
Showing 5 changed files with 553 additions and 0 deletions.
35 changes: 35 additions & 0 deletions pkg/plugins/pluginrpc-gen/fixtures/foo.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
package foo

type wobble struct {
Some string
Val string
Inception *wobble
}

type Fooer interface{}

type Fooer2 interface {
Foo()
}

type Fooer3 interface {
Foo()
Bar(a string)
Baz(a string) (err error)
Qux(a, b string) (val string, err error)
Wobble() (w *wobble)
Wiggle() (w wobble)
}

type Fooer4 interface {
Foo() error
}

type Bar interface {
Boo(a string, b string) (s string, err error)
}

type Fooer5 interface {
Foo()
Bar
}
91 changes: 91 additions & 0 deletions pkg/plugins/pluginrpc-gen/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
package main

import (
"bytes"
"flag"
"fmt"
"go/format"
"io/ioutil"
"os"
"unicode"
"unicode/utf8"
)

type stringSet struct {
values map[string]struct{}
}

func (s stringSet) String() string {
return ""
}

func (s stringSet) Set(value string) error {
s.values[value] = struct{}{}
return nil
}
func (s stringSet) GetValues() map[string]struct{} {
return s.values
}

var (
typeName = flag.String("type", "", "interface type to generate plugin rpc proxy for")
rpcName = flag.String("name", *typeName, "RPC name, set if different from type")
inputFile = flag.String("i", "", "input file path")
outputFile = flag.String("o", *inputFile+"_proxy.go", "output file path")

skipFuncs map[string]struct{}
flSkipFuncs = stringSet{make(map[string]struct{})}

flBuildTags = stringSet{make(map[string]struct{})}
)

func errorOut(msg string, err error) {
if err == nil {
return
}
fmt.Fprintf(os.Stderr, "%s: %v\n", msg, err)
os.Exit(1)
}

func checkFlags() error {
if *outputFile == "" {
return fmt.Errorf("missing required flag `-o`")
}
if *inputFile == "" {
return fmt.Errorf("missing required flag `-i`")
}
return nil
}

func main() {
flag.Var(flSkipFuncs, "skip", "skip parsing for function")
flag.Var(flBuildTags, "tag", "build tags to add to generated files")
flag.Parse()
skipFuncs = flSkipFuncs.GetValues()

errorOut("error", checkFlags())

pkg, err := Parse(*inputFile, *typeName)
errorOut(fmt.Sprintf("error parsing requested type %s", *typeName), err)

var analysis = struct {
InterfaceType string
RPCName string
BuildTags map[string]struct{}
*parsedPkg
}{toLower(*typeName), *rpcName, flBuildTags.GetValues(), pkg}
var buf bytes.Buffer

errorOut("parser error", generatedTempl.Execute(&buf, analysis))
src, err := format.Source(buf.Bytes())
errorOut("error formating generated source", err)
errorOut("error writing file", ioutil.WriteFile(*outputFile, src, 0644))
}

func toLower(s string) string {
if s == "" {
return ""
}
r, n := utf8.DecodeRuneInString(s)
return string(unicode.ToLower(r)) + s[n:]
}
162 changes: 162 additions & 0 deletions pkg/plugins/pluginrpc-gen/parser.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
package main

import (
"errors"
"fmt"
"go/ast"
"go/parser"
"go/token"
"reflect"
"strings"
)

var ErrBadReturn = errors.New("found return arg with no name: all args must be named")

type ErrUnexpectedType struct {
expected string
actual interface{}
}

func (e ErrUnexpectedType) Error() string {
return fmt.Sprintf("got wrong type expecting %s, got: %v", e.expected, reflect.TypeOf(e.actual))
}

type parsedPkg struct {
Name string
Functions []function
}

type function struct {
Name string
Args []arg
Returns []arg
Doc string
}

type arg struct {
Name string
ArgType string
}

func (a *arg) String() string {
return strings.ToLower(a.Name) + " " + strings.ToLower(a.ArgType)
}

// Parses the given file for an interface definition with the given name
func Parse(filePath string, objName string) (*parsedPkg, error) {
fs := token.NewFileSet()
pkg, err := parser.ParseFile(fs, filePath, nil, parser.AllErrors)
if err != nil {
return nil, err
}
p := &parsedPkg{}
p.Name = pkg.Name.Name
obj, exists := pkg.Scope.Objects[objName]
if !exists {
return nil, fmt.Errorf("could not find object %s in %s", objName, filePath)
}
if obj.Kind != ast.Typ {
return nil, fmt.Errorf("exected type, got %s", obj.Kind)
}
spec, ok := obj.Decl.(*ast.TypeSpec)
if !ok {
return nil, ErrUnexpectedType{"*ast.TypeSpec", obj.Decl}
}
iface, ok := spec.Type.(*ast.InterfaceType)
if !ok {
return nil, ErrUnexpectedType{"*ast.InterfaceType", spec.Type}
}

p.Functions, err = parseInterface(iface)
if err != nil {
return nil, err
}

return p, nil
}

func parseInterface(iface *ast.InterfaceType) ([]function, error) {
var functions []function
for _, field := range iface.Methods.List {
switch f := field.Type.(type) {
case *ast.FuncType:
method, err := parseFunc(field)
if err != nil {
return nil, err
}
if method == nil {
continue
}
functions = append(functions, *method)
case *ast.Ident:
spec, ok := f.Obj.Decl.(*ast.TypeSpec)
if !ok {
return nil, ErrUnexpectedType{"*ast.TypeSpec", f.Obj.Decl}
}
iface, ok := spec.Type.(*ast.InterfaceType)
if !ok {
return nil, ErrUnexpectedType{"*ast.TypeSpec", spec.Type}
}
funcs, err := parseInterface(iface)
if err != nil {
fmt.Println(err)
continue
}
functions = append(functions, funcs...)
default:
return nil, ErrUnexpectedType{"*astFuncType or *ast.Ident", f}
}
}
return functions, nil
}

func parseFunc(field *ast.Field) (*function, error) {
f := field.Type.(*ast.FuncType)
method := &function{Name: field.Names[0].Name}
if _, exists := skipFuncs[method.Name]; exists {
fmt.Println("skipping:", method.Name)
return nil, nil
}
if f.Params != nil {
args, err := parseArgs(f.Params.List)
if err != nil {
return nil, err
}
method.Args = args
}
if f.Results != nil {
returns, err := parseArgs(f.Results.List)
if err != nil {
return nil, fmt.Errorf("error parsing function returns for %q: %v", method.Name, err)
}
method.Returns = returns
}
return method, nil
}

func parseArgs(fields []*ast.Field) ([]arg, error) {
var args []arg
for _, f := range fields {
if len(f.Names) == 0 {
return nil, ErrBadReturn
}
for _, name := range f.Names {
var typeName string
switch argType := f.Type.(type) {
case *ast.Ident:
typeName = argType.Name
case *ast.StarExpr:
i, ok := argType.X.(*ast.Ident)
if !ok {
return nil, ErrUnexpectedType{"*ast.Ident", f.Type}
}
typeName = "*" + i.Name
default:
return nil, ErrUnexpectedType{"*ast.Ident or *ast.StarExpr", f.Type}
}

args = append(args, arg{name.Name, typeName})
}
}
return args, nil
}
Loading

0 comments on commit 4c81c9d

Please sign in to comment.