Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make FlyteFile and FlyteDirectory pickleable #3030

Merged
merged 11 commits into from
Jan 8, 2025
11 changes: 4 additions & 7 deletions flytekit/types/directory/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import random
import typing
from dataclasses import dataclass, field
from functools import partial
from pathlib import Path
from typing import Any, Dict, Generator, Tuple
from uuid import UUID
Expand Down Expand Up @@ -374,23 +375,20 @@ def listdir(cls, directory: FlyteDirectory) -> typing.List[typing.Union[FlyteDir
paths.append(FlyteDirectory(joined_path))
return paths

def create_downloader(_remote_path: str, _local_path: str, is_multipart: bool):
return lambda: file_access.get_data(_remote_path, _local_path, is_multipart=is_multipart)

fs = file_access.get_filesystem_for_path(final_path)
for key in fs.listdir(final_path):
remote_path = os.path.join(final_path, key["name"].split(os.sep)[-1])
if key["type"] == "file":
local_path = file_access.get_random_local_path()
os.makedirs(pathlib.Path(local_path).parent, exist_ok=True)
downloader = create_downloader(remote_path, local_path, is_multipart=False)
downloader = partial(file_access.get_data, remote_path, local_path, is_multipart=False)

flyte_file: FlyteFile = FlyteFile(local_path, downloader=downloader)
flyte_file._remote_source = remote_path
paths.append(flyte_file)
else:
local_folder = file_access.get_random_local_directory()
downloader = create_downloader(remote_path, local_folder, is_multipart=True)
downloader = partial(file_access.get_data, remote_path, local_folder, is_multipart=True)

flyte_directory: FlyteDirectory = FlyteDirectory(path=local_folder, downloader=downloader)
flyte_directory._remote_source = remote_path
Expand Down Expand Up @@ -665,8 +663,7 @@ async def async_to_python_value(

batch_size = get_batch_size(expected_python_type)

def _downloader():
return ctx.file_access.get_data(uri, local_folder, is_multipart=True, batch_size=batch_size)
_downloader = partial(ctx.file_access.get_data, uri, local_folder, is_multipart=True, batch_size=batch_size)

expected_format = self.get_format(expected_python_type)

Expand Down
16 changes: 11 additions & 5 deletions flytekit/types/file/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import typing
from contextlib import contextmanager
from dataclasses import dataclass, field
from functools import partial
from typing import Dict, cast
from urllib.parse import unquote

Expand All @@ -20,6 +21,7 @@

from flytekit.core.constants import MESSAGEPACK
from flytekit.core.context_manager import FlyteContext, FlyteContextManager
from flytekit.core.data_persistence import FileAccessProvider
from flytekit.core.type_engine import (
AsyncTypeTransformer,
TypeEngine,
Expand Down Expand Up @@ -307,7 +309,8 @@ def __init__(
if ctx.file_access.is_remote(self.path):
self._remote_source = self.path
self._local_path = ctx.file_access.get_random_local_path(self._remote_source)
self._downloader = lambda: FlyteFilePathTransformer.downloader(
self._downloader = partial(
FlyteFilePathTransformer.downloader,
ctx=ctx,
remote_path=self._remote_source, # type: ignore
local_path=self._local_path,
Expand Down Expand Up @@ -732,25 +735,28 @@ async def async_to_python_value(

# For the remote case, return an FlyteFile object that can download
local_path = ctx.file_access.get_random_local_path(uri)

expected_format = FlyteFilePathTransformer.get_format(expected_python_type)
ff = FlyteFile.__class_getitem__(expected_format)(
path=local_path, downloader=lambda: self.downloader(ctx=ctx, remote_path=uri, local_path=local_path)
path=local_path,
downloader=partial(self.downloader, ctx.file_access, remote_path=uri, local_path=local_path),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider adding type hints to downloader

Consider updating the downloader partial function call to include type hints for better code maintainability and IDE support. The ctx.file_access parameter type could be explicitly specified as FileAccessProvider.

Code suggestion
Check the AI-generated fix before applying
Suggested change
downloader=partial(self.downloader, ctx.file_access, remote_path=uri, local_path=local_path),
downloader=partial(self.downloader, typing.cast(FileAccessProvider, ctx.file_access), remote_path=uri, local_path=local_path),

Code Review Run #d95336


Is this a valid issue, or was it incorrectly flagged by the Agent?

  • it was incorrectly flagged

)
ff._remote_source = uri

return ff

@staticmethod
def downloader(
ctx: FlyteContext, remote_path: typing.Union[str, os.PathLike], local_path: typing.Union[str, os.PathLike]
file_access_provider: FileAccessProvider,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This breaks backward compatibility. Is it safe to change the signature here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As far as I can tell this staticmethod was added in a PR 4 days ago: https://github.com/flyteorg/flytekit/pull/2991/files

And the only call site was in this module itself

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I rather not have this be part of the public API. Can this be renamed to _remote_downloader?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can also revert this to the older implementation with the pure partial

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That works for me.

remote_path: typing.Union[str, os.PathLike],
local_path: typing.Union[str, os.PathLike],
) -> None:
"""
Download data from remote_path to local_path.

We design the downloader as a static method because its behavior is logically
related to this class but don't need to interact with class or instance data.
"""
ctx.file_access.get_data(remote_path, local_path, is_multipart=False)
file_access_provider.get_data(remote_path, local_path, is_multipart=False)

def guess_python_type(self, literal_type: LiteralType) -> typing.Type[FlyteFile[typing.Any]]:
if (
Expand Down
30 changes: 27 additions & 3 deletions tests/flytekit/unit/core/test_flyte_directory.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
import pathlib
import pickle
import shutil
import tempfile
import typing
Expand All @@ -20,7 +21,7 @@
from flytekit.core.workflow import workflow
from flytekit.exceptions.user import FlyteAssertion
from flytekit.models.core.types import BlobType
from flytekit.models.literals import LiteralMap
from flytekit.models.literals import LiteralMap, Blob, BlobMetadata
from flytekit.types.directory.types import FlyteDirectory, FlyteDirToMultipartBlobTransformer
from google.protobuf import json_format as _json_format
from google.protobuf import struct_pb2 as _struct
Expand Down Expand Up @@ -407,8 +408,7 @@ def my_wf(path: SvgDirectory) -> DC:
assert dc1 == dc2


def test_input_from_flyte_console_attribute_access_flytefile(
local_dummy_directory):
def test_input_from_flyte_console_attribute_access_flytefile(local_dummy_directory):
# Flyte Console will send the input data as protobuf Struct

dict_obj = {"path": local_dummy_directory}
Expand All @@ -422,3 +422,27 @@ def test_input_from_flyte_console_attribute_access_flytefile(
FlyteContextManager.current_context(), upstream_output, FlyteDirectory)
assert isinstance(downstream_input, FlyteDirectory)
assert downstream_input == FlyteDirectory(local_dummy_directory)


def test_flyte_directory_is_pickleable():
upstream_output = Literal(
scalar=Scalar(
blob=Blob(
uri="s3://sample-path/directory",
metadata=BlobMetadata(
type=BlobType(
dimensionality=BlobType.BlobDimensionality.MULTIPART,
format=""
)
)
)
)
)
downstream_input = TypeEngine.to_python_value(
FlyteContextManager.current_context(), upstream_output, FlyteDirectory
)

# test round trip pickling
pickled_input = pickle.dumps(downstream_input)
unpickled_input = pickle.loads(pickled_input)
assert downstream_input == unpickled_input
Comment on lines +445 to +448
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider adding more pickle test assertions

Consider adding more assertions to verify the pickled/unpickled FlyteDirectory object's properties like uri and other attributes are preserved correctly after deserialization.

Code suggestion
Check the AI-generated fix before applying
Suggested change
# test round trip pickling
pickled_input = pickle.dumps(downstream_input)
unpickled_input = pickle.loads(pickled_input)
assert downstream_input == unpickled_input
# test round trip pickling
pickled_input = pickle.dumps(downstream_input)
unpickled_input = pickle.loads(pickled_input)
assert downstream_input == unpickled_input
assert unpickled_input.uri == "s3://sample-path/directory"

Code Review Run #639caa


Is this a valid issue, or was it incorrectly flagged by the Agent?

  • it was incorrectly flagged

27 changes: 26 additions & 1 deletion tests/flytekit/unit/core/test_flyte_file.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import json
import os
import pathlib
import pickle
import tempfile
import typing
from unittest.mock import MagicMock, patch
Expand All @@ -19,7 +20,7 @@
from flytekit.core.type_engine import TypeEngine
from flytekit.core.workflow import workflow
from flytekit.models.core.types import BlobType
from flytekit.models.literals import LiteralMap
from flytekit.models.literals import LiteralMap, Blob, BlobMetadata
from flytekit.types.file.file import FlyteFile, FlyteFilePathTransformer
from google.protobuf import json_format as _json_format
from google.protobuf import struct_pb2 as _struct
Expand Down Expand Up @@ -782,3 +783,27 @@ def test_input_from_flyte_console_attribute_access_flytefile(local_dummy_file):
downstream_input = TypeEngine.to_python_value(
FlyteContextManager.current_context(), upstream_output, FlyteFile)
assert downstream_input == FlyteFile(local_dummy_file)


def test_flyte_file_is_pickleable():
upstream_output = Literal(
scalar=Scalar(
blob=Blob(
uri="s3://sample-path/file",
metadata=BlobMetadata(
type=BlobType(
dimensionality=BlobType.BlobDimensionality.SINGLE,
format="txt"
)
)
)
)
)
downstream_input = TypeEngine.to_python_value(
FlyteContextManager.current_context(), upstream_output, FlyteFile
)

# test round trip pickling
pickled_input = pickle.dumps(downstream_input)
unpickled_input = pickle.loads(pickled_input)
assert downstream_input == unpickled_input
Loading