Skip to content

Commit

Permalink
udpate
Browse files Browse the repository at this point in the history
  • Loading branch information
zcemycl committed Oct 28, 2023
1 parent cf9d61b commit 0a24662
Show file tree
Hide file tree
Showing 5 changed files with 124 additions and 127 deletions.
35 changes: 18 additions & 17 deletions app/database.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,7 @@
import os
from typing import Iterator

from loguru import logger
from sqlalchemy.ext.asyncio import (
AsyncSession,
async_sessionmaker,
create_async_engine,
)
from sqlalchemy.ext.asyncio import create_async_engine

db_async_url = (
os.environ["DB_ASYNC_URL"]
Expand All @@ -18,20 +13,26 @@
async_engine = create_async_engine(db_async_url)


async def get_async_engine():
async with async_engine.begin() as conn:
yield conn
async_engine.sync_engine.dispose()


# def make_async_engine() -> AsyncEngine:
# global async_engine
# if async_engine is None:
# async_engine = create_async_engine(db_async_url)


async def get_async_session() -> Iterator[AsyncSession]:
session = async_sessionmaker(
async_engine,
autocommit=False,
autoflush=False,
# bind=create_async_engine(db_async_url),
class_=AsyncSession,
expire_on_commit=False,
)
async with session() as sess:
yield sess
# async def get_async_session() -> Iterator[AsyncSession]:
# session = async_sessionmaker(
# async_engine,
# autocommit=False,
# autoflush=False,
# # bind=create_async_engine(db_async_url),
# class_=AsyncSession,
# expire_on_commit=False,
# )
# async with session() as sess:
# yield sess
16 changes: 7 additions & 9 deletions app/main.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,9 @@
from fastapi import Depends, FastAPI
from loguru import logger
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.sql import select

# import example_package.dataclasses.orm as d
from example_package.dataclasses import person

from .database import get_async_session
from .database import get_async_engine

app = FastAPI()

Expand All @@ -29,10 +26,11 @@ async def read_root():

@app.get("/async/persons")
async def get_async_persons(
session: AsyncSession = Depends(get_async_session),
conn=Depends(get_async_engine),
):
stmt = select(person)
res = await session.execute(stmt)
res = res.all()
logger.info(res)
return res
res = await conn.execute(stmt)

return [
dict(zip(person.columns.keys(), list(tmp))) for tmp in res.fetchall()
]
77 changes: 35 additions & 42 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,11 @@

import pytest_asyncio
from httpx import AsyncClient
from sqlalchemy.ext.asyncio import (
AsyncEngine,
AsyncSession,
async_sessionmaker,
create_async_engine,
)

import example_package.dataclasses.orm as d
from app.database import get_async_session
from app.main import app
from sqlalchemy.ext.asyncio import create_async_engine

# from example_package.dataclasses import metadata
from app.database import get_async_engine
from app.main import app
from example_package.dataclasses import metadata


@pytest.fixture(scope="session")
Expand All @@ -25,50 +18,50 @@ def event_loop() -> asyncio.AbstractEventLoop:
loop.close()


# @pytest_asyncio.fixture(autouse=True)
# async def get_engine():
# db_url = "postgresql+asyncpg://postgres:postgres@localhost/postgres"
# engine = create_async_engine(db_url)
@pytest_asyncio.fixture(autouse=True)
async def get_engine():
db_url = "postgresql+asyncpg://postgres:postgres@localhost/postgres"
engine = create_async_engine(db_url)

# async with engine.begin() as conn:
# await conn.run_sync(metadata.drop_all)
# await conn.run_sync(metadata.create_all)
async with engine.begin() as conn:
await conn.run_sync(metadata.drop_all)
await conn.run_sync(metadata.create_all)

# yield conn
# engine.sync_engine.dispose()
yield conn
engine.sync_engine.dispose()


@pytest.fixture(autouse=True)
def get_engine_orm() -> AsyncEngine:
db_url = "postgresql+asyncpg://postgres:postgres@localhost/postgres"
engine = create_async_engine(db_url)
# @pytest.fixture(autouse=True)
# def get_engine_orm() -> AsyncEngine:
# db_url = "postgresql+asyncpg://postgres:postgres@localhost/postgres"
# engine = create_async_engine(db_url)

yield engine
engine.sync_engine.dispose()
# yield engine
# engine.sync_engine.dispose()


@pytest_asyncio.fixture(autouse=True)
async def test_session_orm(get_engine_orm: AsyncEngine) -> AsyncSession:
async with get_engine_orm.begin() as conn:
await conn.run_sync(d.Base.metadata.drop_all)
await conn.run_sync(d.Base.metadata.create_all)
_local_async_session = async_sessionmaker(
expire_on_commit=False, class_=AsyncSession, bind=get_engine_orm
)
async with _local_async_session(bind=conn) as sess:
yield sess
await sess.flush()
await sess.rollback()
# @pytest_asyncio.fixture(autouse=True)
# async def test_session_orm(get_engine_orm: AsyncEngine) -> AsyncSession:
# async with get_engine_orm.begin() as conn:
# await conn.run_sync(d.Base.metadata.drop_all)
# await conn.run_sync(d.Base.metadata.create_all)
# _local_async_session = async_sessionmaker(
# expire_on_commit=False, class_=AsyncSession, bind=get_engine_orm
# )
# async with _local_async_session(bind=conn) as sess:
# yield sess
# await sess.flush()
# await sess.rollback()


@pytest_asyncio.fixture(autouse=True)
async def test_async_client(test_session_orm: AsyncSession) -> AsyncClient:
def _local_session():
async def test_async_client(get_engine) -> AsyncClient:
def _local_engine():
try:
yield test_session_orm
yield get_engine
finally:
pass

app.dependency_overrides[get_async_session] = _local_session
app.dependency_overrides[get_async_engine] = _local_engine
async with AsyncClient(app=app, base_url="http://test") as ac:
yield ac
7 changes: 7 additions & 0 deletions tests/test_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,10 @@ async def test_root(test_async_client: AsyncClient):
resp = await test_async_client.get("/")
assert resp.status_code == 200
assert resp.json() == {"Hello": "World"}


@pytest.mark.asyncio
async def test_async_persons(test_async_client: AsyncClient):
resp = await test_async_client.get("/async/persons")
assert resp.status_code == 200
assert len(resp.json()) == 0
116 changes: 57 additions & 59 deletions tests/test_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,74 +3,72 @@

import pytest

from sqlalchemy.orm import selectinload
from sqlalchemy.sql import select
from sqlalchemy.sql import delete, insert

import example_package.dataclasses.orm as d
from example_package.dataclasses import friendship, person

# from example_package.dataclasses import friendship, person

# @pytest.mark.asyncio
# async def test_db_core_link(get_engine):
# conn = get_engine
@pytest.mark.asyncio
async def test_db_core_link(get_engine):
conn = get_engine

# with open(Path("tests/test_data/base-persons.json"), "r") as f:
# jsons = json.load(f)
with open(Path("tests/test_data/base-persons.json"), "r") as f:
jsons = json.load(f)

# _ = await conn.execute(insert(person).values(jsons))
# _ = await conn.execute(
# insert(friendship).values(
# [{"parent_person_id": 1, "child_person_id": 2}]
# )
# )
# d = delete(person).where(person.c.id == 1)
# await conn.execute(d)
_ = await conn.execute(insert(person).values(jsons))
_ = await conn.execute(
insert(friendship).values(
[{"parent_person_id": 1, "child_person_id": 2}]
)
)
d = delete(person).where(person.c.id == 1)
await conn.execute(d)


@pytest.mark.asyncio
async def test_orm_link(test_session_orm):
ps = []
with open(Path("tests/test_data/base-persons.json"), "r") as f:
jsons = json.load(f)
# @pytest.mark.asyncio
# async def test_orm_link(test_session_orm):
# ps = []
# with open(Path("tests/test_data/base-persons.json"), "r") as f:
# jsons = json.load(f)

for tmpjson in jsons:
tmpp = d.person(**tmpjson)
ps.append(tmpp)
ps[1].parent_friendships = [ps[0]]
s1 = d.skill(name="python", persons=[ps[0]])
test_session_orm.add_all(ps + [s1])
await test_session_orm.commit()
# for tmpjson in jsons:
# tmpp = d.person(**tmpjson)
# ps.append(tmpp)
# ps[1].parent_friendships = [ps[0]]
# s1 = d.skill(name="python", persons=[ps[0]])
# test_session_orm.add_all(ps + [s1])
# await test_session_orm.commit()

p1_promise = await test_session_orm.execute(
select(d.person)
.where(d.person.id == ps[0].id)
.options(selectinload(d.person.parent_friendships))
.options(selectinload(d.person.child_friendships))
)
p1_ = p1_promise.scalar()
await test_session_orm.refresh(
ps[1], attribute_names=["child_friendships"]
)
# p1_promise = await test_session_orm.execute(
# select(d.person)
# .where(d.person.id == ps[0].id)
# .options(selectinload(d.person.parent_friendships))
# .options(selectinload(d.person.child_friendships))
# )
# p1_ = p1_promise.scalar()
# await test_session_orm.refresh(
# ps[1], attribute_names=["child_friendships"]
# )

assert s1.id == p1_.skills[0].id
assert p1_.parent_friendships == []
assert p1_.child_friendships[0].id == ps[1].id
assert ps[0].child_friendships[0].id == ps[1].id
assert ps[1].child_friendships == []
assert ps[1].parent_friendships == [ps[0]]
# assert s1.id == p1_.skills[0].id
# assert p1_.parent_friendships == []
# assert p1_.child_friendships[0].id == ps[1].id
# assert ps[0].child_friendships[0].id == ps[1].id
# assert ps[1].child_friendships == []
# assert ps[1].parent_friendships == [ps[0]]

s1_ = (
await test_session_orm.execute(
select(d.skill).where(d.skill.id == s1.id)
)
).scalar()
await test_session_orm.delete(s1_)
await test_session_orm.commit()
# s1_ = (
# await test_session_orm.execute(
# select(d.skill).where(d.skill.id == s1.id)
# )
# ).scalar()
# await test_session_orm.delete(s1_)
# await test_session_orm.commit()

p2_ = (
await test_session_orm.execute(
select(d.person).where(d.person.id == ps[1].id)
)
).scalar()
await test_session_orm.delete(p2_)
await test_session_orm.commit()
# p2_ = (
# await test_session_orm.execute(
# select(d.person).where(d.person.id == ps[1].id)
# )
# ).scalar()
# await test_session_orm.delete(p2_)
# await test_session_orm.commit()

0 comments on commit 0a24662

Please sign in to comment.