Skip to content

Commit

Permalink
feat: csp and bake in csp and cors as asgi send middleware
Browse files Browse the repository at this point in the history
  • Loading branch information
alexogeny committed Aug 19, 2024
1 parent baf7832 commit 1e08d87
Show file tree
Hide file tree
Showing 7 changed files with 142 additions and 63 deletions.
21 changes: 12 additions & 9 deletions src/zara/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,18 @@


class Config:
_instance = None
_config = configparser.ConfigParser()

def __new__(cls, config_file: str = "config.ini") -> "Config":
if cls._instance is None:
cls._instance = super().__new__(cls)
with open(config_file, "r") as f:
cls._config.read_file(f)
return cls._instance
def __init__(
self, config_file: str = None, config: dict[Any, Any] = None
) -> "Config":
self._config = configparser.ConfigParser()
if config:
self._config.read_dict(config)
else:
try:
with open(config_file, "r") as f:
self._config.read_file(f)
except FileNotFoundError:
self._config.read_dict({})

def __getattr__(self, section: str) -> Any:
if section in self._config:
Expand Down
35 changes: 7 additions & 28 deletions src/zara/server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,14 @@


class SimpleASGIApp:
def __init__(self) -> None:
def __init__(self, config=None) -> None:
self.routers: list[Router] = []
self.before_request_handlers: List[AsgiHandlerType] = []
self.after_request_handlers: List[AsgiHandlerType] = []
self.startup_handlers: List[GenericHandlerType] = []
self.shutdown_handlers: List[GenericHandlerType] = []
self._config = None
self._raw_config = config
self.rate_limit = (100, 60) # 100 requests every 60 seconds

def add_router(self, router: Router) -> None:
Expand All @@ -27,32 +28,14 @@ def add_router(self, router: Router) -> None:
@property
def config(self):
if self._config is None:
self._config = Config("config.ini")
if self._raw_config is not None:
self._config = Config(config=self._raw_config)
else:
self._config = Config("config.ini")
return self._config

def add_cors_headers(self, response: Dict[str, Any], origin: str) -> None:
cors_config = self.config.cors
if origin in cors_config.allowed_origins:
response["headers"].extend(
[
(b"access-control-allow-origin", origin.encode("utf-8")),
(
b"access-control-allow-methods",
cors_config.allowed_methods.encode("utf-8"),
),
(
b"access-control-allow-headers",
cors_config.allowed_headers.encode("utf-8"),
),
]
)
if cors_config.allow_credentials.lower() == "true":
response["headers"].append(
(b"access-control-allow-credentials", b"true")
)

async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
asgi = ASGI(scope, receive, send)
asgi = ASGI(scope, receive, send, self.config)
assert asgi.scope["type"] == "http"

path = asgi.scope["path"]
Expand All @@ -68,8 +51,6 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
"error_was_raised": False,
}

origin = dict(headers).get(b"origin", b"").decode("utf-8")

for handler in self.before_request_handlers:
await handler(request)

