Skip to content

Commit

Permalink
Merge pull request #21 from hexfrost/feat/update-6
Browse files Browse the repository at this point in the history
Fix bug and tests
  • Loading branch information
kaziamov authored Mar 3, 2024
2 parents d3c9a84 + df22d25 commit 06fe87f
Show file tree
Hide file tree
Showing 5 changed files with 78 additions and 135 deletions.
2 changes: 1 addition & 1 deletion simplecrud/crud.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ async def update_or_error(obj, params, conn: AsyncSession = None) -> object:

async def update_object_by_id(model, id: int, params, conn: AsyncSession = None) -> object:
"""Update object in db by id"""
obj = await get_object(model, dict(id=id))
obj = await get_object(model, dict(id=id), conn=conn)
updated_obj = await update_object(obj, params, conn=conn)
return updated_obj

Expand Down
6 changes: 1 addition & 5 deletions simplecrud/tests/factories.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,11 @@
from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker, AsyncSession

from simplecrud.settings import CRUDConfig


class AsyncConnFactory():

def __init__(self):
self.async_engine = create_async_engine("sqlite+aiosqlite:///./test.db")
self.async_session_maker = async_sessionmaker(self.async_engine, expire_on_commit=False, class_=AsyncSession)
self.config = CRUDConfig()
self.config.set_sessionmaker(self.async_session_maker)

def __call__(self, *args, **kwargs):
return self.config.sessionmaker
return self.async_session_maker()
149 changes: 76 additions & 73 deletions simplecrud/tests/test_crud.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,21 @@

from sqlalchemy import Column, Integer, String, create_engine
from sqlalchemy.exc import InvalidRequestError
from sqlalchemy.ext.asyncio import create_async_engine, async_sessionmaker
from sqlalchemy.orm import DeclarativeBase, Session, sessionmaker

from simplecrud.crud import *
from simplecrud.tests.factories import AsyncConnFactory
from simplecrud.utils import async_to_sync

database_url = "sqlite:///./test.db"
engine = create_engine(database_url)
session_maker = sessionmaker(bind=engine)


async_engine = create_async_engine("sqlite+aiosqlite:///./test.db")
AsyncSession = async_sessionmaker(async_engine, expire_on_commit=False, class_=AsyncSession)


class Base(DeclarativeBase):
pass

Expand All @@ -29,207 +33,206 @@ class TestAsyncCRUDFunctions(unittest.TestCase):

def setUp(self):
Base.metadata.create_all(engine)
self.session = AsyncConnFactory()

def tearDown(self):
Base.metadata.drop_all(engine)

@async_to_sync
async def test_async_create_obj(self):
params_1 = dict(name="test_async_create_obj1")
new_obj_1 = await create_object(ExampleModel, params_1)
new_obj_1 = await create_object(ExampleModel, params_1, conn=AsyncSession())
self.assertEqual(new_obj_1.name, "test_async_create_obj1")

params_2 = dict(name="test_async_create_obj2")
new_obj_2 = await create_object(ExampleModel, params_2)
new_obj_2 = await create_object(ExampleModel, params_2, conn=AsyncSession())
self.assertEqual(new_obj_2.name, "test_async_create_obj2")

@async_to_sync
async def test_create_obj_params_error(self):
params_1 = dict(name="test_create_obj_params_error", wrong="wrong")
with self.assertRaises(TypeError):
new_obj_1 = await create_object(ExampleModel, **params_1)
new_obj_1 = await create_object(ExampleModel, **params_1, conn=AsyncSession())

@async_to_sync
async def test_bulk_create(self):
all_ = await get_all(ExampleModel)
all_ = await get_all(ExampleModel, conn=AsyncSession())
self.assertEqual(len(all_), 0)
data = [dict(name=f"test_bulk_create{i}") for i in range(1, 11)]
objects = await bulk_create(ExampleModel, data)
all_ = await get_all(ExampleModel)
objects = await bulk_create(ExampleModel, data, conn=AsyncSession())
all_ = await get_all(ExampleModel, conn=AsyncSession())
self.assertEqual(len(all_), 10)
for i in range(1, 11):
self.assertEqual(all_[i - 1].name, f"test_bulk_create{i}")

@async_to_sync
async def test_get_object(self):
params_1 = dict(name="test_get_object")
await create_object(ExampleModel, params_1)
obj = await get_object(ExampleModel, params_1)
await create_object(ExampleModel, params_1, conn=AsyncSession())
obj = await get_object(ExampleModel, params_1, conn=AsyncSession())
self.assertEqual(obj.name, "test_get_object")

