Skip to content

Commit

Permalink
feat(admin-auth): Add TTL cache for admin roles (#4421)
Browse files Browse the repository at this point in the history
  • Loading branch information
davidtsuk authored Jun 26, 2023
1 parent 0a80f15 commit 198a5cd
Show file tree
Hide file tree
Showing 7 changed files with 179 additions and 2 deletions.
36 changes: 34 additions & 2 deletions snuba/admin/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import json
from typing import Sequence

import rapidjson
import structlog
from flask import request

Expand All @@ -11,9 +12,12 @@
from snuba.admin.google import CloudIdentityAPI
from snuba.admin.jwt import validate_assertion
from snuba.admin.user import AdminUser
from snuba.redis import RedisClientKey, get_redis_client

USER_HEADER_KEY = "X-Goog-Authenticated-User-Email"

redis_client = get_redis_client(RedisClientKey.ADMIN_AUTH)

logger = structlog.get_logger().bind(module=__name__)


Expand Down Expand Up @@ -41,7 +45,7 @@ def _is_member_of_group(user: AdminUser, group: str) -> bool:
return google_api.check_group_membership(group_email=group, member=user.email)


def get_iam_roles_from_file(user: AdminUser) -> Sequence[str]:
def get_iam_roles_from_user(user: AdminUser) -> Sequence[str]:
iam_roles = []
try:
with open(settings.ADMIN_IAM_POLICY_FILE, "r") as policy_file:
Expand All @@ -65,10 +69,38 @@ def get_iam_roles_from_file(user: AdminUser) -> Sequence[str]:
return iam_roles


def get_cached_iam_roles(user: AdminUser) -> Sequence[str]:
iam_roles_str = redis_client.get(f"roles-{user.email}")
if not iam_roles_str:
return []

iam_roles = rapidjson.loads(iam_roles_str)
if isinstance(iam_roles, list):
return iam_roles

return []


def _set_roles(user: AdminUser) -> AdminUser:
# todo: depending on provider convert user email
# to subset of DEFAULT_ROLES based on IAM roles
iam_roles = get_iam_roles_from_file(user)
iam_roles: Sequence[str] = []
try:
iam_roles = get_cached_iam_roles(user)
except Exception as e:
logger.exception("Failed to load roles from cache", exception=e)

if not iam_roles:
iam_roles = get_iam_roles_from_user(user)
try:
redis_client.set(
f"roles-{user.email}",
rapidjson.dumps(iam_roles),
ex=settings.ADMIN_ROLES_REDIS_TTL,
)
except Exception as e:
logger.exception(e)

user.roles = [*[ROLES[role] for role in iam_roles if role in ROLES], *DEFAULT_ROLES]
return user

Expand Down
4 changes: 4 additions & 0 deletions snuba/redis.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ class RedisClientKey(Enum):
CONFIG = "config"
DLQ = "dlq"
OPTIMIZE = "optimize"
ADMIN_AUTH = "admin_auth"


_redis_clients: Mapping[RedisClientKey, RedisClientType] = {
Expand All @@ -137,6 +138,9 @@ class RedisClientKey(Enum):
RedisClientKey.OPTIMIZE: _initialize_specialized_redis_cluster(
settings.REDIS_CLUSTERS["optimize"]
),
RedisClientKey.ADMIN_AUTH: _initialize_specialized_redis_cluster(
settings.REDIS_CLUSTERS["admin_auth"]
),
}


Expand Down
4 changes: 4 additions & 0 deletions snuba/settings/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@
os.environ.get("ADMIN_REPLAYS_SAMPLE_RATE_ON_ERROR", 1.0)
)

ADMIN_ROLES_REDIS_TTL = 600

######################
# End Admin Settings #
######################
Expand Down Expand Up @@ -154,6 +156,7 @@ class RedisClusters(TypedDict):
config: RedisClusterConfig | None
dlq: RedisClusterConfig | None
optimize: RedisClusterConfig | None
admin_auth: RedisClusterConfig | None


REDIS_CLUSTERS: RedisClusters = {
Expand All @@ -164,6 +167,7 @@ class RedisClusters(TypedDict):
"config": None,
"dlq": None,
"optimize": None,
"admin_auth": None,
}

