Skip to content

Commit

Permalink
upgrade starlette again (#2560)
Browse files Browse the repository at this point in the history
  • Loading branch information
jotare authored Oct 23, 2024
1 parent 5eebea5 commit f16f3d7
Show file tree
Hide file tree
Showing 5 changed files with 1,483 additions and 1,297 deletions.
1 change: 0 additions & 1 deletion nucliadb/tests/search/integration/api/v1/test_ask_audit.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,6 @@ async def test_ask_sends_only_one_audit(
resp = await client.post(
f"/{KB_PREFIX}/{kbid}/ask",
json={"query": "title"},
headers={"X-Synchronous": "True"},
)
assert resp.status_code == 200

Expand Down
19 changes: 10 additions & 9 deletions nucliadb_telemetry/src/nucliadb_telemetry/fastapi/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,26 +16,27 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
from starlette.requests import Request
from starlette.responses import Response

from nucliadb_telemetry import context

from .utils import get_path_template


class ContextInjectorMiddleware(BaseHTTPMiddleware):
class ContextInjectorMiddleware:
"""
Automatically inject context values for the current request's path parameters
For example:
- `/api/v1/kb/{kbid}` would inject a context value for `kbid`
"""

async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response:
found_path_template = get_path_template(request.scope)
if found_path_template.match:
context.add_context(found_path_template.scope.get("path_params", {})) # type: ignore
def __init__(self, app):
self.app = app

return await call_next(request)
async def __call__(self, scope, receive, send):
if scope["type"] == "http":
found_path_template = get_path_template(scope)
if found_path_template.match:
context.add_context(found_path_template.scope.get("path_params", {})) # type: ignore

return await self.app(scope, receive, send)
29 changes: 9 additions & 20 deletions nucliadb_telemetry/tests/unit/fastapi/test_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,12 @@
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
from unittest.mock import AsyncMock
from unittest.mock import patch

import pytest
from fastapi import FastAPI
from httpx import ASGITransport, AsyncClient

from nucliadb_telemetry import context
from nucliadb_telemetry.fastapi.context import ContextInjectorMiddleware

app = FastAPI()
Expand All @@ -34,23 +34,12 @@ def get_kb(kbid: str):

@pytest.mark.asyncio
async def test_context_injected():
scope = {
"app": app,
"path": "/api/v1/kb/123",
"method": "GET",
"type": "http",
}
app.add_middleware(ContextInjectorMiddleware)

mdlw = ContextInjectorMiddleware(app)
transport = ASGITransport(app=app) # type: ignore
client = AsyncClient(transport=transport, base_url="http://test/api/v1")

found_ctx = {}

async def receive(*args, **kwargs):
found_ctx.update(context.get_context())
return {
"type": "http.disconnect",
}

await mdlw(scope, receive, AsyncMock())

assert found_ctx == {"kbid": "123"}
with patch("nucliadb_telemetry.fastapi.context.context.add_context") as add_context:
await client.get("/kb/123")
assert add_context.call_count == 1
assert add_context.call_args[0][0] == {"kbid": "123"}
38 changes: 10 additions & 28 deletions nucliadb_utils/src/nucliadb_utils/audit/stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,9 @@
from fastapi import Request
from google.protobuf.timestamp_pb2 import Timestamp
from opentelemetry.trace import format_trace_id, get_current_span
from starlette.background import BackgroundTask
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
from starlette.responses import Response, StreamingResponse
from starlette.responses import Response
from starlette.types import ASGIApp

from nucliadb_protos.audit_pb2 import AuditField, AuditRequest, ChatContext, ClientType, RetrievedContext
Expand Down Expand Up @@ -98,11 +99,15 @@ async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -

response = await call_next(request)

if isinstance(response, StreamingResponse):
response = self.wrap_streaming_response(response, context)
else:
self.enqueue_pending(context)
# This task will run when the response finishes streaming
# When dealing with streaming responses, AND if we depend on any state that only will be available once
# the request is fully finished, the response we have after the dispatch call_next is not enough, as
# there, no iteration of the streaming response has been done yet.
response.background = BackgroundTask(self.enqueue_pending, context)

# It is safe to reset the context here since the asyncio task for generating the streaming response is
# already running. If we want to spawn a different task during streaming and we want that task be able
# to read the context_var, we need to manually pass the context into that task.
request_context_var.reset(token)

return response
Expand All @@ -116,29 +121,6 @@ def enqueue_pending(self, context: RequestContext):
if self.audit_utility is not None:
self.audit_utility.send(context.audit_request)

def wrap_streaming_response(
self, response: StreamingResponse, context: RequestContext
) -> StreamingResponse:
"""
When dealing with streaming responses, AND if we depend on any state that only will be available once
the request is fully finished, the response we have after the dispatch call_next is not enough, as
there, no iteration of the streaming response has been done yet.
This is why we need to rewrap to be able to to the auditing at the _real_ request end without losing
any audit bits.
"""
original_body_iterator = response.body_iterator

async def custom_body_iterator():
try:
async for chunk in original_body_iterator:
yield chunk
finally:
self.enqueue_pending(context)

response.body_iterator = custom_body_iterator()
return response


KB_USAGE_STREAM_SUBJECT = "kb-usage.nuclia_db"

Expand Down
Loading

0 comments on commit f16f3d7

Please sign in to comment.