diff --git a/flytekit/types/directory/types.py b/flytekit/types/directory/types.py index 3bcbf77c97..82f35fd336 100644 --- a/flytekit/types/directory/types.py +++ b/flytekit/types/directory/types.py @@ -6,6 +6,7 @@ import random import typing from dataclasses import dataclass, field +from functools import partial from pathlib import Path from typing import Any, Dict, Generator, Tuple from uuid import UUID @@ -374,23 +375,20 @@ def listdir(cls, directory: FlyteDirectory) -> typing.List[typing.Union[FlyteDir paths.append(FlyteDirectory(joined_path)) return paths - def create_downloader(_remote_path: str, _local_path: str, is_multipart: bool): - return lambda: file_access.get_data(_remote_path, _local_path, is_multipart=is_multipart) - fs = file_access.get_filesystem_for_path(final_path) for key in fs.listdir(final_path): remote_path = os.path.join(final_path, key["name"].split(os.sep)[-1]) if key["type"] == "file": local_path = file_access.get_random_local_path() os.makedirs(pathlib.Path(local_path).parent, exist_ok=True) - downloader = create_downloader(remote_path, local_path, is_multipart=False) + downloader = partial(file_access.get_data, remote_path, local_path, is_multipart=False) flyte_file: FlyteFile = FlyteFile(local_path, downloader=downloader) flyte_file._remote_source = remote_path paths.append(flyte_file) else: local_folder = file_access.get_random_local_directory() - downloader = create_downloader(remote_path, local_folder, is_multipart=True) + downloader = partial(file_access.get_data, remote_path, local_folder, is_multipart=True) flyte_directory: FlyteDirectory = FlyteDirectory(path=local_folder, downloader=downloader) flyte_directory._remote_source = remote_path @@ -665,8 +663,7 @@ async def async_to_python_value( batch_size = get_batch_size(expected_python_type) - def _downloader(): - return ctx.file_access.get_data(uri, local_folder, is_multipart=True, batch_size=batch_size) + _downloader = partial(ctx.file_access.get_data, uri, local_folder, is_multipart=True, batch_size=batch_size) expected_format = self.get_format(expected_python_type) diff --git a/flytekit/types/file/file.py b/flytekit/types/file/file.py index 7f681af3ca..b66c48443c 100644 --- a/flytekit/types/file/file.py +++ b/flytekit/types/file/file.py @@ -7,6 +7,7 @@ import typing from contextlib import contextmanager from dataclasses import dataclass, field +from functools import partial from typing import Dict, cast from urllib.parse import unquote @@ -307,7 +308,8 @@ def __init__( if ctx.file_access.is_remote(self.path): self._remote_source = self.path self._local_path = ctx.file_access.get_random_local_path(self._remote_source) - self._downloader = lambda: FlyteFilePathTransformer.downloader( + self._downloader = partial( + ctx.file_access.get_data, ctx=ctx, remote_path=self._remote_source, # type: ignore local_path=self._local_path, @@ -732,26 +734,14 @@ async def async_to_python_value( # For the remote case, return an FlyteFile object that can download local_path = ctx.file_access.get_random_local_path(uri) + + _downloader = partial(ctx.file_access.get_data, remote_path=uri, local_path=local_path, is_multipart=False) + expected_format = FlyteFilePathTransformer.get_format(expected_python_type) - ff = FlyteFile.__class_getitem__(expected_format)( - path=local_path, downloader=lambda: self.downloader(ctx=ctx, remote_path=uri, local_path=local_path) - ) + ff = FlyteFile.__class_getitem__(expected_format)(path=local_path, downloader=_downloader) ff._remote_source = uri - return ff - @staticmethod - def downloader( - ctx: FlyteContext, remote_path: typing.Union[str, os.PathLike], local_path: typing.Union[str, os.PathLike] - ) -> None: - """ - Download data from remote_path to local_path. - - We design the downloader as a static method because its behavior is logically - related to this class but don't need to interact with class or instance data. - """ - ctx.file_access.get_data(remote_path, local_path, is_multipart=False) - def guess_python_type(self, literal_type: LiteralType) -> typing.Type[FlyteFile[typing.Any]]: if ( literal_type.blob is not None diff --git a/tests/flytekit/unit/core/test_flyte_directory.py b/tests/flytekit/unit/core/test_flyte_directory.py index be61388fa5..fdb12e1dae 100644 --- a/tests/flytekit/unit/core/test_flyte_directory.py +++ b/tests/flytekit/unit/core/test_flyte_directory.py @@ -1,5 +1,6 @@ import os import pathlib +import pickle import shutil import tempfile import typing @@ -20,7 +21,7 @@ from flytekit.core.workflow import workflow from flytekit.exceptions.user import FlyteAssertion from flytekit.models.core.types import BlobType -from flytekit.models.literals import LiteralMap +from flytekit.models.literals import LiteralMap, Blob, BlobMetadata from flytekit.types.directory.types import FlyteDirectory, FlyteDirToMultipartBlobTransformer from google.protobuf import json_format as _json_format from google.protobuf import struct_pb2 as _struct @@ -407,8 +408,7 @@ def my_wf(path: SvgDirectory) -> DC: assert dc1 == dc2 -def test_input_from_flyte_console_attribute_access_flytefile( - local_dummy_directory): +def test_input_from_flyte_console_attribute_access_flytefile(local_dummy_directory): # Flyte Console will send the input data as protobuf Struct dict_obj = {"path": local_dummy_directory} @@ -422,3 +422,27 @@ def test_input_from_flyte_console_attribute_access_flytefile( FlyteContextManager.current_context(), upstream_output, FlyteDirectory) assert isinstance(downstream_input, FlyteDirectory) assert downstream_input == FlyteDirectory(local_dummy_directory) + + +def test_flyte_directory_is_pickleable(): + upstream_output = Literal( + scalar=Scalar( + blob=Blob( + uri="s3://sample-path/directory", + metadata=BlobMetadata( + type=BlobType( + dimensionality=BlobType.BlobDimensionality.MULTIPART, + format="" + ) + ) + ) + ) + ) + downstream_input = TypeEngine.to_python_value( + FlyteContextManager.current_context(), upstream_output, FlyteDirectory + ) + + # test round trip pickling + pickled_input = pickle.dumps(downstream_input) + unpickled_input = pickle.loads(pickled_input) + assert downstream_input == unpickled_input diff --git a/tests/flytekit/unit/core/test_flyte_file.py b/tests/flytekit/unit/core/test_flyte_file.py index 1b5d32b1f2..fb0903c567 100644 --- a/tests/flytekit/unit/core/test_flyte_file.py +++ b/tests/flytekit/unit/core/test_flyte_file.py @@ -1,6 +1,7 @@ import json import os import pathlib +import pickle import tempfile import typing from unittest.mock import MagicMock, patch @@ -19,7 +20,7 @@ from flytekit.core.type_engine import TypeEngine from flytekit.core.workflow import workflow from flytekit.models.core.types import BlobType -from flytekit.models.literals import LiteralMap +from flytekit.models.literals import LiteralMap, Blob, BlobMetadata from flytekit.types.file.file import FlyteFile, FlyteFilePathTransformer from google.protobuf import json_format as _json_format from google.protobuf import struct_pb2 as _struct @@ -782,3 +783,27 @@ def test_input_from_flyte_console_attribute_access_flytefile(local_dummy_file): downstream_input = TypeEngine.to_python_value( FlyteContextManager.current_context(), upstream_output, FlyteFile) assert downstream_input == FlyteFile(local_dummy_file) + + +def test_flyte_file_is_pickleable(): + upstream_output = Literal( + scalar=Scalar( + blob=Blob( + uri="s3://sample-path/file", + metadata=BlobMetadata( + type=BlobType( + dimensionality=BlobType.BlobDimensionality.SINGLE, + format="txt" + ) + ) + ) + ) + ) + downstream_input = TypeEngine.to_python_value( + FlyteContextManager.current_context(), upstream_output, FlyteFile + ) + + # test round trip pickling + pickled_input = pickle.dumps(downstream_input) + unpickled_input = pickle.loads(pickled_input) + assert downstream_input == unpickled_input