diff --git a/.gitignore b/.gitignore index e8cd0ba7..a879595c 100644 --- a/.gitignore +++ b/.gitignore @@ -7,3 +7,4 @@ mem.out docs/site docs/__pycache__ docs/.cache +.idea/ diff --git a/go.mod b/go.mod index 800eaf29..6011babf 100644 --- a/go.mod +++ b/go.mod @@ -24,6 +24,7 @@ require ( github.com/uptrace/bunrouter v1.0.21 golang.org/x/exp v0.0.0-20230515195305-f3d0a9c9a5cc golang.org/x/net v0.17.0 + golang.org/x/text v0.14.0 google.golang.org/protobuf v1.30.0 ) @@ -83,7 +84,6 @@ require ( golang.org/x/arch v0.3.0 // indirect golang.org/x/crypto v0.17.0 // indirect golang.org/x/sys v0.15.0 // indirect - golang.org/x/text v0.14.0 // indirect golang.org/x/xerrors v0.0.0-20220907171357-04be3eba64a2 // indirect gopkg.in/ini.v1 v1.67.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect diff --git a/registry.go b/registry.go index 813a3de6..dcf62518 100644 --- a/registry.go +++ b/registry.go @@ -4,13 +4,11 @@ import ( "encoding/json" "fmt" "reflect" - "regexp" "strings" -) -// reGenericName helps to convert `MyType[path/to.SubType]` to `MyTypeSubType` -// when using the default schema namer. -var reGenericName = regexp.MustCompile(`\[[^\]]+\]`) + "golang.org/x/text/cases" + "golang.org/x/text/language" +) // Registry creates and stores schemas and their references, and supports // marshalling to JSON/YAML for use as an OpenAPI #/components/schemas object. @@ -34,15 +32,23 @@ type Registry interface { func DefaultSchemaNamer(t reflect.Type, hint string) string { name := deref(t).Name() - // Fix up generics, if used, for nicer refs & URLs. - name = reGenericName.ReplaceAllStringFunc(name, func(s string) string { - // Convert `MyType[path/to.SubType]` to `MyType[SubType]`. - parts := strings.Split(s, ".") - return parts[len(parts)-1] - }) - // Remove square brackets. - name = strings.ReplaceAll(name, "[", "") - name = strings.ReplaceAll(name, "]", "") + // Better support for lists, so e.g. `[]int` becomes `ListInt`. + name = strings.ReplaceAll(name, "[]", "List[") + + result := "" + for _, part := range strings.FieldsFunc(name, func(r rune) bool { + // Split on special characters. Note that `,` is used when there are + // multiple inputs to a generic type. + return r == '[' || r == ']' || r == '*' || r == ',' + }) { + // Split fully qualified names like `github.com/foo/bar.Baz` into `Baz`. + fqn := strings.Split(part, ".") + base := fqn[len(fqn)-1] + + // Add to result, and uppercase for better scalar support (`int` -> `Int`). + result += cases.Title(language.Und, cases.NoLower).String(base) + } + name = result if name == "" { name = hint diff --git a/registry_test.go b/registry_test.go new file mode 100644 index 00000000..38109af0 --- /dev/null +++ b/registry_test.go @@ -0,0 +1,55 @@ +package huma + +import ( + "reflect" + "testing" + "time" + + "github.com/danielgtaylor/huma/v2/examples/protodemo/protodemo" + "github.com/stretchr/testify/assert" +) + +type Output[T any] struct{} + +type Embedded[P any] struct{} + +type EmbeddedTwo[P, V any] struct{} + +type S struct{} + +type ü struct{} + +type MP4 struct{} + +func TestDefaultSchemaNamer(t *testing.T) { + type Renamed Output[*[]Embedded[protodemo.User]] + + for _, example := range []struct { + typ any + name string + }{ + {int(0), "Int"}, + {int64(0), "Int64"}, + {S{}, "S"}, + {time.Time{}, "Time"}, + {Output[int]{}, "OutputInt"}, + {Output[*int]{}, "OutputInt"}, + {Output[[]int]{}, "OutputListInt"}, + {Output[[]*int]{}, "OutputListInt"}, + {Output[[][]int]{}, "OutputListListInt"}, + {Output[map[string]int]{}, "OutputMapStringInt"}, + {Output[map[string][]*int]{}, "OutputMapStringListInt"}, + {Output[S]{}, "OutputS"}, + {Output[ü]{}, "OutputÜ"}, + {Output[MP4]{}, "OutputMP4"}, + {Output[Embedded[*protodemo.User]]{}, "OutputEmbeddedUser"}, + {Output[*[]Embedded[protodemo.User]]{}, "OutputListEmbeddedUser"}, + {Output[EmbeddedTwo[[]protodemo.User, **time.Time]]{}, "OutputEmbeddedTwoListUserTime"}, + {Renamed{}, "Renamed"}, + } { + t.Run(example.name, func(t *testing.T) { + name := DefaultSchemaNamer(reflect.TypeOf(example.typ), "hint") + assert.Equal(t, example.name, name) + }) + } +} diff --git a/schema_test.go b/schema_test.go index 2e6b1128..c5baaa1d 100644 --- a/schema_test.go +++ b/schema_test.go @@ -501,7 +501,7 @@ func TestSchemaGenericNaming(t *testing.T) { b, _ := json.Marshal(s) assert.JSONEq(t, `{ - "$ref": "#/components/schemas/SchemaGenericint" + "$ref": "#/components/schemas/SchemaGenericInt" }`, string(b)) }