Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(datasets): Add option to async load and save in PartitionedDatasets #696

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
Open
2 changes: 2 additions & 0 deletions kedro-datasets/RELEASE.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Upcoming Release
## Major features and improvements
* Added async functionality for loading and saving data in `PartitionedDataset` via `use_async` argument.

## Bug fixes and other changes
* Removed arbitrary upper bound for `s3fs`.
Expand All @@ -8,6 +9,7 @@
## Community contributions
Many thanks to the following Kedroids for contributing PRs to this release:
* [Charles Guan](https://github.com/charlesbmi)
* [Puneet Saini](https://github.com/puneeter)


# Release 3.0.0
Expand Down
5 changes: 1 addition & 4 deletions kedro-datasets/docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,10 +220,7 @@
todo_include_todos = False

# -- Kedro specific configuration -----------------------------------------
KEDRO_MODULES = [
"kedro_datasets",
"kedro_datasets_experimental"
]
KEDRO_MODULES = ["kedro_datasets", "kedro_datasets_experimental"]


def get_classes(module):
Expand Down
66 changes: 66 additions & 0 deletions kedro-datasets/kedro_datasets/partitions/partitioned_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from __future__ import annotations

import asyncio
import operator
from copy import deepcopy
from pathlib import PurePosixPath
Expand Down Expand Up @@ -152,6 +153,7 @@ def __init__( # noqa: PLR0913
fs_args: dict[str, Any] | None = None,
overwrite: bool = False,
metadata: dict[str, Any] | None = None,
use_async: bool = False,
) -> None:
"""Creates a new instance of ``PartitionedDataset``.

Expand Down Expand Up @@ -192,6 +194,8 @@ def __init__( # noqa: PLR0913
overwrite: If True, any existing partitions will be removed.
metadata: Any arbitrary metadata.
This is ignored by Kedro, but may be consumed by users or external plugins.
use_async: If True, the dataset will be loaded and saved asynchronously.
Defaults to False.

Raises:
DatasetError: If versioning is enabled for the underlying dataset.
Expand All @@ -206,6 +210,7 @@ def __init__( # noqa: PLR0913
self._protocol = infer_storage_options(self._path)["protocol"]
self._partition_cache: Cache = Cache(maxsize=1)
self.metadata = metadata
self._use_async = use_async

dataset = dataset if isinstance(dataset, dict) else {"type": dataset}
self._dataset_type, self._dataset_config = parse_dataset_definition(dataset)
Expand Down Expand Up @@ -285,6 +290,12 @@ def _path_to_partition(self, path: str) -> str:
return path

def _load(self) -> dict[str, Callable[[], Any]]:
if self._use_async:
return asyncio.run(self._async_load())
else:
return self._sync_load()

def _sync_load(self) -> dict[str, Callable[[], Any]]:
partitions = {}

for partition in self._list_partitions():
Expand All @@ -300,7 +311,32 @@ def _load(self) -> dict[str, Callable[[], Any]]:

return partitions

async def _async_load(self) -> dict[str, Callable[[], Any]]:
partitions = {}

async def load_partition(partition: str) -> None:
kwargs = deepcopy(self._dataset_config)
kwargs[self._filepath_arg] = self._join_protocol(partition)
dataset = self._dataset_type(**kwargs) # type: ignore
partition_id = self._path_to_partition(partition)
partitions[partition_id] = dataset.load

await asyncio.gather(
*[load_partition(partition) for partition in self._list_partitions()]
)
astrojuanlu marked this conversation as resolved.
Show resolved Hide resolved

if not partitions:
raise DatasetError(f"No partitions found in '{self._path}'")

return partitions

def _save(self, data: dict[str, Any]) -> None:
if self._use_async:
asyncio.run(self._async_save(data))
else:
self._sync_save(data)

def _sync_save(self, data: dict[str, Any]) -> None:
if self._overwrite and self._filesystem.exists(self._normalized_path):
self._filesystem.rm(self._normalized_path, recursive=True)

Expand All @@ -315,6 +351,36 @@ def _save(self, data: dict[str, Any]) -> None:
dataset.save(partition_data)
self._invalidate_caches()

async def _async_save(self, data: dict[str, Any]) -> None:
if self._overwrite and await self._filesystem_exists(self._normalized_path):
await self._filesystem_rm(self._normalized_path, recursive=True)

async def save_partition(partition_id: str, partition_data: Any) -> None:
kwargs = deepcopy(self._dataset_config)
partition = self._partition_to_path(partition_id)
kwargs[self._filepath_arg] = self._join_protocol(partition)
dataset = self._dataset_type(**kwargs) # type: ignore
if callable(partition_data):
partition_data = partition_data() # noqa: PLW2901
await self._dataset_save(dataset, partition_data)

await asyncio.gather(
*[
save_partition(partition_id, partition_data)
for partition_id, partition_data in sorted(data.items())
]
)
self._invalidate_caches()

async def _filesystem_exists(self, path: str) -> bool:
return self._filesystem.exists(path)

async def _filesystem_rm(self, path: str, recursive: bool) -> None:
self._filesystem.rm(path, recursive=recursive)

async def _dataset_save(self, dataset: AbstractDataset, data: Any) -> None:
dataset.save(data)

def _describe(self) -> dict[str, Any]:
clean_dataset_config = (
{k: v for k, v in self._dataset_config.items() if k != CREDENTIALS_KEY}
Expand Down
Loading