Skip to content

Commit

Permalink
improve xml encode and decode
Browse files Browse the repository at this point in the history
  • Loading branch information
hgiasac committed Nov 24, 2024
1 parent 055848f commit 01c8a4d
Show file tree
Hide file tree
Showing 9 changed files with 565 additions and 41 deletions.
4 changes: 4 additions & 0 deletions .golangci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,10 @@ linters-settings:
# Default: 5
min-complexity: 10

gocritic:
disabled-checks:
- appendAssign

issues:
exclude-files:
- ".*_test\\.go$"
18 changes: 15 additions & 3 deletions connector/internal/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -231,9 +231,21 @@ func (client *HTTPClient) sendSingle(ctx context.Context, request *RetryableRequ
contentType := parseContentType(resp.Header.Get(contentTypeHeader))
if resp.StatusCode >= 400 {
details := make(map[string]any)
if contentType == rest.ContentTypeJSON && json.Valid(errorBytes) {
details["error"] = json.RawMessage(errorBytes)
} else {
switch contentType {
case rest.ContentTypeJSON:
if json.Valid(errorBytes) {
details["error"] = json.RawMessage(errorBytes)
} else {
details["error"] = string(errorBytes)
}
case rest.ContentTypeXML:
errData, err := decodeArbitraryXML(bytes.NewReader(errorBytes))
if err != nil {
details["error"] = string(errorBytes)
} else {
details["error"] = errData
}
default:
details["error"] = string(errorBytes)
}

Expand Down
18 changes: 18 additions & 0 deletions connector/internal/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -143,3 +143,21 @@ func getXMLName(xmlSchema *rest.XMLSchema, defaultName string) string {

return defaultName
}

func getArrayOrNamedType(schemaType schema.Type) (*schema.ArrayType, *schema.NamedType, error) {
rawType, err := schemaType.InterfaceT()
if err != nil {
return nil, nil, err
}

switch t := rawType.(type) {
case *schema.NullableType:
return getArrayOrNamedType(t.UnderlyingType)
case *schema.ArrayType:
return t, nil, nil
case *schema.NamedType:
return nil, t, nil
default:
return nil, nil, nil
}
}
177 changes: 139 additions & 38 deletions connector/internal/xml_decode.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ func (c *XMLDecoder) Decode(r io.Reader, resultType schema.Type) (any, error) {

if se, ok := token.(xml.StartElement); ok {
xmlTree := createXMLBlock(se)
if err := c.evalXMLTree(xmlTree); err != nil {
if err := evalXMLTree(c.decoder, xmlTree); err != nil {
return nil, fmt.Errorf("failed to decode the xml result: %w", err)
}

Expand Down Expand Up @@ -84,6 +84,19 @@ func (c *XMLDecoder) evalXMLField(block *xmlBlock, fieldName string, field rest.
}
}

func (c *XMLDecoder) getArrayItemObjectField(field rest.ObjectField, t *schema.ArrayType) rest.ObjectField {
fieldItem := rest.ObjectField{
ObjectField: schema.ObjectField{
Type: t.ElementType,
},
}

if field.HTTP != nil && field.HTTP.Items != nil {
fieldItem.HTTP = field.HTTP.Items
}

return fieldItem
}
func (c *XMLDecoder) evalArrayField(block *xmlBlock, fieldName string, field rest.ObjectField, t *schema.ArrayType, fieldPaths []string) (any, error) {
if block.Fields == nil {
return nil, nil
Expand All @@ -95,19 +108,12 @@ func (c *XMLDecoder) evalArrayField(block *xmlBlock, fieldName string, field res
var elements []xmlBlock
itemTokenName := fieldName
wrapped := len(fieldPaths) == 0
fieldItem := rest.ObjectField{
ObjectField: schema.ObjectField{
Type: t.ElementType,
},
}
fieldItem := c.getArrayItemObjectField(field, t)

if field.HTTP != nil {
wrapped = wrapped || (field.HTTP.XML != nil && field.HTTP.XML.Wrapped)
if field.HTTP.Items != nil {
fieldItem.HTTP = field.HTTP.Items
if field.HTTP.Items.XML != nil && field.HTTP.Items.XML.Name != "" {
itemTokenName = field.HTTP.Items.XML.Name
}
if field.HTTP.Items != nil && field.HTTP.Items.XML != nil && field.HTTP.Items.XML.Name != "" {
itemTokenName = field.HTTP.Items.XML.Name
}
}

Expand All @@ -123,6 +129,10 @@ func (c *XMLDecoder) evalArrayField(block *xmlBlock, fieldName string, field res
elements = elems
}

return c.evalArrayElements(elements, itemTokenName, fieldItem, fieldPaths)
}

func (c *XMLDecoder) evalArrayElements(elements []xmlBlock, itemTokenName string, fieldItem rest.ObjectField, fieldPaths []string) ([]any, error) {
if len(elements) == 0 {
return []any{}, nil
}
Expand All @@ -141,7 +151,7 @@ func (c *XMLDecoder) evalArrayField(block *xmlBlock, fieldName string, field res

func (c *XMLDecoder) evalNamedField(block *xmlBlock, t *schema.NamedType, fieldPaths []string) (any, error) {
if scalarType, ok := c.schema.ScalarTypes[t.Name]; ok {
return c.decodeSimpleScalarValue(block.Data, scalarType, fieldPaths)
return c.decodeSimpleScalarValue(block, scalarType, fieldPaths)
}

objectType, ok := c.schema.ObjectTypes[t.Name]
Expand Down Expand Up @@ -178,7 +188,7 @@ func (c *XMLDecoder) evalNamedField(block *xmlBlock, t *schema.NamedType, fieldP

_, textFieldName, isLeaf := findXMLLeafObjectField(objectType)
if isLeaf {
textValue, err := c.decodeSimpleScalarValue(block.Data, c.schema.ScalarTypes[string(rest.ScalarString)], fieldPaths)
textValue, err := c.decodeSimpleScalarValue(block, c.schema.ScalarTypes[string(rest.ScalarString)], fieldPaths)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -210,12 +220,40 @@ func (c *XMLDecoder) evalNamedField(block *xmlBlock, t *schema.NamedType, fieldP
case 0:
result[key] = []any{}
case 1:
fieldResult, err := c.evalXMLField(&fieldElems[0], xmlKey, objectField, append(fieldPaths, key))
propPaths := append(fieldPaths, key)
if objectField.HTTP.XML != nil && objectField.HTTP.XML.Wrapped {
// this can be a wrapped array
fieldResult, err := c.evalXMLField(&fieldElems[0], xmlKey, objectField, propPaths)
if err != nil {
return nil, err
}

result[key] = fieldResult

continue
}

at, nt, err := getArrayOrNamedType(objectField.Type)
if err != nil {
return nil, err
return nil, fmt.Errorf("%s: %w", strings.Join(propPaths, "."), err)
}

result[key] = fieldResult
if at != nil {
fieldItem := c.getArrayItemObjectField(objectField, at)
fieldResult, err := c.evalArrayElements(fieldElems, xmlKey, fieldItem, propPaths)
if err != nil {
return nil, err
}

result[key] = fieldResult
} else if nt != nil {
fieldResult, err := c.evalNamedField(&fieldElems[0], nt, propPaths)
if err != nil {
return nil, err
}

result[key] = fieldResult
}
default:
fieldResult, err := c.evalXMLField(&xmlBlock{
Start: fieldElems[0].Start,
Expand Down Expand Up @@ -252,7 +290,9 @@ func (c *XMLDecoder) evalAttribute(schemaType schema.Type, attr xml.Attr, fieldP
return result, nil
case *schema.NamedType:
if scalarType, ok := c.schema.ScalarTypes[t.Name]; ok {
return c.decodeSimpleScalarValue(attr.Value, scalarType, fieldPaths)
return c.decodeSimpleScalarValue(&xmlBlock{
Data: attr.Value,
}, scalarType, fieldPaths)
}

var result any
Expand All @@ -266,49 +306,45 @@ func (c *XMLDecoder) evalAttribute(schemaType schema.Type, attr xml.Attr, fieldP
}
}

func (c *XMLDecoder) decodeSimpleScalarValue(token string, scalarType schema.ScalarType, fieldPaths []string) (any, error) {
func (c *XMLDecoder) decodeSimpleScalarValue(block *xmlBlock, scalarType schema.ScalarType, fieldPaths []string) (any, error) {
respType, err := scalarType.Representation.InterfaceT()

var result any = nil
switch respType.(type) {
case *schema.TypeRepresentationString:
result = token
result = block.Data
case *schema.TypeRepresentationDate, *schema.TypeRepresentationTimestamp, *schema.TypeRepresentationTimestampTZ, *schema.TypeRepresentationUUID, *schema.TypeRepresentationEnum:
if len(token) > 0 {
result = token
if len(block.Data) > 0 {
result = block.Data
}
case *schema.TypeRepresentationBytes:
result = token
result = block.Data
case *schema.TypeRepresentationBoolean:
if len(token) == 0 {
if len(block.Data) == 0 {
break
}

result, err = strconv.ParseBool(token)
result, err = strconv.ParseBool(block.Data)
case *schema.TypeRepresentationBigDecimal, *schema.TypeRepresentationBigInteger:
if len(token) == 0 {
if len(block.Data) == 0 {
break
}

result = token
result = block.Data
case *schema.TypeRepresentationInteger, *schema.TypeRepresentationInt8, *schema.TypeRepresentationInt16, *schema.TypeRepresentationInt32, *schema.TypeRepresentationInt64: //nolint:all
if len(token) == 0 {
if len(block.Data) == 0 {
break
}

result, err = strconv.ParseInt(token, 10, 64)
result, err = strconv.ParseInt(block.Data, 10, 64)
case *schema.TypeRepresentationNumber, *schema.TypeRepresentationFloat32, *schema.TypeRepresentationFloat64: //nolint:all
if len(token) == 0 {
if len(block.Data) == 0 {
break
}

result, err = strconv.ParseFloat(token, 64)
result, err = strconv.ParseFloat(block.Data, 64)
case *schema.TypeRepresentationGeography, *schema.TypeRepresentationGeometry, *schema.TypeRepresentationJSON:
if len(token) == 0 {
break
}

result = token
result = decodeArbitraryXMLBlock(block)
}

if err != nil {
Expand All @@ -318,6 +354,44 @@ func (c *XMLDecoder) decodeSimpleScalarValue(token string, scalarType schema.Sca
return result, nil
}

func decodeArbitraryXMLBlock(block *xmlBlock) any {
if len(block.Start.Attr) == 0 && len(block.Fields) == 0 {
return block.Data
}

result := make(map[string]any)
if len(block.Start.Attr) > 0 {
attributes := make(map[string]string)
for _, attr := range block.Start.Attr {
attributes[attr.Name.Local] = attr.Value
}
result["attributes"] = attributes
}

if len(block.Fields) == 0 {
result["content"] = block.Data

return result
}

for key, field := range block.Fields {
switch len(field) {
case 0:
case 1:
// limitation: we can't know if the array is wrapped
result[key] = decodeArbitraryXMLBlock(&field[0])
default:
items := make([]any, len(field))
for i, f := range field {
items[i] = decodeArbitraryXMLBlock(&f)
}
result[key] = items
}
}

return result
}

type xmlBlock struct {
Start xml.StartElement
Data string
Expand All @@ -331,10 +405,10 @@ func createXMLBlock(start xml.StartElement) *xmlBlock {
}
}

func (c *XMLDecoder) evalXMLTree(block *xmlBlock) error {
func evalXMLTree(decoder *xml.Decoder, block *xmlBlock) error {
L:
for {
nextToken, err := c.decoder.Token()
nextToken, err := decoder.Token()
if err != nil {
return err
}
Expand All @@ -346,7 +420,7 @@ L:
switch tok := nextToken.(type) {
case xml.StartElement:
childBlock := createXMLBlock(tok)
if err := c.evalXMLTree(childBlock); err != nil {
if err := evalXMLTree(decoder, childBlock); err != nil {
return err
}
block.Fields[tok.Name.Local] = append(block.Fields[tok.Name.Local], *childBlock)
Expand All @@ -359,3 +433,30 @@ L:

return nil
}

func decodeArbitraryXML(r io.Reader) (any, error) {
decoder := xml.NewDecoder(r)

for {
token, err := decoder.Token()
if err != nil {
return nil, err
}
if token == nil {
break
}

if se, ok := token.(xml.StartElement); ok {
xmlTree := createXMLBlock(se)
if err := evalXMLTree(decoder, xmlTree); err != nil {
return nil, fmt.Errorf("failed to decode the xml result: %w", err)
}

result := decodeArbitraryXMLBlock(xmlTree)

return result, nil
}
}

return nil, nil
}
Loading

0 comments on commit 01c8a4d

Please sign in to comment.