diff --git a/traceml/tests/test_tracking/test_run_tracking.py b/traceml/tests/test_tracking/test_run_tracking.py index 1bb7f8c24..b150400d6 100644 --- a/traceml/tests/test_tracking/test_run_tracking.py +++ b/traceml/tests/test_tracking/test_run_tracking.py @@ -821,7 +821,9 @@ def test_log_mpl_plotly(self): self.run_path, kind=V1ArtifactKind.CHART, name="figure" ) assert os.path.exists(events_file) is True - results = V1Events.read(kind="image", name="figure", data=events_file) + results = V1Events.read( + kind=V1ArtifactKind.CHART, name="figure", data=events_file + ) assert len(results.df.values) == 1 with patch("traceml.tracking.run.Run._log_has_events") as log_mpl_plotly_chart: @@ -842,7 +844,9 @@ def test_log_mpl_plotly(self): self.run_path, kind=V1ArtifactKind.CHART, name="figure" ) assert os.path.exists(events_file) is True - results = V1Events.read(kind="image", name="figure", data=events_file) + results = V1Events.read( + kind=V1ArtifactKind.CHART, name="figure", data=events_file + ) assert len(results.df.values) == 2 def test_log_video_from_path(self): diff --git a/traceml/traceml/artifacts/enums.py b/traceml/traceml/artifacts/enums.py index 76f37f91f..3d84ea3a1 100644 --- a/traceml/traceml/artifacts/enums.py +++ b/traceml/traceml/artifacts/enums.py @@ -36,6 +36,14 @@ class V1ArtifactKind(str, PEnum): SYSTEM = "system" ARTIFACT = "artifact" + @classmethod + def is_jsonl_file_event(cls, kind: Optional[Union["V1ArtifactKind", str]]) -> bool: + return kind in { + V1ArtifactKind.HTML, + V1ArtifactKind.TEXT, + V1ArtifactKind.CHART, + } + @classmethod def is_single_file_event(cls, kind: Optional[Union["V1ArtifactKind", str]]) -> bool: return kind in { diff --git a/traceml/traceml/events/schemas.py b/traceml/traceml/events/schemas.py index 5360dc5c0..1500c93f1 100644 --- a/traceml/traceml/events/schemas.py +++ b/traceml/traceml/events/schemas.py @@ -367,6 +367,58 @@ def __init__(self, kind, name, df): self.name = name self.df = df + @staticmethod + def _read_csv( + data: str, + parse_dates: Optional[bool] = True, + engine: Optional[str] = None, + ) -> "pandas.DataFrame": + import pandas as pd + + if parse_dates: + return pd.read_csv( + data, + sep=V1Event._SEPARATOR, + parse_dates=["timestamp"], + engine=engine, + ) + else: + df = pd.read_csv( + data, + sep=V1Event._SEPARATOR, + engine=engine, + ) + # Pyarrow automatically converts timestamp fields + if "timestamp" in df.columns: + df["timestamp"] = df["timestamp"].astype(str) + return df + + @staticmethod + def _read_jsonl( + data: str, + parse_dates: Optional[bool] = True, + engine: Optional[str] = None, + ) -> "pandas.DataFrame": + import pandas as pd + + engine = engine or "ujson" + if parse_dates: + return pd.read_json( + data, + lines=True, + engine=engine, + ) + else: + df = pd.read_json( + data, + lines=True, + engine=engine, + ) + # Pyarrow automatically converts timestamp fields + if "timestamp" in df.columns: + df["timestamp"] = df["timestamp"].astype(str) + return df + @classmethod def read( cls, @@ -379,23 +431,17 @@ def read( import pandas as pd if isinstance(data, str): - csv = validate_csv(data) - if parse_dates: - df = pd.read_csv( - csv, - sep=V1Event._SEPARATOR, - parse_dates=["timestamp"], - engine=engine, - ) - else: - df = pd.read_csv( - csv, - sep=V1Event._SEPARATOR, - engine=engine, - ) - # Pyarrow automatically converts timestamp fields - if "timestamp" in df.columns: - df["timestamp"] = df["timestamp"].astype(str) + data = validate_csv(data) + error = None + if V1ArtifactKind.is_jsonl_file_event(kind): + try: + df = cls._read_jsonl( + data=data, parse_dates=parse_dates, engine=engine + ) + except Exception as e: + error = e + if not V1ArtifactKind.is_jsonl_file_event(kind) or error: + df = cls._read_csv(data=data, parse_dates=parse_dates, engine=engine) elif isinstance(data, dict): df = pd.DataFrame.from_dict(data) else: @@ -469,6 +515,10 @@ def get_csv_events(self) -> str: events = ["\n{}".format(e.to_csv()) for e in self.events] return "".join(events) + def get_jsonl_events(self) -> str: + events = ["\n{}".format(e.to_json()) for e in self.events] + return "".join(events) + def empty_events(self): self.events[:] = [] diff --git a/traceml/traceml/serialization/base.py b/traceml/traceml/serialization/base.py index b127ca52b..220ae8165 100644 --- a/traceml/traceml/serialization/base.py +++ b/traceml/traceml/serialization/base.py @@ -5,8 +5,13 @@ from clipped.utils.enums import get_enum_value from clipped.utils.paths import check_or_create_path -from traceml.events import LoggedEventSpec -from traceml.events.schemas import LoggedEventListSpec +from traceml.artifacts import V1ArtifactKind +from traceml.events import ( + LoggedEventListSpec, + LoggedEventSpec, + get_event_path, + get_resource_path, +) class EventWriter: @@ -21,18 +26,16 @@ def __init__(self, run_path: str, backend: str): def _get_event_path(self, kind: str, name: str) -> str: if self._events_backend == self.EVENTS_BACKEND: - return os.path.join( - self._run_path, - get_enum_value(self._events_backend), - kind, - "{}.plx".format(name), + return get_event_path( + run_path=self._run_path, + kind=kind, + name=name, ) if self._events_backend == self.RESOURCES_BACKEND: - return os.path.join( - self._run_path, - get_enum_value(self._events_backend), - kind, - "{}.plx".format(name), + return get_resource_path( + run_path=self._run_path, + kind=kind, + name=name, ) raise ValueError( "Unrecognized backend {}".format(get_enum_value(self._events_backend)) @@ -44,12 +47,18 @@ def _init_events(self, events_spec: LoggedEventListSpec): if not os.path.exists(event_path): check_or_create_path(event_path, is_dir=False) with open(event_path, "w") as event_file: - event_file.write(events_spec.get_csv_header()) + if V1ArtifactKind.is_jsonl_file_event(events_spec.kind): + event_file.write("") + else: + event_file.write(events_spec.get_csv_header()) def _append_events(self, events_spec: LoggedEventListSpec): event_path = self._get_event_path(kind=events_spec.kind, name=events_spec.name) with open(event_path, "a") as event_file: - event_file.write(events_spec.get_csv_events()) + if V1ArtifactKind.is_jsonl_file_event(events_spec.kind): + event_file.write(events_spec.get_jsonl_events()) + else: + event_file.write(events_spec.get_csv_events()) def _events_to_files(self, events: List[LoggedEventSpec]): for event in events: