diff --git a/README.md b/README.md index 639e7fa..2ccc8c6 100644 --- a/README.md +++ b/README.md @@ -30,7 +30,10 @@ type GreetResponse struct { Use the `oto` tool to generate a client and server: ```bash -oto -template ./otohttp/templates/server.go.plush -out ./api/oto.gen.go ./api/definitions +oto -template ./otohttp/templates/server.go.plush \ + -out ./api/oto.gen.go \ + -ignore Ignorer \ + ./api/definitions gofmt -w ./api/oto.gen.go ./api/oto.gen.go ``` diff --git a/main.go b/main.go index b275d78..b815f37 100644 --- a/main.go +++ b/main.go @@ -29,11 +29,12 @@ flags:`) flags.PrintDefaults() } var ( - template = flags.String("template", "", "plush template to render") - outfile = flags.String("out", "", "output file (default: stdout)") - pkg = flags.String("pkg", "", "explicit package name (default: inferred)") - v = flags.Bool("v", false, "verbose output") - paramsStr = flags.String("params", "", "list of parameters in the format: \"key:value,key:value\"") + template = flags.String("template", "", "plush template to render") + outfile = flags.String("out", "", "output file (default: stdout)") + pkg = flags.String("pkg", "", "explicit package name (default: inferred)") + v = flags.Bool("v", false, "verbose output") + paramsStr = flags.String("params", "", "list of parameters in the format: \"key:value,key:value\"") + ignoreList = flags.String("ignore", "", "comma separated list of interfaces to ignore") ) if err := flags.Parse(args[1:]); err != nil { return err @@ -46,6 +47,10 @@ flags:`) return errors.Wrap(err, "params") } parser := newParser(flags.Args()...) + ignoreItems := strings.Split(*ignoreList, ",") + if ignoreItems[0] != "" { + parser.ExcludeInterfaces = ignoreItems + } parser.Verbose = *v if parser.Verbose { fmt.Println("oto - github.com/pacedotdev/oto") diff --git a/parser.go b/parser.go index d0c4881..5a228fb 100644 --- a/parser.go +++ b/parser.go @@ -87,6 +87,8 @@ func (f fieldType) JSType() (string, error) { type parser struct { Verbose bool + ExcludeInterfaces []string + patterns []string def definition @@ -112,6 +114,7 @@ func (p *parser) parse() (definition, error) { } p.outputObjects = make(map[string]struct{}) p.objects = make(map[string]struct{}) + var excludedObjectsTypeIDs []string pkgs, err := packages.Load(cfg, p.patterns...) if err != nil { return p.def, err @@ -127,12 +130,35 @@ func (p *parser) parse() (definition, error) { if err != nil { return p.def, err } + if isInSlice(p.ExcludeInterfaces, name) { + for _, method := range s.Methods { + excludedObjectsTypeIDs = append(excludedObjectsTypeIDs, method.InputObject.TypeID) + excludedObjectsTypeIDs = append(excludedObjectsTypeIDs, method.OutputObject.TypeID) + } + continue + } p.def.Services = append(p.def.Services, s) case *types.Struct: p.parseObject(pkg, obj, item) } } } + // remove any excluded objects + nonExcludedObjects := make([]object, 0, len(p.def.Objects)) + for _, object := range p.def.Objects { + excluded := false + for _, excludedTypeID := range excludedObjectsTypeIDs { + if object.TypeID == excludedTypeID { + excluded = true + break + } + } + if excluded { + continue + } + nonExcludedObjects = append(nonExcludedObjects, object) + } + p.def.Objects = nonExcludedObjects sort.Slice(p.def.Services, func(i, j int) bool { return p.def.Services[i].Name < p.def.Services[j].Name }) @@ -275,7 +301,8 @@ func (p *parser) addOutputFields() error { for typeName := range p.outputObjects { obj, err := p.def.Object(typeName) if err != nil { - return errors.Wrapf(err, "missing output object: %s", typeName) + // skip if we can't find it - it must be excluded + continue } obj.Fields = append(obj.Fields, errorField) } @@ -286,3 +313,12 @@ func (p *parser) wrapErr(err error, pkg *packages.Package, pos token.Pos) error position := pkg.Fset.Position(pos) return errors.Wrap(err, position.String()) } + +func isInSlice(slice []string, s string) bool { + for i := range slice { + if slice[i] == s { + return true + } + } + return false +} diff --git a/parser_test.go b/parser_test.go index bd4d800..630fbe9 100644 --- a/parser_test.go +++ b/parser_test.go @@ -9,11 +9,13 @@ import ( func TestParse(t *testing.T) { is := is.New(t) patterns := []string{"./testdata/services/pleasantries"} - def, err := newParser(patterns...).parse() + parser := newParser(patterns...) + parser.ExcludeInterfaces = []string{"Ignorer"} + def, err := parser.parse() is.NoErr(err) is.Equal(def.PackageName, "pleasantries") - is.Equal(len(def.Services), 2) + is.Equal(len(def.Services), 2) // should be 2 services is.Equal(def.Services[0].Name, "GreeterService") is.Equal(len(def.Services[0].Methods), 2) is.Equal(def.Services[0].Methods[0].Name, "GetGreetings") diff --git a/testdata/services/pleasantries/ignorer.go b/testdata/services/pleasantries/ignorer.go new file mode 100644 index 0000000..19ba304 --- /dev/null +++ b/testdata/services/pleasantries/ignorer.go @@ -0,0 +1,9 @@ +package pleasantries + +type Ignorer interface { + Ignore(IgnoreRequest) IgnoreResponse +} + +type IgnoreRequest struct{} + +type IgnoreResponse struct{}