Skip to content

Commit 0b93018

Browse files
committed
feat: add support for listen address and listen address slices flags
1 parent 11b01e2 commit 0b93018

File tree

4 files changed

+427
-0
lines changed

4 files changed

+427
-0
lines changed

listen_addr.go

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
package dynflags
2+
3+
import (
4+
"fmt"
5+
"net"
6+
)
7+
8+
type ListenAddrValue struct {
9+
Bound *string
10+
}
11+
12+
func (l *ListenAddrValue) GetBound() any {
13+
if l.Bound == nil {
14+
return nil
15+
}
16+
return *l.Bound
17+
}
18+
19+
func (l *ListenAddrValue) Parse(value string) (any, error) {
20+
_, err := net.ResolveTCPAddr("tcp", value)
21+
if err != nil {
22+
return nil, fmt.Errorf("invalid listen address: %w", err)
23+
}
24+
return &value, nil
25+
}
26+
27+
func (l *ListenAddrValue) Set(value any) error {
28+
if str, ok := value.(*string); ok {
29+
*l.Bound = *str
30+
return nil
31+
}
32+
return fmt.Errorf("invalid value type: expected string pointer for listen address")
33+
}
34+
35+
// ListenAddr defines a flag that validates a TCP listen address (host:port or :port).
36+
func (g *ConfigGroup) ListenAddr(name, defaultValue, usage string) *Flag {
37+
bound := new(string)
38+
if defaultValue != "" {
39+
if _, err := net.ResolveTCPAddr("tcp", defaultValue); err != nil {
40+
panic(fmt.Sprintf("%s has an invalid default listen address '%s': %v", name, defaultValue, err))
41+
}
42+
*bound = defaultValue // Copy the parsed ListenAddr into bound
43+
}
44+
flag := &Flag{
45+
Type: FlagTypeString,
46+
Default: defaultValue,
47+
Usage: usage,
48+
value: &ListenAddrValue{Bound: bound},
49+
}
50+
g.Flags[name] = flag
51+
g.flagOrder = append(g.flagOrder, name)
52+
return flag
53+
}
54+
55+
// GetListenAddr returns the string value of a validated listen address flag.
56+
func (pg *ParsedGroup) GetListenAddr(flagName string) (string, error) {
57+
value, exists := pg.Values[flagName]
58+
if !exists {
59+
return "", fmt.Errorf("flag '%s' not found in group '%s'", flagName, pg.Name)
60+
}
61+
if str, ok := value.(string); ok {
62+
return str, nil
63+
}
64+
return "", fmt.Errorf("flag '%s' is not a string listen address", flagName)
65+
}

listen_addr_slice.go

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
package dynflags
2+
3+
import (
4+
"fmt"
5+
"net"
6+
"strings"
7+
)
8+
9+
type ListenAddrSlicesValue struct {
10+
Bound *[]string
11+
}
12+
13+
func (s *ListenAddrSlicesValue) GetBound() any {
14+
if s.Bound == nil {
15+
return nil
16+
}
17+
return *s.Bound
18+
}
19+
20+
func (s *ListenAddrSlicesValue) Parse(value string) (any, error) {
21+
_, err := net.ResolveTCPAddr("tcp", value)
22+
if err != nil {
23+
return nil, fmt.Errorf("invalid listen address: %w", err)
24+
}
25+
return value, nil
26+
}
27+
28+
func (s *ListenAddrSlicesValue) Set(value any) error {
29+
if addr, ok := value.(string); ok {
30+
*s.Bound = append(*s.Bound, addr)
31+
return nil
32+
}
33+
return fmt.Errorf("invalid value type: expected string listen address")
34+
}
35+
36+
// ListenAddrSlices defines a slice-of-listen-address flag with the specified name, default values, and usage.
37+
func (g *ConfigGroup) ListenAddrSlices(name string, value []string, usage string) *Flag {
38+
bound := &value
39+
defaultValue := strings.Join(value, ",")
40+
41+
// Validate all default addresses
42+
for _, v := range value {
43+
if _, err := net.ResolveTCPAddr("tcp", v); err != nil {
44+
panic(fmt.Sprintf("%s has an invalid default listen address '%s': %v", name, v, err))
45+
}
46+
}
47+
48+
flag := &Flag{
49+
Type: FlagTypeStringSlice,
50+
Default: defaultValue,
51+
Usage: usage,
52+
value: &ListenAddrSlicesValue{Bound: bound},
53+
}
54+
g.Flags[name] = flag
55+
g.flagOrder = append(g.flagOrder, name)
56+
return flag
57+
}
58+
59+
// GetListenAddrSlices returns the []string value of a listen address slice flag.
60+
func (pg *ParsedGroup) GetListenAddrSlices(flagName string) ([]string, error) {
61+
value, exists := pg.Values[flagName]
62+
if !exists {
63+
return nil, fmt.Errorf("flag '%s' not found in group '%s'", flagName, pg.Name)
64+
}
65+
66+
if list, ok := value.([]string); ok {
67+
return list, nil
68+
}
69+
70+
if str, ok := value.(string); ok {
71+
return []string{str}, nil
72+
}
73+
74+
return nil, fmt.Errorf("flag '%s' is not a []string listen address slice", flagName)
75+
}

