diff --git a/app/database.py b/app/database.py index acf4f39..8e947b0 100644 --- a/app/database.py +++ b/app/database.py @@ -1,12 +1,7 @@ import os -from typing import Iterator from loguru import logger -from sqlalchemy.ext.asyncio import ( - AsyncSession, - async_sessionmaker, - create_async_engine, -) +from sqlalchemy.ext.asyncio import create_async_engine db_async_url = ( os.environ["DB_ASYNC_URL"] @@ -18,20 +13,26 @@ async_engine = create_async_engine(db_async_url) +async def get_async_engine(): + async with async_engine.begin() as conn: + yield conn + async_engine.sync_engine.dispose() + + # def make_async_engine() -> AsyncEngine: # global async_engine # if async_engine is None: # async_engine = create_async_engine(db_async_url) -async def get_async_session() -> Iterator[AsyncSession]: - session = async_sessionmaker( - async_engine, - autocommit=False, - autoflush=False, - # bind=create_async_engine(db_async_url), - class_=AsyncSession, - expire_on_commit=False, - ) - async with session() as sess: - yield sess +# async def get_async_session() -> Iterator[AsyncSession]: +# session = async_sessionmaker( +# async_engine, +# autocommit=False, +# autoflush=False, +# # bind=create_async_engine(db_async_url), +# class_=AsyncSession, +# expire_on_commit=False, +# ) +# async with session() as sess: +# yield sess diff --git a/app/main.py b/app/main.py index ce7df69..9f29591 100644 --- a/app/main.py +++ b/app/main.py @@ -1,12 +1,9 @@ from fastapi import Depends, FastAPI -from loguru import logger -from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.sql import select -# import example_package.dataclasses.orm as d from example_package.dataclasses import person -from .database import get_async_session +from .database import get_async_engine app = FastAPI() @@ -29,10 +26,11 @@ async def read_root(): @app.get("/async/persons") async def get_async_persons( - session: AsyncSession = Depends(get_async_session), + conn=Depends(get_async_engine), ): stmt = select(person) - res = await session.execute(stmt) - res = res.all() - logger.info(res) - return res + res = await conn.execute(stmt) + + return [ + dict(zip(person.columns.keys(), list(tmp))) for tmp in res.fetchall() + ] diff --git a/tests/conftest.py b/tests/conftest.py index 565dd30..d8a9a60 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -4,18 +4,11 @@ import pytest_asyncio from httpx import AsyncClient -from sqlalchemy.ext.asyncio import ( - AsyncEngine, - AsyncSession, - async_sessionmaker, - create_async_engine, -) - -import example_package.dataclasses.orm as d -from app.database import get_async_session -from app.main import app +from sqlalchemy.ext.asyncio import create_async_engine -# from example_package.dataclasses import metadata +from app.database import get_async_engine +from app.main import app +from example_package.dataclasses import metadata @pytest.fixture(scope="session") @@ -25,50 +18,50 @@ def event_loop() -> asyncio.AbstractEventLoop: loop.close() -# @pytest_asyncio.fixture(autouse=True) -# async def get_engine(): -# db_url = "postgresql+asyncpg://postgres:postgres@localhost/postgres" -# engine = create_async_engine(db_url) +@pytest_asyncio.fixture(autouse=True) +async def get_engine(): + db_url = "postgresql+asyncpg://postgres:postgres@localhost/postgres" + engine = create_async_engine(db_url) -# async with engine.begin() as conn: -# await conn.run_sync(metadata.drop_all) -# await conn.run_sync(metadata.create_all) + async with engine.begin() as conn: + await conn.run_sync(metadata.drop_all) + await conn.run_sync(metadata.create_all) -# yield conn -# engine.sync_engine.dispose() + yield conn + engine.sync_engine.dispose() -@pytest.fixture(autouse=True) -def get_engine_orm() -> AsyncEngine: - db_url = "postgresql+asyncpg://postgres:postgres@localhost/postgres" - engine = create_async_engine(db_url) +# @pytest.fixture(autouse=True) +# def get_engine_orm() -> AsyncEngine: +# db_url = "postgresql+asyncpg://postgres:postgres@localhost/postgres" +# engine = create_async_engine(db_url) - yield engine - engine.sync_engine.dispose() +# yield engine +# engine.sync_engine.dispose() -@pytest_asyncio.fixture(autouse=True) -async def test_session_orm(get_engine_orm: AsyncEngine) -> AsyncSession: - async with get_engine_orm.begin() as conn: - await conn.run_sync(d.Base.metadata.drop_all) - await conn.run_sync(d.Base.metadata.create_all) - _local_async_session = async_sessionmaker( - expire_on_commit=False, class_=AsyncSession, bind=get_engine_orm - ) - async with _local_async_session(bind=conn) as sess: - yield sess - await sess.flush() - await sess.rollback() +# @pytest_asyncio.fixture(autouse=True) +# async def test_session_orm(get_engine_orm: AsyncEngine) -> AsyncSession: +# async with get_engine_orm.begin() as conn: +# await conn.run_sync(d.Base.metadata.drop_all) +# await conn.run_sync(d.Base.metadata.create_all) +# _local_async_session = async_sessionmaker( +# expire_on_commit=False, class_=AsyncSession, bind=get_engine_orm +# ) +# async with _local_async_session(bind=conn) as sess: +# yield sess +# await sess.flush() +# await sess.rollback() @pytest_asyncio.fixture(autouse=True) -async def test_async_client(test_session_orm: AsyncSession) -> AsyncClient: - def _local_session(): +async def test_async_client(get_engine) -> AsyncClient: + def _local_engine(): try: - yield test_session_orm + yield get_engine finally: pass - app.dependency_overrides[get_async_session] = _local_session + app.dependency_overrides[get_async_engine] = _local_engine async with AsyncClient(app=app, base_url="http://test") as ac: yield ac diff --git a/tests/test_app.py b/tests/test_app.py index 5ff7fac..daf3d6f 100644 --- a/tests/test_app.py +++ b/tests/test_app.py @@ -8,3 +8,10 @@ async def test_root(test_async_client: AsyncClient): resp = await test_async_client.get("/") assert resp.status_code == 200 assert resp.json() == {"Hello": "World"} + + +@pytest.mark.asyncio +async def test_async_persons(test_async_client: AsyncClient): + resp = await test_async_client.get("/async/persons") + assert resp.status_code == 200 + assert len(resp.json()) == 0 diff --git a/tests/test_db.py b/tests/test_db.py index af94b53..26623c1 100644 --- a/tests/test_db.py +++ b/tests/test_db.py @@ -3,74 +3,72 @@ import pytest -from sqlalchemy.orm import selectinload -from sqlalchemy.sql import select +from sqlalchemy.sql import delete, insert -import example_package.dataclasses.orm as d +from example_package.dataclasses import friendship, person -# from example_package.dataclasses import friendship, person -# @pytest.mark.asyncio -# async def test_db_core_link(get_engine): -# conn = get_engine +@pytest.mark.asyncio +async def test_db_core_link(get_engine): + conn = get_engine -# with open(Path("tests/test_data/base-persons.json"), "r") as f: -# jsons = json.load(f) + with open(Path("tests/test_data/base-persons.json"), "r") as f: + jsons = json.load(f) -# _ = await conn.execute(insert(person).values(jsons)) -# _ = await conn.execute( -# insert(friendship).values( -# [{"parent_person_id": 1, "child_person_id": 2}] -# ) -# ) -# d = delete(person).where(person.c.id == 1) -# await conn.execute(d) + _ = await conn.execute(insert(person).values(jsons)) + _ = await conn.execute( + insert(friendship).values( + [{"parent_person_id": 1, "child_person_id": 2}] + ) + ) + d = delete(person).where(person.c.id == 1) + await conn.execute(d) -@pytest.mark.asyncio -async def test_orm_link(test_session_orm): - ps = [] - with open(Path("tests/test_data/base-persons.json"), "r") as f: - jsons = json.load(f) +# @pytest.mark.asyncio +# async def test_orm_link(test_session_orm): +# ps = [] +# with open(Path("tests/test_data/base-persons.json"), "r") as f: +# jsons = json.load(f) - for tmpjson in jsons: - tmpp = d.person(**tmpjson) - ps.append(tmpp) - ps[1].parent_friendships = [ps[0]] - s1 = d.skill(name="python", persons=[ps[0]]) - test_session_orm.add_all(ps + [s1]) - await test_session_orm.commit() +# for tmpjson in jsons: +# tmpp = d.person(**tmpjson) +# ps.append(tmpp) +# ps[1].parent_friendships = [ps[0]] +# s1 = d.skill(name="python", persons=[ps[0]]) +# test_session_orm.add_all(ps + [s1]) +# await test_session_orm.commit() - p1_promise = await test_session_orm.execute( - select(d.person) - .where(d.person.id == ps[0].id) - .options(selectinload(d.person.parent_friendships)) - .options(selectinload(d.person.child_friendships)) - ) - p1_ = p1_promise.scalar() - await test_session_orm.refresh( - ps[1], attribute_names=["child_friendships"] - ) +# p1_promise = await test_session_orm.execute( +# select(d.person) +# .where(d.person.id == ps[0].id) +# .options(selectinload(d.person.parent_friendships)) +# .options(selectinload(d.person.child_friendships)) +# ) +# p1_ = p1_promise.scalar() +# await test_session_orm.refresh( +# ps[1], attribute_names=["child_friendships"] +# ) - assert s1.id == p1_.skills[0].id - assert p1_.parent_friendships == [] - assert p1_.child_friendships[0].id == ps[1].id - assert ps[0].child_friendships[0].id == ps[1].id - assert ps[1].child_friendships == [] - assert ps[1].parent_friendships == [ps[0]] +# assert s1.id == p1_.skills[0].id +# assert p1_.parent_friendships == [] +# assert p1_.child_friendships[0].id == ps[1].id +# assert ps[0].child_friendships[0].id == ps[1].id +# assert ps[1].child_friendships == [] +# assert ps[1].parent_friendships == [ps[0]] - s1_ = ( - await test_session_orm.execute( - select(d.skill).where(d.skill.id == s1.id) - ) - ).scalar() - await test_session_orm.delete(s1_) - await test_session_orm.commit() +# s1_ = ( +# await test_session_orm.execute( +# select(d.skill).where(d.skill.id == s1.id) +# ) +# ).scalar() +# await test_session_orm.delete(s1_) +# await test_session_orm.commit() - p2_ = ( - await test_session_orm.execute( - select(d.person).where(d.person.id == ps[1].id) - ) - ).scalar() - await test_session_orm.delete(p2_) - await test_session_orm.commit() +# p2_ = ( +# await test_session_orm.execute( +# select(d.person).where(d.person.id == ps[1].id) +# ) +# ).scalar() +# await test_session_orm.delete(p2_) +# await test_session_orm.commit()