@async_to_sync
async def test_get_object_not_exist(self):
params_1 = dict(name="test_get_object_not_exist1")
obj = await create_object(ExampleModel, params_1)
none_expected = await get_object(ExampleModel, filters=dict(name="test_get_object_not_exist0"))
obj = await create_object(ExampleModel, params_1, conn=AsyncSession())
none_expected = await get_object(ExampleModel, filters=dict(name="test_get_object_not_exist0"), conn=AsyncSession())
self.assertEqual(none_expected, None)

@async_to_sync
async def test_get_object_error(self):
params_1 = dict(name="test_get_object_error")
obj = await create_object(ExampleModel, params_1)
obj = await create_object(ExampleModel, params_1, conn=AsyncSession())

with self.assertRaises(InvalidRequestError):
error_expected = await get_object(ExampleModel, filters=dict(wrong="wrong"))

@async_to_sync
async def test_get_all_objects(self):
all_ = await get_all(ExampleModel)
all_ = await get_all(ExampleModel, conn=AsyncSession())
self.assertEqual(len(all_), 0)
for i in range(5):
params_1 = dict(name=f"test_get_all_objects{i}")
await create_object(ExampleModel, params_1)
all_ = await get_all(ExampleModel)
await create_object(ExampleModel, params_1, conn=AsyncSession())
all_ = await get_all(ExampleModel, conn=AsyncSession())
self.assertEqual(5, len(all_))
self.assertTrue(isinstance(all_, list))
await delete_object(all_[0])
await delete_object(all_[0], conn=AsyncSession())

@async_to_sync
async def test_get_all_if_objects_not_exist(self):
all_ = await get_all(ExampleModel)
all_ = await get_all(ExampleModel, conn=AsyncSession())
self.assertEqual(len(all_), 0)
self.assertTrue(isinstance(all_, list))

@async_to_sync
async def test_get_all_with_filter(self):
for i in range(1, 6):
params_1 = dict(name=f"test_get_all_with_filter{i}")
await create_object(ExampleModel, params_1)
all_ = await get_all_with_filter(ExampleModel, dict(name="test_get_all_with_filter1"))
await create_object(ExampleModel, params_1, conn=AsyncSession())
all_ = await get_all_with_filter(ExampleModel, dict(name="test_get_all_with_filter1"), conn=AsyncSession())
self.assertEqual(len(all_), 1)

# TODO: Add test for multiple filter parameters

@async_to_sync
async def test_get_all_with_filter_error(self):
await create_object(ExampleModel, dict(name="test_get_all_with_filter_negative"))
await create_object(ExampleModel, dict(name="test_get_all_with_filter_negative"), conn=AsyncSession())
with self.assertRaises(InvalidRequestError):
await get_all_with_filter(ExampleModel, dict(wrong="wrong"))
await get_all_with_filter(ExampleModel, dict(wrong="wrong"), conn=AsyncSession())

@async_to_sync
async def test_get_all_with_filter_negative(self):
for i in range(1, 6):
params_1 = dict(name=f"test_get_all_with_filter{i}")
await create_object(ExampleModel, params_1)
all_ = await get_all_with_filter(ExampleModel, dict(name="not_exist"))
await create_object(ExampleModel, params_1, conn=AsyncSession())
all_ = await get_all_with_filter(ExampleModel, dict(name="not_exist"), conn=AsyncSession())
self.assertEqual(len(all_), 0)

@async_to_sync
async def test_get_objects_with_limit_and_per_page(self):
for i in range(1, 30):
params_1 = dict(name=f"test_get_objects_with_limit_and_ofset{i}")
await create_object(ExampleModel, params_1)
ten = await get_objects(ExampleModel, {}, limit=10, offset=10)
await create_object(ExampleModel, params_1, conn=AsyncSession())
ten = await get_objects(ExampleModel, {}, limit=10, offset=10, conn=AsyncSession())
self.assertEqual(len(ten), 10)
one = await get_objects(ExampleModel, {}, limit=1, offset=1)
one = await get_objects(ExampleModel, {}, limit=1, offset=1, conn=AsyncSession())
self.assertEqual(len(one), 1)
second = await get_objects(ExampleModel, {}, limit=1, offset=2)
second = await get_objects(ExampleModel, {}, limit=1, offset=2, conn=AsyncSession())
self.assertEqual(len(second), 1)
self.assertEqual(second[0].id, 3)

@async_to_sync
async def test_get_object_by_filters(self):
params_1 = dict(name="test_get_object_by_filters")
new_ = await create_object(ExampleModel, params_1)
obj = await get_object(ExampleModel, filters=dict(id=new_.id))
new_ = await create_object(ExampleModel, params_1, conn=AsyncSession())
obj = await get_object(ExampleModel, filters=dict(id=new_.id), conn=AsyncSession())
self.assertEqual(obj.name, "test_get_object_by_filters")

