Skip to content

Commit

Permalink
Add load function for flow run and node run (#1939)
Browse files Browse the repository at this point in the history
# Description

1. In _local_storage_operations.py, maintain dicts of node run and flow
run info to quickly get run info for certain line number. Split
load_details into two functions to avoid repetition of calls.
2. In _run_storage, add class AbstractBatchRunStorage which inherits
AbstractRunStorage. It has unique load functions.
3. Add test case for load functions.

# All Promptflow Contribution checklist:
- [ ] **The pull request does not introduce [breaking changes].**
- [ ] **CHANGELOG is updated for new features, bug fixes or other
significant changes.**
- [ ] **I have read the [contribution guidelines](../CONTRIBUTING.md).**
- [ ] **Create an issue and link to the pull request to get dedicated
review from promptflow team. Learn more: [suggested
workflow](../CONTRIBUTING.md#suggested-workflow).**

## General Guidelines and Best Practices
- [ ] Title of the pull request is clear and informative.
- [ ] There are a small number of commits, each of which have an
informative message. This means that previously merged commits do not
appear in the history of the PR. For more information on cleaning up the
commits in your PR, [see this
page](https://github.com/Azure/azure-powershell/blob/master/documentation/development-docs/cleaning-up-commits.md).

### Testing Guidelines
- [ ] Pull request includes test coverage for the included changes.

---------

Co-authored-by: Min Shi <minshi@microsoft.com>
  • Loading branch information
Jasmin3q and Min Shi authored Feb 7, 2024
1 parent d701cb0 commit 54e5438
Show file tree
Hide file tree
Showing 4 changed files with 182 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
from promptflow._utils.dataclass_serializer import serialize
from promptflow._utils.exception_utils import PromptflowExceptionPresenter
from promptflow._utils.logger_utils import LogContext, get_cli_sdk_logger
from promptflow._utils.multimedia_utils import get_file_reference_encoder
from promptflow._utils.multimedia_utils import get_file_reference_encoder, resolve_multimedia_data_recursively
from promptflow._utils.yaml_utils import load_yaml
from promptflow.batch._result import BatchResult
from promptflow.contracts.multimedia import Image
Expand All @@ -49,7 +49,7 @@
from promptflow.contracts.run_info import Status
from promptflow.contracts.run_mode import RunMode
from promptflow.exceptions import UserErrorException
from promptflow.storage import AbstractRunStorage
from promptflow.storage import AbstractBatchRunStorage

logger = get_cli_sdk_logger()

Expand Down Expand Up @@ -180,7 +180,7 @@ def dump(self, path: Path) -> None:
json_dump(asdict(self), path)


class LocalStorageOperations(AbstractRunStorage):
class LocalStorageOperations(AbstractBatchRunStorage):
"""LocalStorageOperations."""

LINE_NUMBER_WIDTH = 9
Expand Down Expand Up @@ -222,6 +222,9 @@ def __init__(self, run: Run, stream=False, run_mode=RunMode.Test):
self._dump_meta_file()
self._eager_mode = self._calculate_eager_mode(run)

self._loaded_flow_run_info = {} # {line_number: flow_run_info}
self._loaded_node_run_info = {} # {line_number: [node_run_info]}

@property
def eager_mode(self) -> bool:
return self._eager_mode
Expand Down Expand Up @@ -366,24 +369,8 @@ def load_detail(self, parse_const_as_str: bool = False) -> Dict[str, list]:
# legacy run with local file detail.json, then directly load from the file
return json_load(self._detail_path)
else:
json_loads = json.loads if not parse_const_as_str else json_loads_parse_const_as_str
# collect from local files and concat in the memory
flow_runs, node_runs = [], []
for line_run_record_file in sorted(self._run_infos_folder.iterdir()):
# In addition to the output jsonl files, there may be multimedia files in the output folder,
# so we should skip them.
if line_run_record_file.suffix.lower() != ".jsonl":
continue
with read_open(line_run_record_file) as f:
new_runs = [json_loads(line)["run_info"] for line in list(f)]
flow_runs += new_runs
for node_folder in sorted(self._node_infos_folder.iterdir()):
for node_run_record_file in sorted(node_folder.iterdir()):
if node_run_record_file.suffix.lower() != ".jsonl":
continue
with read_open(node_run_record_file) as f:
new_runs = [json_loads(line)["run_info"] for line in list(f)]
node_runs += new_runs
flow_runs = self._load_all_flow_run_info(parse_const_as_str=parse_const_as_str)
node_runs = self._load_all_node_run_info(parse_const_as_str=parse_const_as_str)
return {"flow_runs": flow_runs, "node_runs": node_runs}

def load_metrics(self, *, parse_const_as_str: bool = False) -> Dict[str, Union[int, float, str]]:
Expand All @@ -400,6 +387,33 @@ def persist_node_run(self, run_info: NodeRunInfo) -> None:
filename = f"{str(line_number).zfill(self.LINE_NUMBER_WIDTH)}.jsonl"
node_run_record.dump(node_folder / filename, run_name=self._run.name)

def _load_info_from_file(self, file_path, parse_const_as_str: bool = False):
json_loads = json.loads if not parse_const_as_str else json_loads_parse_const_as_str
run_infos = []
if file_path.suffix.lower() == ".jsonl":
with read_open(file_path) as f:
run_infos = [json_loads(line)["run_info"] for line in list(f)]
return run_infos

def _load_all_node_run_info(self, parse_const_as_str: bool = False) -> List[Dict]:
node_run_infos = []
for node_folder in sorted(self._node_infos_folder.iterdir()):
for node_run_record_file in sorted(node_folder.iterdir()):
new_runs = self._load_info_from_file(node_run_record_file, parse_const_as_str)
node_run_infos.extend(new_runs)
for new_run in new_runs:
new_run = resolve_multimedia_data_recursively(node_run_record_file, new_run)
run_info = NodeRunInfo.deserialize(new_run)
line_number = run_info.index
self._loaded_node_run_info[line_number] = self._loaded_node_run_info.get(line_number, [])
self._loaded_node_run_info[line_number].append(run_info)
return node_run_infos

def load_node_run_info_for_line(self, line_number: int = None) -> List[NodeRunInfo]:
if not self._loaded_node_run_info:
self._load_all_node_run_info()
return self._loaded_node_run_info.get(line_number)

def persist_flow_run(self, run_info: FlowRunInfo) -> None:
"""Persist line run record to local storage."""
if not Status.is_terminated(run_info.status):
Expand All @@ -417,6 +431,23 @@ def persist_flow_run(self, run_info: FlowRunInfo) -> None:
)
line_run_record.dump(self._run_infos_folder / filename)

def _load_all_flow_run_info(self, parse_const_as_str: bool = False) -> List[Dict]:
flow_run_infos = []
for line_run_record_file in sorted(self._run_infos_folder.iterdir()):
new_runs = self._load_info_from_file(line_run_record_file, parse_const_as_str)
flow_run_infos.extend(new_runs)
for new_run in new_runs:
new_run = resolve_multimedia_data_recursively(line_run_record_file, new_run)
run_info = FlowRunInfo.deserialize(new_run)
line_number = run_info.index
self._loaded_flow_run_info[line_number] = run_info
return flow_run_infos

def load_flow_run_info(self, line_number: int = None) -> FlowRunInfo:
if not self._loaded_flow_run_info:
self._load_all_flow_run_info()
return self._loaded_flow_run_info.get(line_number)

def persist_result(self, result: Optional[BatchResult]) -> None:
"""Persist metrics from return of executor."""
if result is None:
Expand Down
4 changes: 2 additions & 2 deletions src/promptflow/promptflow/storage/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,6 @@
# ---------------------------------------------------------

from ._cache_storage import AbstractCacheStorage # noqa: F401
from ._run_storage import AbstractRunStorage # noqa: F401
from ._run_storage import AbstractBatchRunStorage, AbstractRunStorage # noqa: F401

__all__ = ["AbstractCacheStorage", "AbstractRunStorage"]
__all__ = ["AbstractCacheStorage", "AbstractRunStorage", "AbstractBatchRunStorage"]
12 changes: 12 additions & 0 deletions src/promptflow/promptflow/storage/_run_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,18 @@ def persist_flow_run(self, run_info: FlowRunInfo):
raise NotImplementedError("AbstractRunStorage is an abstract class, no implementation for persist_flow_run.")


class AbstractBatchRunStorage(AbstractRunStorage):
def load_node_run_info_for_line(self, line_number: int):
raise NotImplementedError(
"AbstractBatchRunStorage is an abstract class, no implementation for load_node_run_info_for_line."
)

def load_flow_run_info(self, line_number: int):
raise NotImplementedError(
"AbstractBatchRunStorage is an abstract class, no implementation for load_flow_run_info."
)


class DummyRunStorage(AbstractRunStorage):
def persist_node_run(self, run_info: NodeRunInfo):
"""Dummy implementation for persist_node_run
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
import datetime
import json
from pathlib import Path

import pytest

from promptflow._sdk.entities._run import Run
from promptflow._sdk.operations._local_storage_operations import LocalStorageOperations
from promptflow.contracts.run_info import FlowRunInfo, RunInfo, Status


@pytest.fixture
def run_instance():
return Run(flow="flow", name="run_name")


@pytest.fixture
def local_storage(run_instance):
return LocalStorageOperations(run_instance)


@pytest.fixture
def node_run_info():
return RunInfo(
node="node1",
flow_run_id="flow_run_id",
run_id="run_id",
status=Status.Completed,
inputs={"image1": {"data:image/png;path": "test.png"}},
output={"output1": {"data:image/png;path": "test.png"}},
metrics={},
error={},
parent_run_id="parent_run_id",
start_time=datetime.datetime.now(),
end_time=datetime.datetime.now() + datetime.timedelta(seconds=5),
index=1,
)


@pytest.fixture
def flow_run_info():
return FlowRunInfo(
run_id="run_id",
status=Status.Completed,
error=None,
inputs={"image1": {"data:image/png;path": "test.png"}},
output={"output1": {"data:image/png;path": "test.png"}},
metrics={},
request="request",
parent_run_id="parent_run_id",
root_run_id="root_run_id",
source_run_id="source_run_id",
flow_id="flow_id",
start_time=datetime.datetime.now(),
end_time=datetime.datetime.now() + datetime.timedelta(seconds=5),
index=1,
)


@pytest.mark.unittest
class TestLocalStorageOperations:
def test_persist_node_run(self, local_storage, node_run_info):
local_storage.persist_node_run(node_run_info)
expected_file_path = local_storage.path / "node_artifacts" / node_run_info.node / "000000001.jsonl"
assert expected_file_path.exists()
with open(expected_file_path, "r") as file:
content = file.read()
node_run_info_dict = json.loads(content)
assert node_run_info_dict["NodeName"] == node_run_info.node
assert node_run_info_dict["line_number"] == node_run_info.index

def test_persist_flow_run(self, local_storage, flow_run_info):
local_storage.persist_flow_run(flow_run_info)
expected_file_path = local_storage.path / "flow_artifacts" / "000000001_000000001.jsonl"
assert expected_file_path.exists()
with open(expected_file_path, "r") as file:
content = file.read()
flow_run_info_dict = json.loads(content)
assert flow_run_info_dict["run_info"]["run_id"] == flow_run_info.run_id
assert flow_run_info_dict["line_number"] == flow_run_info.index

def test_load_node_run_info(self, local_storage, node_run_info):
local_storage.persist_node_run(node_run_info)
loaded_node_run_info = local_storage._load_all_node_run_info()
assert len(loaded_node_run_info) == 1
assert loaded_node_run_info[0]["node"] == node_run_info.node
assert loaded_node_run_info[0]["index"] == node_run_info.index
assert loaded_node_run_info[0]["inputs"]["image1"]["data:image/png;path"] == str(
Path(local_storage._node_infos_folder, node_run_info.node, "test.png")
)
assert loaded_node_run_info[0]["output"]["output1"]["data:image/png;path"] == str(
Path(local_storage._node_infos_folder, node_run_info.node, "test.png")
)

res = local_storage.load_node_run_info_for_line(1)
assert isinstance(res, list)
assert isinstance(res[0], RunInfo)
assert res[0].node == node_run_info.node

def test_load_flow_run_info(self, local_storage, flow_run_info):
local_storage.persist_flow_run(flow_run_info)

loaded_flow_run_info = local_storage._load_all_flow_run_info()
assert len(loaded_flow_run_info) == 1
assert loaded_flow_run_info[0]["run_id"] == flow_run_info.run_id
assert loaded_flow_run_info[0]["status"] == flow_run_info.status.value
assert loaded_flow_run_info[0]["inputs"]["image1"]["data:image/png;path"] == str(
Path(local_storage._run_infos_folder, "test.png")
)
assert loaded_flow_run_info[0]["output"]["output1"]["data:image/png;path"] == str(
Path(local_storage._run_infos_folder, "test.png")
)

res = local_storage.load_flow_run_info(1)
assert isinstance(res, FlowRunInfo)
assert res.run_id == flow_run_info.run_id

0 comments on commit 54e5438

Please sign in to comment.