Skip to content

Commit

Permalink
A generic mechanism for supporting new types (#23)
Browse files Browse the repository at this point in the history
  • Loading branch information
hujun-open authored May 10, 2023
1 parent 9de0c49 commit 4ff212d
Show file tree
Hide file tree
Showing 9 changed files with 147 additions and 242 deletions.
3 changes: 2 additions & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
name: Test

on:
workflow_dispatch:
push:
branches: [ master ]
pull_request:
Expand All @@ -16,7 +17,7 @@ jobs:
- name: Set up Go
uses: actions/setup-go@v4
with:
go-version: '1.16'
go-version: '1.19'

- name: Test
run: go test -v ./...
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,12 @@ import "github.com/itzg/go-flagsfiller"
- Beyond the standard types supported by flag.FlagSet also includes support for:
- `[]string` where repetition of the argument appends to the slice and/or an argument value can contain a comma-separated list of values. For example: `--arg one --arg two,three`
- `map[string]string` where each entry is a `key=value` and/or repetition of the arguments adds to the map or multiple entries can be comma-separated in a single argument value. For example: `--arg k1=v1 --arg k2=v2,k3=v3`
- `time.Time` parse via time.Parse(), with tag `layout` specify the layout string, default is "2006-01-02 15:04:05"
- `net.IP` parse via net.ParseIP()
- `net.IPNet` parse via net.ParseCIDR()
- `net.HardwareAddr` parse via net.ParseMAC()
- Optionally set flag values from environment variables. Similar to flag names, environment variable names are derived automatically from the field names
- New types could be supported via user code, via `RegisterSimpleType(ConvertFunc)`, check [time.go](time.go) and [net.go](net.go) to see how it works

## Quick example

Expand Down
26 changes: 7 additions & 19 deletions flagset.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,13 +55,8 @@ func (f *FlagSetFiller) Fill(flagSet *flag.FlagSet, from interface{}) error {
}
}

// this is a list of supported struct, like time.Time, that walkFields() won't walk into,
// the key is the is string returned by the getTypeName(<struct_type>),
// each supported struct need to be added in this map in init()
var supportedStructList = make(map[string]struct{})

func isSupportedStruct(name string) bool {
_, ok := supportedStructList[name]
_, ok := extendedTypes[name]
return ok
}

