diff --git a/fastcrud/crud/fast_crud.py b/fastcrud/crud/fast_crud.py index 158e8e0..49a57f9 100644 --- a/fastcrud/crud/fast_crud.py +++ b/fastcrud/crud/fast_crud.py @@ -831,6 +831,7 @@ async def upsert_multi( self, db: AsyncSession, instances: list[Union[UpdateSchemaType, CreateSchemaType]], + commit: bool = False, return_columns: Optional[list[str]] = None, schema_to_select: Optional[type[BaseModel]] = None, return_as_model: bool = False, @@ -843,6 +844,7 @@ async def upsert_multi( Args: db: The database session to use for the operation. instances: A list of Pydantic schemas representing the instances to upsert. + commit: If True, commits the transaction immediately. Default is False. return_columns: Optional list of column names to return after the upsert operation. schema_to_select: Optional Pydantic schema for selecting specific columns. Required if return_as_model is True. return_as_model: If True, returns data as instances of the specified Pydantic model. @@ -891,6 +893,8 @@ async def upsert_multi( if return_columns: statement = statement.returning(*[column(name) for name in return_columns]) db_row = await db.execute(statement, params) + if commit: + await db.commit() return self._as_multi_response( db_row, schema_to_select=schema_to_select, @@ -898,6 +902,8 @@ async def upsert_multi( ) await db.execute(statement, params) + if commit: + await db.commit() return None async def _upsert_multi_postgresql( diff --git a/tests/sqlalchemy/crud/test_upsert.py b/tests/sqlalchemy/crud/test_upsert.py index 74883da..cbaed6e 100644 --- a/tests/sqlalchemy/crud/test_upsert.py +++ b/tests/sqlalchemy/crud/test_upsert.py @@ -299,16 +299,16 @@ async def test_upsert_multi_successful( crud = FastCRUD(test_model) new_data = read_schema(id=1, name="New Record", tier_id=1, category_id=1) fetched_records = await crud.upsert_multi( - async_session, [new_data], **insert["kwargs"] + async_session, [new_data], commit=True, **insert["kwargs"] ) - + assert not async_session.in_transaction() assert fetched_records == insert["expected_result"] updated_new_data = new_data.model_copy(update={"name": "New name"}) updated_fetched_records = await crud.upsert_multi( - async_session, [updated_new_data], **update["kwargs"] + async_session, [updated_new_data], commit=True, **update["kwargs"] ) - + assert not async_session.in_transaction() assert updated_fetched_records == update["expected_result"]