Skip to content

feat: Use structs for CLI-validation errors returned by Cobra. #2266

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 39 additions & 8 deletions args.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
package cobra

import (
"fmt"
"strings"
)

Expand All @@ -33,15 +32,23 @@ func legacyArgs(cmd *Command, args []string) error {

// root command with subcommands, do subcommand checking.
if !cmd.HasParent() && len(args) > 0 {
return fmt.Errorf("unknown command %q for %q%s", args[0], cmd.CommandPath(), cmd.findSuggestions(args[0]))
return &UnknownSubcommandError{
cmd: cmd,
subcmd: args[0],
suggestions: cmd.findSuggestions(args[0]),
}
}
return nil
}

// NoArgs returns an error if any args are included.
func NoArgs(cmd *Command, args []string) error {
if len(args) > 0 {
return fmt.Errorf("unknown command %q for %q", args[0], cmd.CommandPath())
return &UnknownSubcommandError{
cmd: cmd,
subcmd: args[0],
suggestions: "",
}
}
return nil
}
Expand All @@ -58,7 +65,11 @@ func OnlyValidArgs(cmd *Command, args []string) error {
}
for _, v := range args {
if !stringInSlice(v, validArgs) {
return fmt.Errorf("invalid argument %q for %q%s", v, cmd.CommandPath(), cmd.findSuggestions(args[0]))
return &InvalidArgValueError{
cmd: cmd,
arg: v,
suggestions: cmd.findSuggestions(args[0]),
}
}
}
}
Expand All @@ -74,7 +85,12 @@ func ArbitraryArgs(cmd *Command, args []string) error {
func MinimumNArgs(n int) PositionalArgs {
return func(cmd *Command, args []string) error {
if len(args) < n {
return fmt.Errorf("requires at least %d arg(s), only received %d", n, len(args))
return &InvalidArgCountError{
cmd: cmd,
args: args,
atLeast: n,
atMost: -1,
}
}
return nil
}
Expand All @@ -84,7 +100,12 @@ func MinimumNArgs(n int) PositionalArgs {
func MaximumNArgs(n int) PositionalArgs {
return func(cmd *Command, args []string) error {
if len(args) > n {
return fmt.Errorf("accepts at most %d arg(s), received %d", n, len(args))
return &InvalidArgCountError{
cmd: cmd,
args: args,
atLeast: -1,
atMost: n,
}
}
return nil
}
Expand All @@ -94,7 +115,12 @@ func MaximumNArgs(n int) PositionalArgs {
func ExactArgs(n int) PositionalArgs {
return func(cmd *Command, args []string) error {
if len(args) != n {
return fmt.Errorf("accepts %d arg(s), received %d", n, len(args))
return &InvalidArgCountError{
cmd: cmd,
args: args,
atLeast: n,
atMost: n,
}
}
return nil
}
Expand All @@ -104,7 +130,12 @@ func ExactArgs(n int) PositionalArgs {
func RangeArgs(min int, max int) PositionalArgs {
return func(cmd *Command, args []string) error {
if len(args) < min || len(args) > max {
return fmt.Errorf("accepts between %d and %d arg(s), received %d", min, max, len(args))
return &InvalidArgCountError{
cmd: cmd,
args: args,
atLeast: min,
atMost: max,
}
}
return nil
}
Expand Down
77 changes: 77 additions & 0 deletions args_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@
package cobra

import (
"errors"
"fmt"
"reflect"
"strings"
"testing"
)
Expand All @@ -32,6 +34,14 @@ func getCommand(args PositionalArgs, withValid bool) *Command {
return c
}

func getCommandName(c *Command) string {
if c == nil {
return "<nil>"
} else {
return c.Name()
}
}

func expectSuccess(output string, err error, t *testing.T) {
if output != "" {
t.Errorf("Unexpected output: %v", output)
Expand All @@ -41,6 +51,31 @@ func expectSuccess(output string, err error, t *testing.T) {
}
}

func expectErrorAs(err error, target error, t *testing.T) {
if err == nil {
t.Fatalf("Expected error, got nil")
}

targetType := reflect.TypeOf(target)
targetPtr := reflect.New(targetType).Interface() // *SomeError
if !errors.As(err, targetPtr) {
t.Fatalf("Expected error to be %T, got %T", target, err)
}
}

func expectErrorHasCommand(err error, cmd *Command, t *testing.T) {
getCommand, ok := err.(interface{ GetCommand() *Command })
if !ok {
t.Fatalf("Expected error to have GetCommand method, but did not")
}

got := getCommand.GetCommand()
if cmd != got {
t.Errorf("Expected err.GetCommand to return %v, got %v",
getCommandName(cmd), getCommandName(got))
}
}

func validOnlyWithInvalidArgs(err error, t *testing.T) {
if err == nil {
t.Fatal("Expected an error")
Expand Down Expand Up @@ -139,6 +174,13 @@ func TestNoArgs_WithValidOnly_WithInvalidArgs(t *testing.T) {
validOnlyWithInvalidArgs(err, t)
}

func TestNoArgs_ReturnsUnknownSubcommandError(t *testing.T) {
c := getCommand(NoArgs, false)
_, err := executeCommand(c, "a")
expectErrorAs(err, &UnknownSubcommandError{}, t)
expectErrorHasCommand(err, c, t)
}

// OnlyValidArgs

func TestOnlyValidArgs(t *testing.T) {
Expand All @@ -153,6 +195,13 @@ func TestOnlyValidArgs_WithInvalidArgs(t *testing.T) {
validOnlyWithInvalidArgs(err, t)
}

func TestOnlyValidArgs_ReturnsInvalidArgValueError(t *testing.T) {
c := getCommand(OnlyValidArgs, true)
_, err := executeCommand(c, "a")
expectErrorAs(err, &InvalidArgValueError{}, t)
expectErrorHasCommand(err, c, t)
}

// ArbitraryArgs

func TestArbitraryArgs(t *testing.T) {
Expand Down Expand Up @@ -229,6 +278,13 @@ func TestMinimumNArgs_WithLessArgs_WithValidOnly_WithInvalidArgs(t *testing.T) {
validOnlyWithInvalidArgs(err, t)
}

func TestMinimumNArgs_ReturnsInvalidArgCountError(t *testing.T) {
c := getCommand(MinimumNArgs(2), true)
_, err := executeCommand(c, "a")
expectErrorAs(err, &InvalidArgCountError{}, t)
expectErrorHasCommand(err, c, t)
}

// MaximumNArgs

func TestMaximumNArgs(t *testing.T) {
Expand Down Expand Up @@ -279,6 +335,13 @@ func TestMaximumNArgs_WithMoreArgs_WithValidOnly_WithInvalidArgs(t *testing.T) {
validOnlyWithInvalidArgs(err, t)
}

func TestMaximumNArgs_ReturnsInvalidArgCountError(t *testing.T) {
c := getCommand(MaximumNArgs(2), true)
_, err := executeCommand(c, "a", "b", "c")
expectErrorAs(err, &InvalidArgCountError{}, t)
expectErrorHasCommand(err, c, t)
}

// ExactArgs

func TestExactArgs(t *testing.T) {
Expand Down Expand Up @@ -329,6 +392,13 @@ func TestExactArgs_WithInvalidCount_WithValidOnly_WithInvalidArgs(t *testing.T)
validOnlyWithInvalidArgs(err, t)
}

func TestExactArgs_ReturnsInvalidArgCountError(t *testing.T) {
c := getCommand(ExactArgs(2), true)
_, err := executeCommand(c, "a")
expectErrorAs(err, &InvalidArgCountError{}, t)
expectErrorHasCommand(err, c, t)
}

// RangeArgs

func TestRangeArgs(t *testing.T) {
Expand Down Expand Up @@ -379,6 +449,13 @@ func TestRangeArgs_WithInvalidCount_WithValidOnly_WithInvalidArgs(t *testing.T)
validOnlyWithInvalidArgs(err, t)
}

func TestRangeArgs_ReturnsInvalidArgCountError(t *testing.T) {
c := getCommand(RangeArgs(2, 4), true)
_, err := executeCommand(c, "a")
expectErrorAs(err, &InvalidArgCountError{}, t)
expectErrorHasCommand(err, c, t)
}

// Takes(No)Args

func TestRootTakesNoArgs(t *testing.T) {
Expand Down
5 changes: 4 additions & 1 deletion command.go
Original file line number Diff line number Diff line change
Expand Up @@ -1195,7 +1195,10 @@ func (c *Command) ValidateRequiredFlags() error {
})

if len(missingFlagNames) > 0 {
return fmt.Errorf(`required flag(s) "%s" not set`, strings.Join(missingFlagNames, `", "`))
return &RequiredFlagError{
cmd: c,
missingFlagNames: missingFlagNames,
}
}
return nil
}
Expand Down
16 changes: 16 additions & 0 deletions command_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ package cobra
import (
"bytes"
"context"
"errors"
"fmt"
"io"
"os"
Expand Down Expand Up @@ -866,6 +867,21 @@ func TestRequiredFlags(t *testing.T) {
if got != expected {
t.Errorf("Expected error: %q, got: %q", expected, got)
}

// Test it returns valid RequiredFlagError.
var requiredFlagErr *RequiredFlagError
if !errors.As(err, &requiredFlagErr) {
t.Fatalf("Expected error to be RequiredFlagError, got %T", err)
}

expectedMissingFlagNames := "foo1 foo2"
gotMissingFlagNames := strings.Join(requiredFlagErr.missingFlagNames, " ")
if expectedMissingFlagNames != gotMissingFlagNames {
t.Errorf("Expected error missingFlagNames to be %q, got %q",
expectedMissingFlagNames, gotMissingFlagNames)
}

expectErrorHasCommand(err, c, t)
}

func TestPersistentRequiredFlags(t *testing.T) {
Expand Down
Loading