# Query Recording Options
Expand Down
1 change: 1 addition & 0 deletions snuba/settings/settings_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,5 +55,6 @@
(6, "config"),
(7, "dlq"),
(8, "optimize"),
(9, "admin_auth"),
]
}
126 changes: 126 additions & 0 deletions tests/admin/clickhouse_migrations/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from snuba.migrations.policies import MigrationPolicy
from snuba.migrations.runner import MigrationKey, Runner
from snuba.migrations.status import Status
from snuba.redis import RedisClientKey, get_redis_client


def generate_migration_test_role(
Expand Down Expand Up @@ -60,6 +61,7 @@ def admin_api() -> FlaskClient:
return application.test_client()


@pytest.mark.redis_db
@pytest.mark.clickhouse_db
def test_migration_groups(admin_api: FlaskClient) -> None:
runner = Runner()
Expand Down Expand Up @@ -105,6 +107,7 @@ def get_migration_ids(
]


@pytest.mark.redis_db
@pytest.mark.clickhouse_db
def test_list_migration_status(admin_api: FlaskClient) -> None:
with patch(
Expand Down Expand Up @@ -166,6 +169,7 @@ def sort_by_migration_id(migration: Any) -> Any:
assert sorted_response == sorted_expected_json


@pytest.mark.redis_db
@pytest.mark.clickhouse_db
@pytest.mark.parametrize("action", ["run", "reverse"])
def test_run_reverse_migrations(admin_api: FlaskClient, action: str) -> None:
Expand Down Expand Up @@ -310,6 +314,7 @@ def print_something(*args: Any, **kwargs: Any) -> None:
assert mock_run_migration.call_count == 1


@pytest.mark.redis_db
def test_get_iam_roles(caplog: Any) -> None:
system_role = generate_migration_test_role("system", "all")
tool_role = generate_tool_test_role("snql-to-sql")
Expand Down Expand Up @@ -388,6 +393,8 @@ def test_get_iam_roles(caplog: Any) -> None:
tool_role,
]

iam_file.close()

with patch(
"snuba.admin.auth.settings.ADMIN_IAM_POLICY_FILE", "file_not_exists.json"
):
Expand All @@ -398,3 +405,122 @@ def test_get_iam_roles(caplog: Any) -> None:
assert "IAM policy file not found file_not_exists.json" in str(
log.calls
)


@pytest.mark.redis_db
def test_get_iam_roles_cache() -> None:
system_role = generate_migration_test_role("system", "all")
tool_role = generate_tool_test_role("snql-to-sql")
with patch(
"snuba.admin.auth.DEFAULT_ROLES",
[system_role, tool_role],
):
iam_file = tempfile.NamedTemporaryFile()
iam_file.write(
json.dumps(
{
"bindings": [
{
"members": [
"group:team-sns@sentry.io",
"user:test_user1@sentry.io",
],
"role": "roles/NonBlockingMigrationsExecutor",
},
{
"members": [
"group:team-sns@sentry.io",
"user:test_user1@sentry.io",
"user:test_user2@sentry.io",
],
"role": "roles/TestMigrationsExecutor",
},
{
"members": [
"group:team-sns@sentry.io",
"user:test_user1@sentry.io",
"user:test_user2@sentry.io",
],
"role": "roles/owner",
},
{
"members": [
"group:team-sns@sentry.io",
"user:test_user1@sentry.io",
],
"role": "roles/AllTools",
},
]
}
).encode("utf-8")
)

iam_file.flush()
with patch("snuba.admin.auth.settings.ADMIN_IAM_POLICY_FILE", iam_file.name):

user1 = AdminUser(email="test_user1@sentry.io", id="unknown")
_set_roles(user1)

assert user1.roles == [
ROLES["NonBlockingMigrationsExecutor"],
ROLES["TestMigrationsExecutor"],
ROLES["AllTools"],
system_role,
tool_role,
]

iam_file = tempfile.NamedTemporaryFile()
iam_file.write(json.dumps({"bindings": []}).encode("utf-8"))
iam_file.flush()

with patch("snuba.admin.auth.settings.ADMIN_IAM_POLICY_FILE", iam_file.name):
_set_roles(user1)

assert user1.roles == [
ROLES["NonBlockingMigrationsExecutor"],
ROLES["TestMigrationsExecutor"],
ROLES["AllTools"],
system_role,
tool_role,
]

redis_client = get_redis_client(RedisClientKey.ADMIN_AUTH)
redis_client.delete(f"roles-{user1.email}")
_set_roles(user1)

assert user1.roles == [
system_role,
tool_role,
]


@pytest.mark.redis_db
@patch("redis.Redis")
def test_get_iam_roles_cache_fail(mock_redis: Any) -> None:
mock_redis.get.side_effect = Exception("Test exception")
mock_redis.set.side_effect = Exception("Test exception")
system_role = generate_migration_test_role("system", "all")
tool_role = generate_tool_test_role("snql-to-sql")
with patch(
"snuba.admin.auth.DEFAULT_ROLES",
[system_role, tool_role],
):
iam_file = tempfile.NamedTemporaryFile()
iam_file.write(json.dumps({"bindings": []}).encode("utf-8"))
iam_file.flush()

with patch("snuba.admin.auth.settings.ADMIN_IAM_POLICY_FILE", iam_file.name):
user1 = AdminUser(email="test_user1@sentry.io", id="unknown")
_set_roles(user1)

assert user1.roles == [
system_role,
tool_role,
]

_set_roles(user1)

assert user1.roles == [
system_role,
tool_role,
]
8 changes: 8 additions & 0 deletions tests/admin/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,7 @@ def test_config_descriptions(admin_api: FlaskClient) -> None:
}