Expand Down Expand Up @@ -181,20 +176,13 @@ func (f *FlagSetFiller) processField(flagSet *flag.FlagSet, fieldRef interface{}
renamed = f.options.renameLongName(name)
}
typeName := getTypeName(t)
switch {
//check the typeName
case typeName == "net.IP":
f.processIP(fieldRef, hasDefaultTag, tagDefault, flagSet, renamed, usage, aliases)
case typeName == "net.IPNet":
f.processIPNet(fieldRef, hasDefaultTag, tagDefault, flagSet, renamed, usage, aliases)
case typeName == "net.HardwareAddr":
f.processMAC(fieldRef, hasDefaultTag, tagDefault, flagSet, renamed, usage, aliases)

case typeName == "time.Time":
layoutStr, _ := tag.Lookup("layout")
f.processTime(fieldRef, hasDefaultTag, tagDefault, flagSet, renamed, usage, aliases, layoutStr)
//end of check typeName

// go through all supported structs
if handler, ok := extendedTypes[typeName]; ok {
err = handler(tag, fieldRef, hasDefaultTag, tagDefault, flagSet, renamed, usage, aliases)
}

switch {
case t.Kind() == reflect.String:
f.processString(fieldRef, hasDefaultTag, tagDefault, flagSet, renamed, usage, aliases)

Expand Down
51 changes: 51 additions & 0 deletions general.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
package flagsfiller

/*
The code in this file could be opened up in future if more complex implementation is needed
*/

import (
"flag"
"fmt"
"reflect"
"strings"
)

// this is a list of addtional supported types(include struct), like time.Time, that walkFields() won't walk into,
// the key is the is string returned by the getTypeName(<type>),
// each supported type need to be added in this map in init()
var extendedTypes = make(map[string]handlerFunc)

type handlerFunc func(tag reflect.StructTag, fieldRef interface{},
hasDefaultTag bool, tagDefault string,
flagSet *flag.FlagSet, renamed string,
usage string, aliases string) error

type flagVal[T any] interface {
flag.Value
StrConverter(string) (T, error)
SetRef(*T)
}

func processGeneral[T any](fieldRef interface{}, val flagVal[T],
hasDefaultTag bool, tagDefault string,
flagSet *flag.FlagSet, renamed string,
usage string, aliases string) (err error) {

casted := fieldRef.(*T)
if hasDefaultTag {
*casted, err = val.StrConverter(tagDefault)
if err != nil {
return fmt.Errorf("failed to parse default into %T: %w", *new(T), err)
}
}
val.SetRef(casted)
flagSet.Var(val, renamed, usage)
if aliases != "" {
for _, alias := range strings.Split(aliases, ",") {
flagSet.Var(val, alias, usage)
}
}
return nil

}
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
module github.com/itzg/go-flagsfiller

go 1.16
go 1.19

require (
github.com/iancoleman/strcase v0.2.0
Expand Down
57 changes: 0 additions & 57 deletions mac.go

This file was deleted.

120 changes: 15 additions & 105 deletions net.go
Original file line number Diff line number Diff line change
@@ -1,123 +1,33 @@
package flagsfiller

import (
"flag"
"fmt"
"net"
"strings"
"reflect"
)

func init() {
supportedStructList["net.IPNet"] = struct{}{}
RegisterSimpleType(ipConverter)
RegisterSimpleType(ipnetConverter)
RegisterSimpleType(macConverter)
}

type ipValue struct {
addr *net.IP
}

func (v *ipValue) String() string {
if v.addr == nil {
return fmt.Sprint(nil)
}
return v.addr.String()
}

func (v *ipValue) Set(s string) error {
*v.addr = net.ParseIP(s)
if *v.addr == nil {
return fmt.Errorf("invalid ip addr %v", s)
}
return nil
}

func (f *FlagSetFiller) processIP(fieldRef interface{}, hasDefaultTag bool, tagDefault string, flagSet *flag.FlagSet, renamed string, usage string, aliases string) (err error) {
casted, ok := fieldRef.(*net.IP)
if !ok {
return f.processCustom(
fieldRef,
func(s string) (interface{}, error) {
value := net.ParseIP(s)
if value == nil {
return nil, fmt.Errorf("invalid IP address %s", s)
}
return value, nil
},
hasDefaultTag,
tagDefault,
flagSet,
renamed,
usage,
aliases,
)
}

if hasDefaultTag {
*casted = net.ParseIP(tagDefault)
if *casted == nil {
return fmt.Errorf("failed to parse default into net.IP: %s", tagDefault)
}
func ipConverter(s string, tag reflect.StructTag) (net.IP, error) {
addr := net.ParseIP(s)
if addr == nil {
return nil, fmt.Errorf("%s is not a valid IP address", s)
}
flagSet.Var(&ipValue{casted}, renamed, usage)
if aliases != "" {
for _, alias := range strings.Split(aliases, ",") {
flagSet.Var(&ipValue{casted}, alias, usage)
}
}
return nil
}

type ipnetValue struct {
prefix *net.IPNet
}

func (v *ipnetValue) String() string {
if v.prefix == nil {
return fmt.Sprint(nil)
}
return v.prefix.String()
return addr, nil
}

func (v *ipnetValue) Set(s string) error {
_, pr, err := net.ParseCIDR(s)
func ipnetConverter(s string, tag reflect.StructTag) (net.IPNet, error) {
_, prefix, err := net.ParseCIDR(s)
if err != nil {
return fmt.Errorf("invalid ip prefix %v", s)
return net.IPNet{}, err
}
*v.prefix = *pr
return nil
return *prefix, nil
}

func (f *FlagSetFiller) processIPNet(fieldRef interface{}, hasDefaultTag bool, tagDefault string, flagSet *flag.FlagSet, renamed string, usage string, aliases string) (err error) {
casted, ok := fieldRef.(*net.IPNet)
if !ok {
return f.processCustom(
fieldRef,
func(s string) (interface{}, error) {
_, value, err := net.ParseCIDR(s)
if err != nil {
return nil, fmt.Errorf("invalid IP prefix %s, %w", s, err)
}
return *value, nil
},
hasDefaultTag,
tagDefault,
flagSet,
renamed,
usage,
aliases,
)
}

if hasDefaultTag {
_, casted, err = net.ParseCIDR(tagDefault)
if err != nil {
return fmt.Errorf("failed to parse default into net.IPNet: %s, %w", tagDefault, err)
}
}
flagSet.Var(&ipnetValue{casted}, renamed, usage)
if aliases != "" {
for _, alias := range strings.Split(aliases, ",") {
flagSet.Var(&ipnetValue{casted}, alias, usage)
}
}
return nil
func macConverter(s string, tag reflect.StructTag) (net.HardwareAddr, error) {
return net.ParseMAC(s)
}
61 changes: 61 additions & 0 deletions simple.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
package flagsfiller

import (
"flag"
"fmt"
"reflect"
)

// RegisterSimpleType register a new type,
// should be called in init(),
// see time.go and net.go for implementation examples
func RegisterSimpleType[T any](c ConvertFunc[T]) {
base := simpleType[T]{converter: c}
extendedTypes[getTypeName(reflect.TypeOf(*new(T)))] = base.Process
}

// ConvertFunc is a function convert string s into a specific type T, the tag is the struct field tag, as addtional input.
// see time.go and net.go for implementation examples
type ConvertFunc[T any] func(s string, tag reflect.StructTag) (T, error)

type simpleType[T any] struct {
val *T
tags reflect.StructTag
converter ConvertFunc[T]
}

func newSimpleType[T any](c ConvertFunc[T], tag reflect.StructTag) simpleType[T] {
return simpleType[T]{val: new(T), converter: c, tags: tag}
}

func (v *simpleType[T]) String() string {
if v.val == nil {
return fmt.Sprint(nil)
}
return fmt.Sprintf("%v", *v.val)
}

func (v *simpleType[T]) StrConverter(s string) (T, error) {
return v.converter(s, v.tags)
}

func (v *simpleType[T]) Set(s string) error {
var err error
*v.val, err = v.converter(s, v.tags)
if err != nil {
return fmt.Errorf("failed to parse %s into %T, %w", s, *(new(T)), err)
}
return nil
}

func (v *simpleType[T]) SetRef(t *T) {
v.val = t
}

func (v *simpleType[T]) Process(tag reflect.StructTag, fieldRef interface{},
hasDefaultTag bool, tagDefault string,
flagSet *flag.FlagSet, renamed string,
usage string, aliases string) error {
val := newSimpleType(v.converter, tag)
return processGeneral[T](fieldRef, &val, hasDefaultTag, tagDefault, flagSet, renamed, usage, aliases)
}
Loading

0 comments on commit 4ff212d

Please sign in to comment.