Skip to content

Commit

Permalink
Refactor TakesArgs to use an interface for arg validation.
Browse files Browse the repository at this point in the history
Fix some typos in README and comments.
Move arg validation to after flag validation so that the help flag is run first.
Pass the same args to ValidateArgs as the Run methods receive.
Update README.

Signed-off-by: Daniel Nephin <[email protected]>
  • Loading branch information
dnephin authored and n10v committed Jul 23, 2017
1 parent d89c499 commit f20b4e9
Show file tree
Hide file tree
Showing 4 changed files with 160 additions and 84 deletions.
47 changes: 23 additions & 24 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -467,36 +467,34 @@ A flag can also be assigned locally which will only apply to that specific comma
RootCmd.Flags().StringVarP(&Source, "source", "s", "", "Source directory to read from")
```

### Specify if you command takes arguments
## Positional and Custom Arguments

There are multiple options for how a command can handle unknown arguments which can be set in `TakesArgs`
- `Legacy`
- `None`
- `Arbitrary`
- `ValidOnly`
Validation of positional arguments can be specified using the `Args` field.

`Legacy` (or default) the rules are as follows:
- root commands with no subcommands can take arbitrary arguments
- root commands with subcommands will do subcommand validity checking
- subcommands will always accept arbitrary arguments and do no subsubcommand validity checking
The follow validators are built in:

`None` the command will be rejected if there are any left over arguments after parsing flags.
- `NoArgs` - the command will report an error if there are any positional args.
- `ArbitraryArgs` - the command will accept any args.
- `OnlyValidArgs` - the command will report an error if there are any positional args that are not in the ValidArgs list.
- `MinimumNArgs(int)` - the command will report an error if there are not at least N positional args.
- `MaximumNArgs(int)` - the command will report an error if there are more than N positional args.
- `ExactArgs(int)` - the command will report an error if there are not exactly N positional args.
- `RangeArgs(min, max)` - the command will report an error if the number of args is not between the minimum and maximum number of expected args.

`Arbitrary` any additional values left after parsing flags will be passed in to your `Run` function.

`ValidOnly` you must define all valid (non-subcommand) arguments to your command. These are defined in a slice name ValidArgs. For example a command which only takes the argument "one" or "two" would be defined as:
A custom validator can be provided like this:

```go
var HugoCmd = &cobra.Command{
Use: "hugo",
Short: "Hugo is a very fast static site generator",
ValidArgs: []string{"one", "two", "three", "four"}
TakesArgs: cobra.ValidOnly
Run: func(cmd *cobra.Command, args []string) {
// args will only have the values one, two, three, four
// or the cmd.Execute() will fail.
},
}

Args: func validColorArgs(cmd *cobra.Command, args []string) error {
if err := cli.RequiresMinArgs(1)(cmd, args); err != nil {
return err
}
if myapp.IsValidColor(args[0]) {
return nil
}
return fmt.Errorf("Invalid color specified: %s", args[0])
}

