Skip to content

Commit

Permalink
Merge pull request #84 from glyph/arity
Browse files Browse the repository at this point in the history
much better error reporting on arity mismatches in loaders
  • Loading branch information
glyph authored Mar 28, 2024
2 parents 2152e7b + 50f16a1 commit 77c3bf1
Show file tree
Hide file tree
Showing 2 changed files with 196 additions and 19 deletions.
137 changes: 119 additions & 18 deletions src/dbxs/_access.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,15 @@
from __future__ import annotations

from dataclasses import dataclass, field
from inspect import BoundArguments, isawaitable, signature
from inspect import (
BoundArguments,
currentframe,
getsourcefile,
getsourcelines,
isawaitable,
signature,
)
from types import FrameType, TracebackType
from typing import (
Any,
AsyncIterable,
Expand All @@ -12,6 +20,7 @@
Dict,
Iterable,
List,
NoReturn,
Optional,
Sequence,
Tuple,
Expand Down Expand Up @@ -61,22 +70,114 @@ class ExtraneousMethods(Exception):
"""


def one(
load: Callable[..., T],
) -> Callable[[object, AsyncCursor], Coroutine[object, object, T]]:
class WrongRowShape(TypeError):
"""
Fetch a single result with a translator function.
The row was the wrong shape for the given callable.
"""

async def translateOne(db: object, cursor: AsyncCursor) -> T:

@dataclass
class _ExceptionFixer:
loader: Callable[..., object]
definitionLine: int
decorationLine: int
decorationFrame: FrameType
definitionFrame: FrameType

def reraise(self, row: object, e: Exception) -> NoReturn:
withDecorationAdded = TracebackType(
None, self.decorationFrame, 0, self.decorationLine
)
withDefinitionAdded = TracebackType(
withDecorationAdded, self.definitionFrame, 0, self.definitionLine
)
raise WrongRowShape(
f"loader {self.loader.__module__}.{self.loader.__name__}"
f" could not handle {row}"
).with_traceback(withDefinitionAdded) from e

@classmethod
def create(cls, loader: Callable[..., T]) -> _ExceptionFixer:
subFrame = currentframe()
assert subFrame is not None
frameworkFrame = subFrame.f_back # the caller; 'one' or 'many'
assert frameworkFrame is not None
realDecorationFrame = frameworkFrame.f_back
assert realDecorationFrame is not None
wholeSource, definitionLine = getsourcelines(loader)

# coverage is tricked by the __code__ modifications below, so we have
# to explicitly ignore the gap

def decoratedHere() -> FrameType | None:
return currentframe() # pragma: no cover

def definedHere() -> FrameType | None:
return currentframe() # pragma: no cover

decoratedHere.__code__ = decoratedHere.__code__.replace(
co_name="<<decorated here>>",
co_filename=realDecorationFrame.f_code.co_filename,
co_firstlineno=realDecorationFrame.f_lineno,
)

definedSourceFile = getsourcefile(loader)
definedHere.__code__ = definedHere.__code__.replace(
co_name="<<defined here>>",
co_filename=definedSourceFile or "unknown definition",
co_firstlineno=definitionLine,
)

fakeDecorationFrame = decoratedHere()
definitionFrame = definedHere()
assert realDecorationFrame is not None
assert definitionFrame is not None
assert fakeDecorationFrame is not None

return cls(
loader=loader,
definitionFrame=definitionFrame,
definitionLine=definitionLine,
decorationFrame=fakeDecorationFrame,
decorationLine=realDecorationFrame.f_lineno,
)


_NR = TypeVar("_NR")


def _makeTranslator(
fixer: _ExceptionFixer,
load: Callable[..., T],
noResults: Callable[[], _NR],
) -> Callable[[object, AsyncCursor], Coroutine[object, object, T | _NR]]:
async def translator(db: object, cursor: AsyncCursor) -> T | _NR:
rows = await cursor.fetchall()
if len(rows) < 1:
raise NotEnoughResults()
return noResults()
if len(rows) > 1:
raise TooManyResults()
return load(db, *rows[0])
[row] = rows
try:
return load(db, *row)
except TypeError as e:
fixer.reraise(row, e)

return translateOne
return translator


def one(
load: Callable[..., T],
) -> Callable[[object, AsyncCursor], Coroutine[object, object, T]]:
"""
Fetch a single result with a translator function.
"""
fixer = _ExceptionFixer.create(load)

def noResults() -> NoReturn:
raise NotEnoughResults()

return _makeTranslator(fixer, load, noResults)


def maybe(
Expand All @@ -86,16 +187,12 @@ def maybe(
Fetch a single result and pass it to a translator function, but return None
if it's not found.
"""
fixer = _ExceptionFixer.create(load)

async def translateMaybe(db: object, cursor: AsyncCursor) -> Optional[T]:
rows = await cursor.fetchall()
if len(rows) < 1:
return None
if len(rows) > 1:
raise TooManyResults()
return load(db, *rows[0])
def noResults() -> None:
return None

return translateMaybe
return _makeTranslator(fixer, load, noResults)


def many(
Expand All @@ -104,6 +201,7 @@ def many(
"""
Fetch multiple results with a function to translate rows.
"""
fixer = _ExceptionFixer.create(load)

async def translateMany(
db: object, cursor: AsyncCursor
Expand All @@ -112,7 +210,10 @@ async def translateMany(
row = await cursor.fetchone()
if row is None:
return
yield load(db, *row)
try:
yield load(db, *row)
except TypeError as e:
fixer.reraise(row, e)

return translateMany

Expand Down
78 changes: 77 additions & 1 deletion src/dbxs/test/test_access.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from __future__ import annotations

import traceback
from dataclasses import dataclass
from typing import Optional
from typing import AsyncIterable, Optional
from unittest import TestCase

from .. import (
Expand All @@ -10,6 +11,7 @@
ParamMismatch,
TooManyResults,
accessor,
many,
maybe,
one,
query,
Expand All @@ -30,11 +32,39 @@ class Foo:
baz: int


def oops( # point at this definition(one)
db: FooAccessPattern,
bar: int,
baz: int,
extra: str,
) -> str:
return extra # pragma: no cover


# duplicate definition comment on different lines below because
# inspect.getsourcelines changed behavior from 3.8 to 3.9


@dataclass # point at this definition(many)
class Oops2: # point at this definition(many)
db: FooAccessPattern
bar: int
baz: int
extra: str


class FooAccessPattern(Protocol):
@query(sql="select bar, baz from foo where bar = {bar}", load=one(Foo))
async def getFoo(self, bar: int) -> Foo:
...

@query(
sql="select bar, baz from foo order by bar asc",
load=many(Foo),
)
def allFoos(self) -> AsyncIterable[Foo]:
...

@query(sql="select bar, baz from foo where bar = {bar}", load=maybe(Foo))
async def maybeFoo(self, bar: int) -> Optional[Foo]:
...
Expand All @@ -43,6 +73,20 @@ async def maybeFoo(self, bar: int) -> Optional[Foo]:
async def oneFooByBaz(self, baz: int) -> Foo:
...

@query(
sql="select bar, baz from foo where baz = {baz}",
load=one(oops), # point at this decoration(one)
)
async def wrongArityOne(self, baz: int) -> str:
...

@query(
sql="select bar, baz from foo",
load=many(Oops2), # point at this decoration(many)
)
def wrongArityMany(self) -> AsyncIterable[Oops2]:
...

@query(sql="select bar, baz from foo where baz = {baz}", load=maybe(Foo))
async def maybeFooByBaz(self, baz: int) -> Optional[Foo]:
...
Expand Down Expand Up @@ -102,8 +146,40 @@ async def test_happyPath(self, pool: MemoryPool) -> None:
db = accessFoo(c)
result = await db.getFoo(1)
result2 = await db.maybeFoo(1)
result3 = [ # pragma: no branch
each async for each in db.allFoos()
]
self.assertEqual(result, Foo(db, 1, 3))
self.assertEqual(result, result2)
self.assertEqual(result3, [Foo(db, 1, 3), Foo(db, 2, 4)])

@immediateTest()
async def test_wrongResultArity(self, pool: MemoryPool) -> None:
"""
If the signature of the callable provided to C{query(load=one(...))} or
C{query(load=many(...))} does not match with the number of arguments
returned by the database for a row in a particular query, the error
will explain well enough to debug.
"""
async with transaction(pool.connectable) as c:
await schemaAndData(c)
db = accessFoo(c)
try:
await db.wrongArityOne(3)
except TypeError:
tbf1 = traceback.format_exc()
try:
[ # pragma: no branch
each async for each in db.wrongArityMany()
]
except TypeError:
tbf2 = traceback.format_exc()
# print(tbf1)
# print(tbf2)
self.assertIn("point at this definition(one)", tbf1)
self.assertIn("point at this decoration(one)", tbf1)
self.assertIn("point at this definition(many)", tbf2)
self.assertIn("point at this decoration(many)", tbf2)

def test_argumentExhaustiveness(self) -> None:
"""
Expand Down

0 comments on commit 77c3bf1

Please sign in to comment.