Skip to content

Commit

Permalink
fix: avoid modifying mutable request endpoints (#14)
Browse files Browse the repository at this point in the history
  • Loading branch information
hgiasac authored Jun 10, 2024
1 parent e8eac28 commit 6c9a562
Show file tree
Hide file tree
Showing 5 changed files with 29 additions and 23 deletions.
32 changes: 21 additions & 11 deletions rest/connector_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -167,19 +167,26 @@ func TestRESTConnector_authentication(t *testing.T) {
}
}
},
"arguments": {},
"arguments": {
"status": {
"type": "literal",
"value": "available"
}
},
"collection_relationships": {}
}`)

res, err := http.Post(fmt.Sprintf("%s/query", testServer.URL), "application/json", bytes.NewBuffer(reqBody))
assert.NilError(t, err)
assertHTTPResponse(t, res, http.StatusOK, schema.QueryResponse{
{
Rows: []map[string]any{
{"__value": map[string]any{}},
for i := 0; i < 2; i++ {
res, err := http.Post(fmt.Sprintf("%s/query", testServer.URL), "application/json", bytes.NewBuffer(reqBody))
assert.NilError(t, err)
assertHTTPResponse(t, res, http.StatusOK, schema.QueryResponse{
{
Rows: []map[string]any{
{"__value": map[string]any{}},
},
},
},
})
})
}
})

t.Run("retry", func(t *testing.T) {
Expand Down Expand Up @@ -513,8 +520,11 @@ func createMockServer(t *testing.T, apiKey string, bearerToken string) *httptest
switch r.Method {
case http.MethodGet:
if r.Header.Get("Authorization") != fmt.Sprintf("Bearer %s", bearerToken) {
t.Errorf("invalid bearer token, expected %s, got %s", bearerToken, r.Header.Get("Authorization"))
t.FailNow()
t.Fatalf("invalid bearer token, expected %s, got %s", bearerToken, r.Header.Get("Authorization"))
return
}
if r.URL.Query().Encode() != "status=available" {
t.Fatalf("expected query param: status=available, got: %s", r.URL.Query().Encode())
return
}
writeResponse(w)
Expand Down
10 changes: 4 additions & 6 deletions rest/internal/request.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,13 +81,13 @@ func getHostFromServers(servers []rest.ServerConfig, serverIDs []string) (string
}

func buildDistributedRequestsWithOptions(request *RetryableRequest, restOptions *RESTOptions) ([]RetryableRequest, error) {
if strings.HasPrefix(request.RawRequest.URL, "http") {
if strings.HasPrefix(request.URL, "http") {
return []RetryableRequest{*request}, nil
}

if !restOptions.Distributed || len(restOptions.Settings.Servers) == 1 {
host, serverID := getHostFromServers(restOptions.Settings.Servers, restOptions.Servers)
request.URL = fmt.Sprintf("%s%s", host, request.RawRequest.URL)
request.URL = fmt.Sprintf("%s%s", host, request.URL)
request.ServerID = serverID
if err := request.applySettings(restOptions.Settings); err != nil {
return nil, err
Expand Down Expand Up @@ -118,7 +118,7 @@ func buildDistributedRequestsWithOptions(request *RetryableRequest, restOptions
}

req := RetryableRequest{
URL: fmt.Sprintf("%s%s", host, request.RawRequest.URL),
URL: fmt.Sprintf("%s%s", host, request.URL),
ServerID: serverID,
RawRequest: request.RawRequest,
ContentType: request.ContentType,
Expand Down Expand Up @@ -217,7 +217,7 @@ func (req *RetryableRequest) applySecurity(serverConfig *rest.ServerConfig) erro
case rest.APIKeyInQuery:
value := securityScheme.Value.Value()
if value != nil {
endpoint, err := url.Parse(req.RawRequest.URL)
endpoint, err := url.Parse(req.URL)
if err != nil {
return err
}
Expand All @@ -226,8 +226,6 @@ func (req *RetryableRequest) applySecurity(serverConfig *rest.ServerConfig) erro
q.Add(securityScheme.Name, *securityScheme.Value.Value())
endpoint.RawQuery = q.Encode()
req.URL = endpoint.String()
} else {
req.URL = req.RawRequest.URL
}
case rest.APIKeyInCookie:
if securityScheme.Value != nil {
Expand Down
3 changes: 1 addition & 2 deletions rest/mutation.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,8 @@ func (c *RESTConnector) execProcedure(ctx context.Context, operation *schema.Mut

// 2. create and execute request
// 3. evaluate response selection
procedure.Request.URL = endpoint
restOptions.Settings = settings
httpRequest, err := c.createRequest(procedure.Request, headers, rawArgs)
httpRequest, err := c.createRequest(procedure.Request, endpoint, headers, rawArgs)
if err != nil {
return nil, err
}
Expand Down
3 changes: 1 addition & 2 deletions rest/query.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,8 @@ func (c *RESTConnector) execQuery(ctx context.Context, request *schema.QueryRequ

// 2. create and execute request
// 3. evaluate response selection
function.Request.URL = endpoint
restOptions.Settings = settings
httpRequest, err := c.createRequest(function.Request, headers, nil)
httpRequest, err := c.createRequest(function.Request, endpoint, headers, nil)
if err != nil {
return nil, err
}
Expand Down
4 changes: 2 additions & 2 deletions rest/request.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ import (
"github.com/hasura/ndc-sdk-go/utils"
)

func (c *RESTConnector) createRequest(rawRequest *rest.Request, headers http.Header, arguments map[string]any) (*internal.RetryableRequest, error) {
func (c *RESTConnector) createRequest(rawRequest *rest.Request, endpoint string, headers http.Header, arguments map[string]any) (*internal.RetryableRequest, error) {
var buffer io.ReadSeeker
contentType := contentTypeJSON
bodyData, ok := arguments["body"]
Expand Down Expand Up @@ -66,7 +66,7 @@ func (c *RESTConnector) createRequest(rawRequest *rest.Request, headers http.Hea
}

request := &internal.RetryableRequest{
URL: rawRequest.URL,
URL: endpoint,
RawRequest: rawRequest,
ContentType: contentType,
Headers: headers,
Expand Down

0 comments on commit 6c9a562

Please sign in to comment.