@async_to_sync
async def test_get_object_by_filters_negative(self):
params_1 = dict(name="test_get_object_by_filters_negative")
new_ = await create_object(ExampleModel, params_1)
new_ = await create_object(ExampleModel, params_1, conn=AsyncSession())
with self.assertRaises(InvalidRequestError):
obj = await get_object(ExampleModel, filters=dict(pk=new_.id))

@async_to_sync
async def test_get_or_create_object(self):
self.assertEqual(len(await get_all(ExampleModel)), 0)
self.assertEqual(len(await get_all(ExampleModel, conn=AsyncSession())), 0)
params_1 = dict(name="test_get_or_create_object")
new_1 = await get_or_create_object(ExampleModel, params_1)
self.assertEqual(len(await get_all(ExampleModel)), 1)
new_2 = await get_or_create_object(ExampleModel, params_1)
self.assertEqual(len(await get_all(ExampleModel)), 1)
new_1 = await get_or_create_object(ExampleModel, params_1, conn=AsyncSession())
self.assertEqual(len(await get_all(ExampleModel, conn=AsyncSession())), 1)
new_2 = await get_or_create_object(ExampleModel, params_1, conn=AsyncSession())
self.assertEqual(len(await get_all(ExampleModel, conn=AsyncSession())), 1)
self.assertEqual(new_1.id, new_2.id)

@async_to_sync
async def test_update_object(self):
params_1 = dict(name="test_update_object")
obj1 = await create_object(ExampleModel, params_1)
obj1 = await create_object(ExampleModel, params_1, conn=AsyncSession())
params_2 = dict(name="test_update_object2")
obj2 = await update_object(obj1, params_2)
obj2 = await update_object(obj1, params_2, conn=AsyncSession())
self.assertEqual(obj2.name, "test_update_object2")
self.assertEqual(obj1.id, obj2.id)

@async_to_sync
async def test_update_or_error(self):
obj = await create_object(ExampleModel, dict(name="test_update_or_error"))
await update_or_error(obj, dict(name="test_update_or_error_updated"))
upd_obj = await get_object(ExampleModel, dict(id=obj.id))
obj = await create_object(ExampleModel, dict(name="test_update_or_error"), conn=AsyncSession())
await update_or_error(obj, dict(name="test_update_or_error_updated"), conn=AsyncSession())
upd_obj = await get_object(ExampleModel, dict(id=obj.id), conn=AsyncSession())
self.assertEqual(upd_obj.id, obj.id)
self.assertEqual(upd_obj.name, "test_update_or_error_updated")

@async_to_sync
async def test_update_or_error_error(self):
obj = await create_object(ExampleModel, dict(name="test_update_or_error"))
obj = await create_object(ExampleModel, dict(name="test_update_or_error"), conn=AsyncSession())
with self.assertRaises(AttributeError):
await update_or_error(obj, dict(wrong="wrong"))
await update_or_error(obj, dict(wrong="wrong"), conn=AsyncSession())

@async_to_sync
async def test_update_or_error_negative(self):
params_1 = dict(name="test_update_object")
obj1 = await create_object(ExampleModel, params_1)
obj1 = await create_object(ExampleModel, params_1, conn=AsyncSession())
wrong_params = dict(wrong="test_update_object2")
with self.assertRaises(AttributeError) as error:
obj2 = await update_or_error(obj1, wrong_params)
obj2 = await update_or_error(obj1, wrong_params, conn=AsyncSession())
error_msg = "Attribute wrong not exists in ExampleModel"
self.assertEqual(error.exception.args[0], error_msg)

@async_to_sync
async def test_soft_update_without_error(self):
params_1 = dict(name="test_update_object_negative")
obj1 = await create_object(ExampleModel, params_1)
obj1 = await create_object(ExampleModel, params_1, conn=AsyncSession())
self.assertFalse(hasattr(obj1, "wrong"))
wrong_params = dict(wrong="wrong")
obj2 = await update_object(obj1, wrong_params)
obj2 = await update_object(obj1, wrong_params, conn=AsyncSession())
self.assertEqual(obj1.id, obj2.id)
self.assertFalse(hasattr(obj2, "wrong"))

@async_to_sync
async def test_update_or_error(self):
params_1 = dict(name="test1")
obj1 = await create_object(ExampleModel, params_1)
obj1 = await create_object(ExampleModel, params_1, conn=AsyncSession())
params_2 = dict(name="test2")
obj2 = await update_or_error(obj1, params_2)
obj2 = await update_or_error(obj1, params_2, conn=AsyncSession())
self.assertEqual(obj2.name, "test2")