@pytest.mark.redis_db
def get_node_for_table(
admin_api: FlaskClient, storage_name: str
) -> tuple[str, str, int]:
Expand Down Expand Up @@ -204,6 +205,7 @@ def test_system_query(admin_api: FlaskClient) -> None:
assert data["rows"] == []


@pytest.mark.redis_db
def test_predefined_system_queries(admin_api: FlaskClient) -> None:
response = admin_api.get(
"/clickhouse_queries",
Expand Down Expand Up @@ -249,6 +251,7 @@ def test_query_trace_bad_query(admin_api: FlaskClient) -> None:
assert "clickhouse" == data["error"]["type"]


@pytest.mark.redis_db
@pytest.mark.clickhouse_db
def test_query_trace_invalid_query(admin_api: FlaskClient) -> None:
table, _, _ = get_node_for_table(admin_api, "errors_ro")
Expand Down Expand Up @@ -279,6 +282,7 @@ def test_querylog_query(admin_api: FlaskClient) -> None:
assert "column_names" in data and data["column_names"] == ["count()"]


@pytest.mark.redis_db
@pytest.mark.clickhouse_db
def test_querylog_invalid_query(admin_api: FlaskClient) -> None:
table, _, _ = get_node_for_table(admin_api, "errors_ro")
Expand All @@ -301,6 +305,7 @@ def test_querylog_describe(admin_api: FlaskClient) -> None:
assert "column_names" in data and "rows" in data


@pytest.mark.redis_db
def test_predefined_querylog_queries(admin_api: FlaskClient) -> None:
response = admin_api.get(
"/querylog_queries",
Expand All @@ -313,13 +318,15 @@ def test_predefined_querylog_queries(admin_api: FlaskClient) -> None:
assert data[0]["name"] == "QueryByID"


@pytest.mark.redis_db
def test_get_snuba_datasets(admin_api: FlaskClient) -> None:
response = admin_api.get("/snuba_datasets")
assert response.status_code == 200
data = json.loads(response.data)
assert set(data) == set(get_enabled_dataset_names())


@pytest.mark.redis_db
def test_convert_SnQL_to_SQL_invalid_dataset(admin_api: FlaskClient) -> None:
response = admin_api.post(
"/snql_to_sql", data=json.dumps({"dataset": "", "query": ""})
Expand All @@ -329,6 +336,7 @@ def test_convert_SnQL_to_SQL_invalid_dataset(admin_api: FlaskClient) -> None:
assert data["error"]["message"] == "dataset '' does not exist"


@pytest.mark.redis_db
@pytest.mark.redis_db
def test_convert_SnQL_to_SQL_invalid_query(admin_api: FlaskClient) -> None:
response = admin_api.post(
Expand Down
2 changes: 2 additions & 0 deletions tests/admin/test_authorization.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ def admin_api() -> FlaskClient:
return application.test_client()


@pytest.mark.redis_db
def test_tools(admin_api: FlaskClient) -> None:
response = admin_api.get("/tools")
assert response.status_code == 200
Expand All @@ -25,6 +26,7 @@ def test_tools(admin_api: FlaskClient) -> None:
assert "all" in data["tools"]


@pytest.mark.redis_db
@patch("snuba.admin.auth.DEFAULT_ROLES", [ROLES["ProductTools"]])
def test_product_tools_role(
admin_api: FlaskClient,
Expand Down

0 comments on commit 198a5cd

Please sign in to comment.