Skip to content

Commit

Permalink
Merge pull request #705 from b-kamphorst/refactor-huma-register
Browse files Browse the repository at this point in the history
refactor: split up huma.Register
  • Loading branch information
danielgtaylor authored Jan 16, 2025
2 parents fe78329 + 9d5259b commit 1c3924e
Show file tree
Hide file tree
Showing 3 changed files with 929 additions and 772 deletions.
133 changes: 66 additions & 67 deletions formdata.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,30 +85,51 @@ func (v MimeTypeValidator) Validate(fh *multipart.FileHeader, location string) (
}
}

func (m *MultipartFormFiles[T]) readFile(
fh *multipart.FileHeader,
location string,
validator MimeTypeValidator,
) (FormFile, *ErrorDetail) {
f, err := fh.Open()
if err != nil {
return FormFile{}, &ErrorDetail{Message: "Failed to open file", Location: location}
}
contentType, validationErr := validator.Validate(fh, location)
if validationErr != nil {
return FormFile{}, validationErr
func (m *MultipartFormFiles[T]) Data() *T {
return m.data
}

// Decodes multipart.Form data into *T, returning []*ErrorDetail if any
// Schema is used to check for validation constraints
func (m *MultipartFormFiles[T]) Decode(opMediaType *MediaType) []error {
var (
dataType = reflect.TypeOf(m.data).Elem()
value = reflect.New(dataType)
errors []error
)
for i := 0; i < dataType.NumField(); i++ {
field := value.Elem().Field(i)
structField := dataType.Field(i)
key := structField.Tag.Get("form")
if key == "" {
key = structField.Name
}
fileHeaders := m.Form.File[key]
switch {
case field.Type() == reflect.TypeOf(FormFile{}):
file, err := readSingleFile(fileHeaders, key, opMediaType)
if err != nil {
errors = append(errors, err)
continue
}
field.Set(reflect.ValueOf(file))
case field.Type() == reflect.TypeOf([]FormFile{}):
files, errs := readMultipleFiles(fileHeaders, key, opMediaType)
if errs != nil {
errors = append(errors, errs...)
continue
}
field.Set(reflect.ValueOf(files))

default:
continue
}
}
return FormFile{
File: f,
ContentType: contentType,
IsSet: true,
Size: fh.Size,
Filename: fh.Filename,
}, nil
m.data = value.Interface().(*T)
return errors
}

func (m *MultipartFormFiles[T]) readSingleFile(key string, opMediaType *MediaType) (FormFile, *ErrorDetail) {
fileHeaders := m.Form.File[key]
func readSingleFile(fileHeaders []*multipart.FileHeader, key string, opMediaType *MediaType) (FormFile, *ErrorDetail) {
if len(fileHeaders) == 0 {
if opMediaType.Schema.requiredMap[key] {
return FormFile{}, &ErrorDetail{Message: "File required", Location: key}
Expand All @@ -117,16 +138,15 @@ func (m *MultipartFormFiles[T]) readSingleFile(key string, opMediaType *MediaTyp
}
} else if len(fileHeaders) == 1 {
validator := NewMimeTypeValidator(opMediaType.Encoding[key])
return m.readFile(fileHeaders[0], key, validator)
return readFile(fileHeaders[0], key, validator)
}
return FormFile{}, &ErrorDetail{
Message: "Multiple files received but only one was expected",
Location: key,
}
}

func (m *MultipartFormFiles[T]) readMultipleFiles(key string, opMediaType *MediaType) ([]FormFile, []error) {
fileHeaders := m.Form.File[key]
func readMultipleFiles(fileHeaders []*multipart.FileHeader, key string, opMediaType *MediaType) ([]FormFile, []error) {
var (
files = make([]FormFile, len(fileHeaders))
errors []error
Expand All @@ -136,7 +156,7 @@ func (m *MultipartFormFiles[T]) readMultipleFiles(key string, opMediaType *Media
}
validator := NewMimeTypeValidator(opMediaType.Encoding[key])
for i, fh := range fileHeaders {
file, err := m.readFile(
file, err := readFile(
fh,
fmt.Sprintf("%s[%d]", key, i),
validator,
Expand All @@ -150,47 +170,26 @@ func (m *MultipartFormFiles[T]) readMultipleFiles(key string, opMediaType *Media
return files, errors
}

func (m *MultipartFormFiles[T]) Data() *T {
return m.data
}

// Decodes multipart.Form data into *T, returning []*ErrorDetail if any
// Schema is used to check for validation constraints
func (m *MultipartFormFiles[T]) Decode(opMediaType *MediaType) []error {
var (
dataType = reflect.TypeOf(m.data).Elem()
value = reflect.New(dataType)
errors []error
)
for i := 0; i < dataType.NumField(); i++ {
field := value.Elem().Field(i)
structField := dataType.Field(i)
key := structField.Tag.Get("form")
if key == "" {
key = structField.Name
}
switch {
case field.Type() == reflect.TypeOf(FormFile{}):
file, err := m.readSingleFile(key, opMediaType)
if err != nil {
errors = append(errors, err)
continue
}
field.Set(reflect.ValueOf(file))
case field.Type() == reflect.TypeOf([]FormFile{}):
files, errs := m.readMultipleFiles(key, opMediaType)
if errs != nil {
errors = append(errors, errs...)
continue
}
field.Set(reflect.ValueOf(files))

default:
continue
}
func readFile(
fh *multipart.FileHeader,
location string,
validator MimeTypeValidator,
) (FormFile, *ErrorDetail) {
f, err := fh.Open()
if err != nil {
return FormFile{}, &ErrorDetail{Message: "Failed to open file", Location: location}
}
m.data = value.Interface().(*T)
return errors
contentType, validationErr := validator.Validate(fh, location)
if validationErr != nil {
return FormFile{}, validationErr
}
return FormFile{
File: f,
ContentType: contentType,
IsSet: true,
Size: fh.Size,
Filename: fh.Filename,
}, nil
}

func formDataFieldName(f reflect.StructField) string {
Expand All @@ -208,7 +207,7 @@ func multiPartFormFileSchema(t reflect.Type) *Schema {
Properties: make(map[string]*Schema, nFields),
requiredMap: make(map[string]bool, nFields),
}
requiredFields := make([]string, nFields)
requiredFields := make([]string, 0, nFields)
for i := 0; i < nFields; i++ {
f := t.Field(i)
name := formDataFieldName(f)
Expand All @@ -227,7 +226,7 @@ func multiPartFormFileSchema(t reflect.Type) *Schema {
}

if _, ok := f.Tag.Lookup("required"); ok && boolTag(f, "required", false) {
requiredFields[i] = name
requiredFields = append(requiredFields, name)
schema.requiredMap[name] = true
}
}
Expand Down
Loading

0 comments on commit 1c3924e

Please sign in to comment.