@async_to_sync
async def test_update_by_id(self):
params_1 = dict(name="test1")
obj1 = await create_object(ExampleModel, params_1)
obj1 = await create_object(ExampleModel, params_1, conn=AsyncSession())
id_ = obj1.id
params_2 = dict(name="test2")
obj2 = await update_object_by_id(ExampleModel, id_, params_2)
obj2 = await update_object_by_id(ExampleModel, id_, params_2, conn=AsyncSession())
self.assertEqual(obj2.name, "test2")

@async_to_sync
async def test_update_or_create_object(self):
self.assertEqual(len(await get_all(ExampleModel)), 0)
self.assertEqual(len(await get_all(ExampleModel, conn=AsyncSession())), 0)
params_1 = dict(name="test_update_or_create_object1")
new_1 = await update_or_create_object(ExampleModel, params_1, params_1)
self.assertEqual(len(await get_all(ExampleModel)), 1)
new_1 = await update_or_create_object(ExampleModel, params_1, params_1, conn=AsyncSession())
self.assertEqual(len(await get_all(ExampleModel, conn=AsyncSession())), 1)
params_2 = dict(name="test_update_or_create_object2")
new_2 = await update_or_create_object(ExampleModel, params_1, params_2)
new_2 = await update_or_create_object(ExampleModel, params_1, params_2, conn=AsyncSession())
self.assertEqual(new_2.name, "test_update_or_create_object2")
self.assertEqual(new_1.id, new_2.id)

Expand All @@ -252,26 +255,26 @@ async def test_update_or_create_object(self):

@async_to_sync
async def test_delete_object(self):
all_ = await get_all(ExampleModel)
all_ = await get_all(ExampleModel, conn=AsyncSession())
self.assertEqual(len(all_), 0)
params_1 = dict(name="test_delete_object")
new_1 = await create_object(ExampleModel, params_1)
all_ = await get_all(ExampleModel)
new_1 = await create_object(ExampleModel, params_1, conn=AsyncSession())
all_ = await get_all(ExampleModel, conn=AsyncSession())
self.assertEqual(len(all_), 1)
result = await delete_object(new_1)
result = await delete_object(new_1, conn=AsyncSession())
self.assertEqual(result, True)
all_ = await get_all(ExampleModel)
get_ = await get_object(ExampleModel, filters=dict(id=new_1.id))
all_ = await get_all(ExampleModel, conn=AsyncSession())
get_ = await get_object(ExampleModel, filters=dict(id=new_1.id), conn=AsyncSession())
self.assertEqual(len(all_), 0)

@async_to_sync
async def test_delete_objects(self):
for i in range(1, 12):
params_1 = dict(name=f"test_delete_objects{i}")
await create_object(ExampleModel, params_1)
all_ = await get_all(ExampleModel)
result = await bulk_delete(all_[0:10])
all_ = await get_all(ExampleModel)
await create_object(ExampleModel, params_1, conn=AsyncSession())
all_ = await get_all(ExampleModel, conn=AsyncSession())
result = await bulk_delete(all_[0:10], conn=AsyncSession())
all_ = await get_all(ExampleModel, conn=AsyncSession())
self.assertEqual(1, len(all_))
self.assertEqual(all_[0].name, "test_delete_objects11")
self.assertEqual(all_[0].id, 11)
Expand All @@ -280,20 +283,20 @@ async def test_delete_objects(self):
async def test_bulk_delete(self):
for i in range(1, 12):
params_1 = dict(name=f"test_delete_objects{i}")
await create_object(ExampleModel, params_1)
objects = await get_all(ExampleModel)
await create_object(ExampleModel, params_1, conn=AsyncSession())
objects = await get_all(ExampleModel, conn=AsyncSession())
self.assertEqual(11, len(objects))
await bulk_delete(objects)
all_ = await get_all(ExampleModel)
await bulk_delete(objects, conn=AsyncSession())
all_ = await get_all(ExampleModel, conn=AsyncSession())
self.assertEqual(0, len(all_))

@async_to_sync
async def test_bulk_delete_by_id(self):
for i in range(1, 12):
params_1 = dict(name=f"test_delete_objects{i}")
await create_object(ExampleModel, params_1)
ids = [i.id for i in await get_all(ExampleModel)]
await create_object(ExampleModel, params_1, conn=AsyncSession())
ids = [i.id for i in await get_all(ExampleModel, conn=AsyncSession())]
self.assertEqual(11, len(ids))
await bulk_delete_by_id(ExampleModel, ids)
all_ = await get_all(ExampleModel)
await bulk_delete_by_id(ExampleModel, ids, conn=AsyncSession())
all_ = await get_all(ExampleModel, conn=AsyncSession())
self.assertEqual(0, len(all_))
Loading

0 comments on commit 06fe87f

Please sign in to comment.