```

### Bind Flags with Config
Expand All @@ -517,6 +515,7 @@ when the `--author` flag is not provided by user.

More in [viper documentation](https://github.com/spf13/viper#working-with-flags).


## Example

In the example below, we have defined three commands. Two are at the top level
Expand Down
98 changes: 98 additions & 0 deletions args.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
package cobra

import (
"fmt"
)

type PositionalArgs func(cmd *Command, args []string) error

// Legacy arg validation has the following behaviour:
// - root commands with no subcommands can take arbitrary arguments
// - root commands with subcommands will do subcommand validity checking
// - subcommands will always accept arbitrary arguments
func legacyArgs(cmd *Command, args []string) error {
// no subcommand, always take args
if !cmd.HasSubCommands() {
return nil
}

// root command with subcommands, do subcommand checking
if !cmd.HasParent() && len(args) > 0 {
return fmt.Errorf("unknown command %q for %q%s", args[0], cmd.CommandPath(), cmd.findSuggestions(args[0]))
}
return nil
}

// NoArgs returns an error if any args are included
func NoArgs(cmd *Command, args []string) error {
if len(args) > 0 {
return fmt.Errorf("unknown command %q for %q", args[0], cmd.CommandPath())
}
return nil
}

// OnlyValidArgs returns an error if any args are not in the list of ValidArgs
func OnlyValidArgs(cmd *Command, args []string) error {
if len(cmd.ValidArgs) > 0 {
for _, v := range args {
if !stringInSlice(v, cmd.ValidArgs) {
return fmt.Errorf("invalid argument %q for %q%s", v, cmd.CommandPath(), cmd.findSuggestions(args[0]))
}
}
}
return nil
}

func stringInSlice(a string, list []string) bool {
for _, b := range list {
if b == a {
return true
}
}
return false
}

// ArbitraryArgs never returns an error
func ArbitraryArgs(cmd *Command, args []string) error {
return nil
}

// MinimumNArgs returns an error if there is not at least N args
func MinimumNArgs(n int) PositionalArgs {
return func(cmd *Command, args []string) error {
if len(args) < n {
return fmt.Errorf("requires at least %d arg(s), only received %d", n, len(args))
}
return nil
}
}

// MaximumNArgs returns an error if there are more than N args
func MaximumNArgs(n int) PositionalArgs {
return func(cmd *Command, args []string) error {
if len(args) > n {
return fmt.Errorf("accepts at most %d arg(s), received %d", n, len(args))
}
return nil
}
}

// ExactArgs returns an error if there are not exactly n args
func ExactArgs(n int) PositionalArgs {
return func(cmd *Command, args []string) error {
if len(args) != n {
return fmt.Errorf("accepts %d arg(s), received %d", n, len(args))
}
return nil
}
}

// RangeArgs returns an error if the number of args is not within the expected range
func RangeArgs(min int, max int) PositionalArgs {
return func(cmd *Command, args []string) error {
if len(args) < min || len(args) > max {
return fmt.Errorf("accepts between %d and %d arg(s), received %d", min, max, len(args))
}
return nil
}
}
32 changes: 22 additions & 10 deletions cobra_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ var cmdHidden = &Command{

var cmdPrint = &Command{
Use: "print [string to print]",
Args: MinimumNArgs(1),
Short: "Print anything to the screen",
Long: `an absolutely utterly useless command for testing.`,
Run: func(cmd *Command, args []string) {
Expand Down Expand Up @@ -75,7 +76,7 @@ var cmdDeprecated = &Command{
Deprecated: "Please use echo instead",
Run: func(cmd *Command, args []string) {
},
TakesArgs: None,
Args: NoArgs,
}

var cmdTimes = &Command{
Expand All @@ -89,7 +90,7 @@ var cmdTimes = &Command{
Run: func(cmd *Command, args []string) {
tt = args
},
TakesArgs: ValidOnly,
Args: OnlyValidArgs,
ValidArgs: []string{"one", "two", "three", "four"},
}

Expand All @@ -103,10 +104,9 @@ var cmdRootNoRun = &Command{
}

var cmdRootSameName = &Command{
Use: "print",
Short: "Root with the same name as a subcommand",
Long: "The root description for help",
TakesArgs: None,
Use: "print",
Short: "Root with the same name as a subcommand",
Long: "The root description for help",
}

var cmdRootTakesArgs = &Command{
Expand All @@ -116,7 +116,7 @@ var cmdRootTakesArgs = &Command{
Run: func(cmd *Command, args []string) {
tr = args
},
TakesArgs: Arbitrary,
Args: ArbitraryArgs,
}

var cmdRootWithRun = &Command{
Expand Down Expand Up @@ -477,6 +477,10 @@ func TestRootTakesNoArgs(t *testing.T) {
c.AddCommand(cmdPrint, cmdEcho)
result := simpleTester(c, "illegal")

if result.Error == nil {
t.Fatal("Expected an error")
}

expectedError := `unknown command "illegal" for "print"`
if !strings.Contains(result.Error.Error(), expectedError) {
t.Errorf("exptected %v, got %v", expectedError, result.Error.Error())
Expand All @@ -493,7 +497,11 @@ func TestRootTakesArgs(t *testing.T) {
}

func TestSubCmdTakesNoArgs(t *testing.T) {
result := fullSetupTest("deprecated illegal")
result := fullSetupTest("deprecated", "illegal")

if result.Error == nil {
t.Fatal("Expected an error")
}

expectedError := `unknown command "illegal" for "cobra-test deprecated"`
if !strings.Contains(result.Error.Error(), expectedError) {
Expand All @@ -502,14 +510,18 @@ func TestSubCmdTakesNoArgs(t *testing.T) {
}

func TestSubCmdTakesArgs(t *testing.T) {
noRRSetupTest("echo times one two")
noRRSetupTest("echo", "times", "one", "two")
if strings.Join(tt, " ") != "one two" {
t.Error("Command didn't parse correctly")
}
}

func TestCmdOnlyValidArgs(t *testing.T) {
result := noRRSetupTest("echo times one two five")
result := noRRSetupTest("echo", "times", "one", "two", "five")

if result.Error == nil {
t.Fatal("Expected an error")
}

expectedError := `invalid argument "five"`
if !strings.Contains(result.Error.Error(), expectedError) {
Expand Down
67 changes: 17 additions & 50 deletions command.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,6 @@ import (
flag "github.com/spf13/pflag"
)

type Args int

const (
Legacy Args = iota
Arbitrary
ValidOnly
None
)

// Command is just that, a command for your application.
// E.g. 'go run ...' - 'run' is the command. Cobra requires
// you to define the usage and description as part of your command
Expand Down Expand Up @@ -68,8 +59,8 @@ type Command struct {
// but accepted if entered manually.
ArgAliases []string

// Does this command take arbitrary arguments
TakesArgs Args
// Expected arguments
Args PositionalArgs
// BashCompletionFunction is custom functions used by the bash autocompletion generator.
BashCompletionFunction string

Expand Down Expand Up @@ -483,15 +474,6 @@ func argsMinusFirstX(args []string, x string) []string {
return args
}

func stringInSlice(a string, list []string) bool {
for _, b := range list {
if b == a {
return true
}
}
return false
}

// Find the target command given the args and command tree
// Meant to be run on the highest node. Only searches down.
func (c *Command) Find(args []string) (*Command, []string, error) {
Expand Down Expand Up @@ -533,47 +515,21 @@ func (c *Command) Find(args []string) (*Command, []string, error) {
}

commandFound, a := innerfind(c, args)
argsWOflags := stripFlags(a, commandFound)

// "Legacy" has some 'odd' characteristics.
// - root commands with no subcommands can take arbitrary arguments
// - root commands with subcommands will do subcommand validity checking
// - subcommands will always accept arbitrary arguments
if commandFound.TakesArgs == Legacy {
// no subcommand, always take args
if !commandFound.HasSubCommands() {
return commandFound, a, nil
}
// root command with subcommands, do subcommand checking
if commandFound == c && len(argsWOflags) > 0 {
return commandFound, a, fmt.Errorf("unknown command %q for %q%s", argsWOflags[0], commandFound.CommandPath(), c.findSuggestions(argsWOflags))
}
return commandFound, a, nil
}

if commandFound.TakesArgs == None && len(argsWOflags) > 0 {
return commandFound, a, fmt.Errorf("unknown command %q for %q", argsWOflags[0], commandFound.CommandPath())
}

if commandFound.TakesArgs == ValidOnly && len(commandFound.ValidArgs) > 0 {
for _, v := range argsWOflags {
if !stringInSlice(v, commandFound.ValidArgs) {
return commandFound, a, fmt.Errorf("invalid argument %q for %q%s", v, commandFound.CommandPath(), c.findSuggestions(argsWOflags))
}
}
if commandFound.Args == nil {
return commandFound, a, legacyArgs(commandFound, stripFlags(a, commandFound))
}
return commandFound, a, nil
}

func (c *Command) findSuggestions(argsWOflags []string) string {
func (c *Command) findSuggestions(arg string) string {
if c.DisableSuggestions {
return ""
}
if c.SuggestionsMinimumDistance <= 0 {
c.SuggestionsMinimumDistance = 2
}
suggestionsString := ""
if suggestions := c.SuggestionsFor(argsWOflags[0]); len(suggestions) > 0 {
if suggestions := c.SuggestionsFor(arg); len(suggestions) > 0 {
suggestionsString += "\n\nDid you mean this?\n"
for _, s := range suggestions {
suggestionsString += fmt.Sprintf("\t%v\n", s)
Expand Down Expand Up @@ -666,6 +622,10 @@ func (c *Command) execute(a []string) (err error) {
argWoFlags = a
}

if err := c.ValidateArgs(argWoFlags); err != nil {
return err
}

for p := c; p != nil; p = p.Parent() {
if p.PersistentPreRunE != nil {
if err := p.PersistentPreRunE(c, argWoFlags); err != nil {
Expand Down Expand Up @@ -789,6 +749,13 @@ func (c *Command) ExecuteC() (cmd *Command, err error) {
return cmd, err
}

func (c *Command) ValidateArgs(args []string) error {
if c.Args == nil {
return nil
}
return c.Args(c, args)
}

// InitDefaultHelpFlag adds default help flag to c.
// It is called automatically by executing the c or by calling help and usage.
// If c already has help flag, it will do nothing.
Expand Down

0 comments on commit f20b4e9

Please sign in to comment.