Skip to content

Commit

Permalink
Merge pull request #76 from Colin-b/bugfix/aws_memory_leak
Browse files Browse the repository at this point in the history
Ensure non regression on current behavior of AWS4Auth
  • Loading branch information
Colin-b authored Feb 10, 2024
2 parents 138435b + 02ab487 commit 03ec866
Show file tree
Hide file tree
Showing 4 changed files with 674 additions and 90 deletions.
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Fixed
- Remove deprecation warnings due to usage of `utcnow` and `utcfromtimestamp`. Thanks to [`Raphael Krupinski`](https://github.com/rafalkrupinski).
- `httpx_auth.AWS4Auth.default_include_headers` value kept growing in size every time a new `httpx_auth.AWS4Auth` instance was created with `security_token` parameter provided. Thanks to [`Miikka Koskinen`](https://github.com/miikka).
- `httpx_auth.AWS4Auth` is now processing included headers without spaces in value faster.

### Changed
- `httpx_auth.AWS4Auth.default_include_headers` is not available anymore, use `httpx_auth.AWS4Auth` `include_headers` parameter instead to change the list of included headers if the default does not fit your need ().
- `httpx_auth.AWS4Auth.default_include_headers` is not available anymore, use `httpx_auth.AWS4Auth` `include_headers` parameter instead to change the list of included headers if the default does not fit your need (refer to documentation for an exhaustive list).
- `httpx_auth.AWS4Auth` `include_headers` values will not be stripped anymore, meaning that you can now include headers prefixed and/or suffixed with blank spaces.

## [0.19.0] - 2024-01-09
Expand Down
142 changes: 59 additions & 83 deletions httpx_auth/aws.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,29 +64,27 @@ def __init__(
def auth_flow(
self, request: httpx.Request
) -> Generator[httpx.Request, httpx.Response, None]:
"""
Add x-amz-date, x-amz-content-sha256 and Authorization headers to the request.
"""
date = datetime.datetime.now(datetime.timezone.utc)
scope = f"{date.strftime('%Y%m%d')}/{self.region}/{self.service}/aws4_request"
signing_key = generate_key(
self.secret_key, self.region, self.service, date.strftime("%Y%m%d")
)

request.headers["x-amz-date"] = date.strftime("%Y%m%dT%H%M%SZ")

# encode body and generate body hash
request.headers["x-amz-content-sha256"] = hashlib.sha256(
request.read()
).hexdigest()
if self.security_token:
request.headers["x-amz-security-token"] = self.security_token

cano_headers, signed_headers = self._get_canonical_headers(request)
cano_req = self._get_canonical_request(request, cano_headers, signed_headers)
sig_string = self._get_sig_string(request, cano_req, scope)
sig_string = sig_string.encode("utf-8")
signature = hmac.new(signing_key, sig_string, hashlib.sha256).hexdigest()
canonical_headers, signed_headers = self._canonical_headers(request)
canonical_request = self._canonical_request(
request, canonical_headers, signed_headers
)
scope = f"{date.strftime('%Y%m%d')}/{self.region}/{self.service}/aws4_request"
string_to_sign = self._string_to_sign(request, canonical_request, scope)
signing_key = _signing_key(
self.secret_key, self.region, self.service, date.strftime("%Y%m%d")
)
signature = hmac.new(
signing_key, string_to_sign.encode("utf-8"), hashlib.sha256
).hexdigest()

auth_str = "AWS4-HMAC-SHA256 "
auth_str += f"Credential={self.access_id}/{scope}, "
Expand All @@ -95,41 +93,24 @@ def auth_flow(
request.headers["Authorization"] = auth_str
yield request

def _get_canonical_request(
self, req: httpx.Request, cano_headers: str, signed_headers: str
def _canonical_request(
self, req: httpx.Request, canonical_headers: str, signed_headers: str
) -> str:
return "\n".join(
[
req.method.upper(),
self._canonical_uri(req.url),
self._canonical_query_string(req.url),
canonical_headers,
signed_headers,
# Hashed payload
req.headers["x-amz-content-sha256"],
]
)

def _canonical_headers(self, req: httpx.Request) -> Tuple[str, str]:
"""
Create the AWS authentication Canonical Request string.
req -- Should already include an x-amz-content-sha256 header
cano_headers -- Canonical Headers section of Canonical Request, as
returned by get_canonical_headers()
signed_headers -- Signed Headers, as returned by
get_canonical_headers()
"""
url_str = str(req.url)
url = urlparse(url_str)
path = self._amz_cano_path(url.path)
# AWS handles "extreme" querystrings differently to urlparse
# (see post-vanilla-query-nonunreserved test in aws_testsuite)
split = url_str.split("?", 1)
qs = split[1] if len(split) == 2 else ""
qs = self._amz_cano_querystring(qs)
payload_hash = req.headers["x-amz-content-sha256"]
req_parts = [
req.method.upper(),
path,
qs,
cano_headers,
signed_headers,
payload_hash,
]
return "\n".join(req_parts)

def _get_canonical_headers(self, req: httpx.Request) -> Tuple[str, str]:
"""
Generate the Canonical Headers section of the Canonical Request.
Return the Canonical Headers and the Signed Headers strs as a tuple
(canonical_headers, signed_headers).
:return: (canonical_headers, signed_headers)
"""
included_headers = {}
for header, header_value in req.headers.items():
Expand All @@ -152,66 +133,59 @@ def _get_canonical_headers(self, req: httpx.Request) -> Tuple[str, str]:
return canonical_headers, signed_headers

@staticmethod
def _get_sig_string(req: httpx.Request, cano_req: str, scope: str) -> str:
"""
Generate the AWS4 auth string to sign for the request.
req -- This should already include an x-amz-date header.
cano_req -- The Canonical Request, as returned by
get_canonical_request()
"""
amz_date = req.headers["x-amz-date"]
hsh = hashlib.sha256(cano_req.encode())
sig_items = ["AWS4-HMAC-SHA256", amz_date, scope, hsh.hexdigest()]
return "\n".join(sig_items)
def _string_to_sign(req: httpx.Request, canonical_request: str, scope: str) -> str:
hsh = hashlib.sha256(canonical_request.encode())
return "\n".join(
["AWS4-HMAC-SHA256", req.headers["x-amz-date"], scope, hsh.hexdigest()]
)

def _amz_cano_path(self, path: str) -> str:
def _canonical_uri(self, url: httpx.URL) -> str:
"""
Generate the canonical path as per AWS4 auth requirements.
Not documented anywhere, determined from aws4_testsuite examples,
problem reports and testing against the live services.
path -- request path
"""
url_str = str(url)
url = urlparse(url_str)
path = url.path
if len(path) == 0:
path = "/"
safe_chars = "/~"
fixed_path = path
fixed_path = posixpath.normpath(fixed_path)
fixed_path = posixpath.normpath(path)
# Prevent multi /
fixed_path = re.sub("/+", "/", fixed_path)
if path.endswith("/") and not fixed_path.endswith("/"):
fixed_path += "/"
full_path = fixed_path
# S3 seems to require unquoting first.
if self.service == "s3":
full_path = unquote(full_path)
return quote(full_path, safe=safe_chars)
return quote(full_path, safe="/~")

@staticmethod
def _amz_cano_querystring(qs: str) -> str:
def _canonical_query_string(url: httpx.URL) -> str:
"""
Parse and format querystring as per AWS4 auth requirements.
Perform percent quoting as needed.
qs -- querystring
"""
safe_qs_amz_chars = "&=+"
safe_qs_unresvd = "-_.~"
url_str = str(url)
# TODO Now that we have test_aws_auth_query_reserved to ensure non regression on this, check if this is still required
split = url_str.split("?", 1)
qs = split[1] if len(split) == 2 else ""
qs = unquote(qs)
space = " "
qs = qs.split(space)[0]
qs = quote(qs, safe=safe_qs_amz_chars)
qs = qs.split(" ")[0]
qs = quote(qs, safe="&=+")

qs_items = {}
for name, vals in parse_qs(qs, keep_blank_values=True).items():
name = quote(name, safe=safe_qs_unresvd)
vals = [quote(val, safe=safe_qs_unresvd) for val in vals]
name = quote(name, safe="-_.~")
vals = [quote(val, safe="-_.~") for val in vals]
qs_items[name] = vals
qs_strings = []
for name, vals in qs_items.items():
for val in vals:
qs_strings.append("=".join([name, val]))
qs = "&".join(sorted(qs_strings))
return qs

qs_strings = sorted(
["=".join([name, val]) for name, vals in qs_items.items() for val in vals]
)
return "&".join(qs_strings)


def generate_key(secret_key: str, region: str, service: str, date: str) -> bytes:
def _signing_key(secret_key: str, region: str, service: str, date: str) -> bytes:
init_key = f"AWS4{secret_key}".encode("utf-8")
date_key = sign_sha256(init_key, date)
region_key = sign_sha256(date_key, region)
Expand All @@ -228,4 +202,6 @@ def _amz_norm_whitespace(text: str) -> str:
Replace runs of whitespace with a single space.
Ignore text enclosed in quotes.
"""
return " ".join(shlex.split(text, posix=False)).strip()
if re.search(r"\s", text):
return " ".join(shlex.split(text, posix=False)).strip()
return text
Loading

0 comments on commit 03ec866

Please sign in to comment.