Skip to content

Commit

Permalink
AddFlags test
Browse files Browse the repository at this point in the history
  • Loading branch information
unixpickle committed Apr 2, 2024
1 parent 8b1b07a commit 0f154b6
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 5 deletions.
15 changes: 10 additions & 5 deletions toolbox3d/params.go → toolbox3d/flags.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package toolbox3d

import (
"flag"
"fmt"
"reflect"
"strconv"
"unicode"
Expand All @@ -16,20 +17,20 @@ import (
// indicate the default integer or floating point values.
//
// This may panic() if the default for a field is
// incorrectly formatted.
// incorrectly formatted, or if a field is not supported.
func AddFlags(obj any, f *flag.FlagSet) {
if f == nil {
f = flag.CommandLine
}
val := reflect.ValueOf(obj)
val := reflect.ValueOf(obj).Elem()
fields := reflect.VisibleFields(val.Type())
for _, field := range fields {
flagName := flagNameForField(field.Name)
defaultStr := field.Tag.Get("default")
usageStr := field.Tag.Get("usage")
var err error
switch field.Type.Kind() {
case reflect.Int:
switch field.Type {
case reflect.TypeOf(int(0)):
var defaultVal int
if defaultStr != "" {
defaultVal, err = strconv.Atoi(defaultStr)
Expand All @@ -39,7 +40,7 @@ func AddFlags(obj any, f *flag.FlagSet) {
}
f.IntVar(val.FieldByIndex(field.Index).Addr().Interface().(*int),
flagName, defaultVal, usageStr)
case reflect.Float64:
case reflect.TypeOf(float64(0)):
var defaultVal float64
if defaultStr != "" {
defaultVal, err = strconv.ParseFloat(defaultStr, 64)
Expand All @@ -49,6 +50,8 @@ func AddFlags(obj any, f *flag.FlagSet) {
}
f.Float64Var(val.FieldByIndex(field.Index).Addr().Interface().(*float64),
flagName, defaultVal, usageStr)
default:
panic(fmt.Sprintf("unsupported type: %v", field.Type))
}
}
}
Expand All @@ -61,6 +64,8 @@ func flagNameForField(field string) string {
result = append(result, '-')
}
result = append(result, unicode.ToLower(x))
} else {
result = append(result, x)
}
}
return string(result)
Expand Down
45 changes: 45 additions & 0 deletions toolbox3d/flags_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
package toolbox3d

import (
"flag"
"testing"
)

type TestAddFlagsSubObj struct {
}

type TestAddFlagsObj struct {
IntField int
OtherIntField int `default:"123"`
FloatField float64 `default:"3.14"`
}

func TestAddFlags(t *testing.T) {
var obj TestAddFlagsObj
fs := flag.NewFlagSet("foo", flag.PanicOnError)
AddFlags(&obj, fs)
fs.Parse([]string{})
if obj.OtherIntField != 123 {
t.Errorf("incorrect OtherIntField: %v", obj.OtherIntField)
}
if obj.IntField != 0 {
t.Errorf("incorrect IntField: %v", obj.IntField)
}
if obj.FloatField != 3.14 {
t.Errorf("incorrect FloatField: %v", obj.FloatField)
}

var obj1 TestAddFlagsObj
fs = flag.NewFlagSet("foo", flag.PanicOnError)
AddFlags(&obj1, fs)
fs.Parse([]string{"-int-field", "4", "-other-int-field", "5", "-float-field", "3.14"})
if obj1.OtherIntField != 5 {
t.Errorf("incorrect OtherIntField: %v", obj.OtherIntField)
}
if obj1.IntField != 4 {
t.Errorf("incorrect IntField: %v", obj.IntField)
}
if obj1.FloatField != 3.14 {
t.Errorf("incorrect FloatField: %v", obj.FloatField)
}
}

0 comments on commit 0f154b6

Please sign in to comment.