diff --git a/src/promptflow/promptflow/_sdk/operations/_local_storage_operations.py b/src/promptflow/promptflow/_sdk/operations/_local_storage_operations.py index cf2eb843209..e543486072e 100644 --- a/src/promptflow/promptflow/_sdk/operations/_local_storage_operations.py +++ b/src/promptflow/promptflow/_sdk/operations/_local_storage_operations.py @@ -40,7 +40,7 @@ from promptflow._utils.dataclass_serializer import serialize from promptflow._utils.exception_utils import PromptflowExceptionPresenter from promptflow._utils.logger_utils import LogContext, get_cli_sdk_logger -from promptflow._utils.multimedia_utils import get_file_reference_encoder +from promptflow._utils.multimedia_utils import get_file_reference_encoder, resolve_multimedia_data_recursively from promptflow._utils.yaml_utils import load_yaml from promptflow.batch._result import BatchResult from promptflow.contracts.multimedia import Image @@ -49,7 +49,7 @@ from promptflow.contracts.run_info import Status from promptflow.contracts.run_mode import RunMode from promptflow.exceptions import UserErrorException -from promptflow.storage import AbstractRunStorage +from promptflow.storage import AbstractBatchRunStorage logger = get_cli_sdk_logger() @@ -180,7 +180,7 @@ def dump(self, path: Path) -> None: json_dump(asdict(self), path) -class LocalStorageOperations(AbstractRunStorage): +class LocalStorageOperations(AbstractBatchRunStorage): """LocalStorageOperations.""" LINE_NUMBER_WIDTH = 9 @@ -222,6 +222,9 @@ def __init__(self, run: Run, stream=False, run_mode=RunMode.Test): self._dump_meta_file() self._eager_mode = self._calculate_eager_mode(run) + self._loaded_flow_run_info = {} # {line_number: flow_run_info} + self._loaded_node_run_info = {} # {line_number: [node_run_info]} + @property def eager_mode(self) -> bool: return self._eager_mode @@ -366,24 +369,8 @@ def load_detail(self, parse_const_as_str: bool = False) -> Dict[str, list]: # legacy run with local file detail.json, then directly load from the file return json_load(self._detail_path) else: - json_loads = json.loads if not parse_const_as_str else json_loads_parse_const_as_str - # collect from local files and concat in the memory - flow_runs, node_runs = [], [] - for line_run_record_file in sorted(self._run_infos_folder.iterdir()): - # In addition to the output jsonl files, there may be multimedia files in the output folder, - # so we should skip them. - if line_run_record_file.suffix.lower() != ".jsonl": - continue - with read_open(line_run_record_file) as f: - new_runs = [json_loads(line)["run_info"] for line in list(f)] - flow_runs += new_runs - for node_folder in sorted(self._node_infos_folder.iterdir()): - for node_run_record_file in sorted(node_folder.iterdir()): - if node_run_record_file.suffix.lower() != ".jsonl": - continue - with read_open(node_run_record_file) as f: - new_runs = [json_loads(line)["run_info"] for line in list(f)] - node_runs += new_runs + flow_runs = self._load_all_flow_run_info(parse_const_as_str=parse_const_as_str) + node_runs = self._load_all_node_run_info(parse_const_as_str=parse_const_as_str) return {"flow_runs": flow_runs, "node_runs": node_runs} def load_metrics(self, *, parse_const_as_str: bool = False) -> Dict[str, Union[int, float, str]]: @@ -400,6 +387,33 @@ def persist_node_run(self, run_info: NodeRunInfo) -> None: filename = f"{str(line_number).zfill(self.LINE_NUMBER_WIDTH)}.jsonl" node_run_record.dump(node_folder / filename, run_name=self._run.name) + def _load_info_from_file(self, file_path, parse_const_as_str: bool = False): + json_loads = json.loads if not parse_const_as_str else json_loads_parse_const_as_str + run_infos = [] + if file_path.suffix.lower() == ".jsonl": + with read_open(file_path) as f: + run_infos = [json_loads(line)["run_info"] for line in list(f)] + return run_infos + + def _load_all_node_run_info(self, parse_const_as_str: bool = False) -> List[Dict]: + node_run_infos = [] + for node_folder in sorted(self._node_infos_folder.iterdir()): + for node_run_record_file in sorted(node_folder.iterdir()): + new_runs = self._load_info_from_file(node_run_record_file, parse_const_as_str) + node_run_infos.extend(new_runs) + for new_run in new_runs: + new_run = resolve_multimedia_data_recursively(node_run_record_file, new_run) + run_info = NodeRunInfo.deserialize(new_run) + line_number = run_info.index + self._loaded_node_run_info[line_number] = self._loaded_node_run_info.get(line_number, []) + self._loaded_node_run_info[line_number].append(run_info) + return node_run_infos + + def load_node_run_info_for_line(self, line_number: int = None) -> List[NodeRunInfo]: + if not self._loaded_node_run_info: + self._load_all_node_run_info() + return self._loaded_node_run_info.get(line_number) + def persist_flow_run(self, run_info: FlowRunInfo) -> None: """Persist line run record to local storage.""" if not Status.is_terminated(run_info.status): @@ -417,6 +431,23 @@ def persist_flow_run(self, run_info: FlowRunInfo) -> None: ) line_run_record.dump(self._run_infos_folder / filename) + def _load_all_flow_run_info(self, parse_const_as_str: bool = False) -> List[Dict]: + flow_run_infos = [] + for line_run_record_file in sorted(self._run_infos_folder.iterdir()): + new_runs = self._load_info_from_file(line_run_record_file, parse_const_as_str) + flow_run_infos.extend(new_runs) + for new_run in new_runs: + new_run = resolve_multimedia_data_recursively(line_run_record_file, new_run) + run_info = FlowRunInfo.deserialize(new_run) + line_number = run_info.index + self._loaded_flow_run_info[line_number] = run_info + return flow_run_infos + + def load_flow_run_info(self, line_number: int = None) -> FlowRunInfo: + if not self._loaded_flow_run_info: + self._load_all_flow_run_info() + return self._loaded_flow_run_info.get(line_number) + def persist_result(self, result: Optional[BatchResult]) -> None: """Persist metrics from return of executor.""" if result is None: diff --git a/src/promptflow/promptflow/storage/__init__.py b/src/promptflow/promptflow/storage/__init__.py index 0c551043c95..02aadbfcb0f 100644 --- a/src/promptflow/promptflow/storage/__init__.py +++ b/src/promptflow/promptflow/storage/__init__.py @@ -3,6 +3,6 @@ # --------------------------------------------------------- from ._cache_storage import AbstractCacheStorage # noqa: F401 -from ._run_storage import AbstractRunStorage # noqa: F401 +from ._run_storage import AbstractBatchRunStorage, AbstractRunStorage # noqa: F401 -__all__ = ["AbstractCacheStorage", "AbstractRunStorage"] +__all__ = ["AbstractCacheStorage", "AbstractRunStorage", "AbstractBatchRunStorage"] diff --git a/src/promptflow/promptflow/storage/_run_storage.py b/src/promptflow/promptflow/storage/_run_storage.py index 08aea38efce..5a1dcf85cca 100644 --- a/src/promptflow/promptflow/storage/_run_storage.py +++ b/src/promptflow/promptflow/storage/_run_storage.py @@ -30,6 +30,18 @@ def persist_flow_run(self, run_info: FlowRunInfo): raise NotImplementedError("AbstractRunStorage is an abstract class, no implementation for persist_flow_run.") +class AbstractBatchRunStorage(AbstractRunStorage): + def load_node_run_info_for_line(self, line_number: int): + raise NotImplementedError( + "AbstractBatchRunStorage is an abstract class, no implementation for load_node_run_info_for_line." + ) + + def load_flow_run_info(self, line_number: int): + raise NotImplementedError( + "AbstractBatchRunStorage is an abstract class, no implementation for load_flow_run_info." + ) + + class DummyRunStorage(AbstractRunStorage): def persist_node_run(self, run_info: NodeRunInfo): """Dummy implementation for persist_node_run diff --git a/src/promptflow/tests/executor/unittests/storage/test_local_storage_operations.py b/src/promptflow/tests/executor/unittests/storage/test_local_storage_operations.py new file mode 100644 index 00000000000..d004052ec7f --- /dev/null +++ b/src/promptflow/tests/executor/unittests/storage/test_local_storage_operations.py @@ -0,0 +1,116 @@ +import datetime +import json +from pathlib import Path + +import pytest + +from promptflow._sdk.entities._run import Run +from promptflow._sdk.operations._local_storage_operations import LocalStorageOperations +from promptflow.contracts.run_info import FlowRunInfo, RunInfo, Status + + +@pytest.fixture +def run_instance(): + return Run(flow="flow", name="run_name") + + +@pytest.fixture +def local_storage(run_instance): + return LocalStorageOperations(run_instance) + + +@pytest.fixture +def node_run_info(): + return RunInfo( + node="node1", + flow_run_id="flow_run_id", + run_id="run_id", + status=Status.Completed, + inputs={"image1": {"data:image/png;path": "test.png"}}, + output={"output1": {"data:image/png;path": "test.png"}}, + metrics={}, + error={}, + parent_run_id="parent_run_id", + start_time=datetime.datetime.now(), + end_time=datetime.datetime.now() + datetime.timedelta(seconds=5), + index=1, + ) + + +@pytest.fixture +def flow_run_info(): + return FlowRunInfo( + run_id="run_id", + status=Status.Completed, + error=None, + inputs={"image1": {"data:image/png;path": "test.png"}}, + output={"output1": {"data:image/png;path": "test.png"}}, + metrics={}, + request="request", + parent_run_id="parent_run_id", + root_run_id="root_run_id", + source_run_id="source_run_id", + flow_id="flow_id", + start_time=datetime.datetime.now(), + end_time=datetime.datetime.now() + datetime.timedelta(seconds=5), + index=1, + ) + + +@pytest.mark.unittest +class TestLocalStorageOperations: + def test_persist_node_run(self, local_storage, node_run_info): + local_storage.persist_node_run(node_run_info) + expected_file_path = local_storage.path / "node_artifacts" / node_run_info.node / "000000001.jsonl" + assert expected_file_path.exists() + with open(expected_file_path, "r") as file: + content = file.read() + node_run_info_dict = json.loads(content) + assert node_run_info_dict["NodeName"] == node_run_info.node + assert node_run_info_dict["line_number"] == node_run_info.index + + def test_persist_flow_run(self, local_storage, flow_run_info): + local_storage.persist_flow_run(flow_run_info) + expected_file_path = local_storage.path / "flow_artifacts" / "000000001_000000001.jsonl" + assert expected_file_path.exists() + with open(expected_file_path, "r") as file: + content = file.read() + flow_run_info_dict = json.loads(content) + assert flow_run_info_dict["run_info"]["run_id"] == flow_run_info.run_id + assert flow_run_info_dict["line_number"] == flow_run_info.index + + def test_load_node_run_info(self, local_storage, node_run_info): + local_storage.persist_node_run(node_run_info) + loaded_node_run_info = local_storage._load_all_node_run_info() + assert len(loaded_node_run_info) == 1 + assert loaded_node_run_info[0]["node"] == node_run_info.node + assert loaded_node_run_info[0]["index"] == node_run_info.index + assert loaded_node_run_info[0]["inputs"]["image1"]["data:image/png;path"] == str( + Path(local_storage._node_infos_folder, node_run_info.node, "test.png") + ) + assert loaded_node_run_info[0]["output"]["output1"]["data:image/png;path"] == str( + Path(local_storage._node_infos_folder, node_run_info.node, "test.png") + ) + + res = local_storage.load_node_run_info_for_line(1) + assert isinstance(res, list) + assert isinstance(res[0], RunInfo) + assert res[0].node == node_run_info.node + + def test_load_flow_run_info(self, local_storage, flow_run_info): + local_storage.persist_flow_run(flow_run_info) + + loaded_flow_run_info = local_storage._load_all_flow_run_info() + assert len(loaded_flow_run_info) == 1 + assert loaded_flow_run_info[0]["run_id"] == flow_run_info.run_id + assert loaded_flow_run_info[0]["status"] == flow_run_info.status.value + assert loaded_flow_run_info[0]["inputs"]["image1"]["data:image/png;path"] == str( + Path(local_storage._run_infos_folder, "test.png") + ) + assert loaded_flow_run_info[0]["output"]["output1"]["data:image/png;path"] == str( + Path(local_storage._run_infos_folder, "test.png") + ) + + res = local_storage.load_flow_run_info(1) + assert isinstance(res, FlowRunInfo) + assert res.run_id == flow_run_info.run_id