diff --git a/toolbox3d/params.go b/toolbox3d/flags.go similarity index 82% rename from toolbox3d/params.go rename to toolbox3d/flags.go index 03a7808..80f337b 100644 --- a/toolbox3d/params.go +++ b/toolbox3d/flags.go @@ -2,6 +2,7 @@ package toolbox3d import ( "flag" + "fmt" "reflect" "strconv" "unicode" @@ -16,20 +17,20 @@ import ( // indicate the default integer or floating point values. // // This may panic() if the default for a field is -// incorrectly formatted. +// incorrectly formatted, or if a field is not supported. func AddFlags(obj any, f *flag.FlagSet) { if f == nil { f = flag.CommandLine } - val := reflect.ValueOf(obj) + val := reflect.ValueOf(obj).Elem() fields := reflect.VisibleFields(val.Type()) for _, field := range fields { flagName := flagNameForField(field.Name) defaultStr := field.Tag.Get("default") usageStr := field.Tag.Get("usage") var err error - switch field.Type.Kind() { - case reflect.Int: + switch field.Type { + case reflect.TypeOf(int(0)): var defaultVal int if defaultStr != "" { defaultVal, err = strconv.Atoi(defaultStr) @@ -39,7 +40,7 @@ func AddFlags(obj any, f *flag.FlagSet) { } f.IntVar(val.FieldByIndex(field.Index).Addr().Interface().(*int), flagName, defaultVal, usageStr) - case reflect.Float64: + case reflect.TypeOf(float64(0)): var defaultVal float64 if defaultStr != "" { defaultVal, err = strconv.ParseFloat(defaultStr, 64) @@ -49,6 +50,8 @@ func AddFlags(obj any, f *flag.FlagSet) { } f.Float64Var(val.FieldByIndex(field.Index).Addr().Interface().(*float64), flagName, defaultVal, usageStr) + default: + panic(fmt.Sprintf("unsupported type: %v", field.Type)) } } } @@ -61,6 +64,8 @@ func flagNameForField(field string) string { result = append(result, '-') } result = append(result, unicode.ToLower(x)) + } else { + result = append(result, x) } } return string(result) diff --git a/toolbox3d/flags_test.go b/toolbox3d/flags_test.go new file mode 100644 index 0000000..1ceafe1 --- /dev/null +++ b/toolbox3d/flags_test.go @@ -0,0 +1,45 @@ +package toolbox3d + +import ( + "flag" + "testing" +) + +type TestAddFlagsSubObj struct { +} + +type TestAddFlagsObj struct { + IntField int + OtherIntField int `default:"123"` + FloatField float64 `default:"3.14"` +} + +func TestAddFlags(t *testing.T) { + var obj TestAddFlagsObj + fs := flag.NewFlagSet("foo", flag.PanicOnError) + AddFlags(&obj, fs) + fs.Parse([]string{}) + if obj.OtherIntField != 123 { + t.Errorf("incorrect OtherIntField: %v", obj.OtherIntField) + } + if obj.IntField != 0 { + t.Errorf("incorrect IntField: %v", obj.IntField) + } + if obj.FloatField != 3.14 { + t.Errorf("incorrect FloatField: %v", obj.FloatField) + } + + var obj1 TestAddFlagsObj + fs = flag.NewFlagSet("foo", flag.PanicOnError) + AddFlags(&obj1, fs) + fs.Parse([]string{"-int-field", "4", "-other-int-field", "5", "-float-field", "3.14"}) + if obj1.OtherIntField != 5 { + t.Errorf("incorrect OtherIntField: %v", obj.OtherIntField) + } + if obj1.IntField != 4 { + t.Errorf("incorrect IntField: %v", obj.IntField) + } + if obj1.FloatField != 3.14 { + t.Errorf("incorrect FloatField: %v", obj.FloatField) + } +}