Skip to content

Commit

Permalink
feat: add POST support in SAML credential plugin
Browse files Browse the repository at this point in the history
* Add http method config for AWS SAML plugin
* Add http_method to config document of  SAMLCrossAccount
  • Loading branch information
RobertShan2000 authored Aug 24, 2023
1 parent 27a9e2b commit 802dc3c
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 8 deletions.
14 changes: 7 additions & 7 deletions src/awsrun/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,7 @@ def __init__(self, type_):
self.type = type_

def type_check(self, obj):
return type(obj) == self.type
return type(obj) == self.type # noqa: E721

def __str__(self):
return self.type.__name__
Expand All @@ -354,7 +354,7 @@ def __init__(self, pattern):
self.pattern = pattern

def type_check(self, obj):
if type(obj) != str:
if type(obj) != str: # noqa: E721
return False
return bool(re.search(self.pattern, obj))

Expand All @@ -366,7 +366,7 @@ class IpAddress(Type):
"""Represents a string matching an IP address (v4 or v6)."""

def type_check(self, obj):
if type(obj) != str:
if type(obj) != str: # noqa: E721
return False
try:
ipaddress.ip_address(obj)
Expand All @@ -382,7 +382,7 @@ class IpNetwork(Type):
"""Represents a string matching an IP network (v4 or v6)."""

def type_check(self, obj):
if type(obj) != str:
if type(obj) != str: # noqa: E721
return False
try:
ipaddress.ip_network(obj)
Expand All @@ -398,7 +398,7 @@ class FileType(Type):
"""Represents a string pointing to an existing file."""

def type_check(self, obj):
if type(obj) != str:
if type(obj) != str: # noqa: E721
return False
return Path(obj).exists()

Expand Down Expand Up @@ -462,7 +462,7 @@ def __init__(self, element_type):
self.element_type = element_type

def type_check(self, obj):
if type(obj) != list:
if type(obj) != list: # noqa: E721
return False
return all(self.element_type.type_check(e) for e in obj)

Expand All @@ -485,7 +485,7 @@ def __init__(self, key_type, value_type):
self.value_type = value_type

def type_check(self, obj):
if type(obj) != dict:
if type(obj) != dict: # noqa: E721
return False
return all(self.key_type.type_check(k) for k in obj.keys()) and all(
self.value_type.type_check(v) for v in obj.values()
Expand Down
13 changes: 13 additions & 0 deletions src/awsrun/plugins/creds/aws.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ class SAML(Plugin):
role: STRING*
url: STRING*
auth_type: ("basic" | "digest" | "ntlm")
http_method: ("GET"| "POST")
http_headers:
STRING: STRING
no_verify: BOOLEAN
Expand Down Expand Up @@ -175,6 +176,11 @@ class SAML(Plugin):
specified, it must be one of `basic`, `digest`, or `ntlm`. The default value
is `basic`. If using NTLM, username should be specified as `domain\\username`.
`http_method`
: The HTTP method to use when authenticating with the IdP. If
specified, it must be one of `GET`, `POST`. The default value
is `GET`.
`http_headers`
: Additional HTTP headers to send in the request to the IdP. If specified,
it must be a dictionary of `key: value` pairs, where keys and values are
Expand Down Expand Up @@ -275,6 +281,7 @@ def instantiate(self, args):
role=args.saml_role,
url=cfg("url", type=URL, must_exist=True),
auth=auth(args.saml_username, args.saml_password),
http_method=cfg("http_method", type=Choice("GET", "POST"), default="GET"),
headers=cfg("http_headers", type=Dict(Str, Str), default={}),
duration=args.saml_duration,
saml_duration=args.saml_assertion_duration,
Expand Down Expand Up @@ -458,6 +465,7 @@ class SAMLCrossAccount(AbstractCrossAccount):
role: STRING*
url: STRING*
auth_type: ("basic" | "digest" | "ntlm")
http_method: ("GET"| "POST")
http_headers:
STRING: STRING
no_verify: BOOLEAN
Expand Down Expand Up @@ -503,6 +511,11 @@ class SAMLCrossAccount(AbstractCrossAccount):
specified, it must be one of `basic`, `digest`, or `ntlm`. The default value
is `basic`. If using NTLM, username should be specified as `domain\\username`.
`http_method`
: The HTTP method to use when authenticating with the IdP. If
specified, it must be one of `GET`, `POST`. The default value
is `GET`.
`http_headers`
: Additional HTTP headers to send in the request to the IdP. If specified,
it must be a dictionary of `key: value` pairs, where keys and values are
Expand Down
12 changes: 11 additions & 1 deletion src/awsrun/session/aws.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,6 +384,7 @@ def __init__(
role,
url,
auth,
http_method,
headers=None,
duration=3600,
saml_duration=300,
Expand All @@ -392,6 +393,7 @@ def __init__(
super().__init__(role, duration)
self._url = url
self._auth = auth
self._http_method = http_method
self._headers = {} if headers is None else headers
self._cached_saml = ExpiringValue(self._request_assertion, saml_duration)
self._no_verify = no_verify
Expand All @@ -414,7 +416,15 @@ def _request_assertion(self):
with requests.Session() as s:
s.auth = self._auth
s.headers.update(self._headers)
resp = s.get(self._url, verify=not self._no_verify)
if self._http_method == "GET":
resp = s.get(self._url, verify=not self._no_verify)
else:
authData = {
"UserName": s.auth.username,
"Password": s.auth.password,
"AuthMethod": "FormsAuthentication",
}
resp = s.post(self._url, data=authData, verify=not self._no_verify)

if resp.status_code == 401:
raise IDPAccessDeniedException("Could not authenticate")
Expand Down

0 comments on commit 802dc3c

Please sign in to comment.