Skip to content

Commit

Permalink
Improve typing for IAM (#7091)
Browse files Browse the repository at this point in the history
  • Loading branch information
tungol authored Dec 5, 2023
1 parent 16b9f31 commit ff5256d
Show file tree
Hide file tree
Showing 12 changed files with 133 additions and 108 deletions.
2 changes: 1 addition & 1 deletion moto/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@ def __exit__(self, *exc: Any) -> None:

try:
# Need to monkey-patch botocore requests back to underlying urllib3 classes
from botocore.awsrequest import (
from botocore.awsrequest import ( # type: ignore[attr-defined]
HTTPConnection,
HTTPConnectionPool,
HTTPSConnectionPool,
Expand Down
2 changes: 1 addition & 1 deletion moto/config/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@
CONFIG_RULE_PAGE_SIZE = 25

# Map the Config resource type to a backend:
RESOURCE_MAP: Dict[str, ConfigQueryModel] = {
RESOURCE_MAP: Dict[str, ConfigQueryModel[Any]] = {
"AWS::S3::Bucket": s3_config_query,
"AWS::S3::AccountPublicAccessBlock": s3_account_public_access_block_query,
"AWS::IAM::Role": role_config_query,
Expand Down
8 changes: 5 additions & 3 deletions moto/core/botocore_stubber.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from collections import defaultdict
from io import BytesIO
from typing import Any, Callable, Dict, List, Pattern, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Pattern, Tuple, Union

from botocore.awsrequest import AWSResponse

Expand Down Expand Up @@ -38,7 +38,9 @@ def register_response(
matchers = self.methods[method]
matchers.append((pattern, response))

def __call__(self, event_name: str, request: Any, **kwargs: Any) -> AWSResponse:
def __call__(
self, event_name: str, request: Any, **kwargs: Any
) -> Optional[AWSResponse]:
if not self.enabled:
return None

Expand Down Expand Up @@ -70,6 +72,6 @@ def __call__(self, event_name: str, request: Any, **kwargs: Any) -> AWSResponse:
headers = e.get_headers() # type: ignore[assignment]
body = e.get_body()
raw_response = MockRawResponse(body)
response = AWSResponse(request.url, status, headers, raw_response)
response = AWSResponse(request.url, status, headers, raw_response) # type: ignore[arg-type]

return response
8 changes: 4 additions & 4 deletions moto/core/common_models.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from abc import abstractmethod
from typing import Any, Dict, List, Optional, Tuple
from typing import Any, Dict, Generic, List, Optional, Tuple

from .base_backend import InstanceTrackerMeta
from .base_backend import SERVICE_BACKEND, BackendDict, InstanceTrackerMeta


class BaseModel(metaclass=InstanceTrackerMeta):
Expand Down Expand Up @@ -94,8 +94,8 @@ def is_created(self) -> bool:
return True


class ConfigQueryModel:
def __init__(self, backends: Any):
class ConfigQueryModel(Generic[SERVICE_BACKEND]):
def __init__(self, backends: BackendDict[SERVICE_BACKEND]):
"""Inits based on the resource type's backends (1 for each region if applicable)"""
self.backends = backends

Expand Down
2 changes: 1 addition & 1 deletion moto/core/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,7 @@ def patch_client(client: botocore.client.BaseClient) -> None:
if isinstance(client, botocore.client.BaseClient):
# Check if our event handler was already registered
try:
event_emitter = client._ruleset_resolver._event_emitter._emitter
event_emitter = client._ruleset_resolver._event_emitter._emitter # type: ignore[attr-defined]
all_handlers = event_emitter._handlers._root["children"]
handler_trie = list(all_handlers["before-send"].values())[1]
handlers_list = handler_trie.first + handler_trie.middle + handler_trie.last
Expand Down
2 changes: 1 addition & 1 deletion moto/iam/access_control.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ def _raise_invalid_access_key(self, reason: str) -> None:
raise NotImplementedError()

@abstractmethod
def _create_auth(self, credentials: Credentials) -> SigV4Auth: # type: ignore[misc]
def _create_auth(self, credentials: Credentials) -> SigV4Auth:
raise NotImplementedError()

@staticmethod
Expand Down
120 changes: 65 additions & 55 deletions moto/iam/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@

from moto.core.common_models import ConfigQueryModel
from moto.core.exceptions import InvalidNextTokenException
from moto.iam import iam_backends
from moto.iam.models import IAMBackend, iam_backends


class RoleConfigQuery(ConfigQueryModel):
class RoleConfigQuery(ConfigQueryModel[IAMBackend]):
def list_config_service_resources(
self,
account_id: str,
Expand All @@ -32,26 +32,27 @@ def list_config_service_resources(
return [], None

# Filter by resource name or ids
if resource_name or resource_ids:
filtered_roles = []
# resource_name takes precedence over resource_ids
if resource_name:
for role in role_list:
if role.name == resource_name:
filtered_roles = [role]
break
# but if both are passed, it must be a subset
if filtered_roles and resource_ids:
if filtered_roles[0].id not in resource_ids:
return [], None
else:
for role in role_list:
if role.id in resource_ids: # type: ignore[operator]
filtered_roles.append(role)
# resource_name takes precedence over resource_ids
filtered_roles = []
if resource_name:
for role in role_list:
if role.name == resource_name:
filtered_roles = [role]
break
# but if both are passed, it must be a subset
if filtered_roles and resource_ids:
if filtered_roles[0].id not in resource_ids:
return [], None

# Filtered roles are now the subject for the listing
role_list = filtered_roles

elif resource_ids:
for role in role_list:
if role.id in resource_ids:
filtered_roles.append(role)
role_list = filtered_roles

if aggregator:
# IAM is a little special; Roles are created in us-east-1 (which AWS calls the "global" region)
# However, the resource will return in the aggregator (in duplicate) for each region in the aggregator
Expand All @@ -77,7 +78,6 @@ def list_config_service_resources(
duplicate_role_list.append(
{
"_id": f"{role.id}{region}", # this is only for sorting, isn't returned outside of this function
"type": "AWS::IAM::Role",
"id": role.id,
"name": role.name,
"region": region,
Expand All @@ -89,7 +89,10 @@ def list_config_service_resources(
else:
# Non-aggregated queries are in the else block, and we can treat these like a normal config resource
# Pagination logic, sort by role id
sorted_roles = sorted(role_list, key=lambda role: role.id) # type: ignore[attr-defined]
sorted_roles = [
{"_id": role.id, "id": role.id, "name": role.name, "region": "global"}
for role in sorted(role_list, key=lambda role: role.id)
]

new_token = None

Expand All @@ -102,27 +105,27 @@ def list_config_service_resources(
start = next(
index
for (index, r) in enumerate(sorted_roles)
if next_token == (r["_id"] if aggregator else r.id) # type: ignore[attr-defined]
if next_token == r["_id"]
)
except StopIteration:
raise InvalidNextTokenException()

# Get the list of items to collect:
role_list = sorted_roles[start : (start + limit)]
collected_role_list = sorted_roles[start : (start + limit)]

if len(sorted_roles) > (start + limit):
record = sorted_roles[start + limit]
new_token = record["_id"] if aggregator else record.id # type: ignore[attr-defined]
new_token = record["_id"]

return (
[
{
"type": "AWS::IAM::Role",
"id": role["id"] if aggregator else role.id, # type: ignore[attr-defined]
"name": role["name"] if aggregator else role.name, # type: ignore[attr-defined]
"region": role["region"] if aggregator else "global",
"id": role["id"],
"name": role["name"],
"region": role["region"],
}
for role in role_list
for role in collected_role_list
],
new_token,
)
Expand All @@ -136,7 +139,7 @@ def get_config_resource(
resource_region: Optional[str] = None,
) -> Optional[Dict[str, Any]]:

role = self.backends[account_id]["global"].roles.get(resource_id, {})
role = self.backends[account_id]["global"].roles.get(resource_id)

if not role:
return None
Expand All @@ -158,7 +161,7 @@ def get_config_resource(
return config_data


class PolicyConfigQuery(ConfigQueryModel):
class PolicyConfigQuery(ConfigQueryModel[IAMBackend]):
def list_config_service_resources(
self,
account_id: str,
Expand Down Expand Up @@ -194,27 +197,27 @@ def list_config_service_resources(
return [], None

# Filter by resource name or ids
if resource_name or resource_ids:
filtered_policies = []
# resource_name takes precedence over resource_ids
if resource_name:
for policy in policy_list:
if policy.name == resource_name:
filtered_policies = [policy]
break
# but if both are passed, it must be a subset
if filtered_policies and resource_ids:
if filtered_policies[0].id not in resource_ids:
return [], None

else:
for policy in policy_list:
if policy.id in resource_ids: # type: ignore[operator]
filtered_policies.append(policy)
# resource_name takes precedence over resource_ids
filtered_policies = []
if resource_name:
for policy in policy_list:
if policy.name == resource_name:
filtered_policies = [policy]
break
# but if both are passed, it must be a subset
if filtered_policies and resource_ids:
if filtered_policies[0].id not in resource_ids:
return [], None

# Filtered roles are now the subject for the listing
policy_list = filtered_policies

elif resource_ids:
for policy in policy_list:
if policy.id in resource_ids:
filtered_policies.append(policy)
policy_list = filtered_policies

if aggregator:
# IAM is a little special; Policies are created in us-east-1 (which AWS calls the "global" region)
# However, the resource will return in the aggregator (in duplicate) for each region in the aggregator
Expand All @@ -240,7 +243,6 @@ def list_config_service_resources(
duplicate_policy_list.append(
{
"_id": f"{policy.id}{region}", # this is only for sorting, isn't returned outside of this function
"type": "AWS::IAM::Policy",
"id": policy.id,
"name": policy.name,
"region": region,
Expand All @@ -255,7 +257,15 @@ def list_config_service_resources(
else:
# Non-aggregated queries are in the else block, and we can treat these like a normal config resource
# Pagination logic, sort by role id
sorted_policies = sorted(policy_list, key=lambda role: role.id) # type: ignore[attr-defined]
sorted_policies = [
{
"_id": policy.id,
"id": policy.id,
"name": policy.name,
"region": "global",
}
for policy in sorted(policy_list, key=lambda role: role.id)
]

new_token = None

Expand All @@ -268,27 +278,27 @@ def list_config_service_resources(
start = next(
index
for (index, p) in enumerate(sorted_policies)
if next_token == (p["_id"] if aggregator else p.id) # type: ignore[attr-defined]
if next_token == p["_id"]
)
except StopIteration:
raise InvalidNextTokenException()

# Get the list of items to collect:
policy_list = sorted_policies[start : (start + limit)]
collected_policy_list = sorted_policies[start : (start + limit)]

if len(sorted_policies) > (start + limit):
record = sorted_policies[start + limit]
new_token = record["_id"] if aggregator else record.id # type: ignore[attr-defined]
new_token = record["_id"]

return (
[
{
"type": "AWS::IAM::Policy",
"id": policy["id"] if aggregator else policy.id, # type: ignore[attr-defined]
"name": policy["name"] if aggregator else policy.name, # type: ignore[attr-defined]
"region": policy["region"] if aggregator else "global",
"id": policy["id"],
"name": policy["name"],
"region": policy["region"],
}
for policy in policy_list
for policy in collected_policy_list
],
new_token,
)
Expand Down
Loading

0 comments on commit ff5256d

Please sign in to comment.