diff --git a/snuba/admin/auth.py b/snuba/admin/auth.py index 92092888c1..5bb5bb3df6 100644 --- a/snuba/admin/auth.py +++ b/snuba/admin/auth.py @@ -3,6 +3,7 @@ import json from typing import Sequence +import rapidjson import structlog from flask import request @@ -11,9 +12,12 @@ from snuba.admin.google import CloudIdentityAPI from snuba.admin.jwt import validate_assertion from snuba.admin.user import AdminUser +from snuba.redis import RedisClientKey, get_redis_client USER_HEADER_KEY = "X-Goog-Authenticated-User-Email" +redis_client = get_redis_client(RedisClientKey.ADMIN_AUTH) + logger = structlog.get_logger().bind(module=__name__) @@ -41,7 +45,7 @@ def _is_member_of_group(user: AdminUser, group: str) -> bool: return google_api.check_group_membership(group_email=group, member=user.email) -def get_iam_roles_from_file(user: AdminUser) -> Sequence[str]: +def get_iam_roles_from_user(user: AdminUser) -> Sequence[str]: iam_roles = [] try: with open(settings.ADMIN_IAM_POLICY_FILE, "r") as policy_file: @@ -65,10 +69,38 @@ def get_iam_roles_from_file(user: AdminUser) -> Sequence[str]: return iam_roles +def get_cached_iam_roles(user: AdminUser) -> Sequence[str]: + iam_roles_str = redis_client.get(f"roles-{user.email}") + if not iam_roles_str: + return [] + + iam_roles = rapidjson.loads(iam_roles_str) + if isinstance(iam_roles, list): + return iam_roles + + return [] + + def _set_roles(user: AdminUser) -> AdminUser: # todo: depending on provider convert user email # to subset of DEFAULT_ROLES based on IAM roles - iam_roles = get_iam_roles_from_file(user) + iam_roles: Sequence[str] = [] + try: + iam_roles = get_cached_iam_roles(user) + except Exception as e: + logger.exception("Failed to load roles from cache", exception=e) + + if not iam_roles: + iam_roles = get_iam_roles_from_user(user) + try: + redis_client.set( + f"roles-{user.email}", + rapidjson.dumps(iam_roles), + ex=settings.ADMIN_ROLES_REDIS_TTL, + ) + except Exception as e: + logger.exception(e) + user.roles = [*[ROLES[role] for role in iam_roles if role in ROLES], *DEFAULT_ROLES] return user diff --git a/snuba/redis.py b/snuba/redis.py index d8efe15043..575948018e 100644 --- a/snuba/redis.py +++ b/snuba/redis.py @@ -113,6 +113,7 @@ class RedisClientKey(Enum): CONFIG = "config" DLQ = "dlq" OPTIMIZE = "optimize" + ADMIN_AUTH = "admin_auth" _redis_clients: Mapping[RedisClientKey, RedisClientType] = { @@ -137,6 +138,9 @@ class RedisClientKey(Enum): RedisClientKey.OPTIMIZE: _initialize_specialized_redis_cluster( settings.REDIS_CLUSTERS["optimize"] ), + RedisClientKey.ADMIN_AUTH: _initialize_specialized_redis_cluster( + settings.REDIS_CLUSTERS["admin_auth"] + ), } diff --git a/snuba/settings/__init__.py b/snuba/settings/__init__.py index 74a02f5f0f..0ba5c09d61 100644 --- a/snuba/settings/__init__.py +++ b/snuba/settings/__init__.py @@ -55,6 +55,8 @@ os.environ.get("ADMIN_REPLAYS_SAMPLE_RATE_ON_ERROR", 1.0) ) +ADMIN_ROLES_REDIS_TTL = 600 + ###################### # End Admin Settings # ###################### @@ -154,6 +156,7 @@ class RedisClusters(TypedDict): config: RedisClusterConfig | None dlq: RedisClusterConfig | None optimize: RedisClusterConfig | None + admin_auth: RedisClusterConfig | None REDIS_CLUSTERS: RedisClusters = { @@ -164,6 +167,7 @@ class RedisClusters(TypedDict): "config": None, "dlq": None, "optimize": None, + "admin_auth": None, } # Query Recording Options diff --git a/snuba/settings/settings_test.py b/snuba/settings/settings_test.py index d4f0fd5c36..e1780651c4 100644 --- a/snuba/settings/settings_test.py +++ b/snuba/settings/settings_test.py @@ -55,5 +55,6 @@ (6, "config"), (7, "dlq"), (8, "optimize"), + (9, "admin_auth"), ] } diff --git a/tests/admin/clickhouse_migrations/test_api.py b/tests/admin/clickhouse_migrations/test_api.py index 1651b59f26..01450cccd3 100644 --- a/tests/admin/clickhouse_migrations/test_api.py +++ b/tests/admin/clickhouse_migrations/test_api.py @@ -22,6 +22,7 @@ from snuba.migrations.policies import MigrationPolicy from snuba.migrations.runner import MigrationKey, Runner from snuba.migrations.status import Status +from snuba.redis import RedisClientKey, get_redis_client @pytest.fixture @@ -31,6 +32,7 @@ def admin_api() -> FlaskClient: return application.test_client() +@pytest.mark.redis_db @pytest.mark.clickhouse_db def test_migration_groups(admin_api: FlaskClient) -> None: runner = Runner() @@ -76,6 +78,7 @@ def get_migration_ids( ] +@pytest.mark.redis_db @pytest.mark.clickhouse_db def test_list_migration_status(admin_api: FlaskClient) -> None: with patch( @@ -137,6 +140,7 @@ def sort_by_migration_id(migration: Any) -> Any: assert sorted_response == sorted_expected_json +@pytest.mark.redis_db @pytest.mark.clickhouse_db @pytest.mark.parametrize("action", ["run", "reverse"]) def test_run_reverse_migrations(admin_api: FlaskClient, action: str) -> None: @@ -281,6 +285,7 @@ def print_something(*args: Any, **kwargs: Any) -> None: assert mock_run_migration.call_count == 1 +@pytest.mark.redis_db def test_get_iam_roles(caplog: Any) -> None: system_role = generate_migration_test_role("system", "all") tool_role = generate_tool_test_role("snql-to-sql") @@ -359,6 +364,8 @@ def test_get_iam_roles(caplog: Any) -> None: tool_role, ] + iam_file.close() + with patch( "snuba.admin.auth.settings.ADMIN_IAM_POLICY_FILE", "file_not_exists.json" ): @@ -369,3 +376,122 @@ def test_get_iam_roles(caplog: Any) -> None: assert "IAM policy file not found file_not_exists.json" in str( log.calls ) + + +@pytest.mark.redis_db +def test_get_iam_roles_cache() -> None: + system_role = generate_migration_test_role("system", "all") + tool_role = generate_tool_test_role("snql-to-sql") + with patch( + "snuba.admin.auth.DEFAULT_ROLES", + [system_role, tool_role], + ): + iam_file = tempfile.NamedTemporaryFile() + iam_file.write( + json.dumps( + { + "bindings": [ + { + "members": [ + "group:team-sns@sentry.io", + "user:test_user1@sentry.io", + ], + "role": "roles/NonBlockingMigrationsExecutor", + }, + { + "members": [ + "group:team-sns@sentry.io", + "user:test_user1@sentry.io", + "user:test_user2@sentry.io", + ], + "role": "roles/TestMigrationsExecutor", + }, + { + "members": [ + "group:team-sns@sentry.io", + "user:test_user1@sentry.io", + "user:test_user2@sentry.io", + ], + "role": "roles/owner", + }, + { + "members": [ + "group:team-sns@sentry.io", + "user:test_user1@sentry.io", + ], + "role": "roles/AllTools", + }, + ] + } + ).encode("utf-8") + ) + + iam_file.flush() + with patch("snuba.admin.auth.settings.ADMIN_IAM_POLICY_FILE", iam_file.name): + + user1 = AdminUser(email="test_user1@sentry.io", id="unknown") + _set_roles(user1) + + assert user1.roles == [ + ROLES["NonBlockingMigrationsExecutor"], + ROLES["TestMigrationsExecutor"], + ROLES["AllTools"], + system_role, + tool_role, + ] + + iam_file = tempfile.NamedTemporaryFile() + iam_file.write(json.dumps({"bindings": []}).encode("utf-8")) + iam_file.flush() + + with patch("snuba.admin.auth.settings.ADMIN_IAM_POLICY_FILE", iam_file.name): + _set_roles(user1) + + assert user1.roles == [ + ROLES["NonBlockingMigrationsExecutor"], + ROLES["TestMigrationsExecutor"], + ROLES["AllTools"], + system_role, + tool_role, + ] + + redis_client = get_redis_client(RedisClientKey.ADMIN_AUTH) + redis_client.delete(f"roles-{user1.email}") + _set_roles(user1) + + assert user1.roles == [ + system_role, + tool_role, + ] + + +@pytest.mark.redis_db +@patch("redis.Redis") +def test_get_iam_roles_cache_fail(mock_redis: Any) -> None: + mock_redis.get.side_effect = Exception("Test exception") + mock_redis.set.side_effect = Exception("Test exception") + system_role = generate_migration_test_role("system", "all") + tool_role = generate_tool_test_role("snql-to-sql") + with patch( + "snuba.admin.auth.DEFAULT_ROLES", + [system_role, tool_role], + ): + iam_file = tempfile.NamedTemporaryFile() + iam_file.write(json.dumps({"bindings": []}).encode("utf-8")) + iam_file.flush() + + with patch("snuba.admin.auth.settings.ADMIN_IAM_POLICY_FILE", iam_file.name): + user1 = AdminUser(email="test_user1@sentry.io", id="unknown") + _set_roles(user1) + + assert user1.roles == [ + system_role, + tool_role, + ] + + _set_roles(user1) + + assert user1.roles == [ + system_role, + tool_role, + ] diff --git a/tests/admin/test_api.py b/tests/admin/test_api.py index 5f1c4d4bb3..3dca7e6d50 100644 --- a/tests/admin/test_api.py +++ b/tests/admin/test_api.py @@ -166,6 +166,7 @@ def test_config_descriptions(admin_api: FlaskClient) -> None: } +@pytest.mark.redis_db def get_node_for_table( admin_api: FlaskClient, storage_name: str ) -> tuple[str, str, int]: @@ -204,6 +205,7 @@ def test_system_query(admin_api: FlaskClient) -> None: assert data["rows"] == [] +@pytest.mark.redis_db def test_predefined_system_queries(admin_api: FlaskClient) -> None: response = admin_api.get( "/clickhouse_queries", @@ -249,6 +251,7 @@ def test_query_trace_bad_query(admin_api: FlaskClient) -> None: assert "clickhouse" == data["error"]["type"] +@pytest.mark.redis_db @pytest.mark.clickhouse_db def test_query_trace_invalid_query(admin_api: FlaskClient) -> None: table, _, _ = get_node_for_table(admin_api, "errors_ro") @@ -279,6 +282,7 @@ def test_querylog_query(admin_api: FlaskClient) -> None: assert "column_names" in data and data["column_names"] == ["count()"] +@pytest.mark.redis_db @pytest.mark.clickhouse_db def test_querylog_invalid_query(admin_api: FlaskClient) -> None: table, _, _ = get_node_for_table(admin_api, "errors_ro") @@ -301,6 +305,7 @@ def test_querylog_describe(admin_api: FlaskClient) -> None: assert "column_names" in data and "rows" in data +@pytest.mark.redis_db def test_predefined_querylog_queries(admin_api: FlaskClient) -> None: response = admin_api.get( "/querylog_queries", @@ -313,6 +318,7 @@ def test_predefined_querylog_queries(admin_api: FlaskClient) -> None: assert data[0]["name"] == "QueryByID" +@pytest.mark.redis_db def test_get_snuba_datasets(admin_api: FlaskClient) -> None: response = admin_api.get("/snuba_datasets") assert response.status_code == 200 @@ -320,6 +326,7 @@ def test_get_snuba_datasets(admin_api: FlaskClient) -> None: assert set(data) == set(get_enabled_dataset_names()) +@pytest.mark.redis_db def test_convert_SnQL_to_SQL_invalid_dataset(admin_api: FlaskClient) -> None: response = admin_api.post( "/snql_to_sql", data=json.dumps({"dataset": "", "query": ""}) @@ -329,6 +336,7 @@ def test_convert_SnQL_to_SQL_invalid_dataset(admin_api: FlaskClient) -> None: assert data["error"]["message"] == "dataset '' does not exist" +@pytest.mark.redis_db @pytest.mark.redis_db def test_convert_SnQL_to_SQL_invalid_query(admin_api: FlaskClient) -> None: response = admin_api.post( diff --git a/tests/admin/test_authorization.py b/tests/admin/test_authorization.py index 923902a61e..d07614435f 100644 --- a/tests/admin/test_authorization.py +++ b/tests/admin/test_authorization.py @@ -16,6 +16,7 @@ def admin_api() -> FlaskClient: return application.test_client() +@pytest.mark.redis_db def test_tools(admin_api: FlaskClient) -> None: response = admin_api.get("/tools") assert response.status_code == 200 @@ -25,6 +26,7 @@ def test_tools(admin_api: FlaskClient) -> None: assert "all" in data["tools"] +@pytest.mark.redis_db @patch("snuba.admin.auth.DEFAULT_ROLES", [ROLES["ProductTools"]]) def test_product_tools_role( admin_api: FlaskClient,