Skip to content

Commit

Permalink
add error handler for starlette middleware to enable error logs in Op…
Browse files Browse the repository at this point in the history
…erations, fix proxy_base_url logic
  • Loading branch information
voidZXL committed Dec 13, 2024
1 parent 7a5712c commit 995d0c5
Show file tree
Hide file tree
Showing 7 changed files with 142 additions and 28 deletions.
14 changes: 13 additions & 1 deletion utilmeta/core/orm/fields/field.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,13 @@ def __init__(self, model: "ModelAdaptor" = None, **kwargs):
def reconstruct(self, model: "ModelAdaptor"):
return self.__class__(model, **self._kwargs)

def check_schema_cls(self, schema_cls):
if isinstance(schema_cls, type):
if self.model.qualify(schema_cls):
raise TypeError(f'You are using a model class: {schema_cls} to used as schema query class, '
f'which is invalid, you should make a schema class using '
f'orm.Schema[{schema_cls.__name__}]')

def get_query_schema(self):
parser = None
schema = None
Expand All @@ -69,6 +76,8 @@ def get_query_schema(self):
parser = cls_parser
schema = arg
break
else:
self.check_schema_cls(arg)
else:
if (
self.type.__origin__
Expand All @@ -83,6 +92,8 @@ def get_query_schema(self):
if cls_parser:
parser = cls_parser
schema = arg
else:
self.check_schema_cls(arg)

else:
self.related_single = True
Expand All @@ -93,7 +104,8 @@ def get_query_schema(self):
parser = cls_parser
schema = origin
break

else:
self.check_schema_cls(origin)
if parser:
if isinstance(parser, SchemaClassParser):
if parser.model:
Expand Down
2 changes: 2 additions & 0 deletions utilmeta/core/response/backends/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ class ResponseAdaptor(BaseAdaptor):
json_decoder_cls = json.JSONDecoder

def __init__(self, response):
if not self.qualify(response):
raise TypeError(f'Invalid response: {response}')
self.response = response
# self.request = request
self._context = {}
Expand Down
5 changes: 4 additions & 1 deletion utilmeta/core/server/backends/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
if TYPE_CHECKING:
from utilmeta import UtilMeta
from utilmeta.core.api import API
from utilmeta.utils import BaseAdaptor, exceptions, import_obj
from utilmeta.utils import BaseAdaptor, exceptions, import_obj, Error
import re
import inspect
from utilmeta.core.request import Request
Expand All @@ -24,6 +24,9 @@ def process_request(self, request: Request):
def process_response(self, response: Response):
pass

# def handle_error(self, error: Error):
# pass


class ServerAdaptor(BaseAdaptor):
# __backends_route__ = 'backends'
Expand Down
96 changes: 90 additions & 6 deletions utilmeta/core/server/backends/starlette.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,17 @@
from starlette.requests import Request as StarletteRequest
from starlette.applications import Starlette
from starlette.concurrency import iterate_in_threadpool
from starlette.responses import Response as StarletteResponse
from starlette.middleware.base import _StreamingResponse
from .base import ServerAdaptor
from utilmeta.core.response import Response
from utilmeta.core.request.backends.starlette import StarletteRequestAdaptor
from utilmeta.core.response.backends.starlette import StarletteResponseAdaptor
from utilmeta.core.api import API
from utilmeta.core.request import Request
from utilmeta.utils import HAS_BODY_METHODS, RequestType, exceptions
from utilmeta.utils import HAS_BODY_METHODS, RequestType, exceptions, pop, Error
import contextvars
from typing import Optional
from urllib.parse import urlparse

_current_request = contextvars.ContextVar("_starlette.request")
# _current_response = contextvars.ContextVar('_starlette.response')
Expand Down Expand Up @@ -100,17 +100,71 @@ def setup_middlewares(self):
self.app.add_middleware(
BaseHTTPMiddleware, dispatch=self.get_middleware_func() # noqa
)
# self.app.add_exception_handler(Exception, handler=self.get_exception_handler)

@classmethod
async def get_response_body(cls, starlette_response: _StreamingResponse) -> bytes:
response_body = [chunk async for chunk in starlette_response.body_iterator]
starlette_response.body_iterator = iterate_in_threadpool(iter(response_body))
return b"".join(response_body)

# async def get_exception_handler(self, req, exc):
# # async def http_error_handler(_: Request, exc) -> JSONResponse:
# # return JSONResponse({"errors": [exc.detail]}, status_code=exc.status_code)
# handlers = dict(self.app.exception_handlers or {})
# pop(handlers, Exception)
# request = _current_request.get(None) or Request(self.request_adaptor_cls(req))
# error = Error(exc, request=request)
# for middleware in self.middlewares:
# middleware.handle_error(error)
# err_handler = None
# for cls in type(exc).__mro__:
# if cls in handlers:
# err_handler = handlers[cls]
# break
# if not err_handler:
# raise exc from exc
# from starlette.concurrency import run_in_threadpool
# try:
# if inspect.iscoroutinefunction(err_handler):
# resp = await err_handler(req, exc)
# else:
# resp = await run_in_threadpool(err_handler, req, exc)
# except Exception as handler_exc:
# handler_error = Error(handler_exc, request=request)
# for middleware in self.middlewares:
# middleware.handle_error(handler_error)
# middleware.process_response(Response(error=handler_error))
# # there is probably not response
# raise
#
# adaptor = self.response_adaptor_cls(resp)
# if (
# adaptor.content_length or 0
# ) <= self.RECORD_RESPONSE_BODY_LENGTH_LTE:
# body = await self.get_response_body(resp)
# resp.body = body
# # set body
# response = Response(response=adaptor, request=request)
# response_updated = False
# for middleware in self.middlewares:
# _response = middleware.process_response(response)
# if inspect.isawaitable(_response):
# _response = await _response
# if isinstance(_response, Response):
# response = _response
# response_updated = True
#
# if not resp or response_updated:
# resp = self.response_adaptor_cls.reconstruct(response)
#
# return resp

def get_middleware_func(self):
async def utilmeta_middleware(starlette_request: StarletteRequest, call_next):
response = None
starlette_response = None
exc = None

request = Request(self.request_adaptor_cls(starlette_request))
for middleware in self.middlewares:
Expand All @@ -135,11 +189,38 @@ async def utilmeta_middleware(starlette_request: StarletteRequest, call_next):
# and you cannot read it after response is generated

_current_request.set(request)
starlette_response: Optional[_StreamingResponse] = await call_next(
starlette_request
)
try:
starlette_response: Optional[_StreamingResponse] = await call_next(
starlette_request
)
except Exception as e:
handlers = dict(self.app.exception_handlers or {}) # noqa
err_handler = None
for cls in type(e).__mro__:
if cls in handlers:
err_handler = handlers[cls]
break
error = Error(e, request=request)
if err_handler:
from starlette.concurrency import run_in_threadpool
if inspect.iscoroutinefunction(err_handler):
starlette_response = await err_handler(starlette_request, e)
else:
starlette_response = await run_in_threadpool(err_handler, starlette_request, e)
adaptor = self.response_adaptor_cls(starlette_response)
if (
adaptor.content_length or 0
) <= self.RECORD_RESPONSE_BODY_LENGTH_LTE:
body = await self.get_response_body(starlette_response)
starlette_response.body = body
response = Response(response=adaptor, error=error)
else:
starlette_response = StarletteResponse(status_code=500) # noqa: placeholder response
response = Response(response=response, error=error, request=request)
exc = e

_current_request.set(None)
response = request.adaptor.get_context("response")
response = response or request.adaptor.get_context("response")
# response = _current_response.get(None)
# _current_response.set(None)

Expand Down Expand Up @@ -170,6 +251,9 @@ async def utilmeta_middleware(starlette_request: StarletteRequest, call_next):
response = _response
response_updated = True

if exc:
raise exc from exc

if not starlette_response or response_updated:
starlette_response = self.response_adaptor_cls.reconstruct(response)

Expand Down
34 changes: 20 additions & 14 deletions utilmeta/ops/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,14 +205,6 @@ def __init__(
self.local_scope = list(local_scope or [])
self.report_disabled = report_disabled

if base_url:
parsed = urlsplit(base_url)
if not parsed.scheme:
raise ValueError(
f"Operations base_url should be an absolute url, got {base_url}"
)
self._base_url = self.parse_base_url(base_url)

if self.HOST not in self.trusted_hosts:
self.trusted_hosts.append(self.HOST)
if not isinstance(monitor, self.Monitor):
Expand All @@ -239,12 +231,22 @@ def __init__(
)
self.proxy = proxy

@classmethod
def parse_base_url(cls, url: str):
if base_url:
parsed = urlsplit(base_url)
if not parsed.scheme:
raise ValueError(
f"Operations base_url should be an absolute url, got {base_url}"
)
self._base_url = self.parse_base_url(base_url)

def parse_base_url(self, url: str):
if not url:
return url
if "$IP" in url:
url = url.replace("$IP", get_server_ip())
ip = get_server_ip(private_only=bool(self.proxy))
if self.proxy:
ip = ip or get_server_ip() or '127.0.0.1'
url = url.replace("$IP", ip)
return url

def load_openapi(self, no_store: bool = False):
Expand Down Expand Up @@ -688,9 +690,13 @@ def proxy_base_url(self):
except ImportError:
return None
origin = self.proxy_origin
if not service.adaptor.backend_views_empty:
return origin
return url_join(origin, service.root_url)
route = service.root_url
if self._base_url:
parsed = urlsplit(self._base_url)
if parsed.scheme:
# is url
route = parsed.path
return url_join(origin, route)

# def check_host(self):
# parsed = urlsplit(self.ops_api)
Expand Down
6 changes: 6 additions & 0 deletions utilmeta/ops/log.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,6 +493,12 @@ def process_response(self, response: Response):
if len(_responses_queue) >= self.config.max_backlog:
threading.Thread(target=batch_save_logs, kwargs=dict(close=True)).start()

# def handle_error(self, error: Error, response=None):
# logger: Logger = _logger.get(None)
# if not logger:
# raise error.throw()
# logger.commit_error(error)


class Logger(Property):
__context__ = ContextProperty(_logger)
Expand Down
13 changes: 7 additions & 6 deletions utilmeta/utils/functional/sys.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from datetime import datetime
from typing import Optional, List, Union, Tuple, Dict, Set
from .. import constant
from .data import distinct_add
from ipaddress import ip_address, ip_network

posix_os = os.name == "posix"
Expand Down Expand Up @@ -169,16 +170,16 @@ def get_network_ip(ifname: str):
return None


def get_server_ips(max_devices: int = 3) -> Set[str]:
def get_server_ips(max_devices: int = 3) -> List[str]:
ip = socket.gethostbyname(socket.gethostname())
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)

ips = set()
ips = []
for i in range(0, max_devices):
if_ip = get_network_ip(f"eth{i}")
if if_ip:
if not if_ip.startswith("127."):
ips.add(if_ip)
distinct_add(ips, if_ip)
else:
break

Expand All @@ -187,10 +188,10 @@ def get_server_ips(max_devices: int = 3) -> Set[str]:
s.connect(("8.8.8.8", 53))
ip = str(s.getsockname()[0])
if ip:
ips.add(ip)
distinct_add(ips, ip)
s.close()
else:
ips.add(ip)
distinct_add(ips, ip)

return ips

Expand Down Expand Up @@ -226,7 +227,7 @@ def get_server_ip(private_only: bool = False) -> Optional[str]:
except ValueError:
continue

_SERVER_IP = ips.pop() if ips else constant.LOCAL_IP
_SERVER_IP = ips[0] if ips else constant.LOCAL_IP
if private_only:
_SERVER_PRIVATE_IP = constant.LOCAL_IP
return _SERVER_PRIVATE_IP
Expand Down

0 comments on commit 995d0c5

Please sign in to comment.