Skip to content

Commit

Permalink
Add HTTP compression middleware (#676)
Browse files Browse the repository at this point in the history
* Add HTTP compression middleware

* Apply fixes from `make format`
  • Loading branch information
ok300 authored Nov 25, 2024
1 parent 2b233fd commit ee90d84
Show file tree
Hide file tree
Showing 3 changed files with 298 additions and 4 deletions.
50 changes: 48 additions & 2 deletions cashu/mint/middleware.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,17 @@
from fastapi import FastAPI
import gzip
import zlib

import brotli
import zstandard as zstd
from fastapi import FastAPI, Request, Response
from fastapi.exception_handlers import (
request_validation_exception_handler as _request_validation_exception_handler,
)
from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse
from loguru import logger
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.middleware.cors import CORSMiddleware
from starlette.requests import Request

from ..core.settings import settings
from .limit import _rate_limit_exceeded_handler, limiter_global
Expand All @@ -26,6 +31,7 @@ def add_middlewares(app: FastAPI):
allow_headers=["*"],
expose_headers=["*"],
)
app.add_middleware(CompressionMiddleware)

if settings.debug_profiling:
assert PyInstrumentProfilerMiddleware is not None
Expand Down Expand Up @@ -53,3 +59,43 @@ async def request_validation_exception_handler(
logger.error(detail)
# pass on
return await _request_validation_exception_handler(request, exc)


class CompressionMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next):
response = await call_next(request)

# Handle streaming responses differently
if response.__class__.__name__ == 'StreamingResponse':
return response

response_body = b''
async for chunk in response.body_iterator:
response_body += chunk

accept_encoding = request.headers.get("Accept-Encoding", "")
content = response_body

if "br" in accept_encoding:
content = brotli.compress(content)
response.headers["Content-Encoding"] = "br"
elif "zstd" in accept_encoding:
compressor = zstd.ZstdCompressor()
content = compressor.compress(content)
response.headers["Content-Encoding"] = "zstd"
elif "gzip" in accept_encoding:
content = gzip.compress(content)
response.headers["Content-Encoding"] = "gzip"
elif "deflate" in accept_encoding:
content = zlib.compress(content)
response.headers["Content-Encoding"] = "deflate"

response.headers["Content-Length"] = str(len(content))
response.headers["Vary"] = "Accept-Encoding"

return Response(
content=content,
status_code=response.status_code,
headers=dict(response.headers),
media_type=response.media_type
)
Loading

0 comments on commit ee90d84

Please sign in to comment.