Skip to content

Commit

Permalink
Add HTTP compression middleware
Browse files Browse the repository at this point in the history
  • Loading branch information
ok300 committed Nov 20, 2024
1 parent 901b167 commit f36b3d8
Show file tree
Hide file tree
Showing 3 changed files with 297 additions and 3 deletions.
48 changes: 47 additions & 1 deletion cashu/mint/middleware.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from fastapi import FastAPI
from fastapi import FastAPI, Request, Response
from fastapi.exception_handlers import (
request_validation_exception_handler as _request_validation_exception_handler,
)
from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse
from loguru import logger
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.middleware.cors import CORSMiddleware
from starlette.requests import Request

Expand All @@ -17,6 +18,10 @@
from slowapi.errors import RateLimitExceeded
from slowapi.middleware import SlowAPIMiddleware

import brotli
import gzip
import zstandard as zstd
import zlib

def add_middlewares(app: FastAPI):
app.add_middleware(
Expand All @@ -26,6 +31,7 @@ def add_middlewares(app: FastAPI):
allow_headers=["*"],
expose_headers=["*"],
)
app.add_middleware(CompressionMiddleware)

if settings.debug_profiling:
assert PyInstrumentProfilerMiddleware is not None
Expand Down Expand Up @@ -53,3 +59,43 @@ async def request_validation_exception_handler(
logger.error(detail)
# pass on
return await _request_validation_exception_handler(request, exc)


class CompressionMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next):
response = await call_next(request)

# Handle streaming responses differently
if response.__class__.__name__ == 'StreamingResponse':
return response

response_body = b''
async for chunk in response.body_iterator:
response_body += chunk

accept_encoding = request.headers.get("Accept-Encoding", "")
content = response_body

if "br" in accept_encoding:
content = brotli.compress(content)
response.headers["Content-Encoding"] = "br"
elif "zstd" in accept_encoding:
compressor = zstd.ZstdCompressor()
content = compressor.compress(content)
response.headers["Content-Encoding"] = "zstd"
elif "gzip" in accept_encoding:
content = gzip.compress(content)
response.headers["Content-Encoding"] = "gzip"
elif "deflate" in accept_encoding:
content = zlib.compress(content)
response.headers["Content-Encoding"] = "deflate"

response.headers["Content-Length"] = str(len(content))
response.headers["Vary"] = "Accept-Encoding"

return Response(
content=content,
status_code=response.status_code,
headers=dict(response.headers),
media_type=response.media_type
)
Loading

0 comments on commit f36b3d8

Please sign in to comment.