Expand All @@ -89,8 +70,6 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if request["error_was_raised"] is False:
await send_http_error(asgi.send, HTTPStatus.NOT_FOUND)
else:
if origin:
self.add_cors_headers(response, origin)
await asgi.send(
Http.Response.Start(
status=response["status"],
Expand Down
65 changes: 62 additions & 3 deletions src/zara/types/asgi.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Awaitable, Callable, TypeAlias
from typing import Any, Awaitable, Callable, Dict, TypeAlias

Scope: TypeAlias = dict
Receive: TypeAlias = Callable[[], Awaitable[bytes]]
Expand All @@ -8,8 +8,67 @@
CallableAwaitable = Callable[..., Awaitable[Any]]


def make_csp(
script="https://trustedscripts.example.com",
style="https://trustedstyles.example.com",
image="https://trustedimages.example.com",
):
return (
"default-src 'self'; "
f"script-src 'self' {script}; "
f"style-src 'self' {style}; "
f"img-src 'self' data: {image}; "
"frame-ancestors 'self'; "
"form-action 'self'; "
"block-all-mixed-content; "
"upgrade-insecure-requests"
).encode("utf-8")


class ASGI:
def __init__(self, scope: Scope, receive: Receive, send: Send) -> None:
def __init__(self, scope: Scope, receive: Receive, send: Send, config):
self.scope = scope
self.receive = receive
self.send = send
self.original_send = send
self.config = config

async def send(self, message: Dict[str, Any]) -> None:
if message["type"] == "http.response.body":
return await self.original_send(message)

headers = dict(self.scope.get("headers", []))
if message["type"] == "http.response.start" and hasattr(self.config, "csp"):
message["headers"].append(
(
b"content-security-policy",
make_csp(
script=self.config.csp.script,
style=self.config.csp.style,
image=self.config.csp.image,
),
)
)
if (origin := headers.get(b"origin", b"").decode("utf-8")) and hasattr(
self.config, "cors"
):
cors = self.config.cors
if origin in cors.allowed_origins:
message["headers"].extend(
[
(b"access-control-allow-origin", origin.encode("utf-8")),
(
b"access-control-allow-methods",
cors.allowed_methods.encode("utf-8"),
),
(
b"access-control-allow-headers",
cors.allowed_headers.encode("utf-8"),
),
]
)
if cors.allow_credentials.lower() == "true":
message["headers"].append(
(b"access-control-allow-credentials", b"true")
)

await self.original_send(message)
4 changes: 0 additions & 4 deletions tests/config/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,3 @@ def test_singleton_behavior(self):

self.assertIs(config1, config2)
self.assertEqual(config1.server.port, config2.server.port)

def test_no_such_file(self):
with self.assertRaises(FileNotFoundError):
Config("non_existent_file.ini")
4 changes: 2 additions & 2 deletions tests/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,8 @@ async def async_ratelimited(request):
}


def make_test_app():
app = SimpleASGIApp()
def make_test_app(**kwargs):
app = SimpleASGIApp(**kwargs)
app.rate_limit = (3, 5)
router = Router()
router.rate_limit = (2, 5)
Expand Down
48 changes: 48 additions & 0 deletions tests/security/test_content_security_policy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import unittest
from unittest.mock import AsyncMock

from ..helpers import assert_status_code_with_response_body, make_scope, make_test_app


class TestCSP(unittest.IsolatedAsyncioTestCase):
def setUp(self):
self.config = {
"server": {"port": "8080", "host": "127.0.0.1"},
"csp": {
"script": "https://trustedscripts.example.com",
"style": "https://trustedstyles.example.com",
"image": "https://trustedimages.example.com",
},
}
self.app = make_test_app(config=self.config)

async def test_csp_headers(self):
send_mock = AsyncMock()

async def run_app_with_scope(scope):
await self.app(scope, None, send_mock)

scope = make_scope()

await run_app_with_scope(scope)
assert_status_code_with_response_body(
send_mock,
200,
b"Hello, World!",
headers=[
(b"content-type", b"text/plain"),
(
b"content-security-policy",
(
"default-src 'self'; "
"script-src 'self' https://trustedscripts.example.com; "
"style-src 'self' https://trustedstyles.example.com; "
"img-src 'self' data: https://trustedimages.example.com; "
"frame-ancestors 'self'; "
"form-action 'self'; "
"block-all-mixed-content; "
"upgrade-insecure-requests"
).encode("utf-8"),
),
],
)
28 changes: 11 additions & 17 deletions tests/config/test_cors.py → tests/security/test_cors.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,21 @@
import unittest
from unittest.mock import AsyncMock, mock_open, patch

from zara.config.config import Config
from unittest.mock import AsyncMock

from ..helpers import assert_status_code_with_response_body, make_scope, make_test_app


class TestCORS(unittest.IsolatedAsyncioTestCase):
def setUp(self):
config_content = """
[server]
port = 8080
host = 127.0.0.1
[cors]
allowed_origins = https://example.com, https://another.com
allowed_methods = GET, POST, OPTIONS
allowed_headers = Content-Type, Authorization
allow_credentials = true
"""
with patch("builtins.open", mock_open(read_data=config_content)):
self.config = Config("config.ini")
self.app = make_test_app()
self.config = {
"server": {"port": "8080", "host": "127.0.0.1"},
"cors": {
"allowed_origins": "https://example.com, https://another.com",
"allowed_methods": "GET, POST, OPTIONS",
"allowed_headers": "Content-Type, Authorization",
"allow_credentials": "true",
},
}
self.app = make_test_app(config=self.config)

async def test_cors_headers(self):
send_mock = AsyncMock()
Expand Down

0 comments on commit 1e08d87

Please sign in to comment.