Skip to content

Commit

Permalink
Make FlyteFile and FlyteDirectory pickleable (flyteorg#3030)
Browse files Browse the repository at this point in the history
* make _downloader function in FlyteFile/Directory pickleable

Signed-off-by: Niels Bantilan <niels.bantilan@gmail.com>

* make FlyteFile and Directory pickleable

Signed-off-by: Niels Bantilan <niels.bantilan@gmail.com>

* remove unnecessary helper functions

Signed-off-by: Niels Bantilan <niels.bantilan@gmail.com>

* fix lint

Signed-off-by: Niels Bantilan <niels.bantilan@gmail.com>

* use partials instead of lambda

Signed-off-by: Niels Bantilan <niels.bantilan@gmail.com>

* fix lint

Signed-off-by: Niels Bantilan <niels.bantilan@gmail.com>

* remove unneeded helper function

Signed-off-by: Niels Bantilan <niels.bantilan@gmail.com>

* update FlyteFilePathTransformer.downloader method

Signed-off-by: Niels Bantilan <niels.bantilan@gmail.com>

* remove downloader staticmethod

Signed-off-by: Niels Bantilan <niels.bantilan@gmail.com>

* fix lint

Signed-off-by: Niels Bantilan <niels.bantilan@gmail.com>

---------

Signed-off-by: Niels Bantilan <niels.bantilan@gmail.com>
Signed-off-by: Shuying Liang <shuying.liang@gmail.com>
  • Loading branch information
cosmicBboy authored and shuyingliang committed Jan 11, 2025
1 parent ad75523 commit fe31067
Show file tree
Hide file tree
Showing 4 changed files with 64 additions and 28 deletions.
11 changes: 4 additions & 7 deletions flytekit/types/directory/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import random
import typing
from dataclasses import dataclass, field
from functools import partial
from pathlib import Path
from typing import Any, Dict, Generator, Tuple
from uuid import UUID
Expand Down Expand Up @@ -374,23 +375,20 @@ def listdir(cls, directory: FlyteDirectory) -> typing.List[typing.Union[FlyteDir
paths.append(FlyteDirectory(joined_path))
return paths

def create_downloader(_remote_path: str, _local_path: str, is_multipart: bool):
return lambda: file_access.get_data(_remote_path, _local_path, is_multipart=is_multipart)

fs = file_access.get_filesystem_for_path(final_path)
for key in fs.listdir(final_path):
remote_path = os.path.join(final_path, key["name"].split(os.sep)[-1])
if key["type"] == "file":
local_path = file_access.get_random_local_path()
os.makedirs(pathlib.Path(local_path).parent, exist_ok=True)
downloader = create_downloader(remote_path, local_path, is_multipart=False)
downloader = partial(file_access.get_data, remote_path, local_path, is_multipart=False)

flyte_file: FlyteFile = FlyteFile(local_path, downloader=downloader)
flyte_file._remote_source = remote_path
paths.append(flyte_file)
else:
local_folder = file_access.get_random_local_directory()
downloader = create_downloader(remote_path, local_folder, is_multipart=True)
downloader = partial(file_access.get_data, remote_path, local_folder, is_multipart=True)

flyte_directory: FlyteDirectory = FlyteDirectory(path=local_folder, downloader=downloader)
flyte_directory._remote_source = remote_path
Expand Down Expand Up @@ -665,8 +663,7 @@ async def async_to_python_value(

batch_size = get_batch_size(expected_python_type)

def _downloader():
return ctx.file_access.get_data(uri, local_folder, is_multipart=True, batch_size=batch_size)
_downloader = partial(ctx.file_access.get_data, uri, local_folder, is_multipart=True, batch_size=batch_size)

expected_format = self.get_format(expected_python_type)

Expand Down
24 changes: 7 additions & 17 deletions flytekit/types/file/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import typing
from contextlib import contextmanager
from dataclasses import dataclass, field
from functools import partial
from typing import Dict, cast
from urllib.parse import unquote

Expand Down Expand Up @@ -307,7 +308,8 @@ def __init__(
if ctx.file_access.is_remote(self.path):
self._remote_source = self.path
self._local_path = ctx.file_access.get_random_local_path(self._remote_source)
self._downloader = lambda: FlyteFilePathTransformer.downloader(
self._downloader = partial(
ctx.file_access.get_data,
ctx=ctx,
remote_path=self._remote_source, # type: ignore
local_path=self._local_path,
Expand Down Expand Up @@ -732,26 +734,14 @@ async def async_to_python_value(

# For the remote case, return an FlyteFile object that can download
local_path = ctx.file_access.get_random_local_path(uri)

_downloader = partial(ctx.file_access.get_data, remote_path=uri, local_path=local_path, is_multipart=False)

expected_format = FlyteFilePathTransformer.get_format(expected_python_type)
ff = FlyteFile.__class_getitem__(expected_format)(
path=local_path, downloader=lambda: self.downloader(ctx=ctx, remote_path=uri, local_path=local_path)
)
ff = FlyteFile.__class_getitem__(expected_format)(path=local_path, downloader=_downloader)
ff._remote_source = uri

return ff

@staticmethod
def downloader(
ctx: FlyteContext, remote_path: typing.Union[str, os.PathLike], local_path: typing.Union[str, os.PathLike]
) -> None:
"""
Download data from remote_path to local_path.
We design the downloader as a static method because its behavior is logically
related to this class but don't need to interact with class or instance data.
"""
ctx.file_access.get_data(remote_path, local_path, is_multipart=False)

def guess_python_type(self, literal_type: LiteralType) -> typing.Type[FlyteFile[typing.Any]]:
if (
literal_type.blob is not None
Expand Down
30 changes: 27 additions & 3 deletions tests/flytekit/unit/core/test_flyte_directory.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
import pathlib
import pickle
import shutil
import tempfile
import typing
Expand All @@ -20,7 +21,7 @@
from flytekit.core.workflow import workflow
from flytekit.exceptions.user import FlyteAssertion
from flytekit.models.core.types import BlobType
from flytekit.models.literals import LiteralMap
from flytekit.models.literals import LiteralMap, Blob, BlobMetadata
from flytekit.types.directory.types import FlyteDirectory, FlyteDirToMultipartBlobTransformer
from google.protobuf import json_format as _json_format
from google.protobuf import struct_pb2 as _struct
Expand Down Expand Up @@ -407,8 +408,7 @@ def my_wf(path: SvgDirectory) -> DC:
assert dc1 == dc2


def test_input_from_flyte_console_attribute_access_flytefile(
local_dummy_directory):
def test_input_from_flyte_console_attribute_access_flytefile(local_dummy_directory):
# Flyte Console will send the input data as protobuf Struct

dict_obj = {"path": local_dummy_directory}
Expand All @@ -422,3 +422,27 @@ def test_input_from_flyte_console_attribute_access_flytefile(
FlyteContextManager.current_context(), upstream_output, FlyteDirectory)
assert isinstance(downstream_input, FlyteDirectory)
assert downstream_input == FlyteDirectory(local_dummy_directory)


def test_flyte_directory_is_pickleable():
upstream_output = Literal(
scalar=Scalar(
blob=Blob(
uri="s3://sample-path/directory",
metadata=BlobMetadata(
type=BlobType(
dimensionality=BlobType.BlobDimensionality.MULTIPART,
format=""
)
)
)
)
)
downstream_input = TypeEngine.to_python_value(
FlyteContextManager.current_context(), upstream_output, FlyteDirectory
)

# test round trip pickling
pickled_input = pickle.dumps(downstream_input)
unpickled_input = pickle.loads(pickled_input)
assert downstream_input == unpickled_input
27 changes: 26 additions & 1 deletion tests/flytekit/unit/core/test_flyte_file.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import json
import os
import pathlib
import pickle
import tempfile
import typing
from unittest.mock import MagicMock, patch
Expand All @@ -19,7 +20,7 @@
from flytekit.core.type_engine import TypeEngine
from flytekit.core.workflow import workflow
from flytekit.models.core.types import BlobType
from flytekit.models.literals import LiteralMap
from flytekit.models.literals import LiteralMap, Blob, BlobMetadata
from flytekit.types.file.file import FlyteFile, FlyteFilePathTransformer
from google.protobuf import json_format as _json_format
from google.protobuf import struct_pb2 as _struct
Expand Down Expand Up @@ -782,3 +783,27 @@ def test_input_from_flyte_console_attribute_access_flytefile(local_dummy_file):
downstream_input = TypeEngine.to_python_value(
FlyteContextManager.current_context(), upstream_output, FlyteFile)
assert downstream_input == FlyteFile(local_dummy_file)


def test_flyte_file_is_pickleable():
upstream_output = Literal(
scalar=Scalar(
blob=Blob(
uri="s3://sample-path/file",
metadata=BlobMetadata(
type=BlobType(
dimensionality=BlobType.BlobDimensionality.SINGLE,
format="txt"
)
)
)
)
)
downstream_input = TypeEngine.to_python_value(
FlyteContextManager.current_context(), upstream_output, FlyteFile
)

# test round trip pickling
pickled_input = pickle.dumps(downstream_input)
unpickled_input = pickle.loads(pickled_input)
assert downstream_input == unpickled_input

0 comments on commit fe31067

Please sign in to comment.