listen_addr_slice_test.go

Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,151 @@
1+
package dynflags_test
2+
3+
import (
4+
"testing"
5+
6+
"github.com/containeroo/dynflags"
7+
"github.com/stretchr/testify/assert"
8+
)
9+
10+
func TestListenAddrSlicesValue(t *testing.T) {
11+
t.Parallel()
12+
13+
t.Run("Parse valid listen address", func(t *testing.T) {
14+
t.Parallel()
15+
16+
value := dynflags.ListenAddrSlicesValue{Bound: &[]string{}}
17+
parsed, err := value.Parse(":8080")
18+
assert.NoError(t, err)
19+
assert.Equal(t, ":8080", parsed)
20+
})
21+
22+
t.Run("Parse invalid listen address", func(t *testing.T) {
23+
t.Parallel()
24+
25+
value := dynflags.ListenAddrSlicesValue{Bound: &[]string{}}
26+
parsed, err := value.Parse("bad-address")
27+
assert.Error(t, err)
28+
assert.Nil(t, parsed)
29+
})
30+
31+
t.Run("Set valid listen address", func(t *testing.T) {
32+
t.Parallel()
33+
34+
bound := []string{":9090"}
35+
value := dynflags.ListenAddrSlicesValue{Bound: &bound}
36+
37+
err := value.Set(":8080")
38+
assert.NoError(t, err)
39+
assert.Equal(t, []string{":9090", ":8080"}, bound)
40+
})
41+
42+
t.Run("Set invalid type", func(t *testing.T) {
43+
t.Parallel()
44+
45+
bound := []string{}
46+
value := dynflags.ListenAddrSlicesValue{Bound: &bound}
47+
48+
err := value.Set(12345)
49+
assert.Error(t, err)
50+
assert.EqualError(t, err, "invalid value type: expected string listen address")
51+
})
52+
}
53+
54+
func TestGroupConfigListenAddrSlices(t *testing.T) {
55+
t.Parallel()
56+
57+
t.Run("Define listen address slices flag", func(t *testing.T) {
58+
t.Parallel()
59+
60+
group := &dynflags.ConfigGroup{Flags: make(map[string]*dynflags.Flag)}
61+
defaultValue := []string{":9090", "127.0.0.1:9091"}
62+
group.ListenAddrSlices("listenSlice", defaultValue, "Multiple listen addresses")
63+
64+
assert.Contains(t, group.Flags, "listenSlice")
65+
assert.Equal(t, "Multiple listen addresses", group.Flags["listenSlice"].Usage)
66+
assert.Equal(t, ":9090,127.0.0.1:9091", group.Flags["listenSlice"].Default)
67+
})
68+
69+
t.Run("Define listen address slices with invalid default", func(t *testing.T) {
70+
t.Parallel()
71+
72+
group := &dynflags.ConfigGroup{Flags: make(map[string]*dynflags.Flag)}
73+
74+
assert.PanicsWithValue(t,
75+
"listenSlice has an invalid default listen address 'bad-address': address bad-address: missing port in address",
76+
func() {
77+
group.ListenAddrSlices("listenSlice", []string{":8080", "bad-address"}, "Invalid default")
78+
})
79+
})
80+
}
81+
82+
func TestGetListenAddrSlices(t *testing.T) {
83+
t.Parallel()
84+
85+
t.Run("Retrieve []string listen address slice", func(t *testing.T) {
86+
t.Parallel()
87+
88+
parsedGroup := &dynflags.ParsedGroup{
89+
Name: "testGroup",
90+
Values: map[string]any{"flag1": []string{":8080", "127.0.0.1:9090"}},
91+
}
92+
93+
result, err := parsedGroup.GetListenAddrSlices("flag1")
94+
assert.NoError(t, err)
95+
assert.Equal(t, []string{":8080", "127.0.0.1:9090"}, result)
96+
})
97+
98+
t.Run("Retrieve single string as []string", func(t *testing.T) {
99+
t.Parallel()
100+
101+
parsedGroup := &dynflags.ParsedGroup{
102+
Name: "testGroup",
103+
Values: map[string]any{"flag1": ":8080"},
104+
}
105+
106+
result, err := parsedGroup.GetListenAddrSlices("flag1")
107+
assert.NoError(t, err)
108+
assert.Equal(t, []string{":8080"}, result)
109+
})
110+
111+
t.Run("Flag not found", func(t *testing.T) {
112+
t.Parallel()
113+
114+
parsedGroup := &dynflags.ParsedGroup{
115+
Name: "testGroup",
116+
Values: map[string]any{},
117+
}
118+
119+
result, err := parsedGroup.GetListenAddrSlices("missingFlag")
120+
assert.Error(t, err)
121+
assert.Nil(t, result)
122+
assert.EqualError(t, err, "flag 'missingFlag' not found in group 'testGroup'")
123+
})
124+
125+
t.Run("Flag value is invalid type", func(t *testing.T) {
126+
t.Parallel()
127+
128+
parsedGroup := &dynflags.ParsedGroup{
129+
Name: "testGroup",
130+
Values: map[string]any{"flag1": 123},
131+
}
132+
133+
result, err := parsedGroup.GetListenAddrSlices("flag1")
134+
assert.Error(t, err)
135+
assert.Nil(t, result)
136+
assert.EqualError(t, err, "flag 'flag1' is not a []string listen address slice")
137+
})
138+
}
139+
140+
func TestListenAddrSlicesGetBound(t *testing.T) {
141+
t.Run("ListenAddrSlicesValue - GetBound", func(t *testing.T) {
142+
val := []string{":8080", "127.0.0.1:9090"}
143+
bound := &val
144+
145+
value := dynflags.ListenAddrSlicesValue{Bound: bound}
146+
assert.Equal(t, val, value.GetBound())
147+
148+
value = dynflags.ListenAddrSlicesValue{Bound: nil}
149+
assert.Nil(t, value.GetBound())
150+
})
151+
}

0 commit comments

Comments
 (0)