Skip to content

Commit

Permalink
Add rest-xml handling and better Path matching
Browse files Browse the repository at this point in the history
  • Loading branch information
iann0036 committed Mar 18, 2021
1 parent b154fd2 commit 908cd98
Showing 1 changed file with 167 additions and 5 deletions.
172 changes: 167 additions & 5 deletions proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"embed"
"encoding/json"
"encoding/pem"
"encoding/xml"
"fmt"
"io"
"io/ioutil"
Expand Down Expand Up @@ -271,13 +272,18 @@ func handleAWSRequest(req *http.Request, body []byte, respCode int) {
host := req.Host
uri := req.RequestURI

var endpointUriPrefix string

var serviceDef ServiceDefinition
hostSplit := strings.Split(host, ".")
if hostSplit[len(hostSplit)-1] == "com" && hostSplit[len(hostSplit)-2] == "amazonaws" {
endpointPrefix := hostSplit[len(hostSplit)-3]
if len(hostSplit) > 3 {
endpointPrefix = hostSplit[len(hostSplit)-4]
}
if len(hostSplit) > 4 {
endpointUriPrefix = strings.Join(hostSplit[:len(hostSplit)-4], ".")
}
for _, serviceDefinition := range serviceDefinitions {
if serviceDefinition.Metadata.EndpointPrefix == endpointPrefix { // TODO: Ensure latest version
serviceDef = serviceDefinition
Expand All @@ -300,14 +306,61 @@ func handleAWSRequest(req *http.Request, body []byte, respCode int) {
vals := urlobj.Query()

// path part
longestPath := ""

OperationLoop:
for operationName, operation := range serviceDef.Operations {
templateMatches := regexp.MustCompile(`{([^/]+?)}`).FindAllStringSubmatch(operation.Http.RequestURI, -1)
regexStr := fmt.Sprintf("^%s$", regexp.MustCompile(`{([^/]+?)}`).ReplaceAllString(operation.Http.RequestURI, "([^/]+)"))
pathMatchSuccess := regexp.MustCompile(regexStr).Match([]byte(urlobj.Path))
path := urlobj.Path
if operation.Http.RequestURI == "" || operation.Http.RequestURI[0] != '/' {
operation.Http.RequestURI = "/" + operation.Http.RequestURI
}

if strings.Contains(operation.Http.RequestURI, "?") {
path += "?"

operationurlobj, err := url.ParseRequestURI(operation.Http.RequestURI)
if err != nil {
continue
}

operationquery := operationurlobj.Query()
for operationquerykey, operationqueryvalue := range operationquery {
if _, ok := vals[operationquerykey]; ok {
if operationqueryvalue[0] == "" {
path += operationquerykey + "&"
} else if len(vals[operationquerykey]) > 0 {
path += operationquerykey + "=" + vals[operationquerykey][0] + "&"
} else {
continue OperationLoop
}
} else {
continue OperationLoop
}
}

if path[len(path)-1] == '&' {
path = path[:len(path)-1]
}
}

templateMatches := regexp.MustCompile(`{([^}]+?)\+?}`).FindAllStringSubmatch(operation.Http.RequestURI, -1)
regexStr := regexp.MustCompile(`\\{([^}]+?\\\+)\\}`).ReplaceAllString(regexp.QuoteMeta(operation.Http.RequestURI), `([^?]+)`) // {Key+}
regexStr = fmt.Sprintf("^%s$", regexp.MustCompile(`\\{(.+?)\\}`).ReplaceAllString(regexStr, `([^/?]+?)`)) // {Bucket}
pathMatchSuccess := regexp.MustCompile(regexStr).Match([]byte(path))

if operation.Http.Method == "" {
operation.Http.Method = "POST"
}

if operation.Http.Method == req.Method && pathMatchSuccess {
if len(path) > len(longestPath) {
longestPath = path
} else {
continue
}

action = operationName
pathMatches := regexp.MustCompile(regexStr).FindAllStringSubmatch(urlobj.Path, -1)
pathMatches := regexp.MustCompile(regexStr).FindAllStringSubmatch(path, -1)

if len(pathMatches) > 0 && len(pathMatches) > 0 && len(templateMatches) == len(pathMatches[0])-1 {
for i := 0; i < len(templateMatches); i++ {
Expand Down Expand Up @@ -391,13 +444,122 @@ func handleAWSRequest(req *http.Request, body []byte, respCode int) {
}
}
}
} else if serviceDef.Metadata.Protocol == "rest-xml" {
// URL param schema
urlobj, err := url.ParseRequestURI(uri)
if err != nil {
return
}
vals := urlobj.Query()

// path part
longestPath := ""

OperationLoop2:
for operationName, operation := range serviceDef.Operations {
path := urlobj.Path
if serviceDef.Metadata.EndpointPrefix == "s3" && strings.HasPrefix(operation.Http.RequestURI, "/{Bucket}") && endpointUriPrefix != "" { // https://docs.aws.amazon.com/AmazonS3/latest/userguide/VirtualHosting.html#VirtualHostingSpecifyBucket
if len(urlobj.Path) > 1 {
path = "/" + endpointUriPrefix + "/" + urlobj.Path[1:]
} else {
path = "/" + endpointUriPrefix
}
}
if operation.Http.RequestURI == "" || operation.Http.RequestURI[0] != '/' {
operation.Http.RequestURI = "/" + operation.Http.RequestURI
}

if strings.Contains(operation.Http.RequestURI, "?") {
path += "?"

operationurlobj, err := url.ParseRequestURI(operation.Http.RequestURI)
if err != nil {
continue
}

operationquery := operationurlobj.Query()
for operationquerykey, operationqueryvalue := range operationquery {
if _, ok := vals[operationquerykey]; ok {
if operationqueryvalue[0] == "" {
path += operationquerykey + "&"
} else if len(vals[operationquerykey]) > 0 {
path += operationquerykey + "=" + vals[operationquerykey][0] + "&"
} else {
continue OperationLoop2
}
} else {
continue OperationLoop2
}
}

if path[len(path)-1] == '&' {
path = path[:len(path)-1]
}
}

templateMatches := regexp.MustCompile(`{([^}]+?)\+?}`).FindAllStringSubmatch(operation.Http.RequestURI, -1)
regexStr := regexp.MustCompile(`\\{([^}]+?\\\+)\\}`).ReplaceAllString(regexp.QuoteMeta(operation.Http.RequestURI), `([^?]+)`) // {Key+}
regexStr = fmt.Sprintf("^%s$", regexp.MustCompile(`\\{(.+?)\\}`).ReplaceAllString(regexStr, `([^/?]+?)`)) // {Bucket}
pathMatchSuccess := regexp.MustCompile(regexStr).Match([]byte(path))

if operation.Http.Method == "" {
operation.Http.Method = "POST"
}

if operation.Http.Method == req.Method && pathMatchSuccess {
if len(path) > len(longestPath) {
longestPath = path
} else {
continue
}

action = operationName
pathMatches := regexp.MustCompile(regexStr).FindAllStringSubmatch(path, -1)

if len(pathMatches) > 0 && len(pathMatches) > 0 && len(templateMatches) == len(pathMatches[0])-1 {
for i := 0; i < len(templateMatches); i++ {
uriparams[templateMatches[i][1]] = pathMatches[0][1:][i]
}
}
}
}

// query part
for k, v := range vals {
normalizedK := regexp.MustCompile(`\.member\.[0-9]+`).ReplaceAllString(k, "[]")
normalizedK = regexp.MustCompile(`\.[0-9]+`).ReplaceAllString(normalizedK, "[]")

resolvedPropertyName := resolvePropertyName(serviceDef.Operations[action].Input, normalizedK, "", "", serviceDef.Shapes)
if resolvedPropertyName != "" {
normalizedK = resolvedPropertyName
}

if len(params[normalizedK]) > 0 {
params[normalizedK] = append(params[normalizedK], v...)
} else {
params[normalizedK] = v
}
}

// body part
if len(body) > 0 {
var bodyXML interface{}
err := xml.Unmarshal(body, &bodyXML)
if err != nil {
return
}

flatten(true, params, bodyXML, "")
}
}

region := "us-east-1"
re, _ := regexp.Compile(`\.(.+)\.amazonaws\.com(?:\.cn)?$`)
matches := re.FindStringSubmatch(host)
if len(matches) == 2 {
region = matches[1]
if matches[1] != "s3" { // https://docs.aws.amazon.com/AmazonS3/latest/userguide/VirtualHosting.html#VirtualHostingBackwardsCompatibility
region = matches[1]
}
}

callLog = append(callLog, Entry{
Expand Down

0 comments on commit 908cd98

Please